Skip to content

Commit 7c6e9ef

Browse files
authored
[tests] Fix how compiler mixin classes are used (#11680)
* fix how compiler tester mixins are used. * propagate * more
1 parent f46abfe commit 7c6e9ef

File tree

5 files changed

+78
-14
lines changed

5 files changed

+78
-14
lines changed

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
7878
return ip_state_dict
7979

8080

81-
class FluxTransformerTests(
82-
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase
83-
):
81+
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
8482
model_class = FluxTransformer2DModel
8583
main_input_name = "hidden_states"
8684
# We override the items here because the transformer under consideration is small.
@@ -169,3 +167,17 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
169167
def test_gradient_checkpointing_is_applied(self):
170168
expected_set = {"FluxTransformer2DModel"}
171169
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
170+
171+
172+
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
173+
model_class = FluxTransformer2DModel
174+
175+
def prepare_init_args_and_inputs_for_common(self):
176+
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
177+
178+
179+
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
180+
model_class = FluxTransformer2DModel
181+
182+
def prepare_init_args_and_inputs_for_common(self):
183+
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()

tests/models/transformers/test_models_transformer_hunyuan_video.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
enable_full_determinism()
2929

3030

31-
class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
31+
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
3232
model_class = HunyuanVideoTransformer3DModel
3333
main_input_name = "hidden_states"
3434
uses_custom_attn_processor = True
@@ -93,7 +93,14 @@ def test_gradient_checkpointing_is_applied(self):
9393
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
9494

9595

96-
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
96+
class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
97+
model_class = HunyuanVideoTransformer3DModel
98+
99+
def prepare_init_args_and_inputs_for_common(self):
100+
return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
101+
102+
103+
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
97104
model_class = HunyuanVideoTransformer3DModel
98105
main_input_name = "hidden_states"
99106
uses_custom_attn_processor = True
@@ -161,7 +168,14 @@ def test_gradient_checkpointing_is_applied(self):
161168
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
162169

163170

164-
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
171+
class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
172+
model_class = HunyuanVideoTransformer3DModel
173+
174+
def prepare_init_args_and_inputs_for_common(self):
175+
return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
176+
177+
178+
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
165179
model_class = HunyuanVideoTransformer3DModel
166180
main_input_name = "hidden_states"
167181
uses_custom_attn_processor = True
@@ -227,9 +241,14 @@ def test_gradient_checkpointing_is_applied(self):
227241
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
228242

229243

230-
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
231-
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase
232-
):
244+
class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
245+
model_class = HunyuanVideoTransformer3DModel
246+
247+
def prepare_init_args_and_inputs_for_common(self):
248+
return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
249+
250+
251+
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
233252
model_class = HunyuanVideoTransformer3DModel
234253
main_input_name = "hidden_states"
235254
uses_custom_attn_processor = True
@@ -295,3 +314,10 @@ def test_output(self):
295314
def test_gradient_checkpointing_is_applied(self):
296315
expected_set = {"HunyuanVideoTransformer3DModel"}
297316
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
317+
318+
319+
class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase):
320+
model_class = HunyuanVideoTransformer3DModel
321+
322+
def prepare_init_args_and_inputs_for_common(self):
323+
return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()

tests/models/transformers/test_models_transformer_ltx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
enable_full_determinism()
2727

2828

29-
class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
29+
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
3030
model_class = LTXVideoTransformer3DModel
3131
main_input_name = "hidden_states"
3232
uses_custom_attn_processor = True
@@ -81,3 +81,10 @@ def prepare_init_args_and_inputs_for_common(self):
8181
def test_gradient_checkpointing_is_applied(self):
8282
expected_set = {"LTXVideoTransformer3DModel"}
8383
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
84+
85+
86+
class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
87+
model_class = LTXVideoTransformer3DModel
88+
89+
def prepare_init_args_and_inputs_for_common(self):
90+
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()

tests/models/transformers/test_models_transformer_wan.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
enable_full_determinism()
2929

3030

31-
class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
31+
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
3232
model_class = WanTransformer3DModel
3333
main_input_name = "hidden_states"
3434
uses_custom_attn_processor = True
@@ -82,3 +82,10 @@ def prepare_init_args_and_inputs_for_common(self):
8282
def test_gradient_checkpointing_is_applied(self):
8383
expected_set = {"WanTransformer3DModel"}
8484
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
85+
86+
87+
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
88+
model_class = WanTransformer3DModel
89+
90+
def prepare_init_args_and_inputs_for_common(self):
91+
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
355355
return custom_diffusion_attn_procs
356356

357357

358-
class UNet2DConditionModelTests(
359-
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
360-
):
358+
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
361359
model_class = UNet2DConditionModel
362360
main_input_name = "sample"
363361
# We override the items here because the unet under consideration is small.
@@ -1147,6 +1145,20 @@ def test_save_attn_procs_raise_warning(self):
11471145
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
11481146

11491147

1148+
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
1149+
model_class = UNet2DConditionModel
1150+
1151+
def prepare_init_args_and_inputs_for_common(self):
1152+
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
1153+
1154+
1155+
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
1156+
model_class = UNet2DConditionModel
1157+
1158+
def prepare_init_args_and_inputs_for_common(self):
1159+
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
1160+
1161+
11501162
@slow
11511163
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
11521164
def get_file_format(self, seed, shape):

0 commit comments

Comments
 (0)