Skip to content

[tests] add tests to check for graph breaks and recompilation in pipelines during torch.compile() #11085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
de30cba
test for better torch.compile stuff.
sayakpaul Mar 17, 2025
f389a4d
fixes
sayakpaul Mar 17, 2025
6b05db6
Merge branch 'main' into test-better-torch-compile
sayakpaul Mar 18, 2025
e5543dc
Merge branch 'main' into test-better-torch-compile
sayakpaul Mar 20, 2025
6791037
recompilation and graph break.
sayakpaul Mar 21, 2025
abd1f6c
Merge branch 'main' into test-better-torch-compile
sayakpaul Mar 21, 2025
1f797b4
Merge branch 'main' into test-better-torch-compile
sayakpaul Mar 27, 2025
d669340
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 9, 2025
c49a855
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 9, 2025
c060ba0
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 14, 2025
e75a9de
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 14, 2025
c7f153a
clear compilation cache.
sayakpaul Apr 14, 2025
c74c9a8
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 14, 2025
1a934b2
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 14, 2025
e0566e6
change to modeling level test.
sayakpaul Apr 14, 2025
38c1d0d
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 15, 2025
87d957d
allow running compilation tests during nightlies.
sayakpaul Apr 15, 2025
a8184ef
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 15, 2025
fae8b6c
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 18, 2025
1749955
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 21, 2025
a07c63b
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 25, 2025
f71c8f6
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,56 @@
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

run_torch_compile_tests:
name: PyTorch Compile CUDA tests

runs-on:
group: aws-g4dn-2xlarge

container:
image: diffusers/diffusers-pytorch-compile-cuda
options: --gpus 0 --shm-size "16gb" --ipc host

steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2

- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
python utils/print_env.py
- name: Run torch compile tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt

- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_compile_test_reports
path: reports

- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

run_big_gpu_torch_tests:

Check warning

Code scanning / CodeQL

Workflow does not contain permissions Medium

Actions job or workflow does not limit the permissions of the GITHUB_TOKEN. Consider setting an explicit permissions block, using the following as a minimal starting point: {contents: read}
name: Torch tests on big GPU
strategy:
fail-fast: false
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release_tests_fast.yml
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ jobs:
- name: Environment
run: |
python utils/print_env.py
- name: Run example tests on GPU
- name: Run torch compile tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
Expand Down
31 changes: 31 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,6 +1714,37 @@ def test_push_to_hub_library_name(self):
delete_repo(self.repo_id, token=TOKEN)


class TorchCompileTesterMixin:
def setUp(self):
# clean up the VRAM before each test
super().setUp()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)


@slow
@require_torch_2
@require_torch_accelerator
Expand Down
4 changes: 2 additions & 2 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device

from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin


enable_full_determinism()
Expand Down Expand Up @@ -78,7 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict


class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
class FluxTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
Expand Down
2 changes: 2 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,12 +1111,14 @@ def callback_cfg_params(self) -> frozenset:
def setUp(self):
# clean up the VRAM before each test
super().setUp()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)

Expand Down
Loading