Open
Description
#11085 added a test for checking if there's any graph break or recompilation issue for torch.compile
d model.
We should add this test to the most impactful models to ensure our code is torch.compile
friendly and has the potential to benefit from it. So far, we test it for FluxTransformer2DModel
. Below are some models I have in mind where this test should be added:
-
HiDreamImageTransformer2DModel
(currently doesn't have a test class like other models from)src/diffusers/models
-
HunyuanVideoTransformer3DTests
-
WanTransformer3DTests
-
UNet2DConditionModelTests
- LTXVideoTransformer3DModel
Steps to contribute
- Refer to [tests] add tests to check for graph breaks and recompilation in pipelines during torch.compile() #11085 to under the changes needed in the test classes of the respective models.
- To be specific, always attempt to just add
TorchCompileTesterMixin
to the respective model testing file. Like: - Make the changes.
- Run the tests locally with
RUN_SLOW=1 RUN_COMPILE=1 pytest tests/<PATH_TO_TEST_FILE> -k "test_torch_compile_recompilation_and_graph_break"
. Make sure the machine has a GPU. Please confirm that this test is passing when opening the PR. In case it doesn't pass let us know. - Open the PR and mention this issue ([tests] help us test
torch.compile()
for impactful models #11430). Tag @sayakpaul and @DN6 for a review. Don't hesitate to ask for guidance/help if needed.
@DN6 any model classes I am missing?