Skip to content

Commit f2e521c

Browse files
[Dtype] Align dtype casting behavior with Transformers and Accelerate (#1725)
* [Dtype] Align automatic dtype * up * up * fix * re-add accelerate
1 parent debc74f commit f2e521c

File tree

5 files changed

+17
-20
lines changed

5 files changed

+17
-20
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ jobs:
6262
run: |
6363
python -m pip install -e .[quality,test]
6464
python -m pip install -U git+https://github.com/huggingface/transformers
65+
python -m pip install git+https://github.com/huggingface/accelerate
6566
6667
- name: Environment
6768
run: |
@@ -134,6 +135,7 @@ jobs:
134135
${CONDA_RUN} python -m pip install --upgrade pip
135136
${CONDA_RUN} python -m pip install -e .[quality,test]
136137
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
138+
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
137139
138140
- name: Environment
139141
shell: arch -arch arm64 bash {0}
@@ -157,4 +159,4 @@ jobs:
157159
uses: actions/upload-artifact@v2
158160
with:
159161
name: torch_mps_test_reports
160-
path: reports
162+
path: reports

.github/workflows/pr_tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ jobs:
6060
apt-get update && apt-get install libsndfile1-dev -y
6161
python -m pip install -e .[quality,test]
6262
python -m pip install -U git+https://github.com/huggingface/transformers
63+
python -m pip install git+https://github.com/huggingface/accelerate
6364
6465
- name: Environment
6566
run: |
@@ -126,6 +127,7 @@ jobs:
126127
${CONDA_RUN} python -m pip install --upgrade pip
127128
${CONDA_RUN} python -m pip install -e .[quality,test]
128129
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
130+
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
129131
${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers
130132
131133
- name: Environment

.github/workflows/push_tests.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ jobs:
6262
run: |
6363
python -m pip install -e .[quality,test]
6464
python -m pip install -U git+https://github.com/huggingface/transformers
65+
python -m pip install git+https://github.com/huggingface/accelerate
6566
6667
- name: Environment
6768
run: |
@@ -130,6 +131,7 @@ jobs:
130131
- name: Install dependencies
131132
run: |
132133
python -m pip install -e .[quality,test,training]
134+
python -m pip install git+https://github.com/huggingface/accelerate
133135
python -m pip install -U git+https://github.com/huggingface/transformers
134136
135137
- name: Environment
@@ -151,4 +153,4 @@ jobs:
151153
uses: actions/upload-artifact@v2
152154
with:
153155
name: examples_test_reports
154-
path: reports
156+
path: reports

src/diffusers/modeling_utils.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import inspect
1718
import os
1819
from functools import partial
1920
from typing import Callable, List, Optional, Tuple, Union
@@ -489,11 +490,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
489490
state_dict = load_state_dict(model_file)
490491
# move the parms from meta device to cpu
491492
for param_name, param in state_dict.items():
492-
set_module_tensor_to_device(model, param_name, param_device, value=param)
493+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
494+
if accepts_dtype:
495+
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
496+
else:
497+
set_module_tensor_to_device(model, param_name, param_device, value=param)
493498
else: # else let accelerate handle loading and dispatching.
494499
# Load weights and dispatch according to the device_map
495500
# by deafult the device_map is None and the weights are loaded on the CPU
496-
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
501+
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
497502

498503
loading_info = {
499504
"missing_keys": [],
@@ -519,20 +524,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
519524
model = cls.from_config(config, **unused_kwargs)
520525

521526
state_dict = load_state_dict(model_file)
522-
dtype = set(v.dtype for v in state_dict.values())
523-
524-
if len(dtype) > 1 and torch.float32 not in dtype:
525-
raise ValueError(
526-
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
527-
f" make sure that {model_file} weights have only one dtype."
528-
)
529-
elif len(dtype) > 1 and torch.float32 in dtype:
530-
dtype = torch.float32
531-
else:
532-
dtype = dtype.pop()
533-
534-
# move model to correct dtype
535-
model = model.to(dtype)
536527

537528
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
538529
model,

tests/test_modeling_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def test_from_save_pretrained_dtype(self):
7070
with tempfile.TemporaryDirectory() as tmpdirname:
7171
model.to(dtype)
7272
model.save_pretrained(tmpdirname)
73-
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True)
73+
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
7474
assert new_model.dtype == dtype
75-
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False)
75+
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
7676
assert new_model.dtype == dtype
7777

7878
def test_determinism(self):

0 commit comments

Comments
 (0)