|
28 | 28 | enable_full_determinism()
|
29 | 29 |
|
30 | 30 |
|
31 |
| -class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): |
| 31 | +class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
32 | 32 | model_class = HunyuanVideoTransformer3DModel
|
33 | 33 | main_input_name = "hidden_states"
|
34 | 34 | uses_custom_attn_processor = True
|
@@ -93,7 +93,14 @@ def test_gradient_checkpointing_is_applied(self):
|
93 | 93 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
94 | 94 |
|
95 | 95 |
|
96 |
| -class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): |
| 96 | +class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): |
| 97 | + model_class = HunyuanVideoTransformer3DModel |
| 98 | + |
| 99 | + def prepare_init_args_and_inputs_for_common(self): |
| 100 | + return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() |
| 101 | + |
| 102 | + |
| 103 | +class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
97 | 104 | model_class = HunyuanVideoTransformer3DModel
|
98 | 105 | main_input_name = "hidden_states"
|
99 | 106 | uses_custom_attn_processor = True
|
@@ -161,7 +168,14 @@ def test_gradient_checkpointing_is_applied(self):
|
161 | 168 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
162 | 169 |
|
163 | 170 |
|
164 |
| -class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): |
| 171 | +class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase): |
| 172 | + model_class = HunyuanVideoTransformer3DModel |
| 173 | + |
| 174 | + def prepare_init_args_and_inputs_for_common(self): |
| 175 | + return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() |
| 176 | + |
| 177 | + |
| 178 | +class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
165 | 179 | model_class = HunyuanVideoTransformer3DModel
|
166 | 180 | main_input_name = "hidden_states"
|
167 | 181 | uses_custom_attn_processor = True
|
@@ -227,9 +241,14 @@ def test_gradient_checkpointing_is_applied(self):
|
227 | 241 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
228 | 242 |
|
229 | 243 |
|
230 |
| -class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests( |
231 |
| - ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase |
232 |
| -): |
| 244 | +class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase): |
| 245 | + model_class = HunyuanVideoTransformer3DModel |
| 246 | + |
| 247 | + def prepare_init_args_and_inputs_for_common(self): |
| 248 | + return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() |
| 249 | + |
| 250 | + |
| 251 | +class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): |
233 | 252 | model_class = HunyuanVideoTransformer3DModel
|
234 | 253 | main_input_name = "hidden_states"
|
235 | 254 | uses_custom_attn_processor = True
|
@@ -295,3 +314,10 @@ def test_output(self):
|
295 | 314 | def test_gradient_checkpointing_is_applied(self):
|
296 | 315 | expected_set = {"HunyuanVideoTransformer3DModel"}
|
297 | 316 | super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
| 317 | + |
| 318 | + |
| 319 | +class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase): |
| 320 | + model_class = HunyuanVideoTransformer3DModel |
| 321 | + |
| 322 | + def prepare_init_args_and_inputs_for_common(self): |
| 323 | + return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common() |
0 commit comments