Skip to content

Commit 51a855c

Browse files
authored
Merge branch 'main' into layerwise-upcasting
2 parents c64fa22 + 940b8e0 commit 51a855c

File tree

12 files changed

+177
-164
lines changed

12 files changed

+177
-164
lines changed

examples/textual_inversion/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ import torch
109109
model_id = "path-to-your-trained-model"
110110
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
111111

112+
repo_id_embeds = "path-to-your-learned-embeds"
113+
pipe.load_textual_inversion(repo_id_embeds)
114+
112115
prompt = "A <cat-toy> backpack"
113116

114117
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -89,49 +89,44 @@
8989
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
9090

9191

92-
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
92+
def is_safetensors_compatible(filenames, passed_components=None) -> bool:
9393
"""
9494
Checking for safetensors compatibility:
95-
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
96-
files to know which safetensors files are needed.
97-
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
95+
- The model is safetensors compatible only if there is a safetensors file for each model component present in
96+
filenames.
9897
9998
Converting default pytorch serialized filenames to safetensors serialized filenames:
10099
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
101100
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
102101
extension is replaced with ".safetensors"
103102
"""
104-
pt_filenames = []
105-
106-
sf_filenames = set()
107-
108103
passed_components = passed_components or []
109104

105+
# extract all components of the pipeline and their associated files
106+
components = {}
110107
for filename in filenames:
111-
_, extension = os.path.splitext(filename)
108+
if not len(filename.split("/")) == 2:
109+
continue
112110

113-
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
111+
component, component_filename = filename.split("/")
112+
if component in passed_components:
114113
continue
115114

116-
if extension == ".bin":
117-
pt_filenames.append(os.path.normpath(filename))
118-
elif extension == ".safetensors":
119-
sf_filenames.add(os.path.normpath(filename))
115+
components.setdefault(component, [])
116+
components[component].append(component_filename)
120117

121-
for filename in pt_filenames:
122-
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
123-
path, filename = os.path.split(filename)
124-
filename, extension = os.path.splitext(filename)
118+
# iterate over all files of a component
119+
# check if safetensor files exist for that component
120+
# if variant is provided check if the variant of the safetensors exists
121+
for component, component_filenames in components.items():
122+
matches = []
123+
for component_filename in component_filenames:
124+
filename, extension = os.path.splitext(component_filename)
125125

126-
if filename.startswith("pytorch_model"):
127-
filename = filename.replace("pytorch_model", "model")
128-
else:
129-
filename = filename
126+
match_exists = extension == ".safetensors"
127+
matches.append(match_exists)
130128

131-
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
132-
expected_sf_filename = f"{expected_sf_filename}.safetensors"
133-
if expected_sf_filename not in sf_filenames:
134-
logger.warning(f"{expected_sf_filename} not found")
129+
if not any(matches):
135130
return False
136131

137132
return True

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,18 +1416,14 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14161416
if (
14171417
use_safetensors
14181418
and not allow_pickle
1419-
and not is_safetensors_compatible(
1420-
model_filenames, variant=variant, passed_components=passed_components
1421-
)
1419+
and not is_safetensors_compatible(model_filenames, passed_components=passed_components)
14221420
):
14231421
raise EnvironmentError(
14241422
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
14251423
)
14261424
if from_flax:
14271425
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
1428-
elif use_safetensors and is_safetensors_compatible(
1429-
model_filenames, variant=variant, passed_components=passed_components
1430-
):
1426+
elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components):
14311427
ignore_patterns = ["*.bin", "*.msgpack"]
14321428

14331429
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx

tests/lora/test_lora_layers_sd3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
@require_peft_backend
3333
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
3434
pipeline_class = StableDiffusion3Pipeline
35-
scheduler_cls = FlowMatchEulerDiscreteScheduler()
35+
scheduler_cls = FlowMatchEulerDiscreteScheduler
3636
scheduler_kwargs = {}
3737
uses_flow_matching = True
3838
transformer_kwargs = {
@@ -80,8 +80,7 @@ def test_sd3_lora(self):
8080
Related PR: https://github.com/huggingface/diffusers/pull/8584
8181
"""
8282
components = self.get_dummy_components()
83-
84-
pipe = self.pipeline_class(**components)
83+
pipe = self.pipeline_class(**components[0])
8584
pipe = pipe.to(torch_device)
8685
pipe.set_progress_bar_config(disable=None)
8786

tests/lora/test_lora_layers_sdxl.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -124,71 +124,6 @@ def tearDown(self):
124124
gc.collect()
125125
torch.cuda.empty_cache()
126126

127-
def test_sdxl_0_9_lora_one(self):
128-
generator = torch.Generator().manual_seed(0)
129-
130-
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
131-
lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora"
132-
lora_filename = "daiton-xl-lora-test.safetensors"
133-
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
134-
pipe.enable_model_cpu_offload()
135-
136-
images = pipe(
137-
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
138-
).images
139-
140-
images = images[0, -3:, -3:, -1].flatten()
141-
expected = np.array([0.3838, 0.3482, 0.3588, 0.3162, 0.319, 0.3369, 0.338, 0.3366, 0.3213])
142-
143-
max_diff = numpy_cosine_similarity_distance(expected, images)
144-
assert max_diff < 1e-3
145-
pipe.unload_lora_weights()
146-
release_memory(pipe)
147-
148-
def test_sdxl_0_9_lora_two(self):
149-
generator = torch.Generator().manual_seed(0)
150-
151-
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
152-
lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora"
153-
lora_filename = "saijo.safetensors"
154-
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
155-
pipe.enable_model_cpu_offload()
156-
157-
images = pipe(
158-
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
159-
).images
160-
161-
images = images[0, -3:, -3:, -1].flatten()
162-
expected = np.array([0.3137, 0.3269, 0.3355, 0.255, 0.2577, 0.2563, 0.2679, 0.2758, 0.2626])
163-
164-
max_diff = numpy_cosine_similarity_distance(expected, images)
165-
assert max_diff < 1e-3
166-
167-
pipe.unload_lora_weights()
168-
release_memory(pipe)
169-
170-
def test_sdxl_0_9_lora_three(self):
171-
generator = torch.Generator().manual_seed(0)
172-
173-
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
174-
lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora"
175-
lora_filename = "kame_sdxl_v2-000020-16rank.safetensors"
176-
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
177-
pipe.enable_model_cpu_offload()
178-
179-
images = pipe(
180-
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
181-
).images
182-
183-
images = images[0, -3:, -3:, -1].flatten()
184-
expected = np.array([0.4015, 0.3761, 0.3616, 0.3745, 0.3462, 0.3337, 0.3564, 0.3649, 0.3468])
185-
186-
max_diff = numpy_cosine_similarity_distance(expected, images)
187-
assert max_diff < 5e-3
188-
189-
pipe.unload_lora_weights()
190-
release_memory(pipe)
191-
192127
def test_sdxl_1_0_lora(self):
193128
generator = torch.Generator("cpu").manual_seed(0)
194129

tests/models/transformers/test_models_transformer_aura_flow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
enable_full_determinism()
2727

2828

29-
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
29+
class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
3030
model_class = AuraFlowTransformer2DModel
3131
main_input_name = "hidden_states"
3232
# We override the items here because the transformer under consideration is small.
@@ -73,3 +73,7 @@ def prepare_init_args_and_inputs_for_common(self):
7373
}
7474
inputs_dict = self.dummy_input
7575
return init_dict, inputs_dict
76+
77+
@unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply")
78+
def test_set_attn_processor_for_determinism(self):
79+
pass

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,7 @@ def prepare_init_args_and_inputs_for_common(self):
7676
}
7777
inputs_dict = self.dummy_input
7878
return init_dict, inputs_dict
79+
80+
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
81+
def test_set_attn_processor_for_determinism(self):
82+
pass

tests/pipelines/aura_flow/test_pipeline_aura_flow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,7 @@ def test_fused_qkv_projections(self):
163163
assert np.allclose(
164164
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
165165
), "Original outputs should match when fused QKV projections are disabled."
166+
167+
@unittest.skip("xformers attention processor does not exist for AuraFlow")
168+
def test_xformers_attention_forwardGenerator_pass(self):
169+
pass

tests/pipelines/lumina/test_lumina_nextdit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def test_lumina_prompt_embeds(self):
119119
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
120120
assert max_diff < 1e-4
121121

122+
@unittest.skip("xformers attention processor does not exist for Lumina")
123+
def test_xformers_attention_forwardGenerator_pass(self):
124+
pass
125+
122126

123127
@slow
124128
@require_torch_gpu

tests/pipelines/test_pipeline_utils.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,25 +68,21 @@ def test_all_is_compatible_variant(self):
6868
"unet/diffusion_pytorch_model.fp16.bin",
6969
"unet/diffusion_pytorch_model.fp16.safetensors",
7070
]
71-
variant = "fp16"
72-
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
71+
self.assertTrue(is_safetensors_compatible(filenames))
7372

7473
def test_diffusers_model_is_compatible_variant(self):
7574
filenames = [
7675
"unet/diffusion_pytorch_model.fp16.bin",
7776
"unet/diffusion_pytorch_model.fp16.safetensors",
7877
]
79-
variant = "fp16"
80-
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
78+
self.assertTrue(is_safetensors_compatible(filenames))
8179

82-
def test_diffusers_model_is_compatible_variant_partial(self):
83-
# pass variant but use the non-variant filenames
80+
def test_diffusers_model_is_compatible_variant_mixed(self):
8481
filenames = [
8582
"unet/diffusion_pytorch_model.bin",
86-
"unet/diffusion_pytorch_model.safetensors",
83+
"unet/diffusion_pytorch_model.fp16.safetensors",
8784
]
88-
variant = "fp16"
89-
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
85+
self.assertTrue(is_safetensors_compatible(filenames))
9086

9187
def test_diffusers_model_is_not_compatible_variant(self):
9288
filenames = [
@@ -99,25 +95,14 @@ def test_diffusers_model_is_not_compatible_variant(self):
9995
"unet/diffusion_pytorch_model.fp16.bin",
10096
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
10197
]
102-
variant = "fp16"
103-
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
98+
self.assertFalse(is_safetensors_compatible(filenames))
10499

105100
def test_transformer_model_is_compatible_variant(self):
106101
filenames = [
107102
"text_encoder/pytorch_model.fp16.bin",
108103
"text_encoder/model.fp16.safetensors",
109104
]
110-
variant = "fp16"
111-
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
112-
113-
def test_transformer_model_is_compatible_variant_partial(self):
114-
# pass variant but use the non-variant filenames
115-
filenames = [
116-
"text_encoder/pytorch_model.bin",
117-
"text_encoder/model.safetensors",
118-
]
119-
variant = "fp16"
120-
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
105+
self.assertTrue(is_safetensors_compatible(filenames))
121106

122107
def test_transformer_model_is_not_compatible_variant(self):
123108
filenames = [
@@ -126,9 +111,45 @@ def test_transformer_model_is_not_compatible_variant(self):
126111
"vae/diffusion_pytorch_model.fp16.bin",
127112
"vae/diffusion_pytorch_model.fp16.safetensors",
128113
"text_encoder/pytorch_model.fp16.bin",
129-
# 'text_encoder/model.fp16.safetensors',
130114
"unet/diffusion_pytorch_model.fp16.bin",
131115
"unet/diffusion_pytorch_model.fp16.safetensors",
132116
]
133-
variant = "fp16"
134-
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
117+
self.assertFalse(is_safetensors_compatible(filenames))
118+
119+
def test_transformers_is_compatible_sharded(self):
120+
filenames = [
121+
"text_encoder/pytorch_model.bin",
122+
"text_encoder/model-00001-of-00002.safetensors",
123+
"text_encoder/model-00002-of-00002.safetensors",
124+
]
125+
self.assertTrue(is_safetensors_compatible(filenames))
126+
127+
def test_transformers_is_compatible_variant_sharded(self):
128+
filenames = [
129+
"text_encoder/pytorch_model.bin",
130+
"text_encoder/model.fp16-00001-of-00002.safetensors",
131+
"text_encoder/model.fp16-00001-of-00002.safetensors",
132+
]
133+
self.assertTrue(is_safetensors_compatible(filenames))
134+
135+
def test_diffusers_is_compatible_sharded(self):
136+
filenames = [
137+
"unet/diffusion_pytorch_model.bin",
138+
"unet/diffusion_pytorch_model-00001-of-00002.safetensors",
139+
"unet/diffusion_pytorch_model-00002-of-00002.safetensors",
140+
]
141+
self.assertTrue(is_safetensors_compatible(filenames))
142+
143+
def test_diffusers_is_compatible_variant_sharded(self):
144+
filenames = [
145+
"unet/diffusion_pytorch_model.bin",
146+
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
147+
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
148+
]
149+
self.assertTrue(is_safetensors_compatible(filenames))
150+
151+
def test_diffusers_is_compatible_only_variants(self):
152+
filenames = [
153+
"unet/diffusion_pytorch_model.fp16.safetensors",
154+
]
155+
self.assertTrue(is_safetensors_compatible(filenames))

0 commit comments

Comments
 (0)