From 316ff46b7648bfa24525ac02c284afcf440404aa Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 10 Mar 2025 08:29:51 +0530 Subject: [PATCH 01/19] feat: pipeline-level quant config. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: SunMarc condition better. support mapping. improvements. [Quantization] Add Quanto backend (#10756) * update * updaet * update * update * update * update * update * update * update * update * update * update * Update docs/source/en/quantization/quanto.md Co-authored-by: Sayak Paul * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * Update src/diffusers/quantizers/quanto/utils.py Co-authored-by: Sayak Paul * update * update --------- Co-authored-by: Sayak Paul [Single File] Add single file loading for SANA Transformer (#10947) * added support for from_single_file * added diffusers mapping script * added testcase * bug fix * updated tests * corrected code quality * corrected code quality --------- Co-authored-by: Dhruv Nair [LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187) * updates * updates * updates * updates * notebooks revert * fix-copies. * seeing * fix * revert * fixes * fixes * fixes * remove print * fix * conflicts ii. * updates * fixes * better filtering of prefix. --------- Co-authored-by: hlky [LoRA] CogView4 (#10981) * update * make fix-copies * update [Tests] improve quantization tests by additionally measuring the inference memory savings (#11021) * memory usage tests * fixes * gguf [`Research Project`] Add AnyText: Multilingual Visual Text Generation And Editing (#8998) * Add initial template * Second template * feat: Add TextEmbeddingModule to AnyTextPipeline * feat: Add AuxiliaryLatentModule template to AnyTextPipeline * Add bert tokenizer from the anytext repo for now * feat: Update AnyTextPipeline's modify_prompt method This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe. * Fill in the `forward` pass of `AuxiliaryLatentModule` * `make style && make quality` * `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library` * Update error handling to raise and logging * Add `create_glyph_lines` function into `TextEmbeddingModule` * make style * Up * Up * Up * Up * Remove several comments * refactor: Remove ControlNetConditioningEmbedding and update code accordingly * Up * Up * up * refactor: Update AnyTextPipeline to include new optional parameters * up * feat: Add OCR model and its components * chore: Update `TextEmbeddingModule` to include OCR model components and dependencies * chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task * `make style` * refactor: Update `AnyTextPipeline`'s docstring * Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once * simplify * `make style` * Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function * Simplify for now * `make style` * Up * feat: Add scripts to convert AnyText controlnet to diffusers * `make style` * Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule` * make style * Up * Simplify * Up * feat: Add safetensors module for loading model file * Fix device issues * Up * Up * refactor: Simplify * refactor: Simplify code for loading models and handling data types * `make style` * refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule * refactor: Update dtype in embedding_manager.py to match proj.weight * Up * Add attribution and adaptation information to pipeline_anytext.py * Update usage example * Will refactor `controlnet_cond_embedding` initialization * Add `AnyTextControlNetConditioningEmbedding` template * Refactor organization * style * style * Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding` * Follow one-file policy * style * [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel * [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py * [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py * Refactor AnyTextControlNet to use configurable conditioning embedding channels * Complete control net conditioning embedding in AnyTextControlNetModel * up * [FIX] Ensure embeddings use correct device in AnyTextControlNetModel * up * up * style * [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline * [UPDATE] Update example code in anytext.py to use correct font file and improve clarity * down * [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing * update pillow * [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity * [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file * [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency * ๐Ÿ†™ * style * [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py * style * Update examples/research_projects/anytext/README.md Co-authored-by: Aryan * Remove commented-out image preparation code in AnyTextPipeline * Remove unnecessary blank line in README.md [Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6 (#11018) * update * update * update * update * update * update * update * update * update fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings (#11012) small fix on generating time_ids & embeddings [LoRA] support wan i2v loras from the world. (#11025) * support wan i2v loras from the world. * remove copied from. * upates * add lora. Fix SD3 IPAdapter feature extractor (#11027) chore: fix help messages in advanced diffusion examples (#10923) Fix missing **kwargs in lora_pipeline.py (#11011) * Update lora_pipeline.py * Apply style fixes * fix-copies --------- Co-authored-by: hlky Co-authored-by: github-actions[bot] Fix for multi-GPU WAN inference (#10997) Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs Co-authored-by: Jimmy <39@๐Ÿ‡บ๐Ÿ‡ธ.com> [Refactor] Clean up import utils boilerplate (#11026) * update * update * update Use `output_size` in `repeat_interleave` (#11030) [hybrid inference ๐Ÿฏ๐Ÿ] Add VAE encode (#11017) * [hybrid inference ๐Ÿฏ๐Ÿ] Add VAE encode * _toctree: add vae encode * Add endpoints, tests * vae_encode docs * vae encode benchmarks * api reference * changelog * Update docs/source/en/hybrid_inference/overview.md Co-authored-by: Sayak Paul * update --------- Co-authored-by: Sayak Paul Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007) * Wan Pipeline scaling fix, type hint warning, multi generator fix * Apply suggestions from code review [LoRA] change to warning from info when notifying the users about a LoRA no-op (#11044) * move to warning. * test related changes. Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline (#10827) * Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline --------- Co-authored-by: YiYi Xu making ```formatted_images``` initialization compact (#10801) compact writing Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820) * get_1d_rotary_pos_embed support npu * Update src/diffusers/models/embeddings.py --------- Co-authored-by: Kai zheng Co-authored-by: hlky Co-authored-by: YiYi Xu [Tests] restrict memory tests for quanto for certain schemes. (#11052) * restrict memory tests for quanto for certain schemes. * Apply suggestions from code review Co-authored-by: Dhruv Nair * fixes * style --------- Co-authored-by: Dhruv Nair [LoRA] feat: support non-diffusers wan t2v loras. (#11059) feat: support non-diffusers wan t2v loras. [examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051) Fix: dtype mismatch of prompt embeddings in sd3 controlnet training Co-authored-by: Andreas Jรถrg Co-authored-by: Sayak Paul reverts accidental change that removes attn_mask in attn. Improves flโ€ฆ (#11065) reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop. Co-authored-by: Juan Acevedo Fix deterministic issue when getting pipeline dtype and device (#10696) Co-authored-by: Dhruv Nair [Tests] add requires peft decorator. (#11037) * add requires peft decorator. * install peft conditionally. * conditional deps. Co-authored-by: DN6 --------- Co-authored-by: DN6 CogView4 Control Block (#10809) * cogview4 control training --------- Co-authored-by: OleehyO Co-authored-by: yiyixuxu [CI] pin transformers version for benchmarking. (#11067) pin transformers version for benchmarking. updates Fix Wan I2V Quality (#11087) * fix_wan_i2v_quality * Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu * Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu * Update pipeline_wan_i2v.py --------- Co-authored-by: YiYi Xu Co-authored-by: hlky LTX 0.9.5 (#10968) * update --------- Co-authored-by: YiYi Xu Co-authored-by: hlky make PR GPU tests conditioned on styling. (#11099) Group offloading improvements (#11094) update Fix pipeline_flux_controlnet.py (#11095) * Fix pipeline_flux_controlnet.py * Fix style update readme instructions. (#11096) Co-authored-by: Juan Acevedo Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098) Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP Fix Group offloading behaviour when using streams (#11097) * update * update Quality options in `export_to_video` (#11090) * Quality options in `export_to_video` * make style improve more. add placeholders for docstrings. formatting. smol fix. solidify validation and annotation --- .github/workflows/benchmark.yml | 1 + .github/workflows/nightly_tests.yml | 9 + .github/workflows/pr_tests_gpu.yml | 44 + docs/source/en/_toctree.yml | 4 + docs/source/en/api/pipelines/ltx_video.md | 6 + docs/source/en/api/pipelines/lumina.md | 14 +- docs/source/en/api/pipelines/lumina2.md | 12 +- docs/source/en/api/pipelines/wan.md | 4 + docs/source/en/api/quantization.md | 5 + .../en/hybrid_inference/api_reference.md | 4 + docs/source/en/hybrid_inference/overview.md | 10 +- docs/source/en/hybrid_inference/vae_encode.md | 183 ++ docs/source/en/quantization/overview.md | 1 + docs/source/en/quantization/quanto.md | 148 ++ docs/source/en/quantization/torchao.md | 2 +- .../README_flux.md | 4 +- .../train_dreambooth_lora_flux_advanced.py | 4 +- .../train_dreambooth_lora_sd15_advanced.py | 2 +- .../train_dreambooth_lora_sdxl_advanced.py | 4 +- examples/cogview4-control/README.md | 201 ++ examples/cogview4-control/requirements.txt | 6 + .../train_control_cogview4.py | 1242 +++++++++ examples/community/mixture_tiling_sdxl.py | 44 +- examples/controlnet/train_controlnet.py | 4 +- examples/controlnet/train_controlnet_flux.py | 4 +- examples/controlnet/train_controlnet_sd3.py | 4 +- examples/controlnet/train_controlnet_sdxl.py | 4 +- examples/research_projects/anytext/README.md | 32 + examples/research_projects/anytext/anytext.py | 2360 +++++++++++++++++ .../anytext/anytext_controlnet.py | 463 ++++ .../anytext/ocr_recog/RNN.py | 209 ++ .../anytext/ocr_recog/RecCTCHead.py | 45 + .../anytext/ocr_recog/RecModel.py | 49 + .../anytext/ocr_recog/RecMv1_enhance.py | 197 ++ .../anytext/ocr_recog/RecSVTR.py | 570 ++++ .../anytext/ocr_recog/common.py | 74 + .../anytext/ocr_recog/en_dict.txt | 95 + .../controlnet/train_controlnet_webdataset.py | 4 +- .../pixart/train_pixart_controlnet_hf.py | 4 +- .../pytorch_xla/inference/flux/README.md | 168 +- .../inference/flux/flux_inference.py | 28 +- .../t2i_adapter/train_t2i_adapter_sdxl.py | 4 +- scripts/convert_cogview4_to_diffusers.py | 15 +- .../convert_cogview4_to_diffusers_megatron.py | 66 +- scripts/convert_ltx_to_diffusers.py | 104 +- scripts/convert_lumina_to_diffusers.py | 4 +- setup.py | 9 + src/diffusers/__init__.py | 96 +- src/diffusers/dependency_versions_table.py | 4 + src/diffusers/hooks/group_offloading.py | 57 +- src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/ip_adapter.py | 4 +- src/diffusers/loaders/lora_base.py | 164 +- .../loaders/lora_conversion_utils.py | 53 + src/diffusers/loaders/lora_pipeline.py | 653 +++-- src/diffusers/loaders/peft.py | 15 +- src/diffusers/loaders/single_file_model.py | 5 + src/diffusers/loaders/single_file_utils.py | 115 + src/diffusers/models/attention_processor.py | 18 +- .../models/autoencoders/autoencoder_dc.py | 6 +- .../autoencoders/autoencoder_kl_allegro.py | 2 +- .../models/autoencoders/autoencoder_kl_ltx.py | 237 +- .../autoencoders/autoencoder_kl_mochi.py | 4 +- .../controlnets/controlnet_sparsectrl.py | 2 +- src/diffusers/models/embeddings.py | 11 +- src/diffusers/models/model_loading_utils.py | 7 +- src/diffusers/models/resnet.py | 2 +- .../transformers/latte_transformer_3d.py | 18 +- .../models/transformers/prior_transformer.py | 6 +- .../models/transformers/sana_transformer.py | 4 +- .../transformers/transformer_cogview4.py | 58 +- .../models/transformers/transformer_ltx.py | 62 +- .../models/transformers/transformer_wan.py | 8 + .../models/unets/unet_3d_condition.py | 6 +- src/diffusers/models/unets/unet_i2vgen_xl.py | 4 +- .../models/unets/unet_motion_model.py | 7 +- .../unets/unet_spatio_temporal_condition.py | 6 +- src/diffusers/pipelines/__init__.py | 16 +- src/diffusers/pipelines/auto_pipeline.py | 11 +- src/diffusers/pipelines/cogview4/__init__.py | 2 + .../pipelines/cogview4/pipeline_cogview4.py | 21 +- .../cogview4/pipeline_cogview4_control.py | 727 +++++ .../flux/pipeline_flux_controlnet.py | 1 + src/diffusers/pipelines/ltx/__init__.py | 2 + src/diffusers/pipelines/ltx/pipeline_ltx.py | 3 +- .../pipelines/ltx/pipeline_ltx_condition.py | 1174 ++++++++ .../pipelines/ltx/pipeline_ltx_image2video.py | 3 +- src/diffusers/pipelines/lumina/__init__.py | 4 +- .../pipelines/lumina/pipeline_lumina.py | 29 +- src/diffusers/pipelines/lumina2/__init__.py | 4 +- .../pipelines/lumina2/pipeline_lumina2.py | 27 +- .../pipelines/pipeline_loading_utils.py | 14 + src/diffusers/pipelines/pipeline_utils.py | 13 +- .../pipelines/wan/pipeline_wan_i2v.py | 17 +- src/diffusers/quantizers/__init__.py | 158 ++ src/diffusers/quantizers/auto.py | 4 + .../quantizers/bitsandbytes/bnb_quantizer.py | 2 + .../quantizers/gguf/gguf_quantizer.py | 1 + .../quantizers/quantization_config.py | 36 + src/diffusers/quantizers/quanto/__init__.py | 1 + .../quantizers/quanto/quanto_quantizer.py | 177 ++ src/diffusers/quantizers/quanto/utils.py | 60 + .../quantizers/torchao/torchao_quantizer.py | 47 +- .../scheduling_flow_match_euler_discrete.py | 23 +- src/diffusers/utils/__init__.py | 3 + src/diffusers/utils/constants.py | 11 + .../utils/dummy_bitsandbytes_objects.py | 17 + src/diffusers/utils/dummy_gguf_objects.py | 17 + .../utils/dummy_optimum_quanto_objects.py | 17 + .../dummy_torch_and_transformers_objects.py | 60 + src/diffusers/utils/dummy_torchao_objects.py | 17 + src/diffusers/utils/export_utils.py | 31 +- src/diffusers/utils/import_utils.py | 328 +-- src/diffusers/utils/remote_utils.py | 103 +- src/diffusers/utils/testing_utils.py | 16 + tests/lora/test_lora_layers_cogview4.py | 174 ++ tests/lora/test_lora_layers_flux.py | 7 +- tests/lora/utils.py | 44 + tests/pipelines/ltx/test_ltx_condition.py | 284 ++ tests/pipelines/lumina/test_lumina_nextdit.py | 22 +- .../lumina2/test_pipeline_lumina2.py | 12 +- tests/pipelines/test_pipeline_utils.py | 103 +- tests/quantization/__init__.py | 0 tests/quantization/bnb/test_4bit.py | 59 +- tests/quantization/bnb/test_mixed_int8.py | 55 +- tests/quantization/quanto/__init__.py | 0 tests/quantization/quanto/test_quanto.py | 328 +++ tests/quantization/torchao/__init__.py | 0 tests/quantization/torchao/test_torchao.py | 38 +- tests/quantization/utils.py | 38 + tests/remote/test_remote_decode.py | 31 +- tests/remote/test_remote_encode.py | 224 ++ tests/single_file/test_sana_transformer.py | 61 + 133 files changed, 11902 insertions(+), 881 deletions(-) create mode 100644 docs/source/en/hybrid_inference/vae_encode.md create mode 100644 docs/source/en/quantization/quanto.md create mode 100644 examples/cogview4-control/README.md create mode 100644 examples/cogview4-control/requirements.txt create mode 100644 examples/cogview4-control/train_control_cogview4.py create mode 100644 examples/research_projects/anytext/README.md create mode 100644 examples/research_projects/anytext/anytext.py create mode 100644 examples/research_projects/anytext/anytext_controlnet.py create mode 100755 examples/research_projects/anytext/ocr_recog/RNN.py create mode 100755 examples/research_projects/anytext/ocr_recog/RecCTCHead.py create mode 100755 examples/research_projects/anytext/ocr_recog/RecModel.py create mode 100644 examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py create mode 100644 examples/research_projects/anytext/ocr_recog/RecSVTR.py create mode 100644 examples/research_projects/anytext/ocr_recog/common.py create mode 100644 examples/research_projects/anytext/ocr_recog/en_dict.txt create mode 100644 src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py create mode 100644 src/diffusers/pipelines/ltx/pipeline_ltx_condition.py create mode 100644 src/diffusers/quantizers/quanto/__init__.py create mode 100644 src/diffusers/quantizers/quanto/quanto_quantizer.py create mode 100644 src/diffusers/quantizers/quanto/utils.py create mode 100644 src/diffusers/utils/dummy_bitsandbytes_objects.py create mode 100644 src/diffusers/utils/dummy_gguf_objects.py create mode 100644 src/diffusers/utils/dummy_optimum_quanto_objects.py create mode 100644 src/diffusers/utils/dummy_torchao_objects.py create mode 100644 tests/lora/test_lora_layers_cogview4.py create mode 100644 tests/pipelines/ltx/test_ltx_condition.py create mode 100644 tests/quantization/__init__.py create mode 100644 tests/quantization/quanto/__init__.py create mode 100644 tests/quantization/quanto/test_quanto.py create mode 100644 tests/quantization/torchao/__init__.py create mode 100644 tests/quantization/utils.py create mode 100644 tests/remote/test_remote_encode.py create mode 100644 tests/single_file/test_sana_transformer.py diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index d311c1c73f11..ff915e046946 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -38,6 +38,7 @@ jobs: python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] python -m uv pip install pandas peft + python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0 - name: Environment run: | python utils/print_env.py diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index a40be8558499..2b39eea2fe5d 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -414,10 +414,16 @@ jobs: config: - backend: "bitsandbytes" test_location: "bnb" + additional_deps: ["peft"] - backend: "gguf" test_location: "gguf" + additional_deps: [] - backend: "torchao" test_location: "torchao" + additional_deps: [] + - backend: "optimum_quanto" + test_location: "quanto" + additional_deps: [] runs-on: group: aws-g6e-xlarge-plus container: @@ -435,6 +441,9 @@ jobs: python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] python -m uv pip install -U ${{ matrix.config.backend }} + if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then + python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }} + fi python -m uv pip install pytest-reportlog - name: Environment run: | diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml index 82f824c8f192..d86eccc28bb5 100644 --- a/.github/workflows/pr_tests_gpu.yml +++ b/.github/workflows/pr_tests_gpu.yml @@ -28,7 +28,51 @@ env: PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run jobs: + check_code_quality: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Check quality + run: make quality + - name: Check if failure + if: ${{ failure() }} + run: | + echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY + + check_repository_consistency: + needs: check_code_quality + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Check repo consistency + run: | + python utils/check_copies.py + python utils/check_dummies.py + python utils/check_support_list.py + make deps_table_check_updated + - name: Check if failure + if: ${{ failure() }} + run: | + echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY + setup_torch_cuda_pipeline_matrix: + needs: [check_code_quality, check_repository_consistency] name: Setup Torch Pipelines CUDA Slow Tests Matrix runs-on: group: aws-general-8-plus diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 9438fe1a55e1..d1805ff605d8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -81,6 +81,8 @@ title: Overview - local: hybrid_inference/vae_decode title: VAE Decode + - local: hybrid_inference/vae_encode + title: VAE Encode - local: hybrid_inference/api_reference title: API Reference title: Hybrid Inference @@ -173,6 +175,8 @@ title: gguf - local: quantization/torchao title: torchao + - local: quantization/quanto + title: quanto title: Quantization Methods - sections: - local: optimization/fp16 diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index f31c621293fc..4bc22c0f9f6c 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24) - all - __call__ +## LTXConditionPipeline + +[[autodoc]] LTXConditionPipeline + - all + - __call__ + ## LTXPipelineOutput [[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput diff --git a/docs/source/en/api/pipelines/lumina.md b/docs/source/en/api/pipelines/lumina.md index 1967e85f173a..ce5cf8b103cc 100644 --- a/docs/source/en/api/pipelines/lumina.md +++ b/docs/source/en/api/pipelines/lumina.md @@ -58,10 +58,10 @@ Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fa First, load the pipeline: ```python -from diffusers import LuminaText2ImgPipeline +from diffusers import LuminaPipeline import torch -pipeline = LuminaText2ImgPipeline.from_pretrained( +pipeline = LuminaPipeline.from_pretrained( "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 ).to("cuda") ``` @@ -86,11 +86,11 @@ image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit w Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model. -Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaText2ImgPipeline`] for inference with bitsandbytes. +Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaPipeline`] for inference with bitsandbytes. ```py import torch -from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaText2ImgPipeline +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel quant_config = BitsAndBytesConfig(load_in_8bit=True) @@ -109,7 +109,7 @@ transformer_8bit = Transformer2DModel.from_pretrained( torch_dtype=torch.float16, ) -pipeline = LuminaText2ImgPipeline.from_pretrained( +pipeline = LuminaPipeline.from_pretrained( "Alpha-VLLM/Lumina-Next-SFT-diffusers", text_encoder=text_encoder_8bit, transformer=transformer_8bit, @@ -122,9 +122,9 @@ image = pipeline(prompt).images[0] image.save("lumina.png") ``` -## LuminaText2ImgPipeline +## LuminaPipeline -[[autodoc]] LuminaText2ImgPipeline +[[autodoc]] LuminaPipeline - all - __call__ diff --git a/docs/source/en/api/pipelines/lumina2.md b/docs/source/en/api/pipelines/lumina2.md index cf04bc17e3ef..57f0e8e2105d 100644 --- a/docs/source/en/api/pipelines/lumina2.md +++ b/docs/source/en/api/pipelines/lumina2.md @@ -36,14 +36,14 @@ Single file loading for Lumina Image 2.0 is available for the `Lumina2Transforme ```python import torch -from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline +from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth" transformer = Lumina2Transformer2DModel.from_single_file( ckpt_path, torch_dtype=torch.bfloat16 ) -pipe = Lumina2Text2ImgPipeline.from_pretrained( +pipe = Lumina2Pipeline.from_pretrained( "Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16 ) pipe.enable_model_cpu_offload() @@ -60,7 +60,7 @@ image.save("lumina-single-file.png") GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig` ```python -from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig +from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline, GGUFQuantizationConfig ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf" transformer = Lumina2Transformer2DModel.from_single_file( @@ -69,7 +69,7 @@ transformer = Lumina2Transformer2DModel.from_single_file( torch_dtype=torch.bfloat16, ) -pipe = Lumina2Text2ImgPipeline.from_pretrained( +pipe = Lumina2Pipeline.from_pretrained( "Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16 ) pipe.enable_model_cpu_offload() @@ -80,8 +80,8 @@ image = pipe( image.save("lumina-gguf.png") ``` -## Lumina2Text2ImgPipeline +## Lumina2Pipeline -[[autodoc]] Lumina2Text2ImgPipeline +[[autodoc]] Lumina2Pipeline - all - __call__ diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index b16bf92a6370..a35b73cb8a2e 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -14,6 +14,10 @@ # Wan +
+ LoRA +
+ [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team. diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 168a9a03473f..2c728cff3c07 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui ## GGUFQuantizationConfig [[autodoc]] GGUFQuantizationConfig + +## QuantoConfig + +[[autodoc]] QuantoConfig + ## TorchAoConfig [[autodoc]] TorchAoConfig diff --git a/docs/source/en/hybrid_inference/api_reference.md b/docs/source/en/hybrid_inference/api_reference.md index aa0a5e5ae58f..865aaba5ebb6 100644 --- a/docs/source/en/hybrid_inference/api_reference.md +++ b/docs/source/en/hybrid_inference/api_reference.md @@ -3,3 +3,7 @@ ## Remote Decode [[autodoc]] utils.remote_utils.remote_decode + +## Remote Encode + +[[autodoc]] utils.remote_utils.remote_encode diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md index 9bbe245901df..b44393c77cbd 100644 --- a/docs/source/en/hybrid_inference/overview.md +++ b/docs/source/en/hybrid_inference/overview.md @@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir ## Available Models * **VAE Decode ๐Ÿ–ผ๏ธ:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed. -* **VAE Encode ๐Ÿ”ข (coming soon):** Efficiently encode images into latent representations for generation and training. +* **VAE Encode ๐Ÿ”ข:** Efficiently encode images into latent representations for generation and training. * **Text Encoders ๐Ÿ“ƒ (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow. --- @@ -46,9 +46,15 @@ Hybrid Inference offers a fast and simple way to offload local generation requir * **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. * **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. +## Changelog + +- March 10 2025: Added VAE encode +- March 2 2025: Initial release with VAE decoding + ## Contents -The documentation is organized into two sections: +The documentation is organized into three sections: * **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference. +* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference. * **API Reference** Dive into task-specific settings and parameters. diff --git a/docs/source/en/hybrid_inference/vae_encode.md b/docs/source/en/hybrid_inference/vae_encode.md new file mode 100644 index 000000000000..dd285fa25c03 --- /dev/null +++ b/docs/source/en/hybrid_inference/vae_encode.md @@ -0,0 +1,183 @@ +# Getting Started: VAE Encode with Hybrid Inference + +VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations. + +## Memory + +These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs. + +For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality. + +
SD v1.5 + +| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | +|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:| +| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 | +| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 | +| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 | +| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 | +| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 | +| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 | +| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 | +| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 | +| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 | +| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 | +| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 | +| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 | +| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 | + + +
+ +
SDXL + +| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | +|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:| +| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 | +| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 | +| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 | +| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 | +| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 | +| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 | +| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 | +| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 | +| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 | +| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 | +| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 | +| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 | +| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 | + +
+ +## Available VAEs + +| | **Endpoint** | **Model** | +|:-:|:-----------:|:--------:| +| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) | +| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | +| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | + + +> [!TIP] +> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml). + + +## Code + +> [!TIP] +> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main` + + +A helper method simplifies interacting with Hybrid Inference. + +```python +from diffusers.utils.remote_utils import remote_encode +``` + +### Basic example + +Let's encode an image, then decode it to demonstrate. + +
+ +
+ +
Code + +```python +from diffusers.utils import load_image +from diffusers.utils.remote_utils import remote_decode + +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true") + +latent = remote_encode( + endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/", + scaling_factor=0.3611, + shift_factor=0.1159, +) + +decoded = remote_decode( + endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + scaling_factor=0.3611, + shift_factor=0.1159, +) +``` + +
+ +
+ +
+ + +### Generation + +Now let's look at a generation example, we'll encode the image, generate then remotely decode too! + +
Code + +```python +import torch +from diffusers import StableDiffusionImg2ImgPipeline +from diffusers.utils import load_image +from diffusers.utils.remote_utils import remote_decode, remote_encode + +pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", + torch_dtype=torch.float16, + variant="fp16", + vae=None, +).to("cuda") + +init_image = load_image( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +) +init_image = init_image.resize((768, 512)) + +init_latent = remote_encode( + endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/", + image=init_image, + scaling_factor=0.18215, +) + +prompt = "A fantasy landscape, trending on artstation" +latent = pipe( + prompt=prompt, + image=init_latent, + strength=0.75, + output_type="latent", +).images + +image = remote_decode( + endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + scaling_factor=0.18215, +) +image.save("fantasy_landscape.jpg") +``` + +
+ +
+ +
+ +## Integrations + +* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. +* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 794098e210a6..93323f86c7fc 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods. - [BitsandBytes](./bitsandbytes) - [TorchAO](./torchao) - [GGUF](./gguf) +- [Quanto](./quanto.md) [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. diff --git a/docs/source/en/quantization/quanto.md b/docs/source/en/quantization/quanto.md new file mode 100644 index 000000000000..d322d76be267 --- /dev/null +++ b/docs/source/en/quantization/quanto.md @@ -0,0 +1,148 @@ + + +# Quanto + +[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind: + +- All features are available in eager mode (works with non-traceable models) +- Supports quantization aware training +- Quantized models are compatible with `torch.compile` +- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU) + +In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate` + +```shell +pip install optimum-quanto accelerate +``` + +Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto. + +```python +import torch +from diffusers import FluxTransformer2DModel, QuantoConfig + +model_id = "black-forest-labs/FLUX.1-dev" +quantization_config = QuantoConfig(weights_dtype="float8") +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) + +pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype) +pipe.to("cuda") + +prompt = "A cat holding a sign that says hello world" +image = pipe( + prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 +).images[0] +image.save("output.png") +``` + +## Skipping Quantization on specific modules + +It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict` + +```python +import torch +from diffusers import FluxTransformer2DModel, QuantoConfig + +model_id = "black-forest-labs/FLUX.1-dev" +quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"]) +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) +``` + +## Using `from_single_file` with the Quanto Backend + +`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`. + +```python +import torch +from diffusers import FluxTransformer2DModel, QuantoConfig + +ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" +quantization_config = QuantoConfig(weights_dtype="float8") +transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16) +``` + +## Saving Quantized models + +Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method. + +The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized +with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained` + +```python +import torch +from diffusers import FluxTransformer2DModel, QuantoConfig + +model_id = "black-forest-labs/FLUX.1-dev" +quantization_config = QuantoConfig(weights_dtype="float8") +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) +# save quantized model to reuse +transformer.save_pretrained("") + +# you can reload your quantized model with +model = FluxTransformer2DModel.from_pretrained("") +``` + +## Using `torch.compile` with Quanto + +Currently the Quanto backend supports `torch.compile` for the following quantization types: + +- `int8` weights + +```python +import torch +from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig + +model_id = "black-forest-labs/FLUX.1-dev" +quantization_config = QuantoConfig(weights_dtype="int8") +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) +transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) + +pipe = FluxPipeline.from_pretrained( + model_id, transformer=transformer, torch_dtype=torch_dtype +) +pipe.to("cuda") +images = pipe("A cat holding a sign that says hello").images[0] +images.save("flux-quanto-compile.png") +``` + +## Supported Quantization Types + +### Weights + +- float8 +- int8 +- int4 +- int2 + + diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index c056876c2f09..19a8970fa9df 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0] image.save("output.png") ``` -Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. +If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. ```python import torch diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index 1f83235ad50a..f2a571d5eae4 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -79,13 +79,13 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t ### Target Modules When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore -applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string +applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string the exact modules for LoRA training. Here are some examples of target modules you can provide: - for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"` - to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"` - to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"` > [!NOTE] -> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string: +> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string: > **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k` > **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k` > [!NOTE] diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 7cb0d666fe69..b8194507d822 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -378,7 +378,7 @@ def parse_args(input_args=None): default=None, help="the concept to use to initialize the new inserted tokens when training with " "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. " - "Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. " + "Alternatively, you could specify a different word/words whose value will be used as the starting point for the new inserted tokens. " "--num_new_tokens_per_abstraction is ignored when initializer_concept is provided", ) parser.add_argument( @@ -662,7 +662,7 @@ def parse_args(input_args=None): type=str, default=None, help=( - "The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. " + "The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. " 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md' ), ) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 41ab1eb660d7..8cd1d777c00c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -662,7 +662,7 @@ def parse_args(input_args=None): action="store_true", default=False, help=( - "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " + "Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`" ), ) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 5ec028026364..38b6e8dab209 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -773,7 +773,7 @@ def parse_args(input_args=None): action="store_true", default=False, help=( - "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " + "Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`" ), ) @@ -1875,7 +1875,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip): # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. - # if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion + # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion add_special_tokens = True if args.train_text_encoder_ti else False if not train_dataset.custom_instance_prompts: diff --git a/examples/cogview4-control/README.md b/examples/cogview4-control/README.md new file mode 100644 index 000000000000..746a99a1a41b --- /dev/null +++ b/examples/cogview4-control/README.md @@ -0,0 +1,201 @@ +# Training CogView4 Control + +This (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources: + +To incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`. + +> [!NOTE] +> **Gated model** +> +> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youโ€™ve accepted the gate. Use the command below to log in: + +```bash +huggingface-cli login +``` + +The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them. + +```bash +accelerate launch train_control_lora_cogview4.py \ + --pretrained_model_name_or_path="THUDM/CogView4-6B" \ + --dataset_name="raulc0399/open_pose_controlnet" \ + --output_dir="pose-control-lora" \ + --mixed_precision="bf16" \ + --train_batch_size=1 \ + --rank=64 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --use_8bit_adam \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=5000 \ + --validation_image="openpose.png" \ + --validation_prompt="A couple, 4k photo, highly detailed" \ + --offload \ + --seed="0" \ + --push_to_hub +``` + +`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png). + +You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`. + +The training script exposes additional CLI args that might be useful to experiment with: + +* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer. +* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading. +* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached. + +### Training with DeepSpeed + +It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed): + +```yaml +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +``` + +And then while launching training, pass the config file: + +```bash +accelerate launch --config_file=CONFIG_FILE.yaml ... +``` + +### Inference + +The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first: + +```bash +pip install controlnet_aux +``` + +And then we are ready: + +```py +from controlnet_aux import OpenposeDetector +from diffusers import CogView4ControlPipeline +from diffusers.utils import load_image +from PIL import Image +import numpy as np +import torch + +pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("...") # change this. + +open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") + +# prepare pose condition. +url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg" +image = load_image(url) +image = open_pose(image, detect_resolution=512, image_resolution=1024) +image = np.array(image)[:, :, ::-1] +image = Image.fromarray(np.uint8(image)) + +prompt = "A couple, 4k photo, highly detailed" + +gen_images = pipe( + prompt=prompt, + control_image=image, + num_inference_steps=50, + joint_attention_kwargs={"scale": 0.9}, + guidance_scale=25., +).images[0] +gen_images.save("output.png") +``` + +## Full fine-tuning + +We provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command: + +```bash +accelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \ + --pretrained_model_name_or_path="THUDM/CogView4-6B" \ + --dataset_name="raulc0399/open_pose_controlnet" \ + --output_dir="pose-control" \ + --mixed_precision="bf16" \ + --train_batch_size=2 \ + --dataloader_num_workers=4 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --use_8bit_adam \ + --proportion_empty_prompts=0.2 \ + --learning_rate=5e-5 \ + --adam_weight_decay=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="cosine" \ + --lr_warmup_steps=1000 \ + --checkpointing_steps=1000 \ + --max_train_steps=10000 \ + --validation_steps=200 \ + --validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \ + --validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \ + --offload \ + --seed="0" \ + --push_to_hub +``` + +Change the `validation_image` and `validation_prompt` as needed. + +For inference, this time, we will run: + +```py +from controlnet_aux import OpenposeDetector +from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel +from diffusers.utils import load_image +from PIL import Image +import numpy as np +import torch + +transformer = CogView4Transformer2DModel.from_pretrained("...") # change this. +pipe = CogView4ControlPipeline.from_pretrained( + "THUDM/CogView4-6B", transformer=transformer, torch_dtype=torch.bfloat16 +).to("cuda") + +open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") + +# prepare pose condition. +url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg" +image = load_image(url) +image = open_pose(image, detect_resolution=512, image_resolution=1024) +image = np.array(image)[:, :, ::-1] +image = Image.fromarray(np.uint8(image)) + +prompt = "A couple, 4k photo, highly detailed" + +gen_images = pipe( + prompt=prompt, + control_image=image, + num_inference_steps=50, + guidance_scale=25., +).images[0] +gen_images.save("output.png") +``` + +## Things to note + +* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community ๐Ÿค— +* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified. +* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality. \ No newline at end of file diff --git a/examples/cogview4-control/requirements.txt b/examples/cogview4-control/requirements.txt new file mode 100644 index 000000000000..6c5ec2e03f9a --- /dev/null +++ b/examples/cogview4-control/requirements.txt @@ -0,0 +1,6 @@ +transformers==4.47.0 +wandb +torch +torchvision +accelerate==1.2.0 +peft>=0.14.0 diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py new file mode 100644 index 000000000000..506ca0225bf7 --- /dev/null +++ b/examples/cogview4-control/train_control_cogview4.py @@ -0,0 +1,1242 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import logging +import math +import os +import random +import shutil +from contextlib import nullcontext +from pathlib import Path + +import accelerate +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm + +import diffusers +from diffusers import ( + AutoencoderKL, + CogView4ControlPipeline, + CogView4Transformer2DModel, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, +) +from diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.33.0.dev0") + +logger = get_logger(__name__) + +NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] + + +def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype): + pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample() + pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor + return pixel_latents.to(weight_dtype) + + +def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, is_final_validation=False): + logger.info("Running validation... ") + + if not is_final_validation: + cogview4_transformer = accelerator.unwrap_model(cogview4_transformer) + pipeline = CogView4ControlPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=cogview4_transformer, + torch_dtype=weight_dtype, + ) + else: + transformer = CogView4Transformer2DModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + pipeline = CogView4ControlPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=transformer, + torch_dtype=weight_dtype, + ) + + pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + if is_final_validation or torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype) + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = load_image(validation_image) + # maybe need to inference on 1024 to get a good image + validation_image = validation_image.resize((args.resolution, args.resolution)) + + images = [] + + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt=validation_prompt, + control_image=validation_image, + num_inference_steps=50, + guidance_scale=args.guidance_scale, + max_sequence_length=args.max_sequence_length, + generator=generator, + height=args.resolution, + width=args.resolution, + ).images[0] + image = image.resize((args.resolution, args.resolution)) + images.append(image) + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + formatted_images = [] + formatted_images.append(np.asarray(validation_image)) + for image in images: + formatted_images.append(np.asarray(image)) + formatted_images = np.stack(formatted_images) + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + + elif tracker.name == "wandb": + formatted_images = [] + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + formatted_images.append(wandb.Image(validation_image, caption="Conditioning")) + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + del pipeline + free_memory() + return image_logs + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" +# cogview4-control-{repo_id} + +These are Control weights trained on {base_model} with new type of conditioning. +{img_str} + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogView4-6b/blob/main/LICENSE.md) +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "cogview4", + "cogview4-diffusers", + "text-to-image", + "diffusers", + "control", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a CogView4 Control training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="cogview4-control", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--max_sequence_length", type=int, default=128, help="The maximum sequence length for the prompt." + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that ๐Ÿค— Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the control conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.") + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the control conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=1, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="cogview4_train_control", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument( + "--jsonl_for_train", + type=str, + default=None, + help="Path to the jsonl file containing the training data.", + ) + parser.add_argument( + "--only_target_transformer_blocks", + action="store_true", + help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the guidance scale used for transformer.", + ) + + parser.add_argument( + "--upcast_before_saving", + action="store_true", + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoders to CPU when they are not used.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.jsonl_for_train is None: + raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`") + + if args.dataset_name is not None and args.jsonl_for_train is not None: + raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the cogview4 transformer." + ) + + return args + + +def get_train_dataset(args, accelerator): + dataset = None + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + if args.jsonl_for_train is not None: + # load from json + dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir) + dataset = dataset.flatten_indices() + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + with accelerator.main_process_first(): + train_dataset = dataset["train"].shuffle(seed=args.seed) + if args.max_train_samples is not None: + train_dataset = train_dataset.select(range(args.max_train_samples)) + return train_dataset + + +def prepare_train_dataset(dataset, accelerator): + image_transforms = transforms.Compose( + [ + transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + transforms.Lambda(lambda x: x * 2 - 1), + ] + ) + + def preprocess_train(examples): + images = [ + (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB")) + for image in examples[args.image_column] + ] + images = [image_transforms(image) for image in images] + + conditioning_images = [ + (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB")) + for image in examples[args.conditioning_image_column] + ] + conditioning_images = [image_transforms(image) for image in conditioning_images] + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + + is_caption_list = isinstance(examples[args.caption_column][0], list) + if is_caption_list: + examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]] + else: + examples["captions"] = list(examples[args.caption_column]) + + return examples + + with accelerator.main_process_first(): + dataset = dataset.with_transform(preprocess_train) + + return dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + captions = [example["captions"] for example in examples] + return {"pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "captions": captions} + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_out_dir = Path(args.output_dir, args.logging_dir) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir)) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices. + if torch.backends.mps.is_available(): + logger.info("MPS is enabled. Disabling AMP.") + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + # DEBUG, INFO, WARNING, ERROR, CRITICAL + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load models. We will load the text encoders later in a pipeline to compute + # embeddings. + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + cogview4_transformer = CogView4Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + ) + logger.info("All models loaded successfully") + + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + if not args.only_target_transformer_blocks: + cogview4_transformer.requires_grad_(True) + vae.requires_grad_(False) + + # cast down and move to the CPU + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # let's not move the VAE to the GPU yet. + vae.to(dtype=torch.float32) # keep the VAE in float32. + + # enable image inputs + with torch.no_grad(): + patch_size = cogview4_transformer.config.patch_size + initial_input_channels = cogview4_transformer.config.in_channels * patch_size**2 + new_linear = torch.nn.Linear( + cogview4_transformer.patch_embed.proj.in_features * 2, + cogview4_transformer.patch_embed.proj.out_features, + bias=cogview4_transformer.patch_embed.proj.bias is not None, + dtype=cogview4_transformer.dtype, + device=cogview4_transformer.device, + ) + new_linear.weight.zero_() + new_linear.weight[:, :initial_input_channels].copy_(cogview4_transformer.patch_embed.proj.weight) + if cogview4_transformer.patch_embed.proj.bias is not None: + new_linear.bias.copy_(cogview4_transformer.patch_embed.proj.bias) + cogview4_transformer.patch_embed.proj = new_linear + + assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0) + cogview4_transformer.register_to_config( + in_channels=cogview4_transformer.config.in_channels * 2, out_channels=cogview4_transformer.config.in_channels + ) + + if args.only_target_transformer_blocks: + cogview4_transformer.patch_embed.proj.requires_grad_(True) + for name, module in cogview4_transformer.named_modules(): + if "transformer_blocks" in name: + module.requires_grad_(True) + else: + module.requirs_grad_(False) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(unwrap_model(model), type(unwrap_model(cogview4_transformer))): + model = unwrap_model(model) + model.save_pretrained(os.path.join(output_dir, "transformer")) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + def load_model_hook(models, input_dir): + transformer_ = None + + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(model), type(unwrap_model(cogview4_transformer))): + transformer_ = model # noqa: F841 + else: + raise ValueError(f"unexpected save model: {unwrap_model(model).__class__}") + + else: + transformer_ = CogView4Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") # noqa: F841 + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + cogview4_transformer.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimization parameters + optimizer = optimizer_class( + cogview4_transformer.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Prepare dataset and dataloader. + train_dataset = get_train_dataset(args, accelerator) + train_dataset = prepare_train_dataset(train_dataset, accelerator) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + # Prepare everything with our `accelerator`. + cogview4_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + cogview4_transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed. + text_encoding_pipeline = CogView4ControlPipeline.from_pretrained( + args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype + ) + tokenizer = text_encoding_pipeline.tokenizer + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + logger.info(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples: + logger.info("Logging some dataset samples.") + formatted_images = [] + formatted_control_images = [] + all_prompts = [] + for i, batch in enumerate(train_dataloader): + images = (batch["pixel_values"] + 1) / 2 + control_images = (batch["conditioning_pixel_values"] + 1) / 2 + prompts = batch["captions"] + + if len(formatted_images) > 10: + break + + for img, control_img, prompt in zip(images, control_images, prompts): + formatted_images.append(img) + formatted_control_images.append(control_img) + all_prompts.append(prompt) + + logged_artifacts = [] + for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts): + logged_artifacts.append(wandb.Image(control_img, caption="Conditioning")) + logged_artifacts.append(wandb.Image(img, caption=prompt)) + + wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"] + wandb_tracker[0].log({"dataset_samples": logged_artifacts}) + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + cogview4_transformer.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(cogview4_transformer): + # Convert images to latent space + # vae encode + prompts = batch["captions"] + attention_mask = tokenizer( + prompts, + padding="longest", # not use max length + max_length=args.max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).attention_mask.float() + + pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype) + control_latents = encode_images( + batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype + ) + if args.offload: + vae.cpu() + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + bsz = pixel_latents.shape[0] + noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype) + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + + # Add noise according for cogview4 + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device) + sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device) + captions = batch["captions"] + image_seq_lens = torch.tensor( + pixel_latents.shape[2] * pixel_latents.shape[3] // patch_size**2, + dtype=pixel_latents.dtype, + device=pixel_latents.device, + ) # H * W / VAE patch_size + mu = torch.sqrt(image_seq_lens / 256) + mu = mu * 0.75 + 0.25 + scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to( + dtype=pixel_latents.dtype, device=pixel_latents.device + ) + scale_factors = scale_factors.view(len(batch["captions"]), 1, 1, 1) + noisy_model_input = (1.0 - scale_factors) * pixel_latents + scale_factors * noise + concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1) + text_encoding_pipeline = text_encoding_pipeline.to("cuda") + + with torch.no_grad(): + ( + prompt_embeds, + pooled_prompt_embeds, + ) = text_encoding_pipeline.encode_prompt(captions, "") + original_size = (args.resolution, args.resolution) + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) + + target_size = (args.resolution, args.resolution) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) + + target_size = target_size.repeat(len(batch["captions"]), 1) + original_size = original_size.repeat(len(batch["captions"]), 1) + crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device) + crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1) + + # this could be optimized by not having to do any text encoding and just + # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds` + if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts: + # Here, we directly pass 16 pad tokens from pooled_prompt_embeds to prompt_embeds. + prompt_embeds = pooled_prompt_embeds + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + # Predict. + noise_pred_cond = cogview4_transformer( + hidden_states=concatenated_noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + attention_mask=attention_mask, + )[0] + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + # flow-matching loss + target = noise - pixel_latents + + weighting = weighting.view(len(batch["captions"]), 1, 1, 1) + loss = torch.mean( + (weighting.float() * (noise_pred_cond.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = cogview4_transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation( + cogview4_transformer=cogview4_transformer, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + cogview4_transformer = unwrap_model(cogview4_transformer) + if args.upcast_before_saving: + cogview4_transformer.to(torch.float32) + cogview4_transformer.save_pretrained(args.output_dir) + + del cogview4_transformer + del text_encoding_pipeline + del vae + free_memory() + + # Run a final round of validation. + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + cogview4_transformer=None, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*", "checkpoint-*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/community/mixture_tiling_sdxl.py b/examples/community/mixture_tiling_sdxl.py index f7b971bae841..bd56ddb3d61d 100644 --- a/examples/community/mixture_tiling_sdxl.py +++ b/examples/community/mixture_tiling_sdxl.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1070,32 +1070,32 @@ def __call__( text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim - add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left[row][col], - target_size, + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left[row][col], + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left[row][col], + negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = self._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left[row][col], - negative_target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - else: - negative_add_time_ids = add_time_ids + else: + negative_add_time_ids = add_time_ids - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids)) embeddings_and_added_time.append(addition_embed_type_row) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 65d6c14c5efc..aa235ad65bfe 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -152,9 +152,7 @@ def log_validation( validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images = [] - - formatted_images.append(np.asarray(validation_image)) + formatted_images = [np.asarray(validation_image)] for image in images: formatted_images.append(np.asarray(image)) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 7f93477fc5b7..a41615c7b546 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -166,9 +166,7 @@ def log_validation( validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images = [] - - formatted_images.append(np.asarray(validation_image)) + formatted_images = [np.asarray(validation_image)] for image in images: formatted_images.append(np.asarray(image)) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index f4aadc2577f7..ffe460d72de8 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -1283,8 +1283,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Get the text embedding for conditioning - prompt_embeds = batch["prompt_embeds"] - pooled_prompt_embeds = batch["pooled_prompt_embeds"] + prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype) + pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype) # controlnet(s) inference controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index b2d950e09ac1..17f313752989 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -157,9 +157,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images = [] - - formatted_images.append(np.asarray(validation_image)) + formatted_images = [np.asarray(validation_image)] for image in images: formatted_images.append(np.asarray(image)) diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md new file mode 100644 index 000000000000..f5f4fe59ddfd --- /dev/null +++ b/examples/research_projects/anytext/README.md @@ -0,0 +1,32 @@ +# AnyTextPipeline Pipeline + +Project page: https://aigcdesigngroup.github.io/homepage_anytext + +"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy." + +Each text line that needs to be generated should be enclosed in double quotes. For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054). + + +```py +import torch +from diffusers import DiffusionPipeline +from anytext_controlnet import AnyTextControlNetModel +from diffusers.utils import load_image + +# I chose a font file shared by an HF staff: +# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf + +anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, + variant="fp16",) +pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf", + controlnet=anytext_controlnet, torch_dtype=torch.float16, + trust_remote_code=False, # One needs to give permission to run this pipeline's code + ).to("cuda") + +# generate image +prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' +draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png") +image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos, + ).images[0] +image +``` diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py new file mode 100644 index 000000000000..518452f97942 --- /dev/null +++ b/examples/research_projects/anytext/anytext.py @@ -0,0 +1,2360 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright (c) Alibaba, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054). +# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie +# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license +# +# Adapted to Diffusers by [M. Tolga Cangรถz](https://github.com/tolgacangoz). + + +import inspect +import math +import os +import re +import sys +import unicodedata +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from huggingface_hub import hf_hub_download +from ocr_recog.RecModel import RecModel +from PIL import Image, ImageDraw, ImageFont +from safetensors.torch import load_file +from skimage.transform._geometric import _umeyama as get_sym_mat +from torch import nn +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.models.modeling_utils import ModelMixin +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.constants import HF_MODULES_CACHE +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor + + +class Checker: + def __init__(self): + pass + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) + ): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or self._is_control(char): + continue + if self._is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_control(self, char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat in ("Cc", "Cf"): + return True + return False + + def _is_whitespace(self, char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +checker = Checker() + + +PLACE_HOLDER = "*" +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from anytext_controlnet import AnyTextControlNetModel + >>> from diffusers.utils import load_image + + >>> # I chose a font file shared by an HF staff: + >>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf + + >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, + ... variant="fp16",) + >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf", + ... controlnet=anytext_controlnet, torch_dtype=torch.float16, + ... trust_remote_code=False, # One needs to give permission to run this pipeline's code + ... ).to("cuda") + + + >>> # generate image + >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' + >>> draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png") + >>> image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos, + ... ).images[0] + >>> image + ``` +""" + + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer( + string, + truncation=True, + max_length=77, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"] + assert ( + torch.count_nonzero(tokens - 49407) == 2 + ), f"String '{string}' maps to more than a single token. Please use another string" + return tokens[0, 1] + + +def get_recog_emb(encoder, img_list): + _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list] + encoder.predictor.eval() + _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) + return preds_neck + + +class EmbeddingManager(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + embedder, + placeholder_string="*", + use_fp16=False, + token_dim=768, + get_recog_emb=None, + ): + super().__init__() + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + + self.proj = nn.Linear(40 * 64, token_dim) + proj_dir = hf_hub_download( + repo_id="tolgacangoz/anytext", + filename="text_embedding_module/proj.safetensors", + cache_dir=HF_MODULES_CACHE, + ) + self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device))) + if use_fp16: + self.proj = self.proj.to(dtype=torch.float16) + + self.placeholder_token = get_token_for_string(placeholder_string) + + @torch.no_grad() + def encode_text(self, text_info): + if self.config.get_recog_emb is None: + self.config.get_recog_emb = partial(get_recog_emb, self.recog) + + gline_list = [] + for i in range(len(text_info["n_lines"])): # sample index in a batch + n_lines = text_info["n_lines"][i] + for j in range(n_lines): # line + gline_list += [text_info["gly_line"][j][i : i + 1]] + + if len(gline_list) > 0: + recog_emb = self.config.get_recog_emb(gline_list) + enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype)) + + self.text_embs_all = [] + n_idx = 0 + for i in range(len(text_info["n_lines"])): # sample index in a batch + n_lines = text_info["n_lines"][i] + text_embs = [] + for j in range(n_lines): # line + text_embs += [enc_glyph[n_idx : n_idx + 1]] + n_idx += 1 + self.text_embs_all += [text_embs] + + @torch.no_grad() + def forward( + self, + tokenized_text, + embedded_text, + ): + b, device = tokenized_text.shape[0], tokenized_text.device + for i in range(b): + idx = tokenized_text[i] == self.placeholder_token.to(device) + if sum(idx) > 0: + if i >= len(self.text_embs_all): + print("truncation for log images...") + break + text_emb = torch.cat(self.text_embs_all[i], dim=0) + if sum(idx) != len(text_emb): + print("truncation for long caption...") + text_emb = text_emb.to(embedded_text.device) + embedded_text[i][idx] = text_emb[: sum(idx)] + return embedded_text + + def embedding_parameters(self): + return self.parameters() + + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + +def min_bounding_rect(img): + ret, thresh = cv2.threshold(img, 127, 255, 0) + contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if len(contours) == 0: + print("Bad contours, using fake bbox...") + return np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) + max_contour = max(contours, key=cv2.contourArea) + rect = cv2.minAreaRect(max_contour) + box = cv2.boxPoints(rect) + box = np.int0(box) + # sort + x_sorted = sorted(box, key=lambda x: x[0]) + left = x_sorted[:2] + right = x_sorted[2:] + left = sorted(left, key=lambda x: x[1]) + (tl, bl) = left + right = sorted(right, key=lambda x: x[1]) + (tr, br) = right + if tl[1] > bl[1]: + (tl, bl) = (bl, tl) + if tr[1] > br[1]: + (tr, br) = (br, tr) + return np.array([tl, tr, br, bl]) + + +def adjust_image(box, img): + pts1 = np.float32([box[0], box[1], box[2], box[3]]) + width = max(np.linalg.norm(pts1[0] - pts1[1]), np.linalg.norm(pts1[2] - pts1[3])) + height = max(np.linalg.norm(pts1[0] - pts1[3]), np.linalg.norm(pts1[1] - pts1[2])) + pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]]) + # get transform matrix + M = get_sym_mat(pts1, pts2, estimate_scale=True) + C, H, W = img.shape + T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]]) + theta = np.linalg.inv(T @ M @ np.linalg.inv(T)) + theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device) + grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True) + result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True) + result = torch.clamp(result.squeeze(0), 0, 255) + # crop + result = result[:, : int(height), : int(width)] + return result + + +def crop_image(src_img, mask): + box = min_bounding_rect(mask) + result = adjust_image(box, src_img) + if len(result.shape) == 2: + result = torch.stack([result] * 3, axis=-1) + return result + + +def create_predictor(model_lang="ch", device="cpu", use_fp16=False): + model_dir = hf_hub_download( + repo_id="tolgacangoz/anytext", + filename="text_embedding_module/OCR/ppv3_rec.pth", + cache_dir=HF_MODULES_CACHE, + ) + if not os.path.exists(model_dir): + raise ValueError("not find model file path {}".format(model_dir)) + + if model_lang == "ch": + n_class = 6625 + elif model_lang == "en": + n_class = 97 + else: + raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") + rec_config = { + "in_channels": 3, + "backbone": {"type": "MobileNetV1Enhance", "scale": 0.5, "last_conv_stride": [1, 2], "last_pool_type": "avg"}, + "neck": { + "type": "SequenceEncoder", + "encoder_type": "svtr", + "dims": 64, + "depth": 2, + "hidden_dims": 120, + "use_guide": True, + }, + "head": {"type": "CTCHead", "fc_decay": 0.00001, "out_channels": n_class, "return_feats": True}, + } + + rec_model = RecModel(rec_config) + state_dict = torch.load(model_dir, map_location=device) + rec_model.load_state_dict(state_dict) + return rec_model + + +def _check_image_file(path): + img_end = ("tiff", "tif", "bmp", "rgb", "jpg", "png", "jpeg") + return path.lower().endswith(tuple(img_end)) + + +def get_image_file_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + if os.path.isfile(img_file) and _check_image_file(img_file): + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + file_path = os.path.join(img_file, single_file) + if os.path.isfile(file_path) and _check_image_file(file_path): + imgs_lists.append(file_path) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) + return imgs_lists + + +class TextRecognizer(object): + def __init__(self, args, predictor): + self.rec_image_shape = [int(v) for v in args["rec_image_shape"].split(",")] + self.rec_batch_num = args["rec_batch_num"] + self.predictor = predictor + self.chars = self.get_char_dict(args["rec_char_dict_path"]) + self.char2id = {x: i for i, x in enumerate(self.chars)} + self.is_onnx = not isinstance(self.predictor, torch.nn.Module) + self.use_fp16 = args["use_fp16"] + + # img: CHW + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + assert imgC == img.shape[0] + imgW = int((imgH * max_wh_ratio)) + + h, w = img.shape[1:] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = torch.nn.functional.interpolate( + img.unsqueeze(0), + size=(imgH, resized_w), + mode="bilinear", + align_corners=True, + ) + resized_image /= 255.0 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device) + padding_im[:, :, 0:resized_w] = resized_image[0] + return padding_im + + # img_list: list of tensors with shape chw 0-255 + def pred_imglist(self, img_list, show_debug=False): + img_num = len(img_list) + assert img_num > 0 + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[2] / float(img.shape[1])) + # Sorting can speed up the recognition process + indices = torch.from_numpy(np.argsort(np.array(width_list))) + batch_num = self.rec_batch_num + preds_all = [None] * img_num + preds_neck_all = [None] * img_num + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + + imgC, imgH, imgW = self.rec_image_shape[:3] + max_wh_ratio = imgW / imgH + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[1:] + if h > w * 1.2: + img = img_list[indices[ino]] + img = torch.transpose(img, 1, 2).flip(dims=[1]) + img_list[indices[ino]] = img + h, w = img.shape[1:] + # wh_ratio = w * 1.0 / h + # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) + if self.use_fp16: + norm_img = norm_img.half() + norm_img = norm_img.unsqueeze(0) + norm_img_batch.append(norm_img) + norm_img_batch = torch.cat(norm_img_batch, dim=0) + if show_debug: + for i in range(len(norm_img_batch)): + _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy() + _img = (_img + 0.5) * 255 + _img = _img[:, :, ::-1] + file_name = f"{indices[beg_img_no + i]}" + if os.path.exists(file_name + ".jpg"): + file_name += "_2" # ori image + cv2.imwrite(file_name + ".jpg", _img) + if self.is_onnx: + input_dict = {} + input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy() + outputs = self.predictor.run(None, input_dict) + preds = {} + preds["ctc"] = torch.from_numpy(outputs[0]) + preds["ctc_neck"] = [torch.zeros(1)] * img_num + else: + preds = self.predictor(norm_img_batch.to(next(self.predictor.parameters()).device)) + for rno in range(preds["ctc"].shape[0]): + preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno] + preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno] + + return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0) + + def get_char_dict(self, character_dict_path): + character_str = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode("utf-8").strip("\n").strip("\r\n") + character_str.append(line) + dict_character = list(character_str) + dict_character = ["sos"] + dict_character + [" "] # eos is space + return dict_character + + def get_text(self, order): + char_list = [self.chars[text_id] for text_id in order] + return "".join(char_list) + + def decode(self, mat): + text_index = mat.detach().cpu().numpy().argmax(axis=1) + ignored_tokens = [0] + selection = np.ones(len(text_index), dtype=bool) + selection[1:] = text_index[1:] != text_index[:-1] + for ignored_token in ignored_tokens: + selection &= text_index != ignored_token + return text_index[selection], np.where(selection)[0] + + def get_ctcloss(self, preds, gt_text, weight): + if not isinstance(weight, torch.Tensor): + weight = torch.tensor(weight).to(preds.device) + ctc_loss = torch.nn.CTCLoss(reduction="none") + log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC + targets = [] + target_lengths = [] + for t in gt_text: + targets += [self.char2id.get(i, len(self.chars) - 1) for i in t] + target_lengths += [len(t)] + targets = torch.tensor(targets).to(preds.device) + target_lengths = torch.tensor(target_lengths).to(preds.device) + input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(preds.device) + loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) + loss = loss / input_lengths * weight + return loss + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + @register_to_config + def __init__( + self, + device="cpu", + max_length=77, + freeze=True, + use_fp16=False, + variant: Optional[str] = None, + ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer") + self.transformer = CLIPTextModel.from_pretrained( + "tolgacangoz/anytext", + subfolder="text_encoder", + torch_dtype=torch.float16 if use_fp16 else torch.float32, + variant="fp16" if use_fp16 else None, + ) + + if freeze: + self.freeze() + + def embedding_forward( + self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + embedding_manager=None, + ): + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + if embedding_manager is not None: + inputs_embeds = embedding_manager(input_ids, inputs_embeds) + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + return embeddings + + self.transformer.text_model.embeddings.forward = embedding_forward.__get__( + self.transformer.text_model.embeddings + ) + + def encoder_forward( + self, + inputs_embeds, + attention_mask=None, + causal_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + return hidden_states + + self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) + + def text_encoder_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embedding_manager=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if input_ids is None: + raise ValueError("You have to specify either input_ids") + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + hidden_states = self.embeddings( + input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager + ) + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = self.final_layer_norm(last_hidden_state) + return last_hidden_state + + self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) + + def transformer_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + embedding_manager=None, + ): + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + embedding_manager=embedding_manager, + ) + + self.transformer.forward = transformer_forward.__get__(self.transformer) + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text, **kwargs): + batch_encoding = self.tokenizer( + text, + truncation=False, + max_length=self.config.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="longest", + return_tensors="pt", + ) + input_ids = batch_encoding["input_ids"] + tokens_list = self.split_chunks(input_ids) + z_list = [] + for tokens in tokens_list: + tokens = tokens.to(self.device) + _z = self.transformer(input_ids=tokens, **kwargs) + z_list += [_z] + return torch.cat(z_list, dim=1) + + def encode(self, text, **kwargs): + return self(text, **kwargs) + + def split_chunks(self, input_ids, chunk_size=75): + tokens_list = [] + bs, n = input_ids.shape + id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1] + id_end = input_ids[:, -1].unsqueeze(1) + if n == 2: # empty caption + tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1)) + + trimmed_encoding = input_ids[:, 1:-1] + num_full_groups = (n - 2) // chunk_size + + for i in range(num_full_groups): + group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size] + group_pad = torch.cat((id_start, group, id_end), dim=1) + tokens_list.append(group_pad) + + remaining_columns = (n - 2) % chunk_size + if remaining_columns > 0: + remaining_group = trimmed_encoding[:, -remaining_columns:] + padding_columns = chunk_size - remaining_group.shape[1] + padding = id_end.expand(bs, padding_columns) + remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1) + tokens_list.append(remaining_group_pad) + return tokens_list + + +class TextEmbeddingModule(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, font_path, use_fp16=False, device="cpu"): + super().__init__() + font = ImageFont.truetype(font_path, 60) + + self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16) + self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16) + self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval() + args = { + "rec_image_shape": "3, 48, 320", + "rec_batch_num": 6, + "rec_char_dict_path": hf_hub_download( + repo_id="tolgacangoz/anytext", + filename="text_embedding_module/OCR/ppocr_keys_v1.txt", + cache_dir=HF_MODULES_CACHE, + ), + "use_fp16": use_fp16, + } + self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) + + self.register_to_config(font=font) + + @torch.no_grad() + def forward( + self, + prompt, + texts, + negative_prompt, + num_images_per_prompt, + mode, + draw_pos, + sort_priority="โ†•", + max_chars=77, + revise_pos=False, + h=512, + w=512, + ): + if prompt is None and texts is None: + raise ValueError("Prompt or texts must be provided!") + # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) + if draw_pos is None: + pos_imgs = np.zeros((w, h, 1)) + if isinstance(draw_pos, PIL.Image.Image): + pos_imgs = np.array(draw_pos)[..., ::-1] + pos_imgs = 255 - pos_imgs + elif isinstance(draw_pos, str): + draw_pos = cv2.imread(draw_pos)[..., ::-1] + if draw_pos is None: + raise ValueError(f"Can't read draw_pos image from {draw_pos}!") + pos_imgs = 255 - draw_pos + elif isinstance(draw_pos, torch.Tensor): + pos_imgs = draw_pos.cpu().numpy() + else: + if not isinstance(draw_pos, np.ndarray): + raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") + if mode == "edit": + pos_imgs = cv2.resize(pos_imgs, (w, h)) + pos_imgs = pos_imgs[..., 0:1] + pos_imgs = cv2.convertScaleAbs(pos_imgs) + _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) + # separate pos_imgs + pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) + if len(pos_imgs) == 0: + pos_imgs = [np.zeros((h, w, 1))] + n_lines = len(texts) + if len(pos_imgs) < n_lines: + if n_lines == 1 and texts[0] == " ": + pass # text-to-image without text + else: + raise ValueError( + f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" + ) + elif len(pos_imgs) > n_lines: + str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." + logger.warning(str_warning) + # get pre_pos, poly_list, hint that needed for anytext + pre_pos = [] + poly_list = [] + for input_pos in pos_imgs: + if input_pos.mean() != 0: + input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos + poly, pos_img = self.find_polygon(input_pos) + pre_pos += [pos_img / 255.0] + poly_list += [poly] + else: + pre_pos += [np.zeros((h, w, 1))] + poly_list += [None] + np_hint = np.sum(pre_pos, axis=0).clip(0, 1) + # prepare info dict + text_info = {} + text_info["glyphs"] = [] + text_info["gly_line"] = [] + text_info["positions"] = [] + text_info["n_lines"] = [len(texts)] * num_images_per_prompt + for i in range(len(texts)): + text = texts[i] + if len(text) > max_chars: + str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' + logger.warning(str_warning) + text = text[:max_chars] + gly_scale = 2 + if pre_pos[i].mean() != 0: + gly_line = self.draw_glyph(self.config.font, text) + glyphs = self.draw_glyph2( + self.config.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False + ) + if revise_pos: + resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) + new_pos = cv2.morphologyEx( + (resize_gly * 255).astype(np.uint8), + cv2.MORPH_CLOSE, + kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), + iterations=1, + ) + new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos + contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if len(contours) != 1: + str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." + logger.warning(str_warning) + else: + rect = cv2.minAreaRect(contours[0]) + poly = np.int0(cv2.boxPoints(rect)) + pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 + else: + glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) + gly_line = np.zeros((80, 512, 1)) + pos = pre_pos[i] + text_info["glyphs"] += [self.arr2tensor(glyphs, num_images_per_prompt)] + text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)] + text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)] + + self.embedding_manager.encode_text(text_info) + prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager) + + self.embedding_manager.encode_text(text_info) + negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode( + [negative_prompt or ""], embedding_manager=self.embedding_manager + ) + + return prompt_embeds, negative_prompt_embeds, text_info, np_hint + + def arr2tensor(self, arr, bs): + arr = np.transpose(arr, (2, 0, 1)) + _arr = torch.from_numpy(arr.copy()).float().cpu() + if self.config.use_fp16: + _arr = _arr.half() + _arr = torch.stack([_arr for _ in range(bs)], dim=0) + return _arr + + def separate_pos_imgs(self, img, sort_priority, gap=102): + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) + components = [] + for label in range(1, num_labels): + component = np.zeros_like(img) + component[labels == label] = 255 + components.append((component, centroids[label])) + if sort_priority == "โ†•": + fir, sec = 1, 0 # top-down first + elif sort_priority == "โ†”": + fir, sec = 0, 1 # left-right first + else: + raise ValueError(f"Unknown sort_priority: {sort_priority}") + components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) + sorted_components = [c[0] for c in components] + return sorted_components + + def find_polygon(self, image, min_rect=False): + contours, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + max_contour = max(contours, key=cv2.contourArea) # get contour with max area + if min_rect: + # get minimum enclosing rectangle + rect = cv2.minAreaRect(max_contour) + poly = np.int0(cv2.boxPoints(rect)) + else: + # get approximate polygon + epsilon = 0.01 * cv2.arcLength(max_contour, True) + poly = cv2.approxPolyDP(max_contour, epsilon, True) + n, _, xy = poly.shape + poly = poly.reshape(n, xy) + cv2.drawContours(image, [poly], -1, 255, -1) + return poly, image + + def draw_glyph(self, font, text): + g_size = 50 + W, H = (512, 80) + new_font = font.font_variant(size=g_size) + img = Image.new(mode="1", size=(W, H), color=0) + draw = ImageDraw.Draw(img) + left, top, right, bottom = new_font.getbbox(text) + text_width = max(right - left, 5) + text_height = max(bottom - top, 5) + ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) + new_font = font.font_variant(size=int(g_size * ratio)) + + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top + x = (img.width - text_width) // 2 + y = (img.height - text_height) // 2 - top // 2 + draw.text((x, y), text, font=new_font, fill="white") + img = np.expand_dims(np.array(img), axis=2).astype(np.float64) + return img + + def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): + enlarge_polygon = polygon * scale + rect = cv2.minAreaRect(enlarge_polygon) + box = cv2.boxPoints(rect) + box = np.int0(box) + w, h = rect[1] + angle = rect[2] + if angle < -45: + angle += 90 + angle = -angle + if w < h: + angle += 90 + + vert = False + if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: + _w = max(box[:, 0]) - min(box[:, 0]) + _h = max(box[:, 1]) - min(box[:, 1]) + if _h >= _w: + vert = True + angle = 0 + + img = np.zeros((height * scale, width * scale, 3), np.uint8) + img = Image.fromarray(img) + + # infer font size + image4ratio = Image.new("RGB", img.size, "white") + draw = ImageDraw.Draw(image4ratio) + _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) + text_w = min(w, h) * (_tw / _th) + if text_w <= max(w, h): + # add space + if len(text) > 1 and not vert and add_space: + for i in range(1, 100): + text_space = self.insert_spaces(text, i) + _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) + if min(w, h) * (_tw2 / _th2) > max(w, h): + break + text = self.insert_spaces(text, i - 1) + font_size = min(w, h) * 0.80 + else: + shrink = 0.75 if vert else 0.85 + font_size = min(w, h) / (text_w / max(w, h)) * shrink + new_font = font.font_variant(size=int(font_size)) + + left, top, right, bottom = new_font.getbbox(text) + text_width = right - left + text_height = bottom - top + + layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(layer) + if not vert: + draw.text( + (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), + text, + font=new_font, + fill=(255, 255, 255, 255), + ) + else: + x_s = min(box[:, 0]) + _w // 2 - text_height // 2 + y_s = min(box[:, 1]) + for c in text: + draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) + _, _t, _, _b = new_font.getbbox(c) + y_s += _b + + rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) + + x_offset = int((img.width - rotated_layer.width) / 2) + y_offset = int((img.height - rotated_layer.height) / 2) + img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) + img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) + return img + + def insert_spaces(self, string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class AuxiliaryLatentModule(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + vae, + device="cpu", + ): + super().__init__() + + @torch.no_grad() + def forward( + self, + text_info, + mode, + draw_pos, + ori_image, + num_images_per_prompt, + np_hint, + h=512, + w=512, + ): + if mode == "generate": + edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image + elif mode == "edit": + if draw_pos is None or ori_image is None: + raise ValueError("Reference image and position image are needed for text editing!") + if isinstance(ori_image, str): + ori_image = cv2.imread(ori_image)[..., ::-1] + if ori_image is None: + raise ValueError(f"Can't read ori_image image from {ori_image}!") + elif isinstance(ori_image, torch.Tensor): + ori_image = ori_image.cpu().numpy() + else: + if not isinstance(ori_image, np.ndarray): + raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") + edit_image = ori_image.clip(1, 255) # for mask reason + edit_image = self.check_channels(edit_image) + edit_image = self.resize_image( + edit_image, max_length=768 + ) # make w h multiple of 64, resize if w or h > max_length + + # get masked_x + masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) + masked_img = np.transpose(masked_img, (2, 0, 1)) + device = next(self.config.vae.parameters()).device + dtype = next(self.config.vae.parameters()).dtype + masked_img = torch.from_numpy(masked_img.copy()).float().to(device) + if dtype == torch.float16: + masked_img = masked_img.half() + masked_x = ( + retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor + ).detach() + if dtype == torch.float16: + masked_x = masked_x.half() + text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) + + glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) + positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) + + return glyphs, positions, text_info + + def check_channels(self, image): + channels = image.shape[2] if len(image.shape) == 3 else 1 + if channels == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + elif channels > 3: + image = image[:, :, :3] + return image + + def resize_image(self, img, max_length=768): + height, width = img.shape[:2] + max_dimension = max(height, width) + + if max_dimension > max_length: + scale_factor = max_length / max_dimension + new_width = int(round(width * scale_factor)) + new_height = int(round(height * scale_factor)) + new_size = (new_width, new_height) + img = cv2.resize(img, new_size) + height, width = img.shape[:2] + img = cv2.resize(img, (width - (width % 64), height - (height % 64))) + return img + + def insert_spaces(self, string, nSpace): + if nSpace == 0: + return string + new_string = "" + for char in string: + new_string += char + " " * nSpace + return new_string[:-nSpace] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnyTextPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + font_path: str = None, + text_embedding_module: Optional[TextEmbeddingModule] = None, + auxiliary_latent_module: Optional[AuxiliaryLatentModule] = None, + trust_remote_code: bool = False, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + if font_path is None: + raise ValueError("font_path is required!") + + text_embedding_module = TextEmbeddingModule(font_path=font_path, use_fp16=unet.dtype == torch.float16) + auxiliary_latent_module = AuxiliaryLatentModule(vae=vae) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + text_embedding_module=text_embedding_module, + auxiliary_latent_module=auxiliary_latent_module, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def modify_prompt(self, prompt): + prompt = prompt.replace("โ€œ", '"') + prompt = prompt.replace("โ€", '"') + p = '"(.*?)"' + strs = re.findall(p, prompt) + if len(strs) == 0: + strs = [" "] + else: + for s in strs: + prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1) + if self.is_chinese(prompt): + if self.trans_pipe is None: + return None, None + old_prompt = prompt + prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1] + print(f"Translate: {old_prompt} --> {prompt}") + return prompt, strs + + def is_chinese(self, text): + text = checker._clean_text(text) + for char in text: + cp = ord(char) + if checker._is_chinese_char(cp): + return True + return False + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + # image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + print(controlnet_conditioning_scale) + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + mode: Optional[str] = "generate", + draw_pos: Optional[Union[str, torch.Tensor]] = None, + ori_image: Optional[Union[str, torch.Tensor]] = None, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single + ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple + ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (ฮท) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + # image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + prompt, texts = self.modify_prompt(prompt) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos + prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module( + prompt, + texts, + negative_prompt, + num_images_per_prompt, + mode, + draw_pos, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 3.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + guided_hint = self.auxiliary_latent_module( + text_info=text_info, + mode=mode, + draw_pos=draw_pos, + ori_image=ori_image, + num_images_per_prompt=num_images_per_prompt, + np_hint=np_hint, + ) + height, width = 512, 512 + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input.to(self.controlnet.dtype), + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=guided_hint, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.text_embedding_module.to(*args, **kwargs) + self.auxiliary_latent_module.to(*args, **kwargs) + return self diff --git a/examples/research_projects/anytext/anytext_controlnet.py b/examples/research_projects/anytext/anytext_controlnet.py new file mode 100644 index 000000000000..5965ceed1370 --- /dev/null +++ b/examples/research_projects/anytext/anytext_controlnet.py @@ -0,0 +1,463 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054). +# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie +# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license +# +# Adapted to Diffusers by [M. Tolga Cangรถz](https://github.com/tolgacangoz). + + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn + +from diffusers.configuration_utils import register_to_config +from diffusers.models.controlnets.controlnet import ( + ControlNetModel, + ControlNetOutput, +) +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class AnyTextControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 ร— 512 images into smaller 64 ร— 64 โ€œlatent imagesโ€ for stabilized + training. This requires ControlNets to convert image-based conditions to 64 ร— 64 feature space to match the + convolution size. We use a tiny network E(ยท) of four convolution layers with 4 ร— 4 kernels and 2 ร— 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + glyph_channels=1, + position_channels=1, + ): + super().__init__() + + self.glyph_block = nn.Sequential( + nn.Conv2d(glyph_channels, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 96, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(96, 96, 3, padding=1), + nn.SiLU(), + nn.Conv2d(96, 256, 3, padding=1, stride=2), + nn.SiLU(), + ) + + self.position_block = nn.Sequential( + nn.Conv2d(position_channels, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 8, 3, padding=1), + nn.SiLU(), + nn.Conv2d(8, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 64, 3, padding=1, stride=2), + nn.SiLU(), + ) + + self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1) + + def forward(self, glyphs, positions, text_info): + glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device)) + position_embedding = self.position_block(positions.to(self.position_block[0].weight.device)) + guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1)) + + return guided_hint + + +class AnyTextControlNetModel(ControlNetModel): + """ + A AnyTextControlNetModel model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 1, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads: int = 64, + ): + super().__init__( + in_channels, + conditioning_channels, + flip_sin_to_cos, + freq_shift, + down_block_types, + mid_block_type, + only_cross_attention, + block_out_channels, + layers_per_block, + downsample_padding, + mid_block_scale_factor, + act_fn, + norm_num_groups, + norm_eps, + cross_attention_dim, + transformer_layers_per_block, + encoder_hid_dim, + encoder_hid_dim_type, + attention_head_dim, + num_attention_heads, + use_linear_projection, + class_embed_type, + addition_embed_type, + addition_time_embed_dim, + num_class_embeds, + upcast_attention, + resnet_time_scale_shift, + projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order, + conditioning_embedding_out_channels, + global_pool_conditions, + addition_embed_type_num_heads, + ) + + # control net conditioning embedding + self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + glyph_channels=conditioning_channels, + position_channels=conditioning_channels, + ) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + """ + The [`~PromptDiffusionControlNetModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + #controlnet_cond (`torch.Tensor`): + # The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + # elif channel_order == "bgr": + # controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if self.config.addition_embed_type is not None: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond) + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 5. Control net blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +# Copied from diffusers.models.controlnet.zero_module +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/examples/research_projects/anytext/ocr_recog/RNN.py b/examples/research_projects/anytext/ocr_recog/RNN.py new file mode 100755 index 000000000000..aec796d987c0 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/RNN.py @@ -0,0 +1,209 @@ +import torch +from torch import nn + +from .RecSVTR import Block + + +class Swish(nn.Module): + def __int__(self): + super(Swish, self).__int__() + + def forward(self, x): + return x * torch.sigmoid(x) + + +class Im2Im(nn.Module): + def __init__(self, in_channels, **kwargs): + super().__init__() + self.out_channels = in_channels + + def forward(self, x): + return x + + +class Im2Seq(nn.Module): + def __init__(self, in_channels, **kwargs): + super().__init__() + self.out_channels = in_channels + + def forward(self, x): + B, C, H, W = x.shape + # assert H == 1 + x = x.reshape(B, C, H * W) + x = x.permute((0, 2, 1)) + return x + + +class EncoderWithRNN(nn.Module): + def __init__(self, in_channels, **kwargs): + super(EncoderWithRNN, self).__init__() + hidden_size = kwargs.get("hidden_size", 256) + self.out_channels = hidden_size * 2 + self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True) + + def forward(self, x): + self.lstm.flatten_parameters() + x, _ = self.lstm(x) + return x + + +class SequenceEncoder(nn.Module): + def __init__(self, in_channels, encoder_type="rnn", **kwargs): + super(SequenceEncoder, self).__init__() + self.encoder_reshape = Im2Seq(in_channels) + self.out_channels = self.encoder_reshape.out_channels + self.encoder_type = encoder_type + if encoder_type == "reshape": + self.only_reshape = True + else: + support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR} + assert encoder_type in support_encoder_dict, "{} must in {}".format( + encoder_type, support_encoder_dict.keys() + ) + + self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs) + self.out_channels = self.encoder.out_channels + self.only_reshape = False + + def forward(self, x): + if self.encoder_type != "svtr": + x = self.encoder_reshape(x) + if not self.only_reshape: + x = self.encoder(x) + return x + else: + x = self.encoder(x) + x = self.encoder_reshape(x) + return x + + +class ConvBNLayer(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), + bias=bias_attr, + ) + self.norm = nn.BatchNorm2d(out_channels) + self.act = Swish() + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + out = self.act(out) + return out + + +class EncoderWithSVTR(nn.Module): + def __init__( + self, + in_channels, + dims=64, # XS + depth=2, + hidden_dims=120, + use_guide=False, + num_heads=8, + qkv_bias=True, + mlp_ratio=2.0, + drop_rate=0.1, + attn_drop_rate=0.1, + drop_path=0.0, + qk_scale=None, + ): + super(EncoderWithSVTR, self).__init__() + self.depth = depth + self.use_guide = use_guide + self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish") + self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish") + + self.svtr_block = nn.ModuleList( + [ + Block( + dim=hidden_dims, + num_heads=num_heads, + mixer="Global", + HW=None, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer="swish", + attn_drop=attn_drop_rate, + drop_path=drop_path, + norm_layer="nn.LayerNorm", + epsilon=1e-05, + prenorm=False, + ) + for i in range(depth) + ] + ) + self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) + self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish") + # last conv-nxn, the input is concat of input tensor and conv3 output tensor + self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish") + + self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish") + self.out_channels = dims + self.apply(self._init_weights) + + def _init_weights(self, m): + # weight initialization + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + # for use guide + if self.use_guide: + z = x.clone() + z.stop_gradient = True + else: + z = x + # for short cut + h = z + # reduce dim + z = self.conv1(z) + z = self.conv2(z) + # SVTR global block + B, C, H, W = z.shape + z = z.flatten(2).permute(0, 2, 1) + + for blk in self.svtr_block: + z = blk(z) + + z = self.norm(z) + # last stage + z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2) + z = self.conv3(z) + z = torch.cat((h, z), dim=1) + z = self.conv1x1(self.conv4(z)) + + return z + + +if __name__ == "__main__": + svtrRNN = EncoderWithSVTR(56) + print(svtrRNN) diff --git a/examples/research_projects/anytext/ocr_recog/RecCTCHead.py b/examples/research_projects/anytext/ocr_recog/RecCTCHead.py new file mode 100755 index 000000000000..c066c6202b19 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/RecCTCHead.py @@ -0,0 +1,45 @@ +from torch import nn + + +class CTCHead(nn.Module): + def __init__( + self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs + ): + super(CTCHead, self).__init__() + if mid_channels is None: + self.fc = nn.Linear( + in_channels, + out_channels, + bias=True, + ) + else: + self.fc1 = nn.Linear( + in_channels, + mid_channels, + bias=True, + ) + self.fc2 = nn.Linear( + mid_channels, + out_channels, + bias=True, + ) + + self.out_channels = out_channels + self.mid_channels = mid_channels + self.return_feats = return_feats + + def forward(self, x, labels=None): + if self.mid_channels is None: + predicts = self.fc(x) + else: + x = self.fc1(x) + predicts = self.fc2(x) + + if self.return_feats: + result = {} + result["ctc"] = predicts + result["ctc_neck"] = x + else: + result = predicts + + return result diff --git a/examples/research_projects/anytext/ocr_recog/RecModel.py b/examples/research_projects/anytext/ocr_recog/RecModel.py new file mode 100755 index 000000000000..872ccade69e0 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/RecModel.py @@ -0,0 +1,49 @@ +from torch import nn + +from .RecCTCHead import CTCHead +from .RecMv1_enhance import MobileNetV1Enhance +from .RNN import Im2Im, Im2Seq, SequenceEncoder + + +backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance} +neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im} +head_dict = {"CTCHead": CTCHead} + + +class RecModel(nn.Module): + def __init__(self, config): + super().__init__() + assert "in_channels" in config, "in_channels must in model config" + backbone_type = config["backbone"].pop("type") + assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}" + self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"]) + + neck_type = config["neck"].pop("type") + assert neck_type in neck_dict, f"neck.type must in {neck_dict}" + self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"]) + + head_type = config["head"].pop("type") + assert head_type in head_dict, f"head.type must in {head_dict}" + self.head = head_dict[head_type](self.neck.out_channels, **config["head"]) + + self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}" + + def load_3rd_state_dict(self, _3rd_name, _state): + self.backbone.load_3rd_state_dict(_3rd_name, _state) + self.neck.load_3rd_state_dict(_3rd_name, _state) + self.head.load_3rd_state_dict(_3rd_name, _state) + + def forward(self, x): + import torch + + x = x.to(torch.float32) + x = self.backbone(x) + x = self.neck(x) + x = self.head(x) + return x + + def encode(self, x): + x = self.backbone(x) + x = self.neck(x) + x = self.head.ctc_encoder(x) + return x diff --git a/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py b/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py new file mode 100644 index 000000000000..df41519b2713 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py @@ -0,0 +1,197 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .common import Activation + + +class ConvBNLayer(nn.Module): + def __init__( + self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act="hard_swish" + ): + super(ConvBNLayer, self).__init__() + self.act = act + self._conv = nn.Conv2d( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + bias=False, + ) + + self._batch_norm = nn.BatchNorm2d( + num_filters, + ) + if self.act is not None: + self._act = Activation(act_type=act, inplace=True) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + if self.act is not None: + y = self._act(y) + return y + + +class DepthwiseSeparable(nn.Module): + def __init__( + self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False + ): + super(DepthwiseSeparable, self).__init__() + self.use_se = use_se + self._depthwise_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=int(num_filters1 * scale), + filter_size=dw_size, + stride=stride, + padding=padding, + num_groups=int(num_groups * scale), + ) + if use_se: + self._se = SEModule(int(num_filters1 * scale)) + self._pointwise_conv = ConvBNLayer( + num_channels=int(num_filters1 * scale), + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0, + ) + + def forward(self, inputs): + y = self._depthwise_conv(inputs) + if self.use_se: + y = self._se(y) + y = self._pointwise_conv(y) + return y + + +class MobileNetV1Enhance(nn.Module): + def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type="max", **kwargs): + super().__init__() + self.scale = scale + self.block_list = [] + + self.conv1 = ConvBNLayer( + num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1 + ) + + conv2_1 = DepthwiseSeparable( + num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale + ) + self.block_list.append(conv2_1) + + conv2_2 = DepthwiseSeparable( + num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale + ) + self.block_list.append(conv2_2) + + conv3_1 = DepthwiseSeparable( + num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale + ) + self.block_list.append(conv3_1) + + conv3_2 = DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=(2, 1), + scale=scale, + ) + self.block_list.append(conv3_2) + + conv4_1 = DepthwiseSeparable( + num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale + ) + self.block_list.append(conv4_1) + + conv4_2 = DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=(2, 1), + scale=scale, + ) + self.block_list.append(conv4_2) + + for _ in range(5): + conv5 = DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + dw_size=5, + padding=2, + scale=scale, + use_se=False, + ) + self.block_list.append(conv5) + + conv5_6 = DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=(2, 1), + dw_size=5, + padding=2, + scale=scale, + use_se=True, + ) + self.block_list.append(conv5_6) + + conv6 = DepthwiseSeparable( + num_channels=int(1024 * scale), + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=last_conv_stride, + dw_size=5, + padding=2, + use_se=True, + scale=scale, + ) + self.block_list.append(conv6) + + self.block_list = nn.Sequential(*self.block_list) + if last_pool_type == "avg": + self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.out_channels = int(1024 * scale) + + def forward(self, inputs): + y = self.conv1(inputs) + y = self.block_list(y) + y = self.pool(y) + return y + + +def hardsigmoid(x): + return F.relu6(x + 3.0, inplace=True) / 6.0 + + +class SEModule(nn.Module): + def __init__(self, channel, reduction=4): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv2d( + in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True + ) + self.conv2 = nn.Conv2d( + in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True + ) + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = F.relu(outputs) + outputs = self.conv2(outputs) + outputs = hardsigmoid(outputs) + x = torch.mul(inputs, outputs) + + return x diff --git a/examples/research_projects/anytext/ocr_recog/RecSVTR.py b/examples/research_projects/anytext/ocr_recog/RecSVTR.py new file mode 100644 index 000000000000..590a96995b26 --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/RecSVTR.py @@ -0,0 +1,570 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional +from torch.nn.init import ones_, trunc_normal_, zeros_ + + +def drop_path(x, drop_prob=0.0, training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = torch.tensor(1 - drop_prob) + shape = (x.size()[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype) + random_tensor = torch.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class Swish(nn.Module): + def __int__(self): + super(Swish, self).__int__() + + def forward(self, x): + return x * torch.sigmoid(x) + + +class ConvBNLayer(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), + bias=bias_attr, + ) + self.norm = nn.BatchNorm2d(out_channels) + self.act = act() + + def forward(self, inputs): + out = self.conv(inputs) + out = self.norm(out) + out = self.act(out) + return out + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Identity(nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + if isinstance(act_layer, str): + self.act = Swish() + else: + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvMixer(nn.Module): + def __init__( + self, + dim, + num_heads=8, + HW=(8, 25), + local_k=(3, 3), + ): + super().__init__() + self.HW = HW + self.dim = dim + self.local_mixer = nn.Conv2d( + dim, + dim, + local_k, + 1, + (local_k[0] // 2, local_k[1] // 2), + groups=num_heads, + # weight_attr=ParamAttr(initializer=KaimingNormal()) + ) + + def forward(self, x): + h = self.HW[0] + w = self.HW[1] + x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w]) + x = self.local_mixer(x) + x = x.flatten(2).transpose([0, 2, 1]) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + mixer="Global", + HW=(8, 25), + local_k=(7, 11), + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.HW = HW + if HW is not None: + H = HW[0] + W = HW[1] + self.N = H * W + self.C = dim + if mixer == "Local" and HW is not None: + hk = local_k[0] + wk = local_k[1] + mask = torch.ones([H * W, H + hk - 1, W + wk - 1]) + for h in range(0, H): + for w in range(0, W): + mask[h * W + w, h : h + hk, w : w + wk] = 0.0 + mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1) + mask_inf = torch.full([H * W, H * W], fill_value=float("-inf")) + mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf) + self.mask = mask[None, None, :] + # self.mask = mask.unsqueeze([0, 1]) + self.mixer = mixer + + def forward(self, x): + if self.HW is not None: + N = self.N + C = self.C + else: + _, N, C = x.shape + qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4)) + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + + attn = q.matmul(k.permute((0, 1, 3, 2))) + if self.mixer == "Local": + attn += self.mask + attn = functional.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mixer="Global", + local_mixer=(7, 11), + HW=(8, 25), + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer="nn.LayerNorm", + epsilon=1e-6, + prenorm=True, + ): + super().__init__() + if isinstance(norm_layer, str): + self.norm1 = eval(norm_layer)(dim, eps=epsilon) + else: + self.norm1 = norm_layer(dim) + if mixer == "Global" or mixer == "Local": + self.mixer = Attention( + dim, + num_heads=num_heads, + mixer=mixer, + HW=HW, + local_k=local_mixer, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + elif mixer == "Conv": + self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer) + else: + raise TypeError("The mixer must be one of [Global, Local, Conv]") + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() + if isinstance(norm_layer, str): + self.norm2 = eval(norm_layer)(dim, eps=epsilon) + else: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_ratio = mlp_ratio + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.prenorm = prenorm + + def forward(self, x): + if self.prenorm: + x = self.norm1(x + self.drop_path(self.mixer(x))) + x = self.norm2(x + self.drop_path(self.mlp(x))) + else: + x = x + self.drop_path(self.mixer(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2): + super().__init__() + num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num)) + self.img_size = img_size + self.num_patches = num_patches + self.embed_dim = embed_dim + self.norm = None + if sub_num == 2: + self.proj = nn.Sequential( + ConvBNLayer( + in_channels=in_channels, + out_channels=embed_dim // 2, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False, + ), + ConvBNLayer( + in_channels=embed_dim // 2, + out_channels=embed_dim, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False, + ), + ) + if sub_num == 3: + self.proj = nn.Sequential( + ConvBNLayer( + in_channels=in_channels, + out_channels=embed_dim // 4, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False, + ), + ConvBNLayer( + in_channels=embed_dim // 4, + out_channels=embed_dim // 2, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False, + ), + ConvBNLayer( + in_channels=embed_dim // 2, + out_channels=embed_dim, + kernel_size=3, + stride=2, + padding=1, + act=nn.GELU, + bias_attr=False, + ), + ) + + def forward(self, x): + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).permute(0, 2, 1) + return x + + +class SubSample(nn.Module): + def __init__(self, in_channels, out_channels, types="Pool", stride=(2, 1), sub_norm="nn.LayerNorm", act=None): + super().__init__() + self.types = types + if types == "Pool": + self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2)) + self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2)) + self.proj = nn.Linear(in_channels, out_channels) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + # weight_attr=ParamAttr(initializer=KaimingNormal()) + ) + self.norm = eval(sub_norm)(out_channels) + if act is not None: + self.act = act() + else: + self.act = None + + def forward(self, x): + if self.types == "Pool": + x1 = self.avgpool(x) + x2 = self.maxpool(x) + x = (x1 + x2) * 0.5 + out = self.proj(x.flatten(2).permute((0, 2, 1))) + else: + x = self.conv(x) + out = x.flatten(2).permute((0, 2, 1)) + out = self.norm(out) + if self.act is not None: + out = self.act(out) + + return out + + +class SVTRNet(nn.Module): + def __init__( + self, + img_size=[48, 100], + in_channels=3, + embed_dim=[64, 128, 256], + depth=[3, 6, 3], + num_heads=[2, 4, 8], + mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv + local_mixer=[[7, 11], [7, 11], [7, 11]], + patch_merging="Conv", # Conv, Pool, None + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + last_drop=0.1, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer="nn.LayerNorm", + sub_norm="nn.LayerNorm", + epsilon=1e-6, + out_channels=192, + out_char_num=25, + block_unit="Block", + act="nn.GELU", + last_stage=True, + sub_num=2, + prenorm=True, + use_lenhead=False, + **kwargs, + ): + super().__init__() + self.img_size = img_size + self.embed_dim = embed_dim + self.out_channels = out_channels + self.prenorm = prenorm + patch_merging = None if patch_merging != "Conv" and patch_merging != "Pool" else patch_merging + self.patch_embed = PatchEmbed( + img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num + ) + num_patches = self.patch_embed.num_patches + self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)] + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0])) + # self.pos_embed = self.create_parameter( + # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_) + + # self.add_parameter("pos_embed", self.pos_embed) + + self.pos_drop = nn.Dropout(p=drop_rate) + Block_unit = eval(block_unit) + + dpr = np.linspace(0, drop_path_rate, sum(depth)) + self.blocks1 = nn.ModuleList( + [ + Block_unit( + dim=embed_dim[0], + num_heads=num_heads[0], + mixer=mixer[0 : depth[0]][i], + HW=self.HW, + local_mixer=local_mixer[0], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[0 : depth[0]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm, + ) + for i in range(depth[0]) + ] + ) + if patch_merging is not None: + self.sub_sample1 = SubSample( + embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging + ) + HW = [self.HW[0] // 2, self.HW[1]] + else: + HW = self.HW + self.patch_merging = patch_merging + self.blocks2 = nn.ModuleList( + [ + Block_unit( + dim=embed_dim[1], + num_heads=num_heads[1], + mixer=mixer[depth[0] : depth[0] + depth[1]][i], + HW=HW, + local_mixer=local_mixer[1], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0] : depth[0] + depth[1]][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm, + ) + for i in range(depth[1]) + ] + ) + if patch_merging is not None: + self.sub_sample2 = SubSample( + embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging + ) + HW = [self.HW[0] // 4, self.HW[1]] + else: + HW = self.HW + self.blocks3 = nn.ModuleList( + [ + Block_unit( + dim=embed_dim[2], + num_heads=num_heads[2], + mixer=mixer[depth[0] + depth[1] :][i], + HW=HW, + local_mixer=local_mixer[2], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0] + depth[1] :][i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm, + ) + for i in range(depth[2]) + ] + ) + self.last_stage = last_stage + if last_stage: + self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num)) + self.last_conv = nn.Conv2d( + in_channels=embed_dim[2], + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.hardswish = nn.Hardswish() + self.dropout = nn.Dropout(p=last_drop) + if not prenorm: + self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon) + self.use_lenhead = use_lenhead + if use_lenhead: + self.len_conv = nn.Linear(embed_dim[2], self.out_channels) + self.hardswish_len = nn.Hardswish() + self.dropout_len = nn.Dropout(p=last_drop) + + trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward_features(self, x): + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + for blk in self.blocks1: + x = blk(x) + if self.patch_merging is not None: + x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]])) + for blk in self.blocks2: + x = blk(x) + if self.patch_merging is not None: + x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]])) + for blk in self.blocks3: + x = blk(x) + if not self.prenorm: + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.use_lenhead: + len_x = self.len_conv(x.mean(1)) + len_x = self.dropout_len(self.hardswish_len(len_x)) + if self.last_stage: + if self.patch_merging is not None: + h = self.HW[0] // 4 + else: + h = self.HW[0] + x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]])) + x = self.last_conv(x) + x = self.hardswish(x) + x = self.dropout(x) + if self.use_lenhead: + return x, len_x + return x + + +if __name__ == "__main__": + a = torch.rand(1, 3, 48, 100) + svtr = SVTRNet() + + out = svtr(a) + print(svtr) + print(out.size()) diff --git a/examples/research_projects/anytext/ocr_recog/common.py b/examples/research_projects/anytext/ocr_recog/common.py new file mode 100644 index 000000000000..207a95b17d0e --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/common.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Hswish(nn.Module): + def __init__(self, inplace=True): + super(Hswish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 + + +# out = max(0, min(1, slop*x+offset)) +# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None) +class Hsigmoid(nn.Module): + def __init__(self, inplace=True): + super(Hsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + # torch: F.relu6(x + 3., inplace=self.inplace) / 6. + # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. + return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0 + + +class GELU(nn.Module): + def __init__(self, inplace=True): + super(GELU, self).__init__() + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.gelu(x) + + +class Swish(nn.Module): + def __init__(self, inplace=True): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + if self.inplace: + x.mul_(torch.sigmoid(x)) + return x + else: + return x * torch.sigmoid(x) + + +class Activation(nn.Module): + def __init__(self, act_type, inplace=True): + super(Activation, self).__init__() + act_type = act_type.lower() + if act_type == "relu": + self.act = nn.ReLU(inplace=inplace) + elif act_type == "relu6": + self.act = nn.ReLU6(inplace=inplace) + elif act_type == "sigmoid": + raise NotImplementedError + elif act_type == "hard_sigmoid": + self.act = Hsigmoid(inplace) + elif act_type == "hard_swish": + self.act = Hswish(inplace=inplace) + elif act_type == "leakyrelu": + self.act = nn.LeakyReLU(inplace=inplace) + elif act_type == "gelu": + self.act = GELU(inplace=inplace) + elif act_type == "swish": + self.act = Swish(inplace=inplace) + else: + raise NotImplementedError + + def forward(self, inputs): + return self.act(inputs) diff --git a/examples/research_projects/anytext/ocr_recog/en_dict.txt b/examples/research_projects/anytext/ocr_recog/en_dict.txt new file mode 100644 index 000000000000..7677d31b9d3f --- /dev/null +++ b/examples/research_projects/anytext/ocr_recog/en_dict.txt @@ -0,0 +1,95 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; +< += +> +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +[ +\ +] +^ +_ +` +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +{ +| +} +~ +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ + diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py index 765bb495062e..829b0031156e 100644 --- a/examples/research_projects/controlnet/train_controlnet_webdataset.py +++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py @@ -381,9 +381,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images = [] - - formatted_images.append(np.asarray(validation_image)) + formatted_images = [np.asarray(validation_image)] for image in images: formatted_images.append(np.asarray(image)) diff --git a/examples/research_projects/pixart/train_pixart_controlnet_hf.py b/examples/research_projects/pixart/train_pixart_controlnet_hf.py index 995a20dfa28e..67ec30da0ece 100644 --- a/examples/research_projects/pixart/train_pixart_controlnet_hf.py +++ b/examples/research_projects/pixart/train_pixart_controlnet_hf.py @@ -164,9 +164,7 @@ def log_validation( validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images = [] - - formatted_images.append(np.asarray(validation_image)) + formatted_images = [np.asarray(validation_image)] for image in images: formatted_images.append(np.asarray(image)) diff --git a/examples/research_projects/pytorch_xla/inference/flux/README.md b/examples/research_projects/pytorch_xla/inference/flux/README.md index dd7e23c57049..9d482e6805a3 100644 --- a/examples/research_projects/pytorch_xla/inference/flux/README.md +++ b/examples/research_projects/pytorch_xla/inference/flux/README.md @@ -1,8 +1,6 @@ # Generating images using Flux and PyTorch/XLA -The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation. - -It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested. +The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested. ## Create TPU @@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly: python3 -c "import torch; import torch_xla;" ``` -Install dependencies +Clone the diffusers repo and install dependencies ```bash +git clone https://github.com/huggingface/diffusers.git +cd diffusers pip install transformers accelerate sentencepiece structlog -pushd ../../.. pip install . -popd +cd examples/research_projects/pytorch_xla/inference/flux/ ``` ## Run the inference job ### Authenticate -Run the following command to authenticate your token in order to download Flux weights. +**Gated Model** + +As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youโ€™ve accepted the gate. Use the command below to log in: ```bash huggingface-cli login @@ -50,51 +51,116 @@ python flux_inference.py The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest. -On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel): +On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel): ```bash WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. -Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2/2 [00:00<00:00, 7.01it/s] -Loading pipeline components...: 40%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ– | 2/5 [00:00<00:00, 3.78it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers -Loading pipeline components...: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 5/5 [00:00<00:00, 6.72it/s] -2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev -2025-01-10 00:51:25 [info ] loading flux from black-forest-labs/FLUX.1-dev -2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev -2025-01-10 00:51:26 [info ] loading flux from black-forest-labs/FLUX.1-dev -Loading pipeline components...: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 3/3 [00:00<00:00, 4.29it/s] -Loading pipeline components...: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 3/3 [00:00<00:00, 3.26it/s] -Loading pipeline components...: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 3/3 [00:00<00:00, 3.27it/s] -Loading pipeline components...: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 3/3 [00:00<00:00, 3.25it/s] -2025-01-10 00:51:34 [info ] starting compilation run... -2025-01-10 00:51:35 [info ] starting compilation run... -2025-01-10 00:51:37 [info ] starting compilation run... -2025-01-10 00:51:37 [info ] starting compilation run... -2025-01-10 00:52:52 [info ] compilation took 78.5155531649998 sec. -2025-01-10 00:52:53 [info ] starting inference run... -2025-01-10 00:52:57 [info ] compilation took 79.52986721400157 sec. -2025-01-10 00:52:57 [info ] compilation took 81.91776501700042 sec. -2025-01-10 00:52:57 [info ] compilation took 80.24951512600092 sec. -2025-01-10 00:52:57 [info ] starting inference run... -2025-01-10 00:52:57 [info ] starting inference run... -2025-01-10 00:52:58 [info ] starting inference run... -2025-01-10 00:53:22 [info ] inference time: 25.112665320000815 -2025-01-10 00:53:30 [info ] inference time: 7.7019307739992655 -2025-01-10 00:53:38 [info ] inference time: 7.693858365000779 -2025-01-10 00:53:46 [info ] inference time: 7.690621814001133 -2025-01-10 00:53:53 [info ] inference time: 7.679490454000188 -2025-01-10 00:54:01 [info ] inference time: 7.68949568500102 -2025-01-10 00:54:09 [info ] inference time: 7.686633744000574 -2025-01-10 00:54:16 [info ] inference time: 7.696786873999372 -2025-01-10 00:54:24 [info ] inference time: 7.691988694999964 -2025-01-10 00:54:32 [info ] inference time: 7.700649563999832 -2025-01-10 00:54:39 [info ] inference time: 7.684993574001055 -2025-01-10 00:54:47 [info ] inference time: 7.68343457499941 -2025-01-10 00:54:55 [info ] inference time: 7.667921153999487 -2025-01-10 00:55:02 [info ] inference time: 7.683585194001353 -2025-01-10 00:55:06 [info ] avg. inference over 15 iterations took 8.61202360273334 sec. -2025-01-10 00:55:07 [info ] avg. inference over 15 iterations took 8.952725123600006 sec. -2025-01-10 00:55:10 [info ] inference time: 7.673799695001435 -2025-01-10 00:55:10 [info ] avg. inference over 15 iterations took 8.849190365400379 sec. -2025-01-10 00:55:10 [info ] saved metric information as /tmp/metrics_report.txt -2025-01-10 00:55:12 [info ] avg. inference over 15 iterations took 8.940161458400205 sec. +Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2/2 [00:00<00:00, 7.06it/s] +Loading pipeline components...: 60%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ | 3/5 [00:00<00:00, 6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers +Loading pipeline components...: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 5/5 [00:00<00:00, 6.28it/s] +2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev +2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev +Loading pipeline components...: 0%| | 0/3 [00:00 Dict[str, Any]: state_dict = saved_dict @@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: def convert_transformer( ckpt_path: str, dtype: torch.dtype, + version: str = "0.9.0", ): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(load_file(ckpt_path)) + config = {} + if version == "0.9.5": + config["_use_causal_rope_fix"] = True with init_empty_weights(): - transformer = LTXVideoTransformer3DModel() + transformer = LTXVideoTransformer3DModel(**config) for key in list(original_state_dict.keys()): new_key = key[:] @@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]: "out_channels": 3, "latent_channels": 128, "block_out_channels": (128, 256, 512, 512), + "down_block_types": ( + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + ), "decoder_block_out_channels": (128, 256, 512, 512), "layers_per_block": (4, 3, 3, 3, 4), "decoder_layers_per_block": (4, 3, 3, 3, 4), "spatio_temporal_scaling": (True, True, True, False), "decoder_spatio_temporal_scaling": (True, True, True, False), "decoder_inject_noise": (False, False, False, False, False), + "downsample_type": ("conv", "conv", "conv", "conv"), "upsample_residual": (False, False, False, False), "upsample_factor": (1, 1, 1, 1), "patch_size": 4, @@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]: "out_channels": 3, "latent_channels": 128, "block_out_channels": (128, 256, 512, 512), + "down_block_types": ( + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + ), "decoder_block_out_channels": (256, 512, 1024), "layers_per_block": (4, 3, 3, 3, 4), "decoder_layers_per_block": (5, 6, 7, 8), "spatio_temporal_scaling": (True, True, True, False), "decoder_spatio_temporal_scaling": (True, True, True), "decoder_inject_noise": (True, True, True, False), + "downsample_type": ("conv", "conv", "conv", "conv"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), "timestep_conditioning": True, @@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]: "decoder_causal": False, } VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) - VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP) + elif version == "0.9.5": + config = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (128, 256, 512, 1024, 2048), + "down_block_types": ( + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + "LTXVideo095DownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": True, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "scaling_factor": 1.0, + "encoder_causal": True, + "decoder_causal": False, + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + } + VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) return config @@ -223,7 +294,7 @@ def get_args(): parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") parser.add_argument( - "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model" + "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model" ) return parser.parse_args() @@ -277,14 +348,17 @@ def get_args(): for param in text_encoder.parameters(): param.data = param.data.contiguous() - scheduler = FlowMatchEulerDiscreteScheduler( - use_dynamic_shifting=True, - base_shift=0.95, - max_shift=2.05, - base_image_seq_len=1024, - max_image_seq_len=4096, - shift_terminal=0.1, - ) + if args.version == "0.9.5": + scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False) + else: + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, + ) pipe = LTXPipeline( scheduler=scheduler, diff --git a/scripts/convert_lumina_to_diffusers.py b/scripts/convert_lumina_to_diffusers.py index a12625d1376f..c14aad3c6bf2 100644 --- a/scripts/convert_lumina_to_diffusers.py +++ b/scripts/convert_lumina_to_diffusers.py @@ -5,7 +5,7 @@ from safetensors.torch import load_file from transformers import AutoModel, AutoTokenizer -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline def main(args): @@ -115,7 +115,7 @@ def main(args): tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") text_encoder = AutoModel.from_pretrained("google/gemma-2b") - pipeline = LuminaText2ImgPipeline( + pipeline = LuminaPipeline( tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler ) pipeline.save_pretrained(args.dump_path) diff --git a/setup.py b/setup.py index 93945ae040dd..fdc166a81ecf 100644 --- a/setup.py +++ b/setup.py @@ -128,6 +128,10 @@ "GitPython<3.1.19", "scipy", "onnx", + "optimum_quanto>=0.2.6", + "gguf>=0.10.0", + "torchao>=0.7.0", + "bitsandbytes>=0.43.3", "regex!=2019.12.17", "requests", "tensorboard", @@ -235,6 +239,11 @@ def run(self): ) extras["torch"] = deps_list("torch", "accelerate") +extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate") +extras["gguf"] = deps_list("gguf", "accelerate") +extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate") +extras["torchao"] = deps_list("torchao", "accelerate") + if os.name == "nt": # windows extras["flax"] = [] # jax is not supported on windows else: diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d5cfad915e3c..ad658f1b14ff 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -6,14 +6,19 @@ DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, + is_accelerate_available, + is_bitsandbytes_available, is_flax_available, + is_gguf_available, is_k_diffusion_available, is_librosa_available, is_note_seq_available, is_onnx_available, + is_optimum_quanto_available, is_scipy_available, is_sentencepiece_available, is_torch_available, + is_torchao_available, is_torchsde_available, is_transformers_available, ) @@ -32,7 +37,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"], + "quantizers.quantization_config": [], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -54,6 +59,54 @@ ], } +try: + if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_bitsandbytes_objects + + _import_structure["utils.dummy_bitsandbytes_objects"] = [ + name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig") + +try: + if not is_torch_available() and not is_accelerate_available() and not is_gguf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_gguf_objects + + _import_structure["utils.dummy_gguf_objects"] = [ + name for name in dir(dummy_gguf_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig") + +try: + if not is_torch_available() and not is_accelerate_available() and not is_torchao_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torchao_objects + + _import_structure["utils.dummy_torchao_objects"] = [ + name for name in dir(dummy_torchao_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("TorchAoConfig") + +try: + if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_optimum_quanto_objects + + _import_structure["utils.dummy_optimum_quanto_objects"] = [ + name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("QuantoConfig") + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -292,6 +345,7 @@ "CogVideoXPipeline", "CogVideoXVideoToVideoPipeline", "CogView3PlusPipeline", + "CogView4ControlPipeline", "CogView4Pipeline", "ConsisIDPipeline", "CycleDiffusionPipeline", @@ -348,9 +402,12 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LTXConditionPipeline", "LTXImageToVideoPipeline", "LTXPipeline", + "Lumina2Pipeline", "Lumina2Text2ImgPipeline", + "LuminaPipeline", "LuminaText2ImgPipeline", "MarigoldDepthPipeline", "MarigoldIntrinsicsPipeline", @@ -599,7 +656,38 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin - from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig + + try: + if not is_bitsandbytes_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_bitsandbytes_objects import * + else: + from .quantizers.quantization_config import BitsAndBytesConfig + + try: + if not is_gguf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_gguf_objects import * + else: + from .quantizers.quantization_config import GGUFQuantizationConfig + + try: + if not is_torchao_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torchao_objects import * + else: + from .quantizers.quantization_config import TorchAoConfig + + try: + if not is_optimum_quanto_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_optimum_quanto_objects import * + else: + from .quantizers.quantization_config import QuantoConfig try: if not is_onnx_available(): @@ -803,6 +891,7 @@ CogVideoXPipeline, CogVideoXVideoToVideoPipeline, CogView3PlusPipeline, + CogView4ControlPipeline, CogView4Pipeline, ConsisIDPipeline, CycleDiffusionPipeline, @@ -859,9 +948,12 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline, + Lumina2Pipeline, Lumina2Text2ImgPipeline, + LuminaPipeline, LuminaText2ImgPipeline, MarigoldDepthPipeline, MarigoldIntrinsicsPipeline, diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 17d5da60347d..8ec95ed6fc8d 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -35,6 +35,10 @@ "GitPython": "GitPython<3.1.19", "scipy": "scipy", "onnx": "onnx", + "optimum_quanto": "optimum_quanto>=0.2.6", + "gguf": "gguf>=0.10.0", + "torchao": "torchao>=0.7.0", + "bitsandbytes": "bitsandbytes>=0.43.3", "regex": "regex!=2019.12.17", "requests": "requests", "tensorboard": "tensorboard", diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index c389c5dc9826..e4b9ed9307ea 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -83,7 +83,10 @@ def onload_(self): with context: for group_module in self.modules: - group_module.to(self.onload_device, non_blocking=self.non_blocking) + for param in group_module.parameters(): + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + for buffer in group_module.buffers(): + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) if self.parameters is not None: for param in self.parameters: param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) @@ -98,6 +101,12 @@ def offload_(self): for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] + if self.parameters is not None: + for param in self.parameters: + param.data = self.cpu_param_dict[param] + if self.buffers is not None: + for buffer in self.buffers: + buffer.data = self.cpu_param_dict[buffer] else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=self.non_blocking) @@ -172,6 +181,13 @@ def __init__(self): self._layer_execution_tracker_module_names = set() def initialize_hook(self, module): + def make_execution_order_update_callback(current_name, current_submodule): + def callback(): + logger.debug(f"Adding {current_name} to the execution order") + self.execution_order.append((current_name, current_submodule)) + + return callback + # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the # layers are executed during the forward pass. @@ -183,14 +199,8 @@ def initialize_hook(self, module): group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING) if group_offloading_hook is not None: - - def make_execution_order_update_callback(current_name, current_submodule): - def callback(): - logger.debug(f"Adding {current_name} to the execution order") - self.execution_order.append((current_name, current_submodule)) - - return callback - + # For the first forward pass, we have to load in a blocking manner + group_offloading_hook.group.non_blocking = False layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule)) registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER) self._layer_execution_tracker_module_names.add(name) @@ -220,6 +230,7 @@ def post_forward(self, module, output): # Remove the layer execution tracker hooks from the submodules base_module_registry = module._diffusers_hook registries = [submodule._diffusers_hook for _, submodule in self.execution_order] + group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] for i in range(num_executed): registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False) @@ -227,8 +238,13 @@ def post_forward(self, module, output): # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False) - # Apply lazy prefetching by setting required attributes - group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] + # LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True. + # We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to + # see the benefits of prefetching. + for hook in group_offloading_hooks: + hook.group.non_blocking = True + + # Set required attributes for prefetching if num_executed > 0: base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING) base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group @@ -387,9 +403,7 @@ def _apply_group_offloading_block_level( # Create a pinned CPU parameter dict for async data transfer if streams are to be used cpu_param_dict = None if stream is not None: - for param in module.parameters(): - param.data = param.data.cpu().pin_memory() - cpu_param_dict = {param: param.data for param in module.parameters()} + cpu_param_dict = _get_pinned_cpu_param_dict(module) # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() @@ -486,9 +500,7 @@ def _apply_group_offloading_leaf_level( # Create a pinned CPU parameter dict for async data transfer if streams are to be used cpu_param_dict = None if stream is not None: - for param in module.parameters(): - param.data = param.data.cpu().pin_memory() - cpu_param_dict = {param: param.data for param in module.parameters()} + cpu_param_dict = _get_pinned_cpu_param_dict(module) # Create module groups for leaf modules and apply group offloading hooks modules_with_group_offloading = set() @@ -604,6 +616,17 @@ def _apply_lazy_group_offloading_hook( registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) +def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]: + cpu_param_dict = {} + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict[param] = param.data + for buffer in module.buffers(): + buffer.data = buffer.data.cpu().pin_memory() + cpu_param_dict[buffer] = buffer.data + return cpu_param_dict + + def _gather_parameters_with_no_group_offloading_parent( module: torch.nn.Module, modules_with_group_offloading: Set[str] ) -> List[torch.nn.Parameter]: diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 86ffffd7d5df..3ba1bfacf3dd 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -70,6 +70,7 @@ def text_encoder_attn_modules(text_encoder): "LoraLoaderMixin", "FluxLoraLoaderMixin", "CogVideoXLoraLoaderMixin", + "CogView4LoraLoaderMixin", "Mochi1LoraLoaderMixin", "HunyuanVideoLoraLoaderMixin", "SanaLoraLoaderMixin", @@ -103,6 +104,7 @@ def text_encoder_attn_modules(text_encoder): from .lora_pipeline import ( AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, + CogView4LoraLoaderMixin, FluxLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, LoraLoaderMixin, diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index ac0a3c635332..21a1a70ff79b 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -804,9 +804,7 @@ def load_ip_adapter( } self.register_modules( - feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to( - self.device, dtype=self.dtype - ), + feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs), image_encoder=SiglipVisionModel.from_pretrained( image_encoder_subfolder, torch_dtype=self.dtype, **kwargs ).to(self.device), diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 50b6448ecdca..17ed8c5444fc 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -339,93 +339,97 @@ def _load_lora_into_text_encoder( # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as # their prefixes. - keys = list(state_dict.keys()) prefix = text_encoder_name if prefix is None else prefix - # Safe prefix to check with. - if any(text_encoder_name in key for key in keys): - # Load the layers corresponding to text encoder and make necessary adjustments. - text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] - text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys - } - - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) + # Load the layers corresponding to text encoder and make necessary adjustments. + if prefix is not None: + state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + + if len(state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + state_dict = convert_state_dict_to_diffusers(state_dict) + + # convert state dict + state_dict = convert_state_dict_to_peft(state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in state_dict: + continue + rank[rank_key] = state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in state_dict: + continue + rank[rank_key] = state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) + + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") - is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) + lora_config = LoraConfig(**lora_config_kwargs) - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - **peft_kwargs, - ) + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) + is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=state_dict, + peft_config=lora_config, + **peft_kwargs, + ) - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + if prefix is not None and not state_dict: + logger.warning( + f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. " + "This is safe to ignore if LoRA state dict didn't originally have any " + f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` " + "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " + "https://github.com/huggingface/diffusers/issues/new" + ) def _func_optionally_disable_offloading(_pipeline): diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 4be6971755d2..20fcb61f3b80 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1348,3 +1348,56 @@ def process_block(prefix, index, convert_norm): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict + + +def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): + converted_state_dict = {} + original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} + + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict}) + is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) + + for i in range(num_blocks): + # Self-attention + for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): + converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.self_attn.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.self_attn.{o}.lora_B.weight" + ) + + # Cross-attention + for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_B.weight" + ) + + if is_i2v_lora: + for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.cross_attn.{o}.lora_B.weight" + ) + + # FFN + for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): + converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop( + f"blocks.{i}.{o}.lora_A.weight" + ) + converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop( + f"blocks.{i}.{o}.lora_B.weight" + ) + + if len(original_state_dict) > 0: + raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e48725b01ca2..160793ba1b58 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -42,6 +42,7 @@ _convert_kohya_flux_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers, + _convert_non_diffusers_wan_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, ) @@ -298,19 +299,15 @@ def load_lora_into_unet( # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. - keys = list(state_dict.keys()) - only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) - if not only_text_encoder: - # Load the layers corresponding to UNet. - logger.info(f"Loading {cls.unet_name}.") - unet.load_lora_adapter( - state_dict, - prefix=cls.unet_name, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + logger.info(f"Loading {cls.unet_name}.") + unet.load_lora_adapter( + state_dict, + prefix=cls.unet_name, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def load_lora_into_text_encoder( @@ -455,7 +452,11 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs): @@ -476,7 +477,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): @@ -559,31 +560,26 @@ def load_lora_weights( _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, ) - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} - if len(text_encoder_2_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_2_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder_2, - prefix="text_encoder_2", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix=self.text_encoder_name, + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix=f"{self.text_encoder_name}_2", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod @validate_hf_hub_args @@ -738,19 +734,15 @@ def load_lora_into_unet( # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. - keys = list(state_dict.keys()) - only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) - if not only_text_encoder: - # Load the layers corresponding to UNet. - logger.info(f"Loading {cls.unet_name}.") - unet.load_lora_adapter( - state_dict, - prefix=cls.unet_name, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + logger.info(f"Loading {cls.unet_name}.") + unet.load_lora_adapter( + state_dict, + prefix=cls.unet_name, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -904,7 +896,11 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs): @@ -925,7 +921,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_enc Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class SD3LoraLoaderMixin(LoraBaseMixin): @@ -1085,43 +1081,33 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} - if len(transformer_state_dict) > 0: - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) - if not hasattr(self, "transformer") - else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=None, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} - if len(text_encoder_2_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_2_state_dict, - network_alphas=None, - text_encoder=self.text_encoder_2, - prefix="text_encoder_2", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=None, + text_encoder=self.text_encoder, + prefix=self.text_encoder_name, + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=None, + text_encoder=self.text_encoder_2, + prefix=f"{self.text_encoder_name}_2", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def load_lora_into_transformer( @@ -1313,7 +1299,11 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer @@ -1335,7 +1325,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class FluxLoraLoaderMixin(LoraBaseMixin): @@ -1541,18 +1531,23 @@ def load_lora_weights( raise ValueError("Invalid LoRA checkpoint.") transformer_lora_state_dict = { - k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k + k: state_dict.get(k) + for k in list(state_dict.keys()) + if k.startswith(f"{self.transformer_name}.") and "lora" in k } transformer_norm_state_dict = { k: state_dict.pop(k) for k in list(state_dict.keys()) - if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) + if k.startswith(f"{self.transformer_name}.") + and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) } transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( - transformer, transformer_lora_state_dict, transformer_norm_state_dict - ) + has_param_with_expanded_shape = False + if len(transformer_lora_state_dict) > 0: + has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( + transformer, transformer_lora_state_dict, transformer_norm_state_dict + ) if has_param_with_expanded_shape: logger.info( @@ -1560,19 +1555,21 @@ def load_lora_weights( "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " "To get a comprehensive list of parameter names that were modified, enable debug logging." ) - transformer_lora_state_dict = self._maybe_expand_lora_state_dict( - transformer=transformer, lora_state_dict=transformer_lora_state_dict - ) - if len(transformer_lora_state_dict) > 0: - self.load_lora_into_transformer( - transformer_lora_state_dict, - network_alphas=network_alphas, - transformer=transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, + transformer_lora_state_dict = self._maybe_expand_lora_state_dict( + transformer=transformer, lora_state_dict=transformer_lora_state_dict ) + for k in transformer_lora_state_dict: + state_dict.update({k: transformer_lora_state_dict[k]}) + + self.load_lora_into_transformer( + state_dict, + network_alphas=network_alphas, + transformer=transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) if len(transformer_norm_state_dict) > 0: transformer._transformer_norm_layers = self._load_norm_into_transformer( @@ -1581,18 +1578,16 @@ def load_lora_weights( discard_original_layers=False, ) - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix=self.text_encoder_name, + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def load_lora_into_transformer( @@ -1625,17 +1620,14 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - keys = list(state_dict.keys()) - transformer_present = any(key.startswith(cls.transformer_name) for key in keys) - if transformer_present: - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def _load_norm_into_transformer( @@ -1849,7 +1841,11 @@ def fuse_lora( ) super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): @@ -1870,7 +1866,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) # We override this here account for `_transformer_norm_layers` and `_overwritten_params`. def unload_lora_weights(self, reset_to_overwritten_params=False): @@ -2174,17 +2170,14 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - keys = list(state_dict.keys()) - transformer_present = any(key.startswith(cls.transformer_name) for key in keys) - if transformer_present: - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2572,7 +2565,11 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): @@ -2590,7 +2587,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class Mochi1LoraLoaderMixin(LoraBaseMixin): @@ -2876,7 +2873,11 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora @@ -2895,7 +2896,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class LTXVideoLoraLoaderMixin(LoraBaseMixin): @@ -3181,7 +3182,11 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora @@ -3200,7 +3205,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class SanaLoraLoaderMixin(LoraBaseMixin): @@ -3486,7 +3491,11 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora @@ -3505,7 +3514,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): @@ -3794,7 +3803,11 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora @@ -3813,7 +3826,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class Lumina2LoraLoaderMixin(LoraBaseMixin): @@ -4103,7 +4116,11 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora @@ -4122,7 +4139,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class WanLoraLoaderMixin(LoraBaseMixin): @@ -4135,7 +4152,6 @@ class WanLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -4222,6 +4238,8 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + if any(k.startswith("diffusion_model.") for k in state_dict): + state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -4408,7 +4426,320 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + + +class CogView4LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`CogView4Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora @@ -4427,7 +4758,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index aaa2fd4108b1..74e51445cc1e 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -54,6 +54,7 @@ "SanaTransformer2DModel": lambda model_cls, weights: weights, "Lumina2Transformer2DModel": lambda model_cls, weights: weights, "WanTransformer3DModel": lambda model_cls, weights: weights, + "CogView4Transformer2DModel": lambda model_cls, weights: weights, } @@ -235,10 +236,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans raise ValueError("`network_alphas` cannot be None when `prefix` is None.") if prefix is not None: - keys = list(state_dict.keys()) - model_keys = [k for k in keys if k.startswith(f"{prefix}.")] - if len(model_keys) > 0: - state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} + state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}): @@ -355,6 +353,15 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans _pipeline.enable_sequential_cpu_offload() # Unsafe code /> + if prefix is not None and not state_dict: + logger.warning( + f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. " + "This is safe to ignore if LoRA state dict didn't originally have any " + f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` " + "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " + "https://github.com/huggingface/diffusers/issues/new" + ) + def save_lora_adapter( self, save_directory, diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index b7d61b3e8ff4..f72a0dd369f2 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -37,6 +37,7 @@ convert_ltx_vae_checkpoint_to_diffusers, convert_lumina2_to_diffusers, convert_mochi_transformer_checkpoint_to_diffusers, + convert_sana_transformer_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, convert_wan_transformer_to_diffusers, @@ -119,6 +120,10 @@ "checkpoint_mapping_fn": convert_lumina2_to_diffusers, "default_subfolder": "transformer", }, + "SanaTransformer2DModel": { + "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers, + "default_subfolder": "transformer", + }, "WanTransformer3DModel": { "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, "default_subfolder": "transformer", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 8ee7e14cb101..42aee4a84822 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -117,6 +117,12 @@ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], + "sana": [ + "blocks.0.cross_attn.q_linear.weight", + "blocks.0.cross_attn.q_linear.bias", + "blocks.0.cross_attn.kv_linear.weight", + "blocks.0.cross_attn.kv_linear.bias", + ], "wan": ["model.diffusion_model.head.modulation", "head.modulation"], "wan_vae": "decoder.middle.0.residual.0.gamma", } @@ -178,6 +184,7 @@ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"}, "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, + "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"}, "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, @@ -669,6 +676,9 @@ def infer_diffusers_model_type(checkpoint): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): model_type = "lumina2" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]): + model_type = "sana" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]): if "model.diffusion_model.patch_embedding.weight" in checkpoint: target_key = "model.diffusion_model.patch_embedding.weight" @@ -2897,6 +2907,111 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key): return converted_state_dict +def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 + + # Positional and patch embeddings. + checkpoint.pop("pos_embed") + converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Timestep embeddings. + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "t_embedder.mlp.0.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "t_embedder.mlp.2.weight" + ) + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") + converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") + + # Caption Projection. + checkpoint.pop("y_embedder.y_embedding") + converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") + converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") + converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") + converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") + converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") + + for i in range(num_layers): + converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop( + f"blocks.{i}.scale_shift_table" + ) + + # Self-Attention + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v]) + + # Output Projections + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop( + f"blocks.{i}.attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop( + f"blocks.{i}.attn.proj.bias" + ) + + # Cross-Attention + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop( + f"blocks.{i}.cross_attn.q_linear.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop( + f"blocks.{i}.cross_attn.q_linear.bias" + ) + + linear_sample_k, linear_sample_v = torch.chunk( + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0 + ) + linear_sample_k_bias, linear_sample_v_bias = torch.chunk( + checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias + + # Output Projections + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( + f"blocks.{i}.cross_attn.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( + f"blocks.{i}.cross_attn.proj.bias" + ) + + # MLP + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.inverted_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop( + f"blocks.{i}.mlp.inverted_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.depth_conv.conv.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop( + f"blocks.{i}.mlp.depth_conv.conv.bias" + ) + converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop( + f"blocks.{i}.mlp.point_conv.conv.weight" + ) + + # Final layer + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table") + + return converted_state_dict + + def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b45cb2a7950d..21d17d6acdab 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -741,10 +741,14 @@ def prepare_attention_mask( if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + attention_mask = attention_mask.repeat_interleave( + head_size, dim=0, output_size=attention_mask.shape[0] * head_size + ) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) - attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + attention_mask = attention_mask.repeat_interleave( + head_size, dim=1, output_size=attention_mask.shape[1] * head_size + ) return attention_mask @@ -2335,7 +2339,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -3704,8 +3710,10 @@ def __call__( if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads - key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) - value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head) + value = torch.repeat_interleave( + value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head + ) if attn.norm_q is not None: query = attn.norm_q(query) diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 1e6a26dddca8..9146aa5c7c6c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -190,7 +190,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: x = F.pixel_shuffle(x, self.factor) if self.shortcut: - y = hidden_states.repeat_interleave(self.repeats, dim=1) + y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats) y = F.pixel_shuffle(y, self.factor) hidden_states = x + y else: @@ -361,7 +361,9 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.in_shortcut: - x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1) + x = hidden_states.repeat_interleave( + self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats + ) hidden_states = self.conv_in(hidden_states) + x else: hidden_states = self.conv_in(hidden_states) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index f79aabe91dd3..a76277366c09 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -103,7 +103,7 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor: if self.down_sample: identity = hidden_states[:, :, ::2] elif self.up_sample: - identity = hidden_states.repeat_interleave(2, dim=2) + identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2) else: identity = hidden_states diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 75709ca10dfe..2b2f77a5509d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -196,6 +196,55 @@ def forward( return hidden_states +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + self.conv = LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + class LTXVideoUpsampler3d(nn.Module): def __init__( self, @@ -204,6 +253,7 @@ def __init__( is_causal: bool = True, residual: bool = False, upscale_factor: int = 1, + padding_mode: str = "zeros", ) -> None: super().__init__() @@ -219,6 +269,7 @@ def __init__( kernel_size=3, stride=1, is_causal=is_causal, + padding_mode=padding_mode, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -352,6 +403,118 @@ def forward( return hidden_states +class LTXVideo095DownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + downsample_type: str = "conv", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTXVideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + # Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d class LTXVideoMidBlock3d(nn.Module): r""" @@ -593,8 +756,15 @@ def __init__( in_channels: int = 3, out_channels: int = 128, block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + down_block_types: Tuple[str, ...] = ( + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + ), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"), patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, @@ -617,20 +787,37 @@ def __init__( ) # down blocks - num_block_out_channels = len(block_out_channels) + is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D" + num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0) self.down_blocks = nn.ModuleList([]) for i in range(num_block_out_channels): input_channel = output_channel - output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] - - down_block = LTXVideoDownBlock3D( - in_channels=input_channel, - out_channels=output_channel, - num_layers=layers_per_block[i], - resnet_eps=resnet_norm_eps, - spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, - ) + if not is_ltx_095: + output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] + else: + output_channel = block_out_channels[i + 1] + + if down_block_types[i] == "LTXVideoDownBlock3D": + down_block = LTXVideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + ) + elif down_block_types[i] == "LTXVideo095DownBlock3D": + down_block = LTXVideo095DownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + downsample_type=downsample_type[i], + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") self.down_blocks.append(down_block) @@ -794,7 +981,9 @@ def __init__( # timestep embedding self.time_embedder = None self.scale_shift_table = None + self.timestep_scale_multiplier = None if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) @@ -803,6 +992,9 @@ def __init__( def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) @@ -891,12 +1083,19 @@ def __init__( out_channels: int = 3, latent_channels: int = 128, block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + down_block_types: Tuple[str, ...] = ( + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + ), decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False), + downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"), upsample_residual: Tuple[bool, ...] = (False, False, False, False), upsample_factor: Tuple[int, ...] = (1, 1, 1, 1), timestep_conditioning: bool = False, @@ -906,6 +1105,8 @@ def __init__( scaling_factor: float = 1.0, encoder_causal: bool = True, decoder_causal: bool = False, + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, ) -> None: super().__init__() @@ -913,8 +1114,10 @@ def __init__( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, + down_block_types=down_block_types, spatio_temporal_scaling=spatio_temporal_scaling, layers_per_block=layers_per_block, + downsample_type=downsample_type, patch_size=patch_size, patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, @@ -941,8 +1144,16 @@ def __init__( self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) - self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) - self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # to perform decoding of a single video latent at a time. diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index cd3eff73ed64..d69ec6252b00 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -426,7 +426,9 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1] # Interleaved repeat of input channels to match w - h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W] + h = inputs.repeat_interleave( + num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs + ) # [B, C * num_freqs, T, H, W] # Scale channels by frequency. h = w * h diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py index 4edc91cacaa7..25348ce606d6 100644 --- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -687,7 +687,7 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) - emb = emb.repeat_interleave(sample_num_frames, dim=0) + emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames) # 2. pre-process batch_size, channels, num_frames, height, width = sample.shape diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 04a0b273f1fa..006ea8b4013f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed( # 3. Concat pos_embed_spatial = pos_embed_spatial[None, :, :] - pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3] + pos_embed_spatial = pos_embed_spatial.repeat_interleave( + temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size + ) # [T, H*W, D // 4 * 3] pos_embed_temporal = pos_embed_temporal[:, None, :] pos_embed_temporal = pos_embed_temporal.repeat_interleave( @@ -1152,10 +1154,13 @@ def get_1d_rotary_pos_embed( / linear_factor ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] + is_npu = freqs.device.type == "npu" + if is_npu: + freqs = freqs.float() if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox - freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] + freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] return freqs_cos, freqs_sin elif use_real: # stable audio, allegro diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index f019a3cc67a6..741f7075d76d 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -245,6 +245,9 @@ def load_model_dict_into_meta( ): param = param.to(torch.float32) set_module_kwargs["dtype"] = torch.float32 + # For quantizers have save weights using torch.float8_e4m3fn + elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None): + pass else: param = param.to(dtype) set_module_kwargs["dtype"] = dtype @@ -292,7 +295,9 @@ def load_model_dict_into_meta( elif is_quantized and ( hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device) ): - hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys) + hf_quantizer.create_quantized_param( + model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype + ) else: set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 00b55cd9c9d6..260b4b8929b0 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -366,7 +366,7 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) + input_tensor = self.conv_shortcut(input_tensor.contiguous()) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 4fe1d99cb6ee..4b359021f29d 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -227,13 +227,17 @@ def forward( # Prepare text embeddings for spatial block # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 - encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view( - -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] - ) + encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave( + num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame + ).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]) # Prepare timesteps for spatial and temporal block - timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1]) - timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1]) + timestep_spatial = timestep.repeat_interleave( + num_frame, dim=0, output_size=timestep.shape[0] * num_frame + ).view(-1, timestep.shape[-1]) + timestep_temp = timestep.repeat_interleave( + num_patches, dim=0, output_size=timestep.shape[0] * num_patches + ).view(-1, timestep.shape[-1]) # Spatial and temporal transformer blocks for i, (spatial_block, temp_block) in enumerate( @@ -299,7 +303,9 @@ def forward( ).permute(0, 2, 1, 3) hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) - embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1]) + embedded_timestep = embedded_timestep.repeat_interleave( + num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame + ).view(-1, embedded_timestep.shape[-1]) shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py index fdb67384ff5e..24d4e4d3d76f 100644 --- a/src/diffusers/models/transformers/prior_transformer.py +++ b/src/diffusers/models/transformers/prior_transformer.py @@ -353,7 +353,11 @@ def forward( attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) - attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) + attention_mask = attention_mask.repeat_interleave( + self.config.num_attention_heads, + dim=0, + output_size=attention_mask.shape[0] * self.config.num_attention_heads, + ) if self.norm_in is not None: hidden_states = self.norm_in(hidden_states) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index cface676b409..b8cc96d3532c 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -18,7 +18,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, @@ -195,7 +195,7 @@ def forward( return hidden_states -class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index db261ca1ea4b..41c4cbbf97c7 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -12,20 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...models.attention import FeedForward -from ...models.attention_processor import Attention -from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous -from ...utils import logging +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import FeedForward +from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import CogView3CombinedTimestepSizeEmbeddings from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -125,7 +127,8 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape + batch_size, image_seq_length, embed_dim = hidden_states.shape hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) # 1. QKV projections @@ -155,6 +158,15 @@ def __call__( ) # 4. Attention + if attention_mask is not None: + text_attention_mask = attention_mask.float().to(query.device) + actual_text_seq_length = text_attention_mask.size(1) + new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device) + new_attention_mask[:, :actual_text_seq_length] = text_attention_mask + new_attention_mask = new_attention_mask.unsqueeze(2) + attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2) + attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype) + hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -202,6 +214,8 @@ def forward( encoder_hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: # 1. Timestep conditioning ( @@ -222,6 +236,8 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + **kwargs, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) @@ -288,7 +304,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens return (freqs.cos(), freqs.sin()) -class CogView4Transformer2DModel(ModelMixin, ConfigMixin): +class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): r""" Args: patch_size (`int`, defaults to `2`): @@ -383,8 +399,26 @@ def forward( original_size: torch.Tensor, target_size: torch.Tensor, crop_coords: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[torch.Tensor, Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + batch_size, num_channels, height, width = hidden_states.shape # 1. RoPE @@ -404,11 +438,11 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, temb, image_rotary_emb + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs ) else: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, image_rotary_emb + hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs ) # 4. Output norm & projection @@ -419,6 +453,10 @@ def forward( hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index f5dc63f49562..c1f2df587927 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -14,7 +14,7 @@ # limitations under the License. import math -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -113,20 +113,19 @@ def __init__( self.patch_size_t = patch_size_t self.theta = theta - def forward( + def _prepare_video_coords( self, - hidden_states: torch.Tensor, + batch_size: int, num_frames: int, height: int, width: int, - rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size = hidden_states.size(0) - + rope_interpolation_scale: Tuple[torch.Tensor, float, float], + device: torch.device, + ) -> torch.Tensor: # Always compute rope in fp32 - grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device) - grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device) - grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device) + grid_h = torch.arange(height, dtype=torch.float32, device=device) + grid_w = torch.arange(width, dtype=torch.float32, device=device) + grid_f = torch.arange(num_frames, dtype=torch.float32, device=device) grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") grid = torch.stack(grid, dim=0) grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) @@ -138,6 +137,38 @@ def forward( grid = grid.flatten(2, 4).transpose(1, 2) + return grid + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, + video_coords: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.size(0) + + if video_coords is None: + grid = self._prepare_video_coords( + batch_size, + num_frames, + height, + width, + rope_interpolation_scale=rope_interpolation_scale, + device=hidden_states.device, + ) + else: + grid = torch.stack( + [ + video_coords[:, 0] / self.base_num_frames, + video_coords[:, 1] / self.base_height, + video_coords[:, 2] / self.base_width, + ], + dim=-1, + ) + start = 1.0 end = self.theta freqs = self.theta ** torch.linspace( @@ -367,10 +398,11 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - num_frames: int, - height: int, - width: int, - rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None, + video_coords: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: @@ -389,7 +421,7 @@ def forward( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) - image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 66cdda388c06..4eb4add37601 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -441,6 +441,14 @@ def forward( # 5. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 845d93b9db09..a148cf6cbe06 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -638,8 +638,10 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) - emb = emb.repeat_interleave(repeats=num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames + ) # 2. pre-process sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index f0eca75de169..c275e16744f4 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -592,7 +592,7 @@ def forward( # 3. time + FPS embeddings. emb = t_emb + fps_emb - emb = emb.repeat_interleave(repeats=num_frames, dim=0) + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) # 4. context embeddings. # The context embeddings consist of both text embeddings from the input prompt @@ -620,7 +620,7 @@ def forward( image_emb = self.context_embedding(image_embeddings) image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim) context_emb = torch.cat([context_emb, image_emb], dim=1) - context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) + context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames) image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape( image_latents.shape[0] * image_latents.shape[2], diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 21e4db23a166..bd83024c9b7c 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2059,7 +2059,7 @@ def forward( aug_emb = self.add_embedding(add_embeds) emb = emb if aug_emb is None else emb + aug_emb - emb = emb.repeat_interleave(repeats=num_frames, dim=0) + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: @@ -2068,7 +2068,10 @@ def forward( ) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds) - image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds] + image_embeds = [ + image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames) + for image_embed in image_embeds + ] encoder_hidden_states = (encoder_hidden_states, image_embeds) # 2. pre-process diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index db4ace9656a3..059a6e807c8e 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -431,9 +431,11 @@ def forward( sample = sample.flatten(0, 1) # Repeat the embeddings num_video_frames times # emb: [batch, channels] -> [batch * frames, channels] - emb = emb.repeat_interleave(num_frames, dim=0) + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] - encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames + ) # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 8b76e109e754..6b714d31c0e3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -154,7 +154,7 @@ "CogVideoXFunControlPipeline", ] _import_structure["cogview3"] = ["CogView3PlusPipeline"] - _import_structure["cogview4"] = ["CogView4Pipeline"] + _import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"] _import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["controlnet"].extend( [ @@ -264,9 +264,9 @@ ] ) _import_structure["latte"] = ["LattePipeline"] - _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"] - _import_structure["lumina"] = ["LuminaText2ImgPipeline"] - _import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"] + _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"] + _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] + _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["marigold"].extend( [ "MarigoldDepthPipeline", @@ -511,7 +511,7 @@ CogVideoXVideoToVideoPipeline, ) from .cogview3 import CogView3PlusPipeline - from .cogview4 import CogView4Pipeline + from .cogview4 import CogView4ControlPipeline, CogView4Pipeline from .consisid import ConsisIDPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, @@ -618,9 +618,9 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) - from .ltx import LTXImageToVideoPipeline, LTXPipeline - from .lumina import LuminaText2ImgPipeline - from .lumina2 import Lumina2Text2ImgPipeline + from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline + from .lumina import LuminaPipeline, LuminaText2ImgPipeline + from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline from .marigold import ( MarigoldDepthPipeline, MarigoldIntrinsicsPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 4f760ee09add..6a5f6098b6fb 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -22,7 +22,7 @@ from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline from .cogview3 import CogView3PlusPipeline -from .cogview4 import CogView4Pipeline +from .cogview4 import CogView4ControlPipeline, CogView4Pipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, @@ -69,8 +69,8 @@ ) from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline -from .lumina import LuminaText2ImgPipeline -from .lumina2 import Lumina2Text2ImgPipeline +from .lumina import LuminaPipeline +from .lumina2 import Lumina2Pipeline from .pag import ( HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, @@ -141,10 +141,11 @@ ("flux", FluxPipeline), ("flux-control", FluxControlPipeline), ("flux-controlnet", FluxControlNetPipeline), - ("lumina", LuminaText2ImgPipeline), - ("lumina2", Lumina2Text2ImgPipeline), + ("lumina", LuminaPipeline), + ("lumina2", Lumina2Pipeline), ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), + ("cogview4-control", CogView4ControlPipeline), ] ) diff --git a/src/diffusers/pipelines/cogview4/__init__.py b/src/diffusers/pipelines/cogview4/__init__.py index 5a535b3feb4b..6a365e17fee7 100644 --- a/src/diffusers/pipelines/cogview4/__init__.py +++ b/src/diffusers/pipelines/cogview4/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_cogview4"] = ["CogView4Pipeline"] + _import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -31,6 +32,7 @@ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_cogview4 import CogView4Pipeline + from .pipeline_cogview4_control import CogView4ControlPipeline else: import sys diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 6005c419b5c2..c27a1a19774d 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -14,7 +14,7 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -22,6 +22,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor +from ...loaders import CogView4LoraLoaderMixin from ...models import AutoencoderKL, CogView4Transformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -133,7 +134,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class CogView4Pipeline(DiffusionPipeline): +class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): r""" Pipeline for text-to-image generation using CogView4. @@ -388,6 +389,14 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -413,6 +422,7 @@ def __call__( crops_coords_top_left: Tuple[int, int] = (0, 0), output_type: str = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -526,6 +536,8 @@ def __call__( negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # Default call parameters @@ -603,6 +615,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -615,6 +628,7 @@ def __call__( original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -627,6 +641,7 @@ def __call__( original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -652,6 +667,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False, generator=generator)[0] diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py new file mode 100644 index 000000000000..b22705ed05c9 --- /dev/null +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -0,0 +1,727 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import AutoTokenizer, GlmModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKL, CogView4Transformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import CogView4PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogView4ControlPipeline + + >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16) + >>> control_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ... ) + >>> prompt = "A bird in space" + >>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0] + >>> image.save("cogview4-control.png") + ``` +""" + + +# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView4ControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using CogView4. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`GLMModel`]): + Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf). + tokenizer (`PreTrainedTokenizer`): + Tokenizer of class + [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). + transformer ([`CogView4Transformer2DModel`]): + A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: GlmModel, + vae: AutoencoderKL, + transformer: CogView4Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds + def _get_glm_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 1024, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="longest", # not use max length + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + current_length = text_input_ids.shape[1] + pad_length = (16 - (current_length % 16)) % 16 + if pad_length > 0: + pad_ids = torch.full( + (text_input_ids.shape[0], pad_length), + fill_value=self.tokenizer.pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) + prompt_embeds = self.text_encoder( + text_input_ids.to(self.text_encoder.device), output_hidden_states=True + ).hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `1024`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype) + + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + control_image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ) -> Union[CogView4PipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 1024. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 1024. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + + Examples: + + Returns: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = (height, width) + + # Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Prepare latents + latent_channels = self.transformer.config.in_channels // 2 + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + vae_shift_factor = 0 + + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps) + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + # Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CogView4PipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index eee41b9af4d1..f3f1d90204d6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -63,6 +63,7 @@ >>> from diffusers import FluxControlNetPipeline >>> from diffusers import FluxControlNetModel + >>> base_model = "black-forest-labs/FLUX.1-dev" >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny" >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) >>> pipe = FluxControlNetPipeline.from_pretrained( diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py index 20cc1c216522..199e730d9b4d 100644 --- a/src/diffusers/pipelines/ltx/__init__.py +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_ltx"] = ["LTXPipeline"] + _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"] _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -34,6 +35,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_ltx import LTXPipeline + from .pipeline_ltx_condition import LTXConditionPipeline from .pipeline_ltx_image2video import LTXImageToVideoPipeline else: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 866be61810a9..f7b0811d1a22 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -694,9 +694,8 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio rope_interpolation_scale = ( - 1 / latent_frame_rate, + self.vae_temporal_compression_ratio / frame_rate, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio, ) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py new file mode 100644 index 000000000000..e7f3666cb2c7 --- /dev/null +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -0,0 +1,1174 @@ +# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition + >>> from diffusers.utils import export_to_video, load_video, load_image + + >>> pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Load input image and video + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" + ... ) + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg" + ... ) + + >>> # Create conditioning objects + >>> condition1 = LTXVideoCondition( + ... image=image, + ... frame_index=0, + ... ) + >>> condition2 = LTXVideoCondition( + ... video=video, + ... frame_index=80, + ... ) + + >>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> # Generate video + >>> generator = torch.Generator("cuda").manual_seed(0) + >>> video = pipe( + ... conditions=[condition1, condition2], + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=161, + ... num_inference_steps=40, + ... generator=generator, + ... ).frames[0] + + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +@dataclass +class LTXVideoCondition: + """ + Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames. + + Attributes: + image (`PIL.Image.Image`): + The image to condition the video on. + video (`List[PIL.Image.Image]`): + The video to condition the video on. + frame_index (`int`): + The frame index at which the image or video will conditionally effect the video generation. + strength (`float`, defaults to `1.0`): + The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied. + """ + + image: Optional[PIL.Image.Image] = None + video: Optional[List[PIL.Image.Image]] = None + frame_index: int = 0 + strength: float = 1.0 + + +# from LTX-Video/ltx_video/schedulers/rf.py +def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + if num_steps < 2: + return torch.tensor([1.0]) + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return torch.tensor(sigma_schedule[:-1]) + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTXVideo, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXVideoTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128 + ) + + self.default_height = 512 + self.default_width = 704 + self.default_frames = 121 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + conditions, + image, + video, + frame_index, + strength, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if conditions is not None and (image is not None or video is not None): + raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.") + + if conditions is None and (image is None and video is None): + raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.") + + if conditions is None: + if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index): + raise ValueError( + "If `conditions` is not provided, `image` and `frame_index` must be of the same length." + ) + elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength): + raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.") + elif isinstance(video, list) and isinstance(frame_index, list) and len(video) != len(frame_index): + raise ValueError( + "If `conditions` is not provided, `video` and `frame_index` must be of the same length." + ) + elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength): + raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.") + + @staticmethod + def _prepare_video_ids( + batch_size: int, + num_frames: int, + height: int, + width: int, + patch_size: int = 1, + patch_size_t: int = 1, + device: torch.device = None, + ) -> torch.Tensor: + latent_sample_coords = torch.meshgrid( + torch.arange(0, num_frames, patch_size_t, device=device), + torch.arange(0, height, patch_size, device=device), + torch.arange(0, width, patch_size, device=device), + indexing="ij", + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) + + return latent_coords + + @staticmethod + def _scale_video_ids( + video_ids: torch.Tensor, + scale_factor: int = 32, + scale_factor_t: int = 8, + frame_index: int = 0, + device: torch.device = None, + ) -> torch.Tensor: + scaled_latent_coords = ( + video_ids + * torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None] + ) + scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0) + scaled_latent_coords[:, 0] += frame_index + + return scaled_latent_coords + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int): + """ + Trim a conditioning sequence to the allowed number of frames. + + Args: + start_frame (int): The target frame number of the first frame in the sequence. + sequence_num_frames (int): The number of frames in the sequence. + target_num_frames (int): The target number of frames in the generated video. + Returns: + int: updated sequence length + """ + scale_factor = self.vae_temporal_compression_ratio + num_frames = min(sequence_num_frames, target_num_frames - start_frame) + # Trim down to a multiple of temporal_scale_factor frames plus 1 + num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 + return num_frames + + @staticmethod + def add_noise_to_image_conditioning_latents( + t: float, + init_latents: torch.Tensor, + latents: torch.Tensor, + noise_scale: float, + conditioning_mask: torch.Tensor, + generator, + eps=1e-6, + ): + """ + Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially + when conditioned on a single frame. + """ + noise = randn_tensor( + latents.shape, + generator=generator, + device=latents.device, + dtype=latents.dtype, + ) + # Add noise only to hard-conditioning latents (conditioning_mask = 1.0) + need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1) + noised_latents = init_latents + noise_scale * noise * (t**2) + latents = torch.where(need_to_noise, noised_latents, latents) + return latents + + def prepare_latents( + self, + conditions: List[torch.Tensor], + condition_strength: List[float], + condition_frame_index: List[int], + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + num_prefix_latent_frames: int = 2, + generator: Optional[torch.Generator] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) + + extra_conditioning_latents = [] + extra_conditioning_video_ids = [] + extra_conditioning_mask = [] + extra_conditioning_num_latents = 0 + for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index): + condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) + condition_latents = self._normalize_latents( + condition_latents, self.vae.latents_mean, self.vae.latents_std + ).to(device, dtype=dtype) + + num_data_frames = data.size(2) + num_cond_frames = condition_latents.size(2) + + if frame_index == 0: + latents[:, :, :num_cond_frames] = torch.lerp( + latents[:, :, :num_cond_frames], condition_latents, strength + ) + condition_latent_frames_mask[:, :num_cond_frames] = strength + + else: + if num_data_frames > 1: + if num_cond_frames < num_prefix_latent_frames: + raise ValueError( + f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}." + ) + + if num_cond_frames > num_prefix_latent_frames: + start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames + end_frame = start_frame + num_cond_frames - num_prefix_latent_frames + latents[:, :, start_frame:end_frame] = torch.lerp( + latents[:, :, start_frame:end_frame], + condition_latents[:, :, num_prefix_latent_frames:], + strength, + ) + condition_latent_frames_mask[:, start_frame:end_frame] = strength + condition_latents = condition_latents[:, :, :num_prefix_latent_frames] + + noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) + condition_latents = torch.lerp(noise, condition_latents, strength) + + condition_video_ids = self._prepare_video_ids( + batch_size, + condition_latents.size(2), + latent_height, + latent_width, + patch_size=self.transformer_spatial_patch_size, + patch_size_t=self.transformer_temporal_patch_size, + device=device, + ) + condition_video_ids = self._scale_video_ids( + condition_video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=frame_index, + device=device, + ) + condition_latents = self._pack_latents( + condition_latents, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + condition_conditioning_mask = torch.full( + condition_latents.shape[:2], strength, device=device, dtype=dtype + ) + + extra_conditioning_latents.append(condition_latents) + extra_conditioning_video_ids.append(condition_video_ids) + extra_conditioning_mask.append(condition_conditioning_mask) + extra_conditioning_num_latents += condition_latents.size(1) + + video_ids = self._prepare_video_ids( + batch_size, + num_latent_frames, + latent_height, + latent_width, + patch_size_t=self.transformer_temporal_patch_size, + patch_size=self.transformer_spatial_patch_size, + device=device, + ) + conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) + video_ids = self._scale_video_ids( + video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=0, + device=device, + ) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + if len(extra_conditioning_latents) > 0: + latents = torch.cat([*extra_conditioning_latents, latents], dim=1) + video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2) + conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1) + + return latents, conditioning_mask, video_ids, extra_conditioning_num_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + conditions: Union[LTXVideoCondition, List[LTXVideoCondition]] = None, + image: Union[PipelineImageInput, List[PipelineImageInput]] = None, + video: List[PipelineImageInput] = None, + frame_index: Union[int, List[int]] = 0, + strength: Union[float, List[float]] = 1.0, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 704, + num_frames: int = 161, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 3, + image_cond_noise_scale: float = 0.15, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + conditions (`List[LTXVideoCondition], *optional*`): + The list of frame-conditioning items for the video generation.If not provided, conditions will be + created using `image`, `video`, `frame_index` and `strength`. + image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): + The image or images to condition the video generation. If not provided, one has to pass `video` or + `conditions`. + video (`List[PipelineImageInput]`, *optional*): + The video to condition the video generation. If not provided, one has to pass `image` or `conditions`. + frame_index (`int` or `List[int]`, *optional*): + The frame index or frame indices at which the image or video will conditionally effect the video + generation. If not provided, one has to pass `conditions`. + strength (`float` or `List[float]`, *optional*): + The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, defaults to `704`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `161`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `3 `): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `128 `): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if latents is not None: + raise ValueError("Passing latents is not yet supported.") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + conditions=conditions, + image=image, + video=video, + frame_index=frame_index, + strength=strength, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if conditions is not None: + if not isinstance(conditions, list): + conditions = [conditions] + + strength = [condition.strength for condition in conditions] + frame_index = [condition.frame_index for condition in conditions] + image = [condition.image for condition in conditions] + video = [condition.video for condition in conditions] + else: + if not isinstance(image, list): + image = [image] + num_conditions = 1 + elif isinstance(image, list): + num_conditions = len(image) + if not isinstance(video, list): + video = [video] + num_conditions = 1 + elif isinstance(video, list): + num_conditions = len(video) + + if not isinstance(frame_index, list): + frame_index = [frame_index] * num_conditions + if not isinstance(strength, list): + strength = [strength] * num_conditions + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + vae_dtype = self.vae.dtype + + conditioning_tensors = [] + for condition_image, condition_video, condition_frame_index, condition_strength in zip( + image, video, frame_index, strength + ): + if condition_image is not None: + condition_tensor = ( + self.video_processor.preprocess(condition_image, height, width) + .unsqueeze(2) + .to(device, dtype=vae_dtype) + ) + elif condition_video is not None: + condition_tensor = self.video_processor.preprocess_video(condition_video, height, width) + num_frames_input = condition_tensor.size(2) + num_frames_output = self.trim_conditioning_sequence( + condition_frame_index, num_frames_input, num_frames + ) + condition_tensor = condition_tensor[:, :, :num_frames_output] + condition_tensor = condition_tensor.to(device, dtype=vae_dtype) + else: + raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") + + if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1: + raise ValueError( + f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " + f"but got {condition_tensor.size(2)} frames." + ) + conditioning_tensors.append(condition_tensor) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents( + conditioning_tensors, + strength, + frame_index, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + generator=generator, + device=device, + dtype=torch.float32, + ) + + video_coords = video_coords.float() + video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate) + + init_latents = latents.clone() + + if self.do_classifier_free_guidance: + video_coords = torch.cat([video_coords, video_coords], dim=0) + + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + sigmas = linear_quadratic_schedule(num_inference_steps) + timesteps = sigmas * 1000 + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps=timesteps, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if image_cond_noise_scale > 0: + # Add timestep-dependent noise to the hard-conditioning latents + # This helps with motion continuity, especially when conditioned on a single frame + latents = self.add_noise_to_image_conditioning_latents( + t / 1000.0, + init_latents, + latents, + image_cond_noise_scale, + conditioning_mask, + generator, + ) + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + conditioning_mask_model_input = ( + torch.cat([conditioning_mask, conditioning_mask]) + if self.do_classifier_free_guidance + else conditioning_mask + ) + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() + timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + video_coords=video_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + timestep, _ = timestep.chunk(2) + + denoised_latents = self.scheduler.step( + -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False + )[0] + tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) + latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = latents[:, extra_conditioning_num_latents:] + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + if output_type == "latent": + video = latents + else: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 0577a56ec13d..6c4214fe1b26 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -764,9 +764,8 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions - latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio rope_interpolation_scale = ( - 1 / latent_frame_rate, + self.vae_temporal_compression_ratio / frame_rate, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio, ) diff --git a/src/diffusers/pipelines/lumina/__init__.py b/src/diffusers/pipelines/lumina/__init__.py index ca1396359721..a19dc7e94641 100644 --- a/src/diffusers/pipelines/lumina/__init__.py +++ b/src/diffusers/pipelines/lumina/__init__.py @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_lumina"] = ["LuminaText2ImgPipeline"] + _import_structure["pipeline_lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -32,7 +32,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_lumina import LuminaText2ImgPipeline + from .pipeline_lumina import LuminaPipeline, LuminaText2ImgPipeline else: import sys diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index b50079532f94..816213f105cb 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -30,6 +30,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( BACKENDS_MAPPING, + deprecate, is_bs4_available, is_ftfy_available, is_torch_xla_available, @@ -60,11 +61,9 @@ Examples: ```py >>> import torch - >>> from diffusers import LuminaText2ImgPipeline + >>> from diffusers import LuminaPipeline - >>> pipe = LuminaText2ImgPipeline.from_pretrained( - ... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 - ... ) + >>> pipe = LuminaPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16) >>> # Enable memory optimizations. >>> pipe.enable_model_cpu_offload() @@ -134,7 +133,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class LuminaText2ImgPipeline(DiffusionPipeline): +class LuminaPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Lumina-T2I. @@ -932,3 +931,23 @@ def __call__( return (image,) return ImagePipelineOutput(images=image) + + +class LuminaText2ImgPipeline(LuminaPipeline): + def __init__( + self, + transformer: LuminaNextDiT2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: GemmaPreTrainedModel, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + ): + deprecation_message = "`LuminaText2ImgPipeline` has been renamed to `LuminaPipeline` and will be removed in a future version. Please use `LuminaPipeline` instead." + deprecate("diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline", "0.34", deprecation_message) + super().__init__( + transformer=transformer, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) diff --git a/src/diffusers/pipelines/lumina2/__init__.py b/src/diffusers/pipelines/lumina2/__init__.py index 0e51a768a785..b1d6bfeb0d58 100644 --- a/src/diffusers/pipelines/lumina2/__init__.py +++ b/src/diffusers/pipelines/lumina2/__init__.py @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_lumina2"] = ["Lumina2Text2ImgPipeline"] + _import_structure["pipeline_lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -32,7 +32,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_lumina2 import Lumina2Text2ImgPipeline + from .pipeline_lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline else: import sys diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 514192cb70c7..e0905a2f131f 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -25,6 +25,7 @@ from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( + deprecate, is_torch_xla_available, logging, replace_example_docstring, @@ -47,9 +48,9 @@ Examples: ```py >>> import torch - >>> from diffusers import Lumina2Text2ImgPipeline + >>> from diffusers import Lumina2Pipeline - >>> pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16) + >>> pipe = Lumina2Pipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16) >>> # Enable memory optimizations. >>> pipe.enable_model_cpu_offload() @@ -133,7 +134,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin): +class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin): r""" Pipeline for text-to-image generation using Lumina-T2I. @@ -767,3 +768,23 @@ def __call__( return (image,) return ImagePipelineOutput(images=image) + + +class Lumina2Text2ImgPipeline(Lumina2Pipeline): + def __init__( + self, + transformer: Lumina2Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Gemma2PreTrainedModel, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + ): + deprecation_message = "`Lumina2Text2ImgPipeline` has been renamed to `Lumina2Pipeline` and will be removed in a future version. Please use `Lumina2Pipeline` instead." + deprecate("diffusers.pipelines.lumina2.pipeline_lumina2.Lumina2Text2ImgPipeline", "0.34", deprecation_message) + super().__init__( + transformer=transformer, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 07da8b5e2e2e..e80325ed42b0 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -667,9 +667,12 @@ def load_sub_model( use_safetensors: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, + quantization_config: Optional[Any] = None, ): """Helper method to load the module `name` from `library_name` and `class_name`""" + from ..quantizers import PipelineQuantizationConfig + # retrieve class candidates class_obj, class_candidates = get_class_obj_and_candidates( @@ -761,6 +764,17 @@ def load_sub_model( else: loading_kwargs["low_cpu_mem_usage"] = False + if ( + quantization_config is not None + and isinstance(quantization_config, PipelineQuantizationConfig) + and issubclass(class_obj, torch.nn.Module) + ): + model_quant_config = quantization_config._resolve_quant_config( + is_diffusers=is_diffusers_model, module_name=name + ) + if model_quant_config is not None: + loading_kwargs["quantization_config"] = model_quant_config + # check if the module is in a subdirectory if dduf_entries: loading_kwargs["dduf_entries"] = dduf_entries diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index cb60350be1b0..040eb8e8c74f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -702,6 +702,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_safetensors = kwargs.pop("use_safetensors", None) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + quantization_config = kwargs.pop("quantization_config", None) if not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 @@ -874,6 +875,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P } init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + # TODO: add checking for quantization_config `mapping` i.e., if the modules specified there actually exist. + ######################### + # remove `null` components def load_module(name, value): if value[0] is None: @@ -973,6 +977,7 @@ def load_module(name, value): use_safetensors=use_safetensors, dduf_entries=dduf_entries, provider_options=provider_options, + quantization_config=quantization_config, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." @@ -1610,7 +1615,7 @@ def _get_signature_keys(cls, obj): expected_modules.add(name) optional_parameters.remove(name) - return expected_modules, optional_parameters + return sorted(expected_modules), sorted(optional_parameters) @classmethod def _get_signature_types(cls): @@ -1652,10 +1657,12 @@ def components(self) -> Dict[str, Any]: k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters } - if set(components.keys()) != expected_modules: + actual = sorted(set(components.keys())) + expected = sorted(expected_modules) + if actual != expected: raise ValueError( f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" - f" {expected_modules} to be defined, but {components.keys()} are defined." + f" {expected} to be defined, but {actual} are defined." ) return components diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 863178e7c434..e5699718ea71 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -108,6 +108,7 @@ def prompt_clean(text): return text +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): @@ -385,13 +386,6 @@ def prepare_latents( ) video_condition = video_condition.to(device=device, dtype=dtype) - if isinstance(generator, list): - latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator] - latents = latent_condition = torch.cat(latent_condition) - else: - latent_condition = retrieve_latents(self.vae.encode(video_condition), generator) - latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) @@ -401,6 +395,15 @@ def prepare_latents( latents.device, latents.dtype ) + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latent_condition = (latent_condition - latents_mean) * latents_std mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 4c8483a3d6ee..975bf00afac2 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -12,5 +12,163 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +from typing import Dict, List, Optional + +from ..utils import is_transformers_available, logging from .auto import DiffusersAutoQuantizer from .base import DiffusersQuantizer + + +logger = logging.get_logger(__name__) + + +class PipelineQuantizationConfig: + """TODO""" + + def __init__( + self, + quant_backend: str = None, + quant_kwargs: Dict[str, str] = None, + modules_to_quantize: Optional[List[str]] = None, + quant_mapping: Dict[str,] = None, + ): + self.quant_backend = quant_backend + # Initialize kwargs to be {} to set to the defaults. + self.quant_kwargs = quant_kwargs or {} + self.modules_to_quantize = modules_to_quantize + self.quant_mapping = quant_mapping + + self.post_init() + + def post_init(self): + quant_mapping = self.quant_mapping + self.is_granular = True if quant_mapping is not None else False + + self._validate_init_args() + + def _validate_init_args(self): + if self.quant_backend and self.quant_mapping: + raise ValueError("Both `quant_backend` and `quant_mapping` cannot be set.") + + if not self.quant_mapping and not self.quant_backend: + raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.") + + if not self.quant_kwargs and not self.quant_mapping: + raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.") + + if self.quant_backend is not None: + self._validate_init_kwargs_in_backends() + + if self.quant_mapping is not None: + self._validate_quant_mapping_args() + + def _validate_init_kwargs_in_backends(self): + quant_backend = self.quant_backend + + self._check_backend_availability(quant_backend) + + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + if quant_config_mapping_transformers is not None: + if quant_backend not in quant_config_mapping_transformers: + raise ValueError( + f"`{quant_backend=}` is not available in `transformers`, available ones are: {list(quant_config_mapping_transformers.keys())}." + ) + init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__) + init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"} + else: + init_kwargs_transformers = None + + if quant_backend not in quant_config_mapping_diffusers: + raise ValueError( + f"`{quant_backend=}` is not available in `diffusers`, available ones are: {list(quant_config_mapping_diffusers.keys())}." + ) + init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__) + init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"} + + if init_kwargs_transformers != init_kwargs_diffusers: + raise ValueError( + "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " + f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class." + ) + + def _validate_quant_mapping_args(self): + quant_mapping = self.quant_mapping + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + available_configs_transformers = ( + list(quant_config_mapping_transformers.values()) if quant_config_mapping_transformers else None + ) + available_configs_diffusers = list(quant_config_mapping_diffusers.values()) + + for module_name, config in quant_mapping.items(): + if config not in available_configs_diffusers or ( + available_configs_transformers and config not in available_configs_transformers + ): + msg = f"Provided config for {module_name=} could not be found. Available ones for `diffusers` are: {available_configs_diffusers}.)" + if available_configs_transformers is not None: + msg += f" Available ones for `diffusers` are: {available_configs_transformers}." + raise ValueError(msg) + + def _check_backend_availability(self, quant_backend: str): + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + available_backends_transformers = ( + list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None + ) + available_backends_diffusers = list(quant_config_mapping_diffusers.keys()) + + if ( + available_backends_transformers and quant_backend not in available_backends_transformers + ) or quant_backend not in quant_config_mapping_diffusers: + error_message = f"Provided quant_backend={quant_backend} was not found." + if available_backends_transformers: + error_message += f"\nAvailable ones (transformers): {available_backends_transformers}." + error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}." + raise ValueError(error_message) + + def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None): + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + quant_mapping = self.quant_mapping + modules_to_quantize = self.modules_to_quantize + + # Granular case + if self.is_granular and module_name in quant_mapping: + logger.debug(f"Initializing quantization config class for {module_name}.") + config = quant_mapping[module_name] + return config + + # Global config case + else: + should_quantize = False + # Only quantize the modules requested for. + if modules_to_quantize and module_name in modules_to_quantize: + should_quantize = True + # No specification for `modules_to_quantize` means all modules should be quantized. + elif not self.is_granular and not modules_to_quantize: + should_quantize = True + + if should_quantize: + logger.debug(f"Initializing quantization config class for {module_name}.") + mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers + quant_config_cls = mapping_to_use[self.quant_backend] + # If `quant_kwargs` is None we default to initializing with the defaults of `quant_config_cls`. + quant_kwargs = self.quant_kwargs or {} + return quant_config_cls(**quant_kwargs) + + # Fallback: no applicable configuration found. + return None + + def _get_quant_config_list(self): + if is_transformers_available(): + from transformers.quantizers.auto import ( + AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers, + ) + else: + quant_config_mapping_transformers = None + + from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers + + return quant_config_mapping_transformers, quant_config_mapping_diffusers diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index d9874cc282ae..ce214ae7bc17 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -26,8 +26,10 @@ GGUFQuantizationConfig, QuantizationConfigMixin, QuantizationMethod, + QuantoConfig, TorchAoConfig, ) +from .quanto import QuantoQuantizer from .torchao import TorchAoHfQuantizer @@ -35,6 +37,7 @@ "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, "gguf": GGUFQuantizer, + "quanto": QuantoQuantizer, "torchao": TorchAoHfQuantizer, } @@ -42,6 +45,7 @@ "bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig, "gguf": GGUFQuantizationConfig, + "quanto": QuantoConfig, "torchao": TorchAoConfig, } diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index ada75588a42a..f4aa1504534c 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -135,6 +135,7 @@ def create_quantized_param( target_device: "torch.device", state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, + **kwargs, ): import bitsandbytes as bnb @@ -445,6 +446,7 @@ def create_quantized_param( target_device: "torch.device", state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, + **kwargs, ): import bitsandbytes as bnb diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py index 0c760e277ce4..6da69c7bd60c 100644 --- a/src/diffusers/quantizers/gguf/gguf_quantizer.py +++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py @@ -108,6 +108,7 @@ def create_quantized_param( target_device: "torch.device", state_dict: Optional[Dict[str, Any]] = None, unexpected_keys: Optional[List[str]] = None, + **kwargs, ): module, tensor_name = get_module_from_name(model, param_name) if tensor_name not in module._parameters and tensor_name not in module._buffers: diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 4fac8dd3829f..0bc433be0ff3 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -45,6 +45,7 @@ class QuantizationMethod(str, Enum): BITS_AND_BYTES = "bitsandbytes" GGUF = "gguf" TORCHAO = "torchao" + QUANTO = "quanto" if is_torchao_available(): @@ -686,3 +687,38 @@ def __repr__(self): return ( f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n" ) + + +@dataclass +class QuantoConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `quanto`. + + Args: + weights_dtype (`str`, *optional*, defaults to `"int8"`): + The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2") + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have some + modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). + """ + + def __init__( + self, + weights_dtype: str = "int8", + modules_to_not_convert: Optional[List[str]] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.QUANTO + self.weights_dtype = weights_dtype + self.modules_to_not_convert = modules_to_not_convert + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + accepted_weights = ["float8", "int8", "int4", "int2"] + if self.weights_dtype not in accepted_weights: + raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") diff --git a/src/diffusers/quantizers/quanto/__init__.py b/src/diffusers/quantizers/quanto/__init__.py new file mode 100644 index 000000000000..a4e8a1f41a1e --- /dev/null +++ b/src/diffusers/quantizers/quanto/__init__.py @@ -0,0 +1 @@ +from .quanto_quantizer import QuantoQuantizer diff --git a/src/diffusers/quantizers/quanto/quanto_quantizer.py b/src/diffusers/quantizers/quanto/quanto_quantizer.py new file mode 100644 index 000000000000..0120163804c9 --- /dev/null +++ b/src/diffusers/quantizers/quanto/quanto_quantizer.py @@ -0,0 +1,177 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from diffusers.utils.import_utils import is_optimum_quanto_version + +from ...utils import ( + get_module_from_name, + is_accelerate_available, + is_accelerate_version, + is_optimum_quanto_available, + is_torch_available, + logging, +) +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + from accelerate.utils import CustomDtype, set_module_tensor_to_device + +if is_optimum_quanto_available(): + from .utils import _replace_with_quanto_layers + +logger = logging.get_logger(__name__) + + +class QuantoQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for Optimum Quanto + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + required_packages = ["quanto", "accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not is_optimum_quanto_available(): + raise ImportError( + "Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)" + ) + if not is_optimum_quanto_version(">=", "0.2.6"): + raise ImportError( + "Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. " + "Please upgrade your installation with `pip install --upgrade optimum-quanto" + ) + + if not is_accelerate_available(): + raise ImportError( + "Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)" + ) + + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + raise ValueError( + "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend" + ) + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + # Quanto imports diffusers internally. This is here to prevent circular imports + from optimum.quanto import QModuleMixin, QTensor + from optimum.quanto.tensor.packed import PackedTensor + + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]): + return True + elif isinstance(module, QModuleMixin) and "weight" in tensor_name: + return not module.frozen + + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + *args, + **kwargs, + ): + """ + Create the quantized parameter by calling .freeze() after setting it to the module. + """ + + dtype = kwargs.get("dtype", torch.float32) + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized: + setattr(module, tensor_name, param_value) + else: + set_module_tensor_to_device(model, param_name, target_device, param_value, dtype) + module.freeze() + module.weight.requires_grad = False + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if is_accelerate_version(">=", "0.27.0"): + mapping = { + "int8": torch.int8, + "float8": CustomDtype.FP8, + "int4": CustomDtype.INT4, + "int2": CustomDtype.INT2, + } + target_dtype = mapping[self.quantization_config.weights_dtype] + + return target_dtype + + def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype": + if torch_dtype is None: + logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.") + torch_dtype = torch.float32 + return torch_dtype + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + # Quanto imports diffusers internally. This is here to prevent circular imports + from optimum.quanto import QModuleMixin + + not_missing_keys = [] + for name, module in model.named_modules(): + if isinstance(module, QModuleMixin): + for missing in missing_keys: + if ( + (name in missing or name in f"{prefix}.{missing}") + and not missing.endswith(".weight") + and not missing.endswith(".bias") + ): + not_missing_keys.append(missing) + return [k for k in missing_keys if k not in not_missing_keys] + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + model = _replace_with_quanto_layers( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + pre_quantized=self.pre_quantized, + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model, **kwargs): + return model + + @property + def is_trainable(self): + return True + + @property + def is_serializable(self): + return True diff --git a/src/diffusers/quantizers/quanto/utils.py b/src/diffusers/quantizers/quanto/utils.py new file mode 100644 index 000000000000..6f41fd36b43a --- /dev/null +++ b/src/diffusers/quantizers/quanto/utils.py @@ -0,0 +1,60 @@ +import torch.nn as nn + +from ...utils import is_accelerate_available, logging + + +logger = logging.get_logger(__name__) + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False): + # Quanto imports diffusers internally. These are placed here to avoid circular imports + from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8 + + def _get_weight_type(dtype: str): + return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype] + + def _replace_layers(model, quantization_config, modules_to_not_convert): + has_children = list(model.children()) + if not has_children: + return model + + for name, module in model.named_children(): + _replace_layers(module, quantization_config, modules_to_not_convert) + + if name in modules_to_not_convert: + continue + + if isinstance(module, nn.Linear): + with init_empty_weights(): + qlinear = QLinear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + dtype=module.weight.dtype, + weights=_get_weight_type(quantization_config.weights_dtype), + ) + model._modules[name] = qlinear + model._modules[name].source_cls = type(module) + model._modules[name].requires_grad_(False) + + return model + + model = _replace_layers(model, quantization_config, modules_to_not_convert) + has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules()) + + if not has_been_replaced: + logger.warning( + f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied." + " Please check your model architecture, or submit an issue on Github if you think this is a bug." + " https://github.com/huggingface/diffusers/issues/new" + ) + + # We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict + # to match when trying to load weights with load_model_dict_into_meta + if pre_quantized: + freeze(model) + + return model diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index e86ce2f64278..f9fb217ed6bd 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -23,7 +23,14 @@ from packaging import version -from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging +from ...utils import ( + get_module_from_name, + is_torch_available, + is_torch_version, + is_torchao_available, + is_torchao_version, + logging, +) from ..base import DiffusersQuantizer @@ -62,6 +69,43 @@ from torchao.quantization import quantize_ +def _update_torch_safe_globals(): + safe_globals = [ + (torch.uint1, "torch.uint1"), + (torch.uint2, "torch.uint2"), + (torch.uint3, "torch.uint3"), + (torch.uint4, "torch.uint4"), + (torch.uint5, "torch.uint5"), + (torch.uint6, "torch.uint6"), + (torch.uint7, "torch.uint7"), + ] + try: + from torchao.dtypes import NF4Tensor + from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl + from torchao.dtypes.uintx.uint4_layout import UInt4Tensor + from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor + + safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) + + except (ImportError, ModuleNotFoundError) as e: + logger.warning( + "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" + ) + logger.debug(e) + + finally: + torch.serialization.add_safe_globals(safe_globals=safe_globals) + + +if ( + is_torch_available() + and is_torch_version(">=", "2.6.0") + and is_torchao_available() + and is_torchao_version(">=", "0.7.0") +): + _update_torch_safe_globals() + + logger = logging.get_logger(__name__) @@ -215,6 +259,7 @@ def create_quantized_param( target_device: "torch.device", state_dict: Dict[str, Any], unexpected_keys: List[str], + **kwargs, ): r""" Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor, diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index e3bff7582cd9..cbb27e5fad63 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -377,6 +377,7 @@ def step( s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, + per_token_timesteps: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: """ @@ -397,6 +398,8 @@ def step( Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. + per_token_timesteps (`torch.Tensor`, *optional*): + The timesteps for each token in the sample. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. @@ -427,16 +430,26 @@ def step( # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) - sigma = self.sigmas[self.step_index] - sigma_next = self.sigmas[self.step_index + 1] + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps - prev_sample = sample + (sigma_next - sigma) * model_output + sigmas = self.sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(dim=0) + dt = (per_token_sigmas - lower_sigmas)[..., None] + else: + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + dt = sigma_next - sigma - # Cast sample back to model compatible dtype - prev_sample = prev_sample.to(model_output.dtype) + prev_sample = sample + dt * model_output # upon completion increase step index by one self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) if not return_dict: return (prev_sample,) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 6702ea2efbc8..50a470772772 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -79,6 +79,8 @@ is_matplotlib_available, is_note_seq_available, is_onnx_available, + is_optimum_quanto_available, + is_optimum_quanto_version, is_peft_available, is_peft_version, is_safetensors_available, @@ -92,6 +94,7 @@ is_torch_xla_available, is_torch_xla_version, is_torchao_available, + is_torchao_version, is_torchsde_available, is_torchvision_available, is_transformers_available, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 3f88f347710f..fa12318f4714 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -56,3 +56,14 @@ if USE_PEFT_BACKEND and _CHECK_PEFT: dep_version_check("peft") + + +DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" +DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" + + +ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" +ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" +ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/" diff --git a/src/diffusers/utils/dummy_bitsandbytes_objects.py b/src/diffusers/utils/dummy_bitsandbytes_objects.py new file mode 100644 index 000000000000..2dc589428de9 --- /dev/null +++ b/src/diffusers/utils/dummy_bitsandbytes_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class BitsAndBytesConfig(metaclass=DummyObject): + _backends = ["bitsandbytes"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["bitsandbytes"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["bitsandbytes"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["bitsandbytes"]) diff --git a/src/diffusers/utils/dummy_gguf_objects.py b/src/diffusers/utils/dummy_gguf_objects.py new file mode 100644 index 000000000000..4a6d9a060a13 --- /dev/null +++ b/src/diffusers/utils/dummy_gguf_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class GGUFQuantizationConfig(metaclass=DummyObject): + _backends = ["gguf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["gguf"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["gguf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["gguf"]) diff --git a/src/diffusers/utils/dummy_optimum_quanto_objects.py b/src/diffusers/utils/dummy_optimum_quanto_objects.py new file mode 100644 index 000000000000..44f8eaffc246 --- /dev/null +++ b/src/diffusers/utils/dummy_optimum_quanto_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class QuantoConfig(metaclass=DummyObject): + _backends = ["optimum_quanto"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["optimum_quanto"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["optimum_quanto"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["optimum_quanto"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ded30d16cf93..0c916bbbc1bc 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -362,6 +362,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class CogView4ControlPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CogView4Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1202,6 +1217,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTXConditionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1232,6 +1262,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Lumina2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Lumina2Text2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1247,6 +1292,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LuminaPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LuminaText2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/dummy_torchao_objects.py b/src/diffusers/utils/dummy_torchao_objects.py new file mode 100644 index 000000000000..16f0f6a55f64 --- /dev/null +++ b/src/diffusers/utils/dummy_torchao_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class TorchAoConfig(metaclass=DummyObject): + _backends = ["torchao"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchao"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torchao"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torchao"]) diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py index 00805433ceba..30d2c8bebd8e 100644 --- a/src/diffusers/utils/export_utils.py +++ b/src/diffusers/utils/export_utils.py @@ -3,7 +3,7 @@ import struct import tempfile from contextlib import contextmanager -from typing import List, Union +from typing import List, Optional, Union import numpy as np import PIL.Image @@ -139,8 +139,31 @@ def _legacy_export_to_video( def export_to_video( - video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10 + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], + output_video_path: str = None, + fps: int = 10, + quality: float = 5.0, + bitrate: Optional[int] = None, + macro_block_size: Optional[int] = 16, ) -> str: + """ + quality: + Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to + prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead. + Specifying a fixed bitrate using `bitrate` disables this parameter. + + bitrate: + Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead. + Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter + rather than specifiying a fixed bitrate with this parameter. + + macro_block_size: + Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number + imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs + are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic + feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some + codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock. + """ # TODO: Dhruv. Remove by Diffusers release 0.33.0 # Added to prevent breaking existing code if not is_imageio_available(): @@ -177,7 +200,9 @@ def export_to_video( elif isinstance(video_frames[0], PIL.Image.Image): video_frames = [np.array(frame) for frame in video_frames] - with imageio.get_writer(output_video_path, fps=fps) as writer: + with imageio.get_writer( + output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size + ) as writer: for frame in video_frames: writer.append_data(frame) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index ae1b9cae6edc..98b9c75451c8 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -25,7 +25,6 @@ from typing import Any, Union from huggingface_hub.utils import is_jinja_available # noqa: F401 -from packaging import version from packaging.version import Version, parse from . import logging @@ -52,36 +51,30 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} -_torch_version = "N/A" -if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: - _torch_available = importlib.util.find_spec("torch") is not None - if _torch_available: +_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ) + + +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: try: - _torch_version = importlib_metadata.version("torch") - logger.info(f"PyTorch version {_torch_version} available.") - except importlib_metadata.PackageNotFoundError: - _torch_available = False + pkg_version = importlib_metadata.version(pkg_name) + logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + + +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available, _torch_version = _is_package_available("torch") + else: logger.info("Disabling PyTorch because USE_TORCH is set") _torch_available = False -_torch_xla_available = importlib.util.find_spec("torch_xla") is not None -if _torch_xla_available: - try: - _torch_xla_version = importlib_metadata.version("torch_xla") - logger.info(f"PyTorch XLA version {_torch_xla_version} available.") - except ImportError: - _torch_xla_available = False - -# check whether torch_npu is available -_torch_npu_available = importlib.util.find_spec("torch_npu") is not None -if _torch_npu_available: - try: - _torch_npu_version = importlib_metadata.version("torch_npu") - logger.info(f"torch_npu version {_torch_npu_version} available.") - except ImportError: - _torch_npu_available = False - _jax_version = "N/A" _flax_version = "N/A" if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: @@ -97,47 +90,12 @@ _flax_available = False if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: - _safetensors_available = importlib.util.find_spec("safetensors") is not None - if _safetensors_available: - try: - _safetensors_version = importlib_metadata.version("safetensors") - logger.info(f"Safetensors version {_safetensors_version} available.") - except importlib_metadata.PackageNotFoundError: - _safetensors_available = False + _safetensors_available, _safetensors_version = _is_package_available("safetensors") + else: logger.info("Disabling Safetensors because USE_TF is set") _safetensors_available = False -_transformers_available = importlib.util.find_spec("transformers") is not None -try: - _transformers_version = importlib_metadata.version("transformers") - logger.debug(f"Successfully imported transformers version {_transformers_version}") -except importlib_metadata.PackageNotFoundError: - _transformers_available = False - -_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None -try: - _hf_hub_version = importlib_metadata.version("huggingface_hub") - logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}") -except importlib_metadata.PackageNotFoundError: - _hf_hub_available = False - - -_inflect_available = importlib.util.find_spec("inflect") is not None -try: - _inflect_version = importlib_metadata.version("inflect") - logger.debug(f"Successfully imported inflect version {_inflect_version}") -except importlib_metadata.PackageNotFoundError: - _inflect_available = False - - -_unidecode_available = importlib.util.find_spec("unidecode") is not None -try: - _unidecode_version = importlib_metadata.version("unidecode") - logger.debug(f"Successfully imported unidecode version {_unidecode_version}") -except importlib_metadata.PackageNotFoundError: - _unidecode_available = False - _onnxruntime_version = "N/A" _onnx_available = importlib.util.find_spec("onnxruntime") is not None if _onnx_available: @@ -186,85 +144,6 @@ except importlib_metadata.PackageNotFoundError: _opencv_available = False -_scipy_available = importlib.util.find_spec("scipy") is not None -try: - _scipy_version = importlib_metadata.version("scipy") - logger.debug(f"Successfully imported scipy version {_scipy_version}") -except importlib_metadata.PackageNotFoundError: - _scipy_available = False - -_librosa_available = importlib.util.find_spec("librosa") is not None -try: - _librosa_version = importlib_metadata.version("librosa") - logger.debug(f"Successfully imported librosa version {_librosa_version}") -except importlib_metadata.PackageNotFoundError: - _librosa_available = False - -_accelerate_available = importlib.util.find_spec("accelerate") is not None -try: - _accelerate_version = importlib_metadata.version("accelerate") - logger.debug(f"Successfully imported accelerate version {_accelerate_version}") -except importlib_metadata.PackageNotFoundError: - _accelerate_available = False - -_xformers_available = importlib.util.find_spec("xformers") is not None -try: - _xformers_version = importlib_metadata.version("xformers") - if _torch_available: - _torch_version = importlib_metadata.version("torch") - if version.Version(_torch_version) < version.Version("1.12"): - raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12") - - logger.debug(f"Successfully imported xformers version {_xformers_version}") -except importlib_metadata.PackageNotFoundError: - _xformers_available = False - -_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None -try: - _k_diffusion_version = importlib_metadata.version("k_diffusion") - logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") -except importlib_metadata.PackageNotFoundError: - _k_diffusion_available = False - -_note_seq_available = importlib.util.find_spec("note_seq") is not None -try: - _note_seq_version = importlib_metadata.version("note_seq") - logger.debug(f"Successfully imported note-seq version {_note_seq_version}") -except importlib_metadata.PackageNotFoundError: - _note_seq_available = False - -_wandb_available = importlib.util.find_spec("wandb") is not None -try: - _wandb_version = importlib_metadata.version("wandb") - logger.debug(f"Successfully imported wandb version {_wandb_version }") -except importlib_metadata.PackageNotFoundError: - _wandb_available = False - - -_tensorboard_available = importlib.util.find_spec("tensorboard") -try: - _tensorboard_version = importlib_metadata.version("tensorboard") - logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") -except importlib_metadata.PackageNotFoundError: - _tensorboard_available = False - - -_compel_available = importlib.util.find_spec("compel") -try: - _compel_version = importlib_metadata.version("compel") - logger.debug(f"Successfully imported compel version {_compel_version}") -except importlib_metadata.PackageNotFoundError: - _compel_available = False - - -_ftfy_available = importlib.util.find_spec("ftfy") is not None -try: - _ftfy_version = importlib_metadata.version("ftfy") - logger.debug(f"Successfully imported ftfy version {_ftfy_version}") -except importlib_metadata.PackageNotFoundError: - _ftfy_available = False - - _bs4_available = importlib.util.find_spec("bs4") is not None try: # importlib metadata under different name @@ -273,13 +152,6 @@ except importlib_metadata.PackageNotFoundError: _bs4_available = False -_torchsde_available = importlib.util.find_spec("torchsde") is not None -try: - _torchsde_version = importlib_metadata.version("torchsde") - logger.debug(f"Successfully imported torchsde version {_torchsde_version}") -except importlib_metadata.PackageNotFoundError: - _torchsde_available = False - _invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None try: _invisible_watermark_version = importlib_metadata.version("invisible-watermark") @@ -287,82 +159,42 @@ except importlib_metadata.PackageNotFoundError: _invisible_watermark_available = False - -_peft_available = importlib.util.find_spec("peft") is not None -try: - _peft_version = importlib_metadata.version("peft") - logger.debug(f"Successfully imported peft version {_peft_version}") -except importlib_metadata.PackageNotFoundError: - _peft_available = False - -_torchvision_available = importlib.util.find_spec("torchvision") is not None -try: - _torchvision_version = importlib_metadata.version("torchvision") - logger.debug(f"Successfully imported torchvision version {_torchvision_version}") -except importlib_metadata.PackageNotFoundError: - _torchvision_available = False - -_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None -try: - _sentencepiece_version = importlib_metadata.version("sentencepiece") - logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}") -except importlib_metadata.PackageNotFoundError: - _sentencepiece_available = False - -_matplotlib_available = importlib.util.find_spec("matplotlib") is not None -try: - _matplotlib_version = importlib_metadata.version("matplotlib") - logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}") -except importlib_metadata.PackageNotFoundError: - _matplotlib_available = False - -_timm_available = importlib.util.find_spec("timm") is not None -if _timm_available: - try: - _timm_version = importlib_metadata.version("timm") - logger.info(f"Timm version {_timm_version} available.") - except importlib_metadata.PackageNotFoundError: - _timm_available = False - - -def is_timm_available(): - return _timm_available - - -_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None -try: - _bitsandbytes_version = importlib_metadata.version("bitsandbytes") - logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}") -except importlib_metadata.PackageNotFoundError: - _bitsandbytes_available = False - -_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ) - -_imageio_available = importlib.util.find_spec("imageio") is not None -if _imageio_available: - try: - _imageio_version = importlib_metadata.version("imageio") - logger.debug(f"Successfully imported imageio version {_imageio_version}") - - except importlib_metadata.PackageNotFoundError: - _imageio_available = False - -_is_gguf_available = importlib.util.find_spec("gguf") is not None -if _is_gguf_available: +_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") +_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") +_transformers_available, _transformers_version = _is_package_available("transformers") +_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") +_inflect_available, _inflect_version = _is_package_available("inflect") +_unidecode_available, _unidecode_version = _is_package_available("unidecode") +_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion") +_note_seq_available, _note_seq_version = _is_package_available("note_seq") +_wandb_available, _wandb_version = _is_package_available("wandb") +_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard") +_compel_available, _compel_version = _is_package_available("compel") +_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece") +_torchsde_available, _torchsde_version = _is_package_available("torchsde") +_peft_available, _peft_version = _is_package_available("peft") +_torchvision_available, _torchvision_version = _is_package_available("torchvision") +_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib") +_timm_available, _timm_version = _is_package_available("timm") +_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") +_imageio_available, _imageio_version = _is_package_available("imageio") +_ftfy_available, _ftfy_version = _is_package_available("ftfy") +_scipy_available, _scipy_version = _is_package_available("scipy") +_librosa_available, _librosa_version = _is_package_available("librosa") +_accelerate_available, _accelerate_version = _is_package_available("accelerate") +_xformers_available, _xformers_version = _is_package_available("xformers") +_gguf_available, _gguf_version = _is_package_available("gguf") +_torchao_available, _torchao_version = _is_package_available("torchao") +_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") +_torchao_available, _torchao_version = _is_package_available("torchao") + +_optimum_quanto_available = importlib.util.find_spec("optimum") is not None +if _optimum_quanto_available: try: - _gguf_version = importlib_metadata.version("gguf") - logger.debug(f"Successfully import gguf version {_gguf_version}") + _optimum_quanto_version = importlib_metadata.version("optimum_quanto") + logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}") except importlib_metadata.PackageNotFoundError: - _is_gguf_available = False - - -_is_torchao_available = importlib.util.find_spec("torchao") is not None -if _is_torchao_available: - try: - _torchao_version = importlib_metadata.version("torchao") - logger.debug(f"Successfully import torchao version {_torchao_version}") - except importlib_metadata.PackageNotFoundError: - _is_torchao_available = False + _optimum_quanto_available = False def is_torch_available(): @@ -486,11 +318,19 @@ def is_imageio_available(): def is_gguf_available(): - return _is_gguf_available + return _gguf_available def is_torchao_available(): - return _is_torchao_available + return _torchao_available + + +def is_optimum_quanto_available(): + return _optimum_quanto_available + + +def is_timm_available(): + return _timm_available # docstyle-ignore @@ -636,6 +476,11 @@ def is_torchao_available(): torchao` """ +QUANTO_IMPORT_ERROR = """ +{0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip +install optimum-quanto` +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -663,6 +508,7 @@ def is_torchao_available(): ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), ("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)), ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)), + ("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)), ] ) @@ -844,11 +690,26 @@ def is_gguf_version(operation: str, version: str): version (`str`): A version string """ - if not _is_gguf_available: + if not _gguf_available: return False return compare_versions(parse(_gguf_version), operation, version) +def is_torchao_version(operation: str, version: str): + """ + Compares the current torchao version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _torchao_available: + return False + return compare_versions(parse(_torchao_version), operation, version) + + def is_k_diffusion_version(operation: str, version: str): """ Compares the current k-diffusion version to a given reference with an operation. @@ -864,6 +725,21 @@ def is_k_diffusion_version(operation: str, version: str): return compare_versions(parse(_k_diffusion_version), operation, version) +def is_optimum_quanto_version(operation: str, version: str): + """ + Compares the current Accelerate version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _optimum_quanto_available: + return False + return compare_versions(parse(_optimum_quanto_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py index 12bcc94af74f..fbce33d97f54 100644 --- a/src/diffusers/utils/remote_utils.py +++ b/src/diffusers/utils/remote_utils.py @@ -55,7 +55,7 @@ def detect_image_type(data: bytes) -> str: return "unknown" -def check_inputs( +def check_inputs_decode( endpoint: str, tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, @@ -89,7 +89,7 @@ def check_inputs( ) -def postprocess( +def postprocess_decode( response: requests.Response, processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, output_type: Literal["mp4", "pil", "pt"] = "pil", @@ -142,7 +142,7 @@ def postprocess( return output -def prepare( +def prepare_decode( tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, do_scaling: bool = True, @@ -293,7 +293,7 @@ def remote_decode( standard_warn=False, ) output_tensor_type = "binary" - check_inputs( + check_inputs_decode( endpoint, tensor, processor, @@ -309,7 +309,7 @@ def remote_decode( height, width, ) - kwargs = prepare( + kwargs = prepare_decode( tensor=tensor, processor=processor, do_scaling=do_scaling, @@ -324,7 +324,7 @@ def remote_decode( response = requests.post(endpoint, **kwargs) if not response.ok: raise RuntimeError(response.json()) - output = postprocess( + output = postprocess_decode( response=response, processor=processor, output_type=output_type, @@ -332,3 +332,94 @@ def remote_decode( partial_postprocess=partial_postprocess, ) return output + + +def check_inputs_encode( + endpoint: str, + image: Union["torch.Tensor", Image.Image], + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, +): + pass + + +def postprocess_encode( + response: requests.Response, +): + output_tensor = response.content + parameters = response.headers + shape = json.loads(parameters["shape"]) + dtype = parameters["dtype"] + torch_dtype = DTYPE_MAP[dtype] + output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape) + return output_tensor + + +def prepare_encode( + image: Union["torch.Tensor", Image.Image], + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, +): + headers = {} + parameters = {} + if scaling_factor is not None: + parameters["scaling_factor"] = scaling_factor + if shift_factor is not None: + parameters["shift_factor"] = shift_factor + if isinstance(image, torch.Tensor): + data = safetensors.torch._tobytes(image, "tensor") + parameters["shape"] = list(image.shape) + parameters["dtype"] = str(image.dtype).split(".")[-1] + else: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + data = buffer.getvalue() + return {"data": data, "params": parameters, "headers": headers} + + +def remote_encode( + endpoint: str, + image: Union["torch.Tensor", Image.Image], + scaling_factor: Optional[float] = None, + shift_factor: Optional[float] = None, +) -> "torch.Tensor": + """ + Hugging Face Hybrid Inference that allow running VAE encode remotely. + + Args: + endpoint (`str`): + Endpoint for Remote Decode. + image (`torch.Tensor` or `PIL.Image.Image`): + Image to be encoded. + scaling_factor (`float`, *optional*): + Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`]. + - SD v1: 0.18215 + - SD XL: 0.13025 + - Flux: 0.3611 + If `None`, input must be passed with scaling applied. + shift_factor (`float`, *optional*): + Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`. + - Flux: 0.1159 + If `None`, input must be passed with scaling applied. + + Returns: + output (`torch.Tensor`). + """ + check_inputs_encode( + endpoint, + image, + scaling_factor, + shift_factor, + ) + kwargs = prepare_encode( + image=image, + scaling_factor=scaling_factor, + shift_factor=shift_factor, + ) + response = requests.post(endpoint, **kwargs) + if not response.ok: + raise RuntimeError(response.json()) + output = postprocess_encode( + response=response, + ) + return output diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7eda13716025..2a3feae967d7 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -101,6 +101,8 @@ mps_backend_registered = hasattr(torch.backends, "mps") torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device + from .torch_utils import get_torch_cuda_device_capability + def torch_all_close(a, b, *args, **kwargs): if not is_torch_available(): @@ -282,6 +284,20 @@ def require_torch_gpu(test_case): ) +def require_torch_cuda_compatibility(expected_compute_capability): + def decorator(test_case): + if not torch.cuda.is_available(): + return unittest.skip(test_case) + else: + current_compute_capability = get_torch_cuda_device_capability() + return unittest.skipUnless( + float(current_compute_capability) == float(expected_compute_capability), + "Test not supported for this compute capability.", + ) + + return decorator + + # These decorators are for accelerator-specific behaviours that are not GPU-specific def require_torch_accelerator(test_case): """Decorator marking a test that requires an accelerator backend and PyTorch.""" diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py new file mode 100644 index 000000000000..178de2069b7e --- /dev/null +++ b/tests/lora/test_lora_layers_cogview4.py @@ -0,0 +1,174 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import tempfile +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, GlmModel + +from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler +from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device + + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests # noqa: E402 + + +class TokenizerWrapper: + @staticmethod + def from_pretrained(*args, **kwargs): + return AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True + ) + + +@require_peft_backend +@skip_mps +class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = CogView4Pipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_classes = [FlowMatchEulerDiscreteScheduler] + scheduler_kwargs = {} + + transformer_kwargs = { + "patch_size": 2, + "in_channels": 4, + "num_layers": 2, + "attention_head_dim": 4, + "num_attention_heads": 4, + "out_channels": 4, + "text_embed_dim": 32, + "time_embed_dim": 8, + "condition_dim": 4, + } + transformer_cls = CogView4Transformer2DModel + vae_kwargs = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + "sample_size": 128, + } + vae_cls = AutoencoderKL + tokenizer_cls, tokenizer_id, tokenizer_subfolder = ( + TokenizerWrapper, + "hf-internal-testing/tiny-random-cogview4", + "tokenizer", + ) + text_encoder_cls, text_encoder_id, text_encoder_subfolder = ( + GlmModel, + "hf-internal-testing/tiny-random-cogview4", + "text_encoder", + ) + + @property + def output_shape(self): + return (1, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + sizes = (4, 4) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "", + "num_inference_steps": 1, + "guidance_scale": 6.0, + "height": 32, + "width": 32, + "max_sequence_length": sequence_length, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + def test_simple_inference_save_pretrained(self): + """ + Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained + """ + for scheduler_cls in self.scheduler_classes: + components, _, _ = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + + pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) + pipe_from_pretrained.to(torch_device) + + images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] + + self.assertTrue( + np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", + ) + + @unittest.skip("Not supported in CogView4.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in CogView4.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in CogView4.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in CogView4.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in CogView4.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in CogView4.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in CogView4.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in CogView4.") + def test_simple_inference_with_text_lora_save_load(self): + pass diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 06bbcc62a0d5..860aa6511689 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -371,9 +371,8 @@ def test_with_norm_in_state_dict(self): lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - cap_logger.out.startswith( - "The provided state dict contains normalization layers in addition to LoRA layers" - ) + "The provided state dict contains normalization layers in addition to LoRA layers" + in cap_logger.out ) self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0) @@ -392,7 +391,7 @@ def test_with_norm_in_state_dict(self): pipe.load_lora_weights(norm_state_dict) self.assertTrue( - cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers") + "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out ) def test_lora_parameter_expanded_shapes(self): diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 17f6c9ccdf98..8cdb43c9d085 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1948,6 +1948,50 @@ def set_pad_mode(network, mode="circular"): _, _, inputs = self.get_dummy_inputs() _ = pipe(**inputs)[0] + def test_logs_info_when_no_lora_keys_found(self): + scheduler_cls = self.scheduler_classes[0] + # Skip text encoder check for now as that is handled with `transformers`. + components, _, _ = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} + logger = logging.get_logger("diffusers.loaders.peft") + logger.setLevel(logging.WARNING) + + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(no_op_state_dict) + out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0] + + denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer") + self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}")) + self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5)) + + # test only for text encoder + for lora_module in self.pipeline_class._lora_loadable_modules: + if "text_encoder" in lora_module: + text_encoder = getattr(pipe, lora_module) + if lora_module == "text_encoder": + prefix = "text_encoder" + elif lora_module == "text_encoder_2": + prefix = "text_encoder_2" + + logger = logging.get_logger("diffusers.loaders.lora_base") + logger.setLevel(logging.WARNING) + + with CaptureLogger(logger) as cap_logger: + self.pipeline_class.load_lora_into_text_encoder( + no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix + ) + + self.assertTrue( + cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") + ) + def test_set_adapters_match_attention_kwargs(self): """Test to check if outputs after `set_adapters()` and attention kwargs match.""" call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() diff --git a/tests/pipelines/ltx/test_ltx_condition.py b/tests/pipelines/ltx/test_ltx_condition.py new file mode 100644 index 000000000000..dbb9a740b433 --- /dev/null +++ b/tests/pipelines/ltx/test_ltx_condition.py @@ -0,0 +1,284 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXConditionPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTXConditionPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = LTXVideoTransformer3DModel( + in_channels=8, + out_channels=8, + patch_size=1, + patch_size_t=1, + num_attention_heads=4, + attention_head_dim=8, + cross_attention_dim=32, + num_layers=1, + caption_channels=32, + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + decoder_block_out_channels=(8, 8, 8, 8), + layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0, use_conditions=False): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = torch.randn((1, 3, 32, 32), generator=generator, device=device) + if use_conditions: + conditions = LTXVideoCondition( + image=image, + ) + else: + conditions = None + + inputs = { + "conditions": conditions, + "image": None if use_conditions else image, + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + # 8 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs2 = self.get_dummy_inputs(device, use_conditions=True) + video = pipe(**inputs).frames + generated_video = video[0] + video2 = pipe(**inputs2).frames + generated_video2 = video2[0] + + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + + max_diff = np.abs(generated_video - generated_video2).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index 034a0185d338..0c1fe8eb2fcd 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -5,7 +5,13 @@ import torch from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + LuminaNextDiT2DModel, + LuminaPipeline, + LuminaText2ImgPipeline, +) from diffusers.utils.testing_utils import ( backend_empty_cache, numpy_cosine_similarity_distance, @@ -17,8 +23,8 @@ from ..test_pipelines_common import PipelineTesterMixin -class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin): - pipeline_class = LuminaText2ImgPipeline +class LuminaPipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = LuminaPipeline params = frozenset( [ "prompt", @@ -99,11 +105,17 @@ def get_dummy_inputs(self, device, seed=0): def test_xformers_attention_forwardGenerator_pass(self): pass + def test_deprecation_raises_warning(self): + with self.assertWarns(FutureWarning) as warning: + _ = LuminaText2ImgPipeline(**self.get_dummy_components()).to(torch_device) + warning_message = str(warning.warnings[0].message) + assert "renamed to `LuminaPipeline`" in warning_message + @slow @require_torch_accelerator -class LuminaText2ImgPipelineSlowTests(unittest.TestCase): - pipeline_class = LuminaText2ImgPipeline +class LuminaPipelineSlowTests(unittest.TestCase): + pipeline_class = LuminaPipeline repo_id = "Alpha-VLLM/Lumina-Next-SFT-diffusers" def setUp(self): diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py index aa0571559b45..33fc870bcd34 100644 --- a/tests/pipelines/lumina2/test_pipeline_lumina2.py +++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py @@ -6,15 +6,17 @@ from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, + Lumina2Pipeline, Lumina2Text2ImgPipeline, Lumina2Transformer2DModel, ) +from diffusers.utils.testing_utils import torch_device from ..test_pipelines_common import PipelineTesterMixin -class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin): - pipeline_class = Lumina2Text2ImgPipeline +class Lumina2PipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = Lumina2Pipeline params = frozenset( [ "prompt", @@ -115,3 +117,9 @@ def get_dummy_inputs(self, device, seed=0): "output_type": "np", } return inputs + + def test_deprecation_raises_warning(self): + with self.assertWarns(FutureWarning) as warning: + _ = Lumina2Text2ImgPipeline(**self.get_dummy_components()).to(torch_device) + warning_message = str(warning.warnings[0].message) + assert "renamed to `Lumina2Pipeline`" in warning_message diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 964b55fde651..423c2b8ab146 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -19,7 +19,7 @@ UNet2DConditionModel, ) from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings -from diffusers.utils.testing_utils import torch_device +from diffusers.utils.testing_utils import require_torch_gpu, torch_device class IsSafetensorsCompatibleTests(unittest.TestCase): @@ -826,3 +826,104 @@ def test_video_to_video(self): with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): _ = pipe(**inputs) self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") + + +@require_torch_gpu +class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase): + expected_pipe_device = torch.device("cuda:0") + expected_pipe_dtype = torch.float64 + + def get_dummy_components_image_generation(self): + cross_attention_dim = 8 + + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=1, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=cross_attention_dim, + norm_num_groups=2, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[4, 8], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=2, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=cross_attention_dim, + intermediate_size=16, + layer_norm_eps=1e-05, + num_attention_heads=2, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + "image_encoder": None, + } + return components + + def test_deterministic_device(self): + components = self.get_dummy_components_image_generation() + + pipe = StableDiffusionPipeline(**components) + pipe.to(device=torch_device, dtype=torch.float32) + + pipe.unet.to(device="cpu") + pipe.vae.to(device="cuda") + pipe.text_encoder.to(device="cuda:0") + + pipe_device = pipe.device + + self.assertEqual( + self.expected_pipe_device, + pipe_device, + f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.", + ) + + def test_deterministic_dtype(self): + components = self.get_dummy_components_image_generation() + + pipe = StableDiffusionPipeline(**components) + pipe.to(device=torch_device, dtype=torch.float32) + + pipe.unet.to(dtype=torch.float16) + pipe.vae.to(dtype=torch.float32) + pipe.text_encoder.to(dtype=torch.float64) + + pipe_dtype = pipe.dtype + + self.assertEqual( + self.expected_pipe_dtype, + pipe_dtype, + f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.", + ) diff --git a/tests/quantization/__init__.py b/tests/quantization/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 6f85e6f38955..a80286fbb8dd 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -33,6 +33,7 @@ numpy_cosine_similarity_distance, require_accelerate, require_bitsandbytes_version_greater, + require_peft_backend, require_torch, require_torch_gpu, require_transformers_version_greater, @@ -54,29 +55,8 @@ def get_some_linear_layer(model): if is_torch_available(): import torch - import torch.nn as nn - class LoRALayer(nn.Module): - """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only - - Taken from - https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 - """ - - def __init__(self, module: nn.Module, rank: int): - super().__init__() - self.module = module - self.adapter = nn.Sequential( - nn.Linear(module.in_features, rank, bias=False), - nn.Linear(rank, module.out_features, bias=False), - ) - small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 - nn.init.normal_(self.adapter[0].weight, std=small_std) - nn.init.zeros_(self.adapter[1].weight) - self.adapter.to(module.weight.device) - - def forward(self, input, *args, **kwargs): - return self.module(input, *args, **kwargs) + self.adapter(input) + from ..utils import LoRALayer, get_memory_consumption_stat if is_bitsandbytes_available(): @@ -96,6 +76,8 @@ class Base4bitTests(unittest.TestCase): # This was obtained on audace so the number might slightly change expected_rel_difference = 3.69 + expected_memory_saving_ratio = 0.8 + prompt = "a beautiful sunset amidst the mountains." num_inference_steps = 10 seed = 0 @@ -140,8 +122,10 @@ def setUp(self): ) def tearDown(self): - del self.model_fp16 - del self.model_4bit + if hasattr(self, "model_fp16"): + del self.model_fp16 + if hasattr(self, "model_4bit"): + del self.model_4bit gc.collect() torch.cuda.empty_cache() @@ -180,6 +164,32 @@ def test_memory_footprint(self): linear = get_some_linear_layer(self.model_4bit) self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + def test_model_memory_usage(self): + # Delete to not let anything interfere. + del self.model_4bit, self.model_fp16 + + # Re-instantiate. + inputs = self.get_dummy_inputs() + inputs = { + k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool) + } + model_fp16 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", torch_dtype=torch.float16 + ).to(torch_device) + unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs) + del model_fp16 + + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.float16 + ) + quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs) + assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio + def test_original_dtype(self): r""" A simple test to check if the model succesfully stores the original dtype @@ -659,6 +669,7 @@ def test_quality(self): max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-3) + @require_peft_backend def test_lora_loading(self): self.pipeline_4bit.load_lora_weights( hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd" diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 4be420e7dffa..4964f8c9af07 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -60,29 +60,8 @@ def get_some_linear_layer(model): if is_torch_available(): import torch - import torch.nn as nn - class LoRALayer(nn.Module): - """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only - - Taken from - https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77 - """ - - def __init__(self, module: nn.Module, rank: int): - super().__init__() - self.module = module - self.adapter = nn.Sequential( - nn.Linear(module.in_features, rank, bias=False), - nn.Linear(rank, module.out_features, bias=False), - ) - small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 - nn.init.normal_(self.adapter[0].weight, std=small_std) - nn.init.zeros_(self.adapter[1].weight) - self.adapter.to(module.weight.device) - - def forward(self, input, *args, **kwargs): - return self.module(input, *args, **kwargs) + self.adapter(input) + from ..utils import LoRALayer, get_memory_consumption_stat if is_bitsandbytes_available(): @@ -102,6 +81,8 @@ class Base8bitTests(unittest.TestCase): # This was obtained on audace so the number might slightly change expected_rel_difference = 1.94 + expected_memory_saving_ratio = 0.7 + prompt = "a beautiful sunset amidst the mountains." num_inference_steps = 10 seed = 0 @@ -142,8 +123,10 @@ def setUp(self): ) def tearDown(self): - del self.model_fp16 - del self.model_8bit + if hasattr(self, "model_fp16"): + del self.model_fp16 + if hasattr(self, "model_8bit"): + del self.model_8bit gc.collect() torch.cuda.empty_cache() @@ -182,6 +165,28 @@ def test_memory_footprint(self): linear = get_some_linear_layer(self.model_8bit) self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + def test_model_memory_usage(self): + # Delete to not let anything interfere. + del self.model_8bit, self.model_fp16 + + # Re-instantiate. + inputs = self.get_dummy_inputs() + inputs = { + k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool) + } + model_fp16 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", torch_dtype=torch.float16 + ).to(torch_device) + unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs) + del model_fp16 + + config = BitsAndBytesConfig(load_in_8bit=True) + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16 + ) + quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs) + assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio + def test_original_dtype(self): r""" A simple test to check if the model succesfully stores the original dtype @@ -248,7 +253,7 @@ def test_llm_skip(self): self.assertTrue(linear.weight.dtype == torch.int8) self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt)) - self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear)) + self.assertTrue(isinstance(model_8bit.proj_out, torch.nn.Linear)) self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8) def test_config_from_pretrained(self): diff --git a/tests/quantization/quanto/__init__.py b/tests/quantization/quanto/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/quanto/test_quanto.py b/tests/quantization/quanto/test_quanto.py new file mode 100644 index 000000000000..9eb6958d2183 --- /dev/null +++ b/tests/quantization/quanto/test_quanto.py @@ -0,0 +1,328 @@ +import gc +import tempfile +import unittest + +from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig +from diffusers.models.attention_processor import Attention +from diffusers.utils import is_optimum_quanto_available, is_torch_available +from diffusers.utils.testing_utils import ( + nightly, + numpy_cosine_similarity_distance, + require_accelerate, + require_big_gpu_with_torch_cuda, + require_torch_cuda_compatibility, + torch_device, +) + + +if is_optimum_quanto_available(): + from optimum.quanto import QLinear + +if is_torch_available(): + import torch + + from ..utils import LoRALayer, get_memory_consumption_stat + + +@nightly +@require_big_gpu_with_torch_cuda +@require_accelerate +class QuantoBaseTesterMixin: + model_id = None + pipeline_model_id = None + model_cls = None + torch_dtype = torch.bfloat16 + # the expected reduction in peak memory used compared to an unquantized model expressed as a percentage + expected_memory_reduction = 0.0 + keep_in_fp32_module = "" + modules_to_not_convert = "" + _test_torch_compile = False + + def setUp(self): + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + gc.collect() + + def tearDown(self): + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + gc.collect() + + def get_dummy_init_kwargs(self): + return {"weights_dtype": "float8"} + + def get_dummy_model_init_kwargs(self): + return { + "pretrained_model_name_or_path": self.model_id, + "torch_dtype": self.torch_dtype, + "quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()), + } + + def test_quanto_layers(self): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + assert isinstance(module, QLinear) + + def test_quanto_memory_usage(self): + inputs = self.get_dummy_inputs() + inputs = { + k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool) + } + + unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype) + unquantized_model.to(torch_device) + unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs) + + quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + quantized_model.to(torch_device) + quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs) + + assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction + + def test_keep_modules_in_fp32(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. + Also ensures if inference works. + """ + _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules + self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module + + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + model.to("cuda") + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + assert module.weight.dtype == torch.float32 + self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules + + def test_modules_to_not_convert(self): + init_kwargs = self.get_dummy_model_init_kwargs() + + quantization_config_kwargs = self.get_dummy_init_kwargs() + quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert}) + quantization_config = QuantoConfig(**quantization_config_kwargs) + + init_kwargs.update({"quantization_config": quantization_config}) + + model = self.model_cls.from_pretrained(**init_kwargs) + model.to("cuda") + + for name, module in model.named_modules(): + if name in self.modules_to_not_convert: + assert not isinstance(module, QLinear) + + def test_dtype_assignment(self): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + + with self.assertRaises(ValueError): + # Tries with a `dtype` + model.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` and `dtype` + model.to(device="cuda:0", dtype=torch.float16) + + with self.assertRaises(ValueError): + # Tries with a cast + model.float() + + with self.assertRaises(ValueError): + # Tries with a cast + model.half() + + # This should work + model.to("cuda") + + def test_serialization(self): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + inputs = self.get_dummy_inputs() + + model.to(torch_device) + with torch.no_grad(): + model_output = model(**inputs) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + saved_model = self.model_cls.from_pretrained( + tmp_dir, + torch_dtype=torch.bfloat16, + ) + + saved_model.to(torch_device) + with torch.no_grad(): + saved_model_output = saved_model(**inputs) + + assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5) + + def test_torch_compile(self): + if not self._test_torch_compile: + return + + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False) + + model.to(torch_device) + with torch.no_grad(): + model_output = model(**self.get_dummy_inputs()).sample + + compiled_model.to(torch_device) + with torch.no_grad(): + compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample + + model_output = model_output.detach().float().cpu().numpy() + compiled_model_output = compiled_model_output.detach().float().cpu().numpy() + + max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten()) + assert max_diff < 1e-3 + + def test_device_map_error(self): + with self.assertRaises(ValueError): + _ = self.model_cls.from_pretrained( + **self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"} + ) + + +class FluxTransformerQuantoMixin(QuantoBaseTesterMixin): + model_id = "hf-internal-testing/tiny-flux-transformer" + model_cls = FluxTransformer2DModel + pipeline_cls = FluxPipeline + torch_dtype = torch.bfloat16 + keep_in_fp32_module = "proj_out" + modules_to_not_convert = ["proj_out"] + _test_torch_compile = False + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "pooled_projections": torch.randn( + (1, 768), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + "img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype), + } + + def get_dummy_training_inputs(self, device=None, seed: int = 0): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + torch.manual_seed(seed) + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + device, dtype=torch.bfloat16 + ) + + torch.manual_seed(seed) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + + def test_model_cpu_offload(self): + init_kwargs = self.get_dummy_init_kwargs() + transformer = self.model_cls.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + quantization_config=QuantoConfig(**init_kwargs), + subfolder="transformer", + torch_dtype=torch.bfloat16, + ) + pipe = self.pipeline_cls.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16 + ) + pipe.enable_model_cpu_offload(device=torch_device) + _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) + + def test_training(self): + quantization_config = QuantoConfig(**self.get_dummy_init_kwargs()) + quantized_model = self.model_cls.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + + for param in quantized_model.parameters(): + # freeze the model as only adapter layers will be trained + param.requires_grad = False + if param.ndim == 1: + param.data = param.data.to(torch.float32) + + for _, module in quantized_model.named_modules(): + if isinstance(module, Attention): + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): + inputs = self.get_dummy_training_inputs(torch_device) + output = quantized_model(**inputs)[0] + output.norm().backward() + + for module in quantized_model.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + + +class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): + expected_memory_reduction = 0.6 + + def get_dummy_init_kwargs(self): + return {"weights_dtype": "float8"} + + +class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): + expected_memory_reduction = 0.6 + _test_torch_compile = True + + def get_dummy_init_kwargs(self): + return {"weights_dtype": "int8"} + + +@require_torch_cuda_compatibility(8.0) +class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): + expected_memory_reduction = 0.55 + + def get_dummy_init_kwargs(self): + return {"weights_dtype": "int4"} + + +@require_torch_cuda_compatibility(8.0) +class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): + expected_memory_reduction = 0.65 + + def get_dummy_init_kwargs(self): + return {"weights_dtype": "int2"} diff --git a/tests/quantization/torchao/__init__.py b/tests/quantization/torchao/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index e14a1cc0369e..0e671307dd18 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -50,27 +50,7 @@ import torch import torch.nn as nn - class LoRALayer(nn.Module): - """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only - - Taken from - https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 - """ - - def __init__(self, module: nn.Module, rank: int): - super().__init__() - self.module = module - self.adapter = nn.Sequential( - nn.Linear(module.in_features, rank, bias=False), - nn.Linear(rank, module.out_features, bias=False), - ) - small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 - nn.init.normal_(self.adapter[0].weight, std=small_std) - nn.init.zeros_(self.adapter[1].weight) - self.adapter.to(module.weight.device) - - def forward(self, input, *args, **kwargs): - return self.module(input, *args, **kwargs) + self.adapter(input) + from ..utils import LoRALayer, get_memory_consumption_stat if is_torchao_available(): @@ -503,6 +483,22 @@ def test_memory_footprint(self): # there is additional overhead of scales and zero points self.assertTrue(total_bf16 < total_int4wo) + def test_model_memory_usage(self): + model_id = "hf-internal-testing/tiny-flux-pipe" + expected_memory_saving_ratio = 2.0 + + inputs = self.get_dummy_tensor_inputs(device=torch_device) + + transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] + transformer_bf16.to(torch_device) + unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs) + del transformer_bf16 + + transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"] + transformer_int8wo.to(torch_device) + quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs) + assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio + def test_wrong_config(self): with self.assertRaises(ValueError): self.get_dummy_components(TorchAoConfig("int42")) diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py new file mode 100644 index 000000000000..04ebf9e159f4 --- /dev/null +++ b/tests/quantization/utils.py @@ -0,0 +1,38 @@ +from diffusers.utils import is_torch_available + + +if is_torch_available(): + import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + @torch.no_grad() + @torch.inference_mode() + def get_memory_consumption_stat(model, inputs): + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + model(**inputs) + max_memory_mem_allocated = torch.cuda.max_memory_allocated() + return max_memory_mem_allocated diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py index 11f9c24d16f6..cec96e729a48 100644 --- a/tests/remote/test_remote_decode.py +++ b/tests/remote/test_remote_decode.py @@ -21,7 +21,15 @@ import torch from diffusers.image_processor import VaeImageProcessor -from diffusers.utils.remote_utils import remote_decode +from diffusers.utils.constants import ( + DECODE_ENDPOINT_FLUX, + DECODE_ENDPOINT_HUNYUAN_VIDEO, + DECODE_ENDPOINT_SD_V1, + DECODE_ENDPOINT_SD_XL, +) +from diffusers.utils.remote_utils import ( + remote_decode, +) from diffusers.utils.testing_utils import ( enable_full_determinism, slow, @@ -33,11 +41,6 @@ enable_full_determinism() -ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" -ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" -ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" -ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" - class RemoteAutoencoderKLMixin: shape: Tuple[int, ...] = None @@ -350,7 +353,7 @@ class RemoteAutoencoderKLSDv1Tests( 512, 512, ) - endpoint = ENDPOINT_SD_V1 + endpoint = DECODE_ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None @@ -374,7 +377,7 @@ class RemoteAutoencoderKLSDXLTests( 1024, 1024, ) - endpoint = ENDPOINT_SD_XL + endpoint = DECODE_ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None @@ -398,7 +401,7 @@ class RemoteAutoencoderKLFluxTests( 1024, 1024, ) - endpoint = ENDPOINT_FLUX + endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 @@ -425,7 +428,7 @@ class RemoteAutoencoderKLFluxPackedTests( ) height = 1024 width = 1024 - endpoint = ENDPOINT_FLUX + endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 @@ -453,7 +456,7 @@ class RemoteAutoencoderKLHunyuanVideoTests( 320, 512, ) - endpoint = ENDPOINT_HUNYUAN_VIDEO + endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO dtype = torch.float16 scaling_factor = 0.476986 processor_cls = VideoProcessor @@ -504,7 +507,7 @@ class RemoteAutoencoderKLSDv1SlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): - endpoint = ENDPOINT_SD_V1 + endpoint = DECODE_ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None @@ -515,7 +518,7 @@ class RemoteAutoencoderKLSDXLSlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): - endpoint = ENDPOINT_SD_XL + endpoint = DECODE_ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None @@ -527,7 +530,7 @@ class RemoteAutoencoderKLFluxSlowTests( unittest.TestCase, ): channels = 16 - endpoint = ENDPOINT_FLUX + endpoint = DECODE_ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 diff --git a/tests/remote/test_remote_encode.py b/tests/remote/test_remote_encode.py new file mode 100644 index 000000000000..62ed97ee8f49 --- /dev/null +++ b/tests/remote/test_remote_encode.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import PIL.Image +import torch + +from diffusers.utils import load_image +from diffusers.utils.constants import ( + DECODE_ENDPOINT_FLUX, + DECODE_ENDPOINT_SD_V1, + DECODE_ENDPOINT_SD_XL, + ENCODE_ENDPOINT_FLUX, + ENCODE_ENDPOINT_SD_V1, + ENCODE_ENDPOINT_SD_XL, +) +from diffusers.utils.remote_utils import ( + remote_decode, + remote_encode, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + slow, +) + + +enable_full_determinism() + +IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true" + + +class RemoteAutoencoderKLEncodeMixin: + channels: int = None + endpoint: str = None + decode_endpoint: str = None + dtype: torch.dtype = None + scaling_factor: float = None + shift_factor: float = None + image: PIL.Image.Image = None + + def get_dummy_inputs(self): + if self.image is None: + self.image = load_image(IMAGE) + inputs = { + "endpoint": self.endpoint, + "image": self.image, + "scaling_factor": self.scaling_factor, + "shift_factor": self.shift_factor, + } + return inputs + + def test_image_input(self): + inputs = self.get_dummy_inputs() + height, width = inputs["image"].height, inputs["image"].width + output = remote_encode(**inputs) + self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) + decoded = remote_decode( + tensor=output, + endpoint=self.decode_endpoint, + scaling_factor=self.scaling_factor, + shift_factor=self.shift_factor, + image_format="png", + ) + self.assertEqual(decoded.height, height) + self.assertEqual(decoded.width, width) + # image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten()) + # decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten()) + # TODO: how to test this? encode->decode is lossy. expected slice of encoded latent? + + +class RemoteAutoencoderKLSDv1Tests( + RemoteAutoencoderKLEncodeMixin, + unittest.TestCase, +): + channels = 4 + endpoint = ENCODE_ENDPOINT_SD_V1 + decode_endpoint = DECODE_ENDPOINT_SD_V1 + dtype = torch.float16 + scaling_factor = 0.18215 + shift_factor = None + + +class RemoteAutoencoderKLSDXLTests( + RemoteAutoencoderKLEncodeMixin, + unittest.TestCase, +): + channels = 4 + endpoint = ENCODE_ENDPOINT_SD_XL + decode_endpoint = DECODE_ENDPOINT_SD_XL + dtype = torch.float16 + scaling_factor = 0.13025 + shift_factor = None + + +class RemoteAutoencoderKLFluxTests( + RemoteAutoencoderKLEncodeMixin, + unittest.TestCase, +): + channels = 16 + endpoint = ENCODE_ENDPOINT_FLUX + decode_endpoint = DECODE_ENDPOINT_FLUX + dtype = torch.bfloat16 + scaling_factor = 0.3611 + shift_factor = 0.1159 + + +class RemoteAutoencoderKLEncodeSlowTestMixin: + channels: int = 4 + endpoint: str = None + decode_endpoint: str = None + dtype: torch.dtype = None + scaling_factor: float = None + shift_factor: float = None + image: PIL.Image.Image = None + + def get_dummy_inputs(self): + if self.image is None: + self.image = load_image(IMAGE) + inputs = { + "endpoint": self.endpoint, + "image": self.image, + "scaling_factor": self.scaling_factor, + "shift_factor": self.shift_factor, + } + return inputs + + def test_multi_res(self): + inputs = self.get_dummy_inputs() + for height in { + 320, + 512, + 640, + 704, + 896, + 1024, + 1208, + 1384, + 1536, + 1608, + 1864, + 2048, + }: + for width in { + 320, + 512, + 640, + 704, + 896, + 1024, + 1208, + 1384, + 1536, + 1608, + 1864, + 2048, + }: + inputs["image"] = inputs["image"].resize( + ( + width, + height, + ) + ) + output = remote_encode(**inputs) + self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) + decoded = remote_decode( + tensor=output, + endpoint=self.decode_endpoint, + scaling_factor=self.scaling_factor, + shift_factor=self.shift_factor, + image_format="png", + ) + self.assertEqual(decoded.height, height) + self.assertEqual(decoded.width, width) + decoded.save(f"test_multi_res_{height}_{width}.png") + + +@slow +class RemoteAutoencoderKLSDv1SlowTests( + RemoteAutoencoderKLEncodeSlowTestMixin, + unittest.TestCase, +): + endpoint = ENCODE_ENDPOINT_SD_V1 + decode_endpoint = DECODE_ENDPOINT_SD_V1 + dtype = torch.float16 + scaling_factor = 0.18215 + shift_factor = None + + +@slow +class RemoteAutoencoderKLSDXLSlowTests( + RemoteAutoencoderKLEncodeSlowTestMixin, + unittest.TestCase, +): + endpoint = ENCODE_ENDPOINT_SD_XL + decode_endpoint = DECODE_ENDPOINT_SD_XL + dtype = torch.float16 + scaling_factor = 0.13025 + shift_factor = None + + +@slow +class RemoteAutoencoderKLFluxSlowTests( + RemoteAutoencoderKLEncodeSlowTestMixin, + unittest.TestCase, +): + channels = 16 + endpoint = ENCODE_ENDPOINT_FLUX + decode_endpoint = DECODE_ENDPOINT_FLUX + dtype = torch.bfloat16 + scaling_factor = 0.3611 + shift_factor = 0.1159 diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py new file mode 100644 index 000000000000..7695e1577711 --- /dev/null +++ b/tests/single_file/test_sana_transformer.py @@ -0,0 +1,61 @@ +import gc +import unittest + +import torch + +from diffusers import ( + SanaTransformer2DModel, +) +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + torch_device, +) + + +enable_full_determinism() + + +@require_torch_accelerator +class SanaTransformer2DModelSingleFileTests(unittest.TestCase): + model_class = SanaTransformer2DModel + ckpt_path = ( + "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" + ) + alternate_keys_ckpt_paths = [ + "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" + ] + + repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers" + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_components(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") + model_single_file = self.model_class.from_single_file(self.ckpt_path) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between single file loading and pretrained loading" + + def test_checkpoint_loading(self): + for ckpt_path in self.alternate_keys_ckpt_paths: + torch.cuda.empty_cache() + model = self.model_class.from_single_file(ckpt_path) + + del model + gc.collect() + torch.cuda.empty_cache() From eec5b98651755c1c9abb58dc15a5c1fa38ac629e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 20 Mar 2025 20:31:38 +0530 Subject: [PATCH 02/19] Revert "feat: pipeline-level quant config." This reverts commit 316ff46b7648bfa24525ac02c284afcf440404aa. --- .github/workflows/benchmark.yml | 1 - .github/workflows/nightly_tests.yml | 9 - .github/workflows/pr_tests_gpu.yml | 44 - docs/source/en/_toctree.yml | 4 - docs/source/en/api/pipelines/ltx_video.md | 6 - docs/source/en/api/pipelines/lumina.md | 14 +- docs/source/en/api/pipelines/lumina2.md | 12 +- docs/source/en/api/pipelines/wan.md | 4 - docs/source/en/api/quantization.md | 5 - .../en/hybrid_inference/api_reference.md | 4 - docs/source/en/hybrid_inference/overview.md | 10 +- docs/source/en/hybrid_inference/vae_encode.md | 183 -- docs/source/en/quantization/overview.md | 1 - docs/source/en/quantization/quanto.md | 148 -- docs/source/en/quantization/torchao.md | 2 +- .../README_flux.md | 4 +- .../train_dreambooth_lora_flux_advanced.py | 4 +- .../train_dreambooth_lora_sd15_advanced.py | 2 +- .../train_dreambooth_lora_sdxl_advanced.py | 4 +- examples/cogview4-control/README.md | 201 -- examples/cogview4-control/requirements.txt | 6 - .../train_control_cogview4.py | 1242 --------- examples/community/mixture_tiling_sdxl.py | 44 +- examples/controlnet/train_controlnet.py | 4 +- examples/controlnet/train_controlnet_flux.py | 4 +- examples/controlnet/train_controlnet_sd3.py | 4 +- examples/controlnet/train_controlnet_sdxl.py | 4 +- examples/research_projects/anytext/README.md | 32 - examples/research_projects/anytext/anytext.py | 2360 ----------------- .../anytext/anytext_controlnet.py | 463 ---- .../anytext/ocr_recog/RNN.py | 209 -- .../anytext/ocr_recog/RecCTCHead.py | 45 - .../anytext/ocr_recog/RecModel.py | 49 - .../anytext/ocr_recog/RecMv1_enhance.py | 197 -- .../anytext/ocr_recog/RecSVTR.py | 570 ---- .../anytext/ocr_recog/common.py | 74 - .../anytext/ocr_recog/en_dict.txt | 95 - .../controlnet/train_controlnet_webdataset.py | 4 +- .../pixart/train_pixart_controlnet_hf.py | 4 +- .../pytorch_xla/inference/flux/README.md | 168 +- .../inference/flux/flux_inference.py | 28 +- .../t2i_adapter/train_t2i_adapter_sdxl.py | 4 +- scripts/convert_cogview4_to_diffusers.py | 15 +- .../convert_cogview4_to_diffusers_megatron.py | 66 +- scripts/convert_ltx_to_diffusers.py | 104 +- scripts/convert_lumina_to_diffusers.py | 4 +- setup.py | 9 - src/diffusers/__init__.py | 96 +- src/diffusers/dependency_versions_table.py | 4 - src/diffusers/hooks/group_offloading.py | 57 +- src/diffusers/loaders/__init__.py | 2 - src/diffusers/loaders/ip_adapter.py | 4 +- src/diffusers/loaders/lora_base.py | 164 +- .../loaders/lora_conversion_utils.py | 53 - src/diffusers/loaders/lora_pipeline.py | 653 ++--- src/diffusers/loaders/peft.py | 15 +- src/diffusers/loaders/single_file_model.py | 5 - src/diffusers/loaders/single_file_utils.py | 115 - src/diffusers/models/attention_processor.py | 18 +- .../models/autoencoders/autoencoder_dc.py | 6 +- .../autoencoders/autoencoder_kl_allegro.py | 2 +- .../models/autoencoders/autoencoder_kl_ltx.py | 237 +- .../autoencoders/autoencoder_kl_mochi.py | 4 +- .../controlnets/controlnet_sparsectrl.py | 2 +- src/diffusers/models/embeddings.py | 11 +- src/diffusers/models/model_loading_utils.py | 7 +- src/diffusers/models/resnet.py | 2 +- .../transformers/latte_transformer_3d.py | 18 +- .../models/transformers/prior_transformer.py | 6 +- .../models/transformers/sana_transformer.py | 4 +- .../transformers/transformer_cogview4.py | 58 +- .../models/transformers/transformer_ltx.py | 62 +- .../models/transformers/transformer_wan.py | 8 - .../models/unets/unet_3d_condition.py | 6 +- src/diffusers/models/unets/unet_i2vgen_xl.py | 4 +- .../models/unets/unet_motion_model.py | 7 +- .../unets/unet_spatio_temporal_condition.py | 6 +- src/diffusers/pipelines/__init__.py | 16 +- src/diffusers/pipelines/auto_pipeline.py | 11 +- src/diffusers/pipelines/cogview4/__init__.py | 2 - .../pipelines/cogview4/pipeline_cogview4.py | 21 +- .../cogview4/pipeline_cogview4_control.py | 727 ----- .../flux/pipeline_flux_controlnet.py | 1 - src/diffusers/pipelines/ltx/__init__.py | 2 - src/diffusers/pipelines/ltx/pipeline_ltx.py | 3 +- .../pipelines/ltx/pipeline_ltx_condition.py | 1174 -------- .../pipelines/ltx/pipeline_ltx_image2video.py | 3 +- src/diffusers/pipelines/lumina/__init__.py | 4 +- .../pipelines/lumina/pipeline_lumina.py | 29 +- src/diffusers/pipelines/lumina2/__init__.py | 4 +- .../pipelines/lumina2/pipeline_lumina2.py | 27 +- .../pipelines/pipeline_loading_utils.py | 14 - src/diffusers/pipelines/pipeline_utils.py | 13 +- .../pipelines/wan/pipeline_wan_i2v.py | 17 +- src/diffusers/quantizers/__init__.py | 158 -- src/diffusers/quantizers/auto.py | 4 - .../quantizers/bitsandbytes/bnb_quantizer.py | 2 - .../quantizers/gguf/gguf_quantizer.py | 1 - .../quantizers/quantization_config.py | 36 - src/diffusers/quantizers/quanto/__init__.py | 1 - .../quantizers/quanto/quanto_quantizer.py | 177 -- src/diffusers/quantizers/quanto/utils.py | 60 - .../quantizers/torchao/torchao_quantizer.py | 47 +- .../scheduling_flow_match_euler_discrete.py | 23 +- src/diffusers/utils/__init__.py | 3 - src/diffusers/utils/constants.py | 11 - .../utils/dummy_bitsandbytes_objects.py | 17 - src/diffusers/utils/dummy_gguf_objects.py | 17 - .../utils/dummy_optimum_quanto_objects.py | 17 - .../dummy_torch_and_transformers_objects.py | 60 - src/diffusers/utils/dummy_torchao_objects.py | 17 - src/diffusers/utils/export_utils.py | 31 +- src/diffusers/utils/import_utils.py | 328 ++- src/diffusers/utils/remote_utils.py | 103 +- src/diffusers/utils/testing_utils.py | 16 - tests/lora/test_lora_layers_cogview4.py | 174 -- tests/lora/test_lora_layers_flux.py | 7 +- tests/lora/utils.py | 44 - tests/pipelines/ltx/test_ltx_condition.py | 284 -- tests/pipelines/lumina/test_lumina_nextdit.py | 22 +- .../lumina2/test_pipeline_lumina2.py | 12 +- tests/pipelines/test_pipeline_utils.py | 103 +- tests/quantization/__init__.py | 0 tests/quantization/bnb/test_4bit.py | 59 +- tests/quantization/bnb/test_mixed_int8.py | 55 +- tests/quantization/quanto/__init__.py | 0 tests/quantization/quanto/test_quanto.py | 328 --- tests/quantization/torchao/__init__.py | 0 tests/quantization/torchao/test_torchao.py | 38 +- tests/quantization/utils.py | 38 - tests/remote/test_remote_decode.py | 31 +- tests/remote/test_remote_encode.py | 224 -- tests/single_file/test_sana_transformer.py | 61 - 133 files changed, 881 insertions(+), 11902 deletions(-) delete mode 100644 docs/source/en/hybrid_inference/vae_encode.md delete mode 100644 docs/source/en/quantization/quanto.md delete mode 100644 examples/cogview4-control/README.md delete mode 100644 examples/cogview4-control/requirements.txt delete mode 100644 examples/cogview4-control/train_control_cogview4.py delete mode 100644 examples/research_projects/anytext/README.md delete mode 100644 examples/research_projects/anytext/anytext.py delete mode 100644 examples/research_projects/anytext/anytext_controlnet.py delete mode 100755 examples/research_projects/anytext/ocr_recog/RNN.py delete mode 100755 examples/research_projects/anytext/ocr_recog/RecCTCHead.py delete mode 100755 examples/research_projects/anytext/ocr_recog/RecModel.py delete mode 100644 examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py delete mode 100644 examples/research_projects/anytext/ocr_recog/RecSVTR.py delete mode 100644 examples/research_projects/anytext/ocr_recog/common.py delete mode 100644 examples/research_projects/anytext/ocr_recog/en_dict.txt delete mode 100644 src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py delete mode 100644 src/diffusers/pipelines/ltx/pipeline_ltx_condition.py delete mode 100644 src/diffusers/quantizers/quanto/__init__.py delete mode 100644 src/diffusers/quantizers/quanto/quanto_quantizer.py delete mode 100644 src/diffusers/quantizers/quanto/utils.py delete mode 100644 src/diffusers/utils/dummy_bitsandbytes_objects.py delete mode 100644 src/diffusers/utils/dummy_gguf_objects.py delete mode 100644 src/diffusers/utils/dummy_optimum_quanto_objects.py delete mode 100644 src/diffusers/utils/dummy_torchao_objects.py delete mode 100644 tests/lora/test_lora_layers_cogview4.py delete mode 100644 tests/pipelines/ltx/test_ltx_condition.py delete mode 100644 tests/quantization/__init__.py delete mode 100644 tests/quantization/quanto/__init__.py delete mode 100644 tests/quantization/quanto/test_quanto.py delete mode 100644 tests/quantization/torchao/__init__.py delete mode 100644 tests/quantization/utils.py delete mode 100644 tests/remote/test_remote_encode.py delete mode 100644 tests/single_file/test_sana_transformer.py diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index ff915e046946..d311c1c73f11 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -38,7 +38,6 @@ jobs: python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] python -m uv pip install pandas peft - python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0 - name: Environment run: | python utils/print_env.py diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 2b39eea2fe5d..a40be8558499 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -414,16 +414,10 @@ jobs: config: - backend: "bitsandbytes" test_location: "bnb" - additional_deps: ["peft"] - backend: "gguf" test_location: "gguf" - additional_deps: [] - backend: "torchao" test_location: "torchao" - additional_deps: [] - - backend: "optimum_quanto" - test_location: "quanto" - additional_deps: [] runs-on: group: aws-g6e-xlarge-plus container: @@ -441,9 +435,6 @@ jobs: python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] python -m uv pip install -U ${{ matrix.config.backend }} - if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then - python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }} - fi python -m uv pip install pytest-reportlog - name: Environment run: | diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml index d86eccc28bb5..82f824c8f192 100644 --- a/.github/workflows/pr_tests_gpu.yml +++ b/.github/workflows/pr_tests_gpu.yml @@ -28,51 +28,7 @@ env: PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run jobs: - check_code_quality: - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install .[quality] - - name: Check quality - run: make quality - - name: Check if failure - if: ${{ failure() }} - run: | - echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY - - check_repository_consistency: - needs: check_code_quality - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install .[quality] - - name: Check repo consistency - run: | - python utils/check_copies.py - python utils/check_dummies.py - python utils/check_support_list.py - make deps_table_check_updated - - name: Check if failure - if: ${{ failure() }} - run: | - echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY - setup_torch_cuda_pipeline_matrix: - needs: [check_code_quality, check_repository_consistency] name: Setup Torch Pipelines CUDA Slow Tests Matrix runs-on: group: aws-general-8-plus diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d1805ff605d8..9438fe1a55e1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -81,8 +81,6 @@ title: Overview - local: hybrid_inference/vae_decode title: VAE Decode - - local: hybrid_inference/vae_encode - title: VAE Encode - local: hybrid_inference/api_reference title: API Reference title: Hybrid Inference @@ -175,8 +173,6 @@ title: gguf - local: quantization/torchao title: torchao - - local: quantization/quanto - title: quanto title: Quantization Methods - sections: - local: optimization/fp16 diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index 4bc22c0f9f6c..f31c621293fc 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -196,12 +196,6 @@ export_to_video(video, "ship.mp4", fps=24) - all - __call__ -## LTXConditionPipeline - -[[autodoc]] LTXConditionPipeline - - all - - __call__ - ## LTXPipelineOutput [[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput diff --git a/docs/source/en/api/pipelines/lumina.md b/docs/source/en/api/pipelines/lumina.md index ce5cf8b103cc..1967e85f173a 100644 --- a/docs/source/en/api/pipelines/lumina.md +++ b/docs/source/en/api/pipelines/lumina.md @@ -58,10 +58,10 @@ Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fa First, load the pipeline: ```python -from diffusers import LuminaPipeline +from diffusers import LuminaText2ImgPipeline import torch -pipeline = LuminaPipeline.from_pretrained( +pipeline = LuminaText2ImgPipeline.from_pretrained( "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 ).to("cuda") ``` @@ -86,11 +86,11 @@ image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit w Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model. -Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaPipeline`] for inference with bitsandbytes. +Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaText2ImgPipeline`] for inference with bitsandbytes. ```py import torch -from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaText2ImgPipeline from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel quant_config = BitsAndBytesConfig(load_in_8bit=True) @@ -109,7 +109,7 @@ transformer_8bit = Transformer2DModel.from_pretrained( torch_dtype=torch.float16, ) -pipeline = LuminaPipeline.from_pretrained( +pipeline = LuminaText2ImgPipeline.from_pretrained( "Alpha-VLLM/Lumina-Next-SFT-diffusers", text_encoder=text_encoder_8bit, transformer=transformer_8bit, @@ -122,9 +122,9 @@ image = pipeline(prompt).images[0] image.save("lumina.png") ``` -## LuminaPipeline +## LuminaText2ImgPipeline -[[autodoc]] LuminaPipeline +[[autodoc]] LuminaText2ImgPipeline - all - __call__ diff --git a/docs/source/en/api/pipelines/lumina2.md b/docs/source/en/api/pipelines/lumina2.md index 57f0e8e2105d..cf04bc17e3ef 100644 --- a/docs/source/en/api/pipelines/lumina2.md +++ b/docs/source/en/api/pipelines/lumina2.md @@ -36,14 +36,14 @@ Single file loading for Lumina Image 2.0 is available for the `Lumina2Transforme ```python import torch -from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline +from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth" transformer = Lumina2Transformer2DModel.from_single_file( ckpt_path, torch_dtype=torch.bfloat16 ) -pipe = Lumina2Pipeline.from_pretrained( +pipe = Lumina2Text2ImgPipeline.from_pretrained( "Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16 ) pipe.enable_model_cpu_offload() @@ -60,7 +60,7 @@ image.save("lumina-single-file.png") GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig` ```python -from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline, GGUFQuantizationConfig +from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf" transformer = Lumina2Transformer2DModel.from_single_file( @@ -69,7 +69,7 @@ transformer = Lumina2Transformer2DModel.from_single_file( torch_dtype=torch.bfloat16, ) -pipe = Lumina2Pipeline.from_pretrained( +pipe = Lumina2Text2ImgPipeline.from_pretrained( "Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16 ) pipe.enable_model_cpu_offload() @@ -80,8 +80,8 @@ image = pipe( image.save("lumina-gguf.png") ``` -## Lumina2Pipeline +## Lumina2Text2ImgPipeline -[[autodoc]] Lumina2Pipeline +[[autodoc]] Lumina2Text2ImgPipeline - all - __call__ diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index a35b73cb8a2e..b16bf92a6370 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -14,10 +14,6 @@ # Wan -
- LoRA -
- [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team. diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 2c728cff3c07..168a9a03473f 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -31,11 +31,6 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui ## GGUFQuantizationConfig [[autodoc]] GGUFQuantizationConfig - -## QuantoConfig - -[[autodoc]] QuantoConfig - ## TorchAoConfig [[autodoc]] TorchAoConfig diff --git a/docs/source/en/hybrid_inference/api_reference.md b/docs/source/en/hybrid_inference/api_reference.md index 865aaba5ebb6..aa0a5e5ae58f 100644 --- a/docs/source/en/hybrid_inference/api_reference.md +++ b/docs/source/en/hybrid_inference/api_reference.md @@ -3,7 +3,3 @@ ## Remote Decode [[autodoc]] utils.remote_utils.remote_decode - -## Remote Encode - -[[autodoc]] utils.remote_utils.remote_encode diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md index b44393c77cbd..9bbe245901df 100644 --- a/docs/source/en/hybrid_inference/overview.md +++ b/docs/source/en/hybrid_inference/overview.md @@ -36,7 +36,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir ## Available Models * **VAE Decode ๐Ÿ–ผ๏ธ:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed. -* **VAE Encode ๐Ÿ”ข:** Efficiently encode images into latent representations for generation and training. +* **VAE Encode ๐Ÿ”ข (coming soon):** Efficiently encode images into latent representations for generation and training. * **Text Encoders ๐Ÿ“ƒ (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow. --- @@ -46,15 +46,9 @@ Hybrid Inference offers a fast and simple way to offload local generation requir * **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. * **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. -## Changelog - -- March 10 2025: Added VAE encode -- March 2 2025: Initial release with VAE decoding - ## Contents -The documentation is organized into three sections: +The documentation is organized into two sections: * **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference. -* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference. * **API Reference** Dive into task-specific settings and parameters. diff --git a/docs/source/en/hybrid_inference/vae_encode.md b/docs/source/en/hybrid_inference/vae_encode.md deleted file mode 100644 index dd285fa25c03..000000000000 --- a/docs/source/en/hybrid_inference/vae_encode.md +++ /dev/null @@ -1,183 +0,0 @@ -# Getting Started: VAE Encode with Hybrid Inference - -VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations. - -## Memory - -These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs. - -For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality. - -
SD v1.5 - -| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | -|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:| -| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 | -| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 | -| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 | -| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 | -| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 | -| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 | -| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 | -| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 | -| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 | -| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 | -| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 | -| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 | -| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 | -| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 | -| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 | -| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 | -| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 | -| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 | -| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 | -| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 | - - -
- -
SDXL - -| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | -|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:| -| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 | -| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 | -| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 | -| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 | -| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 | -| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 | -| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 | -| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 | -| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 | -| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 | -| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 | -| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 | -| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 | -| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 | -| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 | -| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 | -| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 | -| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 | -| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 | -| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 | - -
- -## Available VAEs - -| | **Endpoint** | **Model** | -|:-:|:-----------:|:--------:| -| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) | -| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | -| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | - - -> [!TIP] -> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml). - - -## Code - -> [!TIP] -> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main` - - -A helper method simplifies interacting with Hybrid Inference. - -```python -from diffusers.utils.remote_utils import remote_encode -``` - -### Basic example - -Let's encode an image, then decode it to demonstrate. - -
- -
- -
Code - -```python -from diffusers.utils import load_image -from diffusers.utils.remote_utils import remote_decode - -image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true") - -latent = remote_encode( - endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/", - scaling_factor=0.3611, - shift_factor=0.1159, -) - -decoded = remote_decode( - endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=latent, - scaling_factor=0.3611, - shift_factor=0.1159, -) -``` - -
- -
- -
- - -### Generation - -Now let's look at a generation example, we'll encode the image, generate then remotely decode too! - -
Code - -```python -import torch -from diffusers import StableDiffusionImg2ImgPipeline -from diffusers.utils import load_image -from diffusers.utils.remote_utils import remote_decode, remote_encode - -pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16, - variant="fp16", - vae=None, -).to("cuda") - -init_image = load_image( - "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" -) -init_image = init_image.resize((768, 512)) - -init_latent = remote_encode( - endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/", - image=init_image, - scaling_factor=0.18215, -) - -prompt = "A fantasy landscape, trending on artstation" -latent = pipe( - prompt=prompt, - image=init_latent, - strength=0.75, - output_type="latent", -).images - -image = remote_decode( - endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=latent, - scaling_factor=0.18215, -) -image.save("fantasy_landscape.jpg") -``` - -
- -
- -
- -## Integrations - -* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. -* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 93323f86c7fc..794098e210a6 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -36,6 +36,5 @@ Diffusers currently supports the following quantization methods. - [BitsandBytes](./bitsandbytes) - [TorchAO](./torchao) - [GGUF](./gguf) -- [Quanto](./quanto.md) [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. diff --git a/docs/source/en/quantization/quanto.md b/docs/source/en/quantization/quanto.md deleted file mode 100644 index d322d76be267..000000000000 --- a/docs/source/en/quantization/quanto.md +++ /dev/null @@ -1,148 +0,0 @@ - - -# Quanto - -[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind: - -- All features are available in eager mode (works with non-traceable models) -- Supports quantization aware training -- Quantized models are compatible with `torch.compile` -- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU) - -In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate` - -```shell -pip install optimum-quanto accelerate -``` - -Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto. - -```python -import torch -from diffusers import FluxTransformer2DModel, QuantoConfig - -model_id = "black-forest-labs/FLUX.1-dev" -quantization_config = QuantoConfig(weights_dtype="float8") -transformer = FluxTransformer2DModel.from_pretrained( - model_id, - subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=torch.bfloat16, -) - -pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype) -pipe.to("cuda") - -prompt = "A cat holding a sign that says hello world" -image = pipe( - prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 -).images[0] -image.save("output.png") -``` - -## Skipping Quantization on specific modules - -It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict` - -```python -import torch -from diffusers import FluxTransformer2DModel, QuantoConfig - -model_id = "black-forest-labs/FLUX.1-dev" -quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"]) -transformer = FluxTransformer2DModel.from_pretrained( - model_id, - subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=torch.bfloat16, -) -``` - -## Using `from_single_file` with the Quanto Backend - -`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`. - -```python -import torch -from diffusers import FluxTransformer2DModel, QuantoConfig - -ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" -quantization_config = QuantoConfig(weights_dtype="float8") -transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16) -``` - -## Saving Quantized models - -Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method. - -The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized -with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained` - -```python -import torch -from diffusers import FluxTransformer2DModel, QuantoConfig - -model_id = "black-forest-labs/FLUX.1-dev" -quantization_config = QuantoConfig(weights_dtype="float8") -transformer = FluxTransformer2DModel.from_pretrained( - model_id, - subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=torch.bfloat16, -) -# save quantized model to reuse -transformer.save_pretrained("") - -# you can reload your quantized model with -model = FluxTransformer2DModel.from_pretrained("") -``` - -## Using `torch.compile` with Quanto - -Currently the Quanto backend supports `torch.compile` for the following quantization types: - -- `int8` weights - -```python -import torch -from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig - -model_id = "black-forest-labs/FLUX.1-dev" -quantization_config = QuantoConfig(weights_dtype="int8") -transformer = FluxTransformer2DModel.from_pretrained( - model_id, - subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=torch.bfloat16, -) -transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) - -pipe = FluxPipeline.from_pretrained( - model_id, transformer=transformer, torch_dtype=torch_dtype -) -pipe.to("cuda") -images = pipe("A cat holding a sign that says hello").images[0] -images.save("flux-quanto-compile.png") -``` - -## Supported Quantization Types - -### Weights - -- float8 -- int8 -- int4 -- int2 - - diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 19a8970fa9df..c056876c2f09 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0] image.save("output.png") ``` -If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. +Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. ```python import torch diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index f2a571d5eae4..1f83235ad50a 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -79,13 +79,13 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t ### Target Modules When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore -applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string +applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string the exact modules for LoRA training. Here are some examples of target modules you can provide: - for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"` - to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"` - to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"` > [!NOTE] -> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string: +> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string: > **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k` > **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k` > [!NOTE] diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index b8194507d822..7cb0d666fe69 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -378,7 +378,7 @@ def parse_args(input_args=None): default=None, help="the concept to use to initialize the new inserted tokens when training with " "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. " - "Alternatively, you could specify a different word/words whose value will be used as the starting point for the new inserted tokens. " + "Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. " "--num_new_tokens_per_abstraction is ignored when initializer_concept is provided", ) parser.add_argument( @@ -662,7 +662,7 @@ def parse_args(input_args=None): type=str, default=None, help=( - "The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. " + "The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. " 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md' ), ) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 8cd1d777c00c..41ab1eb660d7 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -662,7 +662,7 @@ def parse_args(input_args=None): action="store_true", default=False, help=( - "Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " + "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`" ), ) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 38b6e8dab209..5ec028026364 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -773,7 +773,7 @@ def parse_args(input_args=None): action="store_true", default=False, help=( - "Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " + "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`" ), ) @@ -1875,7 +1875,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip): # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. - # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion + # if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion add_special_tokens = True if args.train_text_encoder_ti else False if not train_dataset.custom_instance_prompts: diff --git a/examples/cogview4-control/README.md b/examples/cogview4-control/README.md deleted file mode 100644 index 746a99a1a41b..000000000000 --- a/examples/cogview4-control/README.md +++ /dev/null @@ -1,201 +0,0 @@ -# Training CogView4 Control - -This (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources: - -To incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`. - -> [!NOTE] -> **Gated model** -> -> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youโ€™ve accepted the gate. Use the command below to log in: - -```bash -huggingface-cli login -``` - -The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them. - -```bash -accelerate launch train_control_lora_cogview4.py \ - --pretrained_model_name_or_path="THUDM/CogView4-6B" \ - --dataset_name="raulc0399/open_pose_controlnet" \ - --output_dir="pose-control-lora" \ - --mixed_precision="bf16" \ - --train_batch_size=1 \ - --rank=64 \ - --gradient_accumulation_steps=4 \ - --gradient_checkpointing \ - --use_8bit_adam \ - --learning_rate=1e-4 \ - --report_to="wandb" \ - --lr_scheduler="constant" \ - --lr_warmup_steps=0 \ - --max_train_steps=5000 \ - --validation_image="openpose.png" \ - --validation_prompt="A couple, 4k photo, highly detailed" \ - --offload \ - --seed="0" \ - --push_to_hub -``` - -`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png). - -You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`. - -The training script exposes additional CLI args that might be useful to experiment with: - -* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer. -* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading. -* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached. - -### Training with DeepSpeed - -It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed): - -```yaml -compute_environment: LOCAL_MACHINE -debug: false -deepspeed_config: - gradient_accumulation_steps: 1 - gradient_clipping: 1.0 - offload_optimizer_device: cpu - offload_param_device: cpu - zero3_init_flag: false - zero_stage: 2 -distributed_type: DEEPSPEED -downcast_bf16: 'no' -enable_cpu_affinity: false -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 1 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false -``` - -And then while launching training, pass the config file: - -```bash -accelerate launch --config_file=CONFIG_FILE.yaml ... -``` - -### Inference - -The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first: - -```bash -pip install controlnet_aux -``` - -And then we are ready: - -```py -from controlnet_aux import OpenposeDetector -from diffusers import CogView4ControlPipeline -from diffusers.utils import load_image -from PIL import Image -import numpy as np -import torch - -pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda") -pipe.load_lora_weights("...") # change this. - -open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") - -# prepare pose condition. -url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg" -image = load_image(url) -image = open_pose(image, detect_resolution=512, image_resolution=1024) -image = np.array(image)[:, :, ::-1] -image = Image.fromarray(np.uint8(image)) - -prompt = "A couple, 4k photo, highly detailed" - -gen_images = pipe( - prompt=prompt, - control_image=image, - num_inference_steps=50, - joint_attention_kwargs={"scale": 0.9}, - guidance_scale=25., -).images[0] -gen_images.save("output.png") -``` - -## Full fine-tuning - -We provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command: - -```bash -accelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \ - --pretrained_model_name_or_path="THUDM/CogView4-6B" \ - --dataset_name="raulc0399/open_pose_controlnet" \ - --output_dir="pose-control" \ - --mixed_precision="bf16" \ - --train_batch_size=2 \ - --dataloader_num_workers=4 \ - --gradient_accumulation_steps=4 \ - --gradient_checkpointing \ - --use_8bit_adam \ - --proportion_empty_prompts=0.2 \ - --learning_rate=5e-5 \ - --adam_weight_decay=1e-4 \ - --report_to="wandb" \ - --lr_scheduler="cosine" \ - --lr_warmup_steps=1000 \ - --checkpointing_steps=1000 \ - --max_train_steps=10000 \ - --validation_steps=200 \ - --validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \ - --validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \ - --offload \ - --seed="0" \ - --push_to_hub -``` - -Change the `validation_image` and `validation_prompt` as needed. - -For inference, this time, we will run: - -```py -from controlnet_aux import OpenposeDetector -from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel -from diffusers.utils import load_image -from PIL import Image -import numpy as np -import torch - -transformer = CogView4Transformer2DModel.from_pretrained("...") # change this. -pipe = CogView4ControlPipeline.from_pretrained( - "THUDM/CogView4-6B", transformer=transformer, torch_dtype=torch.bfloat16 -).to("cuda") - -open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") - -# prepare pose condition. -url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg" -image = load_image(url) -image = open_pose(image, detect_resolution=512, image_resolution=1024) -image = np.array(image)[:, :, ::-1] -image = Image.fromarray(np.uint8(image)) - -prompt = "A couple, 4k photo, highly detailed" - -gen_images = pipe( - prompt=prompt, - control_image=image, - num_inference_steps=50, - guidance_scale=25., -).images[0] -gen_images.save("output.png") -``` - -## Things to note - -* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community ๐Ÿค— -* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified. -* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality. \ No newline at end of file diff --git a/examples/cogview4-control/requirements.txt b/examples/cogview4-control/requirements.txt deleted file mode 100644 index 6c5ec2e03f9a..000000000000 --- a/examples/cogview4-control/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -transformers==4.47.0 -wandb -torch -torchvision -accelerate==1.2.0 -peft>=0.14.0 diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py deleted file mode 100644 index 506ca0225bf7..000000000000 --- a/examples/cogview4-control/train_control_cogview4.py +++ /dev/null @@ -1,1242 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -import argparse -import copy -import logging -import math -import os -import random -import shutil -from contextlib import nullcontext -from pathlib import Path - -import accelerate -import numpy as np -import torch -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import DistributedType, ProjectConfiguration, set_seed -from datasets import load_dataset -from huggingface_hub import create_repo, upload_folder -from packaging import version -from PIL import Image -from torchvision import transforms -from tqdm.auto import tqdm - -import diffusers -from diffusers import ( - AutoencoderKL, - CogView4ControlPipeline, - CogView4Transformer2DModel, - FlowMatchEulerDiscreteScheduler, -) -from diffusers.optimization import get_scheduler -from diffusers.training_utils import ( - compute_density_for_timestep_sampling, - compute_loss_weighting_for_sd3, - free_memory, -) -from diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid -from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card -from diffusers.utils.torch_utils import is_compiled_module - - -if is_wandb_available(): - import wandb - -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.33.0.dev0") - -logger = get_logger(__name__) - -NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] - - -def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype): - pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample() - pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor - return pixel_latents.to(weight_dtype) - - -def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, is_final_validation=False): - logger.info("Running validation... ") - - if not is_final_validation: - cogview4_transformer = accelerator.unwrap_model(cogview4_transformer) - pipeline = CogView4ControlPipeline.from_pretrained( - args.pretrained_model_name_or_path, - transformer=cogview4_transformer, - torch_dtype=weight_dtype, - ) - else: - transformer = CogView4Transformer2DModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) - pipeline = CogView4ControlPipeline.from_pretrained( - args.pretrained_model_name_or_path, - transformer=transformer, - torch_dtype=weight_dtype, - ) - - pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - if args.seed is None: - generator = None - else: - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - - if len(args.validation_image) == len(args.validation_prompt): - validation_images = args.validation_image - validation_prompts = args.validation_prompt - elif len(args.validation_image) == 1: - validation_images = args.validation_image * len(args.validation_prompt) - validation_prompts = args.validation_prompt - elif len(args.validation_prompt) == 1: - validation_images = args.validation_image - validation_prompts = args.validation_prompt * len(args.validation_image) - else: - raise ValueError( - "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" - ) - - image_logs = [] - if is_final_validation or torch.backends.mps.is_available(): - autocast_ctx = nullcontext() - else: - autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype) - - for validation_prompt, validation_image in zip(validation_prompts, validation_images): - validation_image = load_image(validation_image) - # maybe need to inference on 1024 to get a good image - validation_image = validation_image.resize((args.resolution, args.resolution)) - - images = [] - - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt=validation_prompt, - control_image=validation_image, - num_inference_steps=50, - guidance_scale=args.guidance_scale, - max_sequence_length=args.max_sequence_length, - generator=generator, - height=args.resolution, - width=args.resolution, - ).images[0] - image = image.resize((args.resolution, args.resolution)) - images.append(image) - image_logs.append( - {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} - ) - - tracker_key = "test" if is_final_validation else "validation" - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - for log in image_logs: - images = log["images"] - validation_prompt = log["validation_prompt"] - validation_image = log["validation_image"] - formatted_images = [] - formatted_images.append(np.asarray(validation_image)) - for image in images: - formatted_images.append(np.asarray(image)) - formatted_images = np.stack(formatted_images) - tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") - - elif tracker.name == "wandb": - formatted_images = [] - for log in image_logs: - images = log["images"] - validation_prompt = log["validation_prompt"] - validation_image = log["validation_image"] - formatted_images.append(wandb.Image(validation_image, caption="Conditioning")) - for image in images: - image = wandb.Image(image, caption=validation_prompt) - formatted_images.append(image) - - tracker.log({tracker_key: formatted_images}) - else: - logger.warning(f"image logging not implemented for {tracker.name}") - - del pipeline - free_memory() - return image_logs - - -def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): - img_str = "" - if image_logs is not None: - img_str = "You can find some example images below.\n\n" - for i, log in enumerate(image_logs): - images = log["images"] - validation_prompt = log["validation_prompt"] - validation_image = log["validation_image"] - validation_image.save(os.path.join(repo_folder, "image_control.png")) - img_str += f"prompt: {validation_prompt}\n" - images = [validation_image] + images - make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) - img_str += f"![images_{i})](./images_{i}.png)\n" - - model_description = f""" -# cogview4-control-{repo_id} - -These are Control weights trained on {base_model} with new type of conditioning. -{img_str} - -## License - -Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogView4-6b/blob/main/LICENSE.md) -""" - - model_card = load_or_create_model_card( - repo_id_or_path=repo_id, - from_training=True, - license="other", - base_model=base_model, - model_description=model_description, - inference=True, - ) - - tags = [ - "cogview4", - "cogview4-diffusers", - "text-to-image", - "diffusers", - "control", - "diffusers-training", - ] - model_card = populate_model_card(model_card, tags=tags) - - model_card.save(os.path.join(repo_folder, "README.md")) - - -def parse_args(input_args=None): - parser = argparse.ArgumentParser(description="Simple example of a CogView4 Control training script.") - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--variant", - type=str, - default=None, - help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--output_dir", - type=str, - default="cogview4-control", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--cache_dir", - type=str, - default=None, - help="The directory where the downloaded models and datasets will be stored.", - ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--resolution", - type=int, - default=1024, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), - ) - parser.add_argument( - "--max_sequence_length", type=int, default=128, help="The maximum sequence length for the prompt." - ) - parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." - ) - parser.add_argument("--num_train_epochs", type=int, default=1) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " - "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." - "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." - "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" - "instructions." - ), - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=None, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--proportion_empty_prompts", - type=float, - default=0, - help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-6, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument( - "--lr_num_cycles", - type=int, - default=1, - help="Number of hard resets of the lr in cosine_with_restarts scheduler.", - ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") - parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." - ) - - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--report_to", - type=str, - default="tensorboard", - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--dataset_name", - type=str, - default=None, - help=( - "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," - " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," - " or to a folder containing files that ๐Ÿค— Datasets can understand." - ), - ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The config of the Dataset, leave as None if there's only one config.", - ) - parser.add_argument( - "--image_column", type=str, default="image", help="The column of the dataset containing the target image." - ) - parser.add_argument( - "--conditioning_image_column", - type=str, - default="conditioning_image", - help="The column of the dataset containing the control conditioning image.", - ) - parser.add_argument( - "--caption_column", - type=str, - default="text", - help="The column of the dataset containing a caption or a list of captions.", - ) - parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.") - parser.add_argument( - "--max_train_samples", - type=int, - default=None, - help=( - "For debugging purposes or quicker training, truncate the number of training examples to this " - "value if set." - ), - ) - parser.add_argument( - "--validation_prompt", - type=str, - default=None, - nargs="+", - help=( - "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." - " Provide either a matching number of `--validation_image`s, a single `--validation_image`" - " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." - ), - ) - parser.add_argument( - "--validation_image", - type=str, - default=None, - nargs="+", - help=( - "A set of paths to the control conditioning image be evaluated every `--validation_steps`" - " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" - " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" - " `--validation_image` that will be used with all `--validation_prompt`s." - ), - ) - parser.add_argument( - "--num_validation_images", - type=int, - default=1, - help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", - ) - parser.add_argument( - "--validation_steps", - type=int, - default=100, - help=( - "Run validation every X steps. Validation consists of running the prompt" - " `args.validation_prompt` multiple times: `args.num_validation_images`" - " and logging the images." - ), - ) - parser.add_argument( - "--tracker_project_name", - type=str, - default="cogview4_train_control", - help=( - "The `project_name` argument passed to Accelerator.init_trackers for" - " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" - ), - ) - parser.add_argument( - "--jsonl_for_train", - type=str, - default=None, - help="Path to the jsonl file containing the training data.", - ) - parser.add_argument( - "--only_target_transformer_blocks", - action="store_true", - help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).", - ) - parser.add_argument( - "--guidance_scale", - type=float, - default=3.5, - help="the guidance scale used for transformer.", - ) - - parser.add_argument( - "--upcast_before_saving", - action="store_true", - help=( - "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " - "Defaults to precision dtype used for training to save memory" - ), - ) - - parser.add_argument( - "--weighting_scheme", - type=str, - default="none", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), - ) - parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." - ) - parser.add_argument( - "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." - ) - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", - ) - parser.add_argument( - "--offload", - action="store_true", - help="Whether to offload the VAE and the text encoders to CPU when they are not used.", - ) - - if input_args is not None: - args = parser.parse_args(input_args) - else: - args = parser.parse_args() - - if args.dataset_name is None and args.jsonl_for_train is None: - raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`") - - if args.dataset_name is not None and args.jsonl_for_train is not None: - raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`") - - if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: - raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") - - if args.validation_prompt is not None and args.validation_image is None: - raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") - - if args.validation_prompt is None and args.validation_image is not None: - raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") - - if ( - args.validation_image is not None - and args.validation_prompt is not None - and len(args.validation_image) != 1 - and len(args.validation_prompt) != 1 - and len(args.validation_image) != len(args.validation_prompt) - ): - raise ValueError( - "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," - " or the same number of `--validation_prompt`s and `--validation_image`s" - ) - - if args.resolution % 8 != 0: - raise ValueError( - "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the cogview4 transformer." - ) - - return args - - -def get_train_dataset(args, accelerator): - dataset = None - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, - ) - if args.jsonl_for_train is not None: - # load from json - dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir) - dataset = dataset.flatten_indices() - # Preprocessing the datasets. - # We need to tokenize inputs and targets. - column_names = dataset["train"].column_names - - # 6. Get the column names for input/target. - if args.image_column is None: - image_column = column_names[0] - logger.info(f"image column defaulting to {image_column}") - else: - image_column = args.image_column - if image_column not in column_names: - raise ValueError( - f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" - ) - - if args.caption_column is None: - caption_column = column_names[1] - logger.info(f"caption column defaulting to {caption_column}") - else: - caption_column = args.caption_column - if caption_column not in column_names: - raise ValueError( - f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" - ) - - if args.conditioning_image_column is None: - conditioning_image_column = column_names[2] - logger.info(f"conditioning image column defaulting to {conditioning_image_column}") - else: - conditioning_image_column = args.conditioning_image_column - if conditioning_image_column not in column_names: - raise ValueError( - f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" - ) - - with accelerator.main_process_first(): - train_dataset = dataset["train"].shuffle(seed=args.seed) - if args.max_train_samples is not None: - train_dataset = train_dataset.select(range(args.max_train_samples)) - return train_dataset - - -def prepare_train_dataset(dataset, accelerator): - image_transforms = transforms.Compose( - [ - transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), - transforms.ToTensor(), - transforms.Lambda(lambda x: x * 2 - 1), - ] - ) - - def preprocess_train(examples): - images = [ - (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB")) - for image in examples[args.image_column] - ] - images = [image_transforms(image) for image in images] - - conditioning_images = [ - (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB")) - for image in examples[args.conditioning_image_column] - ] - conditioning_images = [image_transforms(image) for image in conditioning_images] - examples["pixel_values"] = images - examples["conditioning_pixel_values"] = conditioning_images - - is_caption_list = isinstance(examples[args.caption_column][0], list) - if is_caption_list: - examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]] - else: - examples["captions"] = list(examples[args.caption_column]) - - return examples - - with accelerator.main_process_first(): - dataset = dataset.with_transform(preprocess_train) - - return dataset - - -def collate_fn(examples): - pixel_values = torch.stack([example["pixel_values"] for example in examples]) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) - conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() - captions = [example["captions"] for example in examples] - return {"pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "captions": captions} - - -def main(args): - if args.report_to == "wandb" and args.hub_token is not None: - raise ValueError( - "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." - " Please use `huggingface-cli login` to authenticate with the Hub." - ) - - logging_out_dir = Path(args.output_dir, args.logging_dir) - - if torch.backends.mps.is_available() and args.mixed_precision == "bf16": - # due to pytorch#99272, MPS does not yet support bfloat16. - raise ValueError( - "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." - ) - - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir)) - - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - project_config=accelerator_project_config, - ) - - # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices. - if torch.backends.mps.is_available(): - logger.info("MPS is enabled. Disabling AMP.") - accelerator.native_amp = False - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - # DEBUG, INFO, WARNING, ERROR, CRITICAL - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token - ).repo_id - - # Load models. We will load the text encoders later in a pipeline to compute - # embeddings. - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - revision=args.revision, - variant=args.variant, - ) - cogview4_transformer = CogView4Transformer2DModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer", - revision=args.revision, - variant=args.variant, - ) - logger.info("All models loaded successfully") - - noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="scheduler", - ) - noise_scheduler_copy = copy.deepcopy(noise_scheduler) - if not args.only_target_transformer_blocks: - cogview4_transformer.requires_grad_(True) - vae.requires_grad_(False) - - # cast down and move to the CPU - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # let's not move the VAE to the GPU yet. - vae.to(dtype=torch.float32) # keep the VAE in float32. - - # enable image inputs - with torch.no_grad(): - patch_size = cogview4_transformer.config.patch_size - initial_input_channels = cogview4_transformer.config.in_channels * patch_size**2 - new_linear = torch.nn.Linear( - cogview4_transformer.patch_embed.proj.in_features * 2, - cogview4_transformer.patch_embed.proj.out_features, - bias=cogview4_transformer.patch_embed.proj.bias is not None, - dtype=cogview4_transformer.dtype, - device=cogview4_transformer.device, - ) - new_linear.weight.zero_() - new_linear.weight[:, :initial_input_channels].copy_(cogview4_transformer.patch_embed.proj.weight) - if cogview4_transformer.patch_embed.proj.bias is not None: - new_linear.bias.copy_(cogview4_transformer.patch_embed.proj.bias) - cogview4_transformer.patch_embed.proj = new_linear - - assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0) - cogview4_transformer.register_to_config( - in_channels=cogview4_transformer.config.in_channels * 2, out_channels=cogview4_transformer.config.in_channels - ) - - if args.only_target_transformer_blocks: - cogview4_transformer.patch_embed.proj.requires_grad_(True) - for name, module in cogview4_transformer.named_modules(): - if "transformer_blocks" in name: - module.requires_grad_(True) - else: - module.requirs_grad_(False) - - def unwrap_model(model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - - def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - for model in models: - if isinstance(unwrap_model(model), type(unwrap_model(cogview4_transformer))): - model = unwrap_model(model) - model.save_pretrained(os.path.join(output_dir, "transformer")) - else: - raise ValueError(f"unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() - - def load_model_hook(models, input_dir): - transformer_ = None - - if not accelerator.distributed_type == DistributedType.DEEPSPEED: - while len(models) > 0: - model = models.pop() - - if isinstance(unwrap_model(model), type(unwrap_model(cogview4_transformer))): - transformer_ = model # noqa: F841 - else: - raise ValueError(f"unexpected save model: {unwrap_model(model).__class__}") - - else: - transformer_ = CogView4Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") # noqa: F841 - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - - if args.gradient_checkpointing: - cogview4_transformer.enable_gradient_checkpointing() - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - ) - - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - # Optimization parameters - optimizer = optimizer_class( - cogview4_transformer.parameters(), - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - - # Prepare dataset and dataloader. - train_dataset = get_train_dataset(args, accelerator) - train_dataset = prepare_train_dataset(train_dataset, accelerator) - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - shuffle=True, - collate_fn=collate_fn, - batch_size=args.train_batch_size, - num_workers=args.dataloader_num_workers, - ) - - # Scheduler and math around the number of training steps. - # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. - if args.max_train_steps is None: - len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) - num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) - num_training_steps_for_scheduler = ( - args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes - ) - else: - num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, - num_cycles=args.lr_num_cycles, - power=args.lr_power, - ) - # Prepare everything with our `accelerator`. - cogview4_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - cogview4_transformer, optimizer, train_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: - logger.warning( - f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " - f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " - f"This inconsistency may result in the learning rate scheduler not functioning properly." - ) - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - tracker_config = dict(vars(args)) - - # tensorboard cannot handle list types for config - tracker_config.pop("validation_prompt") - tracker_config.pop("validation_image") - - accelerator.init_trackers(args.tracker_project_name, config=tracker_config) - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num batches each epoch = {len(train_dataloader)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed. - text_encoding_pipeline = CogView4ControlPipeline.from_pretrained( - args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype - ) - tokenizer = text_encoding_pipeline.tokenizer - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") - args.resume_from_checkpoint = None - initial_global_step = 0 - else: - logger.info(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - initial_global_step = global_step - first_epoch = global_step // num_update_steps_per_epoch - else: - initial_global_step = 0 - - if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples: - logger.info("Logging some dataset samples.") - formatted_images = [] - formatted_control_images = [] - all_prompts = [] - for i, batch in enumerate(train_dataloader): - images = (batch["pixel_values"] + 1) / 2 - control_images = (batch["conditioning_pixel_values"] + 1) / 2 - prompts = batch["captions"] - - if len(formatted_images) > 10: - break - - for img, control_img, prompt in zip(images, control_images, prompts): - formatted_images.append(img) - formatted_control_images.append(control_img) - all_prompts.append(prompt) - - logged_artifacts = [] - for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts): - logged_artifacts.append(wandb.Image(control_img, caption="Conditioning")) - logged_artifacts.append(wandb.Image(img, caption=prompt)) - - wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"] - wandb_tracker[0].log({"dataset_samples": logged_artifacts}) - - progress_bar = tqdm( - range(0, args.max_train_steps), - initial=initial_global_step, - desc="Steps", - # Only show the progress bar once on each machine. - disable=not accelerator.is_local_main_process, - ) - - for epoch in range(first_epoch, args.num_train_epochs): - cogview4_transformer.train() - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(cogview4_transformer): - # Convert images to latent space - # vae encode - prompts = batch["captions"] - attention_mask = tokenizer( - prompts, - padding="longest", # not use max length - max_length=args.max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ).attention_mask.float() - - pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype) - control_latents = encode_images( - batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype - ) - if args.offload: - vae.cpu() - - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - bsz = pixel_latents.shape[0] - noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype) - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - ) - - # Add noise according for cogview4 - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device) - sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device) - captions = batch["captions"] - image_seq_lens = torch.tensor( - pixel_latents.shape[2] * pixel_latents.shape[3] // patch_size**2, - dtype=pixel_latents.dtype, - device=pixel_latents.device, - ) # H * W / VAE patch_size - mu = torch.sqrt(image_seq_lens / 256) - mu = mu * 0.75 + 0.25 - scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to( - dtype=pixel_latents.dtype, device=pixel_latents.device - ) - scale_factors = scale_factors.view(len(batch["captions"]), 1, 1, 1) - noisy_model_input = (1.0 - scale_factors) * pixel_latents + scale_factors * noise - concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1) - text_encoding_pipeline = text_encoding_pipeline.to("cuda") - - with torch.no_grad(): - ( - prompt_embeds, - pooled_prompt_embeds, - ) = text_encoding_pipeline.encode_prompt(captions, "") - original_size = (args.resolution, args.resolution) - original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) - - target_size = (args.resolution, args.resolution) - target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device) - - target_size = target_size.repeat(len(batch["captions"]), 1) - original_size = original_size.repeat(len(batch["captions"]), 1) - crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device) - crops_coords_top_left = crops_coords_top_left.repeat(len(batch["captions"]), 1) - - # this could be optimized by not having to do any text encoding and just - # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds` - if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts: - # Here, we directly pass 16 pad tokens from pooled_prompt_embeds to prompt_embeds. - prompt_embeds = pooled_prompt_embeds - if args.offload: - text_encoding_pipeline = text_encoding_pipeline.to("cpu") - # Predict. - noise_pred_cond = cogview4_transformer( - hidden_states=concatenated_noisy_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timesteps, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - return_dict=False, - attention_mask=attention_mask, - )[0] - # these weighting schemes use a uniform timestep sampling - # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) - # flow-matching loss - target = noise - pixel_latents - - weighting = weighting.view(len(batch["captions"]), 1, 1, 1) - loss = torch.mean( - (weighting.float() * (noise_pred_cond.float() - target.float()) ** 2).reshape(target.shape[0], -1), - 1, - ) - loss = loss.mean() - accelerator.backward(loss) - - if accelerator.sync_gradients: - params_to_clip = cogview4_transformer.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. - if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - - if args.validation_prompt is not None and global_step % args.validation_steps == 0: - image_logs = log_validation( - cogview4_transformer=cogview4_transformer, - args=args, - accelerator=accelerator, - weight_dtype=weight_dtype, - step=global_step, - ) - - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - # Create the pipeline using using the trained modules and save it. - accelerator.wait_for_everyone() - if accelerator.is_main_process: - cogview4_transformer = unwrap_model(cogview4_transformer) - if args.upcast_before_saving: - cogview4_transformer.to(torch.float32) - cogview4_transformer.save_pretrained(args.output_dir) - - del cogview4_transformer - del text_encoding_pipeline - del vae - free_memory() - - # Run a final round of validation. - image_logs = None - if args.validation_prompt is not None: - image_logs = log_validation( - cogview4_transformer=None, - args=args, - accelerator=accelerator, - weight_dtype=weight_dtype, - step=global_step, - is_final_validation=True, - ) - - if args.push_to_hub: - save_model_card( - repo_id, - image_logs=image_logs, - base_model=args.pretrained_model_name_or_path, - repo_folder=args.output_dir, - ) - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*", "checkpoint-*"], - ) - - accelerator.end_training() - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/examples/community/mixture_tiling_sdxl.py b/examples/community/mixture_tiling_sdxl.py index bd56ddb3d61d..f7b971bae841 100644 --- a/examples/community/mixture_tiling_sdxl.py +++ b/examples/community/mixture_tiling_sdxl.py @@ -1,4 +1,4 @@ -# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1070,32 +1070,32 @@ def __call__( text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim - add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left[row][col], - target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = self._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left[row][col], - negative_target_size, + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left[row][col], + target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) - else: - negative_add_time_ids = add_time_ids + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left[row][col], + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids)) embeddings_and_added_time.append(addition_embed_type_row) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index aa235ad65bfe..65d6c14c5efc 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -152,7 +152,9 @@ def log_validation( validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images = [np.asarray(validation_image)] + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) for image in images: formatted_images.append(np.asarray(image)) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index a41615c7b546..7f93477fc5b7 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -166,7 +166,9 @@ def log_validation( validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images = [np.asarray(validation_image)] + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) for image in images: formatted_images.append(np.asarray(image)) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index ffe460d72de8..f4aadc2577f7 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -1283,8 +1283,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Get the text embedding for conditioning - prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype) - pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype) + prompt_embeds = batch["prompt_embeds"] + pooled_prompt_embeds = batch["pooled_prompt_embeds"] # controlnet(s) inference controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 17f313752989..b2d950e09ac1 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -157,7 +157,9 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images = [np.asarray(validation_image)] + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) for image in images: formatted_images.append(np.asarray(image)) diff --git a/examples/research_projects/anytext/README.md b/examples/research_projects/anytext/README.md deleted file mode 100644 index f5f4fe59ddfd..000000000000 --- a/examples/research_projects/anytext/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# AnyTextPipeline Pipeline - -Project page: https://aigcdesigngroup.github.io/homepage_anytext - -"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy." - -Each text line that needs to be generated should be enclosed in double quotes. For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054). - - -```py -import torch -from diffusers import DiffusionPipeline -from anytext_controlnet import AnyTextControlNetModel -from diffusers.utils import load_image - -# I chose a font file shared by an HF staff: -# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf - -anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, - variant="fp16",) -pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf", - controlnet=anytext_controlnet, torch_dtype=torch.float16, - trust_remote_code=False, # One needs to give permission to run this pipeline's code - ).to("cuda") - -# generate image -prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' -draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png") -image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos, - ).images[0] -image -``` diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py deleted file mode 100644 index 518452f97942..000000000000 --- a/examples/research_projects/anytext/anytext.py +++ /dev/null @@ -1,2360 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# Copyright (c) Alibaba, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054). -# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie -# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license -# -# Adapted to Diffusers by [M. Tolga Cangรถz](https://github.com/tolgacangoz). - - -import inspect -import math -import os -import re -import sys -import unicodedata -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import cv2 -import numpy as np -import PIL.Image -import torch -import torch.nn.functional as F -from huggingface_hub import hf_hub_download -from ocr_recog.RecModel import RecModel -from PIL import Image, ImageDraw, ImageFont -from safetensors.torch import load_file -from skimage.transform._geometric import _umeyama as get_sym_mat -from torch import nn -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask - -from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import ( - FromSingleFileMixin, - IPAdapterMixin, - StableDiffusionLoraLoaderMixin, - TextualInversionLoaderMixin, -) -from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.models.modeling_utils import ModelMixin -from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel -from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( - USE_PEFT_BACKEND, - deprecate, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.utils.constants import HF_MODULES_CACHE -from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor - - -class Checker: - def __init__(self): - pass - - def _is_chinese_char(self, cp): - """Checks whether CP is the codepoint of a CJK character.""" - # This defines a "chinese character" as anything in the CJK Unicode block: - # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) - # - # Note that the CJK Unicode block is NOT all Japanese and Korean characters, - # despite its name. The modern Korean Hangul alphabet is a different block, - # as is Japanese Hiragana and Katakana. Those alphabets are used to write - # space-separated words, so they are not treated specially and handled - # like the all of the other languages. - if ( - (cp >= 0x4E00 and cp <= 0x9FFF) - or (cp >= 0x3400 and cp <= 0x4DBF) - or (cp >= 0x20000 and cp <= 0x2A6DF) - or (cp >= 0x2A700 and cp <= 0x2B73F) - or (cp >= 0x2B740 and cp <= 0x2B81F) - or (cp >= 0x2B820 and cp <= 0x2CEAF) - or (cp >= 0xF900 and cp <= 0xFAFF) - or (cp >= 0x2F800 and cp <= 0x2FA1F) - ): - return True - - return False - - def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" - output = [] - for char in text: - cp = ord(char) - if cp == 0 or cp == 0xFFFD or self._is_control(char): - continue - if self._is_whitespace(char): - output.append(" ") - else: - output.append(char) - return "".join(output) - - def _is_control(self, char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat in ("Cc", "Cf"): - return True - return False - - def _is_whitespace(self, char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically control characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -checker = Checker() - - -PLACE_HOLDER = "*" -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import DiffusionPipeline - >>> from anytext_controlnet import AnyTextControlNetModel - >>> from diffusers.utils import load_image - - >>> # I chose a font file shared by an HF staff: - >>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf - - >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, - ... variant="fp16",) - >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf", - ... controlnet=anytext_controlnet, torch_dtype=torch.float16, - ... trust_remote_code=False, # One needs to give permission to run this pipeline's code - ... ).to("cuda") - - - >>> # generate image - >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream' - >>> draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png") - >>> image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos, - ... ).images[0] - >>> image - ``` -""" - - -def get_clip_token_for_string(tokenizer, string): - batch_encoding = tokenizer( - string, - truncation=True, - max_length=77, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"] - assert ( - torch.count_nonzero(tokens - 49407) == 2 - ), f"String '{string}' maps to more than a single token. Please use another string" - return tokens[0, 1] - - -def get_recog_emb(encoder, img_list): - _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list] - encoder.predictor.eval() - _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False) - return preds_neck - - -class EmbeddingManager(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - embedder, - placeholder_string="*", - use_fp16=False, - token_dim=768, - get_recog_emb=None, - ): - super().__init__() - get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) - - self.proj = nn.Linear(40 * 64, token_dim) - proj_dir = hf_hub_download( - repo_id="tolgacangoz/anytext", - filename="text_embedding_module/proj.safetensors", - cache_dir=HF_MODULES_CACHE, - ) - self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device))) - if use_fp16: - self.proj = self.proj.to(dtype=torch.float16) - - self.placeholder_token = get_token_for_string(placeholder_string) - - @torch.no_grad() - def encode_text(self, text_info): - if self.config.get_recog_emb is None: - self.config.get_recog_emb = partial(get_recog_emb, self.recog) - - gline_list = [] - for i in range(len(text_info["n_lines"])): # sample index in a batch - n_lines = text_info["n_lines"][i] - for j in range(n_lines): # line - gline_list += [text_info["gly_line"][j][i : i + 1]] - - if len(gline_list) > 0: - recog_emb = self.config.get_recog_emb(gline_list) - enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype)) - - self.text_embs_all = [] - n_idx = 0 - for i in range(len(text_info["n_lines"])): # sample index in a batch - n_lines = text_info["n_lines"][i] - text_embs = [] - for j in range(n_lines): # line - text_embs += [enc_glyph[n_idx : n_idx + 1]] - n_idx += 1 - self.text_embs_all += [text_embs] - - @torch.no_grad() - def forward( - self, - tokenized_text, - embedded_text, - ): - b, device = tokenized_text.shape[0], tokenized_text.device - for i in range(b): - idx = tokenized_text[i] == self.placeholder_token.to(device) - if sum(idx) > 0: - if i >= len(self.text_embs_all): - print("truncation for log images...") - break - text_emb = torch.cat(self.text_embs_all[i], dim=0) - if sum(idx) != len(text_emb): - print("truncation for long caption...") - text_emb = text_emb.to(embedded_text.device) - embedded_text[i][idx] = text_emb[: sum(idx)] - return embedded_text - - def embedding_parameters(self): - return self.parameters() - - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - - -def min_bounding_rect(img): - ret, thresh = cv2.threshold(img, 127, 255, 0) - contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - if len(contours) == 0: - print("Bad contours, using fake bbox...") - return np.array([[0, 0], [100, 0], [100, 100], [0, 100]]) - max_contour = max(contours, key=cv2.contourArea) - rect = cv2.minAreaRect(max_contour) - box = cv2.boxPoints(rect) - box = np.int0(box) - # sort - x_sorted = sorted(box, key=lambda x: x[0]) - left = x_sorted[:2] - right = x_sorted[2:] - left = sorted(left, key=lambda x: x[1]) - (tl, bl) = left - right = sorted(right, key=lambda x: x[1]) - (tr, br) = right - if tl[1] > bl[1]: - (tl, bl) = (bl, tl) - if tr[1] > br[1]: - (tr, br) = (br, tr) - return np.array([tl, tr, br, bl]) - - -def adjust_image(box, img): - pts1 = np.float32([box[0], box[1], box[2], box[3]]) - width = max(np.linalg.norm(pts1[0] - pts1[1]), np.linalg.norm(pts1[2] - pts1[3])) - height = max(np.linalg.norm(pts1[0] - pts1[3]), np.linalg.norm(pts1[1] - pts1[2])) - pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]]) - # get transform matrix - M = get_sym_mat(pts1, pts2, estimate_scale=True) - C, H, W = img.shape - T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]]) - theta = np.linalg.inv(T @ M @ np.linalg.inv(T)) - theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device) - grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True) - result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True) - result = torch.clamp(result.squeeze(0), 0, 255) - # crop - result = result[:, : int(height), : int(width)] - return result - - -def crop_image(src_img, mask): - box = min_bounding_rect(mask) - result = adjust_image(box, src_img) - if len(result.shape) == 2: - result = torch.stack([result] * 3, axis=-1) - return result - - -def create_predictor(model_lang="ch", device="cpu", use_fp16=False): - model_dir = hf_hub_download( - repo_id="tolgacangoz/anytext", - filename="text_embedding_module/OCR/ppv3_rec.pth", - cache_dir=HF_MODULES_CACHE, - ) - if not os.path.exists(model_dir): - raise ValueError("not find model file path {}".format(model_dir)) - - if model_lang == "ch": - n_class = 6625 - elif model_lang == "en": - n_class = 97 - else: - raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}") - rec_config = { - "in_channels": 3, - "backbone": {"type": "MobileNetV1Enhance", "scale": 0.5, "last_conv_stride": [1, 2], "last_pool_type": "avg"}, - "neck": { - "type": "SequenceEncoder", - "encoder_type": "svtr", - "dims": 64, - "depth": 2, - "hidden_dims": 120, - "use_guide": True, - }, - "head": {"type": "CTCHead", "fc_decay": 0.00001, "out_channels": n_class, "return_feats": True}, - } - - rec_model = RecModel(rec_config) - state_dict = torch.load(model_dir, map_location=device) - rec_model.load_state_dict(state_dict) - return rec_model - - -def _check_image_file(path): - img_end = ("tiff", "tif", "bmp", "rgb", "jpg", "png", "jpeg") - return path.lower().endswith(tuple(img_end)) - - -def get_image_file_list(img_file): - imgs_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - if os.path.isfile(img_file) and _check_image_file(img_file): - imgs_lists.append(img_file) - elif os.path.isdir(img_file): - for single_file in os.listdir(img_file): - file_path = os.path.join(img_file, single_file) - if os.path.isfile(file_path) and _check_image_file(file_path): - imgs_lists.append(file_path) - if len(imgs_lists) == 0: - raise Exception("not found any img file in {}".format(img_file)) - imgs_lists = sorted(imgs_lists) - return imgs_lists - - -class TextRecognizer(object): - def __init__(self, args, predictor): - self.rec_image_shape = [int(v) for v in args["rec_image_shape"].split(",")] - self.rec_batch_num = args["rec_batch_num"] - self.predictor = predictor - self.chars = self.get_char_dict(args["rec_char_dict_path"]) - self.char2id = {x: i for i, x in enumerate(self.chars)} - self.is_onnx = not isinstance(self.predictor, torch.nn.Module) - self.use_fp16 = args["use_fp16"] - - # img: CHW - def resize_norm_img(self, img, max_wh_ratio): - imgC, imgH, imgW = self.rec_image_shape - assert imgC == img.shape[0] - imgW = int((imgH * max_wh_ratio)) - - h, w = img.shape[1:] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: - resized_w = imgW - else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = torch.nn.functional.interpolate( - img.unsqueeze(0), - size=(imgH, resized_w), - mode="bilinear", - align_corners=True, - ) - resized_image /= 255.0 - resized_image -= 0.5 - resized_image /= 0.5 - padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device) - padding_im[:, :, 0:resized_w] = resized_image[0] - return padding_im - - # img_list: list of tensors with shape chw 0-255 - def pred_imglist(self, img_list, show_debug=False): - img_num = len(img_list) - assert img_num > 0 - # Calculate the aspect ratio of all text bars - width_list = [] - for img in img_list: - width_list.append(img.shape[2] / float(img.shape[1])) - # Sorting can speed up the recognition process - indices = torch.from_numpy(np.argsort(np.array(width_list))) - batch_num = self.rec_batch_num - preds_all = [None] * img_num - preds_neck_all = [None] * img_num - for beg_img_no in range(0, img_num, batch_num): - end_img_no = min(img_num, beg_img_no + batch_num) - norm_img_batch = [] - - imgC, imgH, imgW = self.rec_image_shape[:3] - max_wh_ratio = imgW / imgH - for ino in range(beg_img_no, end_img_no): - h, w = img_list[indices[ino]].shape[1:] - if h > w * 1.2: - img = img_list[indices[ino]] - img = torch.transpose(img, 1, 2).flip(dims=[1]) - img_list[indices[ino]] = img - h, w = img.shape[1:] - # wh_ratio = w * 1.0 / h - # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio - for ino in range(beg_img_no, end_img_no): - norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) - if self.use_fp16: - norm_img = norm_img.half() - norm_img = norm_img.unsqueeze(0) - norm_img_batch.append(norm_img) - norm_img_batch = torch.cat(norm_img_batch, dim=0) - if show_debug: - for i in range(len(norm_img_batch)): - _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy() - _img = (_img + 0.5) * 255 - _img = _img[:, :, ::-1] - file_name = f"{indices[beg_img_no + i]}" - if os.path.exists(file_name + ".jpg"): - file_name += "_2" # ori image - cv2.imwrite(file_name + ".jpg", _img) - if self.is_onnx: - input_dict = {} - input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy() - outputs = self.predictor.run(None, input_dict) - preds = {} - preds["ctc"] = torch.from_numpy(outputs[0]) - preds["ctc_neck"] = [torch.zeros(1)] * img_num - else: - preds = self.predictor(norm_img_batch.to(next(self.predictor.parameters()).device)) - for rno in range(preds["ctc"].shape[0]): - preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno] - preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno] - - return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0) - - def get_char_dict(self, character_dict_path): - character_str = [] - with open(character_dict_path, "rb") as fin: - lines = fin.readlines() - for line in lines: - line = line.decode("utf-8").strip("\n").strip("\r\n") - character_str.append(line) - dict_character = list(character_str) - dict_character = ["sos"] + dict_character + [" "] # eos is space - return dict_character - - def get_text(self, order): - char_list = [self.chars[text_id] for text_id in order] - return "".join(char_list) - - def decode(self, mat): - text_index = mat.detach().cpu().numpy().argmax(axis=1) - ignored_tokens = [0] - selection = np.ones(len(text_index), dtype=bool) - selection[1:] = text_index[1:] != text_index[:-1] - for ignored_token in ignored_tokens: - selection &= text_index != ignored_token - return text_index[selection], np.where(selection)[0] - - def get_ctcloss(self, preds, gt_text, weight): - if not isinstance(weight, torch.Tensor): - weight = torch.tensor(weight).to(preds.device) - ctc_loss = torch.nn.CTCLoss(reduction="none") - log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC - targets = [] - target_lengths = [] - for t in gt_text: - targets += [self.char2id.get(i, len(self.chars) - 1) for i in t] - target_lengths += [len(t)] - targets = torch.tensor(targets).to(preds.device) - target_lengths = torch.tensor(target_lengths).to(preds.device) - input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(preds.device) - loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) - loss = loss / input_lengths * weight - return loss - - -class AbstractEncoder(nn.Module): - def __init__(self): - super().__init__() - - def encode(self, *args, **kwargs): - raise NotImplementedError - - -class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin): - """Uses the CLIP transformer encoder for text (from Hugging Face)""" - - @register_to_config - def __init__( - self, - device="cpu", - max_length=77, - freeze=True, - use_fp16=False, - variant: Optional[str] = None, - ): - super().__init__() - self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer") - self.transformer = CLIPTextModel.from_pretrained( - "tolgacangoz/anytext", - subfolder="text_encoder", - torch_dtype=torch.float16 if use_fp16 else torch.float32, - variant="fp16" if use_fp16 else None, - ) - - if freeze: - self.freeze() - - def embedding_forward( - self, - input_ids=None, - position_ids=None, - inputs_embeds=None, - embedding_manager=None, - ): - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - if inputs_embeds is None: - inputs_embeds = self.token_embedding(input_ids) - if embedding_manager is not None: - inputs_embeds = embedding_manager(input_ids, inputs_embeds) - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - return embeddings - - self.transformer.text_model.embeddings.forward = embedding_forward.__get__( - self.transformer.text_model.embeddings - ) - - def encoder_forward( - self, - inputs_embeds, - attention_mask=None, - causal_attention_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - hidden_states = inputs_embeds - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - return hidden_states - - self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) - - def text_encoder_forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - embedding_manager=None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is None: - raise ValueError("You have to specify either input_ids") - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - hidden_states = self.embeddings( - input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager - ) - # CLIP's text model uses causal mask, prepare it here. - # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _create_4d_causal_attention_mask( - input_shape, hidden_states.dtype, device=hidden_states.device - ) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - last_hidden_state = self.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - last_hidden_state = self.final_layer_norm(last_hidden_state) - return last_hidden_state - - self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) - - def transformer_forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - embedding_manager=None, - ): - return self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - embedding_manager=embedding_manager, - ) - - self.transformer.forward = transformer_forward.__get__(self.transformer) - - def freeze(self): - self.transformer = self.transformer.eval() - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text, **kwargs): - batch_encoding = self.tokenizer( - text, - truncation=False, - max_length=self.config.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="longest", - return_tensors="pt", - ) - input_ids = batch_encoding["input_ids"] - tokens_list = self.split_chunks(input_ids) - z_list = [] - for tokens in tokens_list: - tokens = tokens.to(self.device) - _z = self.transformer(input_ids=tokens, **kwargs) - z_list += [_z] - return torch.cat(z_list, dim=1) - - def encode(self, text, **kwargs): - return self(text, **kwargs) - - def split_chunks(self, input_ids, chunk_size=75): - tokens_list = [] - bs, n = input_ids.shape - id_start = input_ids[:, 0].unsqueeze(1) # dim --> [bs, 1] - id_end = input_ids[:, -1].unsqueeze(1) - if n == 2: # empty caption - tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1)) - - trimmed_encoding = input_ids[:, 1:-1] - num_full_groups = (n - 2) // chunk_size - - for i in range(num_full_groups): - group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size] - group_pad = torch.cat((id_start, group, id_end), dim=1) - tokens_list.append(group_pad) - - remaining_columns = (n - 2) % chunk_size - if remaining_columns > 0: - remaining_group = trimmed_encoding[:, -remaining_columns:] - padding_columns = chunk_size - remaining_group.shape[1] - padding = id_end.expand(bs, padding_columns) - remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1) - tokens_list.append(remaining_group_pad) - return tokens_list - - -class TextEmbeddingModule(ModelMixin, ConfigMixin): - @register_to_config - def __init__(self, font_path, use_fp16=False, device="cpu"): - super().__init__() - font = ImageFont.truetype(font_path, 60) - - self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16) - self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16) - self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval() - args = { - "rec_image_shape": "3, 48, 320", - "rec_batch_num": 6, - "rec_char_dict_path": hf_hub_download( - repo_id="tolgacangoz/anytext", - filename="text_embedding_module/OCR/ppocr_keys_v1.txt", - cache_dir=HF_MODULES_CACHE, - ), - "use_fp16": use_fp16, - } - self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) - - self.register_to_config(font=font) - - @torch.no_grad() - def forward( - self, - prompt, - texts, - negative_prompt, - num_images_per_prompt, - mode, - draw_pos, - sort_priority="โ†•", - max_chars=77, - revise_pos=False, - h=512, - w=512, - ): - if prompt is None and texts is None: - raise ValueError("Prompt or texts must be provided!") - # preprocess pos_imgs(if numpy, make sure it's white pos in black bg) - if draw_pos is None: - pos_imgs = np.zeros((w, h, 1)) - if isinstance(draw_pos, PIL.Image.Image): - pos_imgs = np.array(draw_pos)[..., ::-1] - pos_imgs = 255 - pos_imgs - elif isinstance(draw_pos, str): - draw_pos = cv2.imread(draw_pos)[..., ::-1] - if draw_pos is None: - raise ValueError(f"Can't read draw_pos image from {draw_pos}!") - pos_imgs = 255 - draw_pos - elif isinstance(draw_pos, torch.Tensor): - pos_imgs = draw_pos.cpu().numpy() - else: - if not isinstance(draw_pos, np.ndarray): - raise ValueError(f"Unknown format of draw_pos: {type(draw_pos)}") - if mode == "edit": - pos_imgs = cv2.resize(pos_imgs, (w, h)) - pos_imgs = pos_imgs[..., 0:1] - pos_imgs = cv2.convertScaleAbs(pos_imgs) - _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) - # separate pos_imgs - pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) - if len(pos_imgs) == 0: - pos_imgs = [np.zeros((h, w, 1))] - n_lines = len(texts) - if len(pos_imgs) < n_lines: - if n_lines == 1 and texts[0] == " ": - pass # text-to-image without text - else: - raise ValueError( - f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!" - ) - elif len(pos_imgs) > n_lines: - str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." - logger.warning(str_warning) - # get pre_pos, poly_list, hint that needed for anytext - pre_pos = [] - poly_list = [] - for input_pos in pos_imgs: - if input_pos.mean() != 0: - input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos - poly, pos_img = self.find_polygon(input_pos) - pre_pos += [pos_img / 255.0] - poly_list += [poly] - else: - pre_pos += [np.zeros((h, w, 1))] - poly_list += [None] - np_hint = np.sum(pre_pos, axis=0).clip(0, 1) - # prepare info dict - text_info = {} - text_info["glyphs"] = [] - text_info["gly_line"] = [] - text_info["positions"] = [] - text_info["n_lines"] = [len(texts)] * num_images_per_prompt - for i in range(len(texts)): - text = texts[i] - if len(text) > max_chars: - str_warning = f'"{text}" length > max_chars: {max_chars}, will be cut off...' - logger.warning(str_warning) - text = text[:max_chars] - gly_scale = 2 - if pre_pos[i].mean() != 0: - gly_line = self.draw_glyph(self.config.font, text) - glyphs = self.draw_glyph2( - self.config.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False - ) - if revise_pos: - resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0])) - new_pos = cv2.morphologyEx( - (resize_gly * 255).astype(np.uint8), - cv2.MORPH_CLOSE, - kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8), - iterations=1, - ) - new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos - contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - if len(contours) != 1: - str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." - logger.warning(str_warning) - else: - rect = cv2.minAreaRect(contours[0]) - poly = np.int0(cv2.boxPoints(rect)) - pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 - else: - glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) - gly_line = np.zeros((80, 512, 1)) - pos = pre_pos[i] - text_info["glyphs"] += [self.arr2tensor(glyphs, num_images_per_prompt)] - text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)] - text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)] - - self.embedding_manager.encode_text(text_info) - prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager) - - self.embedding_manager.encode_text(text_info) - negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode( - [negative_prompt or ""], embedding_manager=self.embedding_manager - ) - - return prompt_embeds, negative_prompt_embeds, text_info, np_hint - - def arr2tensor(self, arr, bs): - arr = np.transpose(arr, (2, 0, 1)) - _arr = torch.from_numpy(arr.copy()).float().cpu() - if self.config.use_fp16: - _arr = _arr.half() - _arr = torch.stack([_arr for _ in range(bs)], dim=0) - return _arr - - def separate_pos_imgs(self, img, sort_priority, gap=102): - num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) - components = [] - for label in range(1, num_labels): - component = np.zeros_like(img) - component[labels == label] = 255 - components.append((component, centroids[label])) - if sort_priority == "โ†•": - fir, sec = 1, 0 # top-down first - elif sort_priority == "โ†”": - fir, sec = 0, 1 # left-right first - else: - raise ValueError(f"Unknown sort_priority: {sort_priority}") - components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) - sorted_components = [c[0] for c in components] - return sorted_components - - def find_polygon(self, image, min_rect=False): - contours, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - max_contour = max(contours, key=cv2.contourArea) # get contour with max area - if min_rect: - # get minimum enclosing rectangle - rect = cv2.minAreaRect(max_contour) - poly = np.int0(cv2.boxPoints(rect)) - else: - # get approximate polygon - epsilon = 0.01 * cv2.arcLength(max_contour, True) - poly = cv2.approxPolyDP(max_contour, epsilon, True) - n, _, xy = poly.shape - poly = poly.reshape(n, xy) - cv2.drawContours(image, [poly], -1, 255, -1) - return poly, image - - def draw_glyph(self, font, text): - g_size = 50 - W, H = (512, 80) - new_font = font.font_variant(size=g_size) - img = Image.new(mode="1", size=(W, H), color=0) - draw = ImageDraw.Draw(img) - left, top, right, bottom = new_font.getbbox(text) - text_width = max(right - left, 5) - text_height = max(bottom - top, 5) - ratio = min(W * 0.9 / text_width, H * 0.9 / text_height) - new_font = font.font_variant(size=int(g_size * ratio)) - - left, top, right, bottom = new_font.getbbox(text) - text_width = right - left - text_height = bottom - top - x = (img.width - text_width) // 2 - y = (img.height - text_height) // 2 - top // 2 - draw.text((x, y), text, font=new_font, fill="white") - img = np.expand_dims(np.array(img), axis=2).astype(np.float64) - return img - - def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True): - enlarge_polygon = polygon * scale - rect = cv2.minAreaRect(enlarge_polygon) - box = cv2.boxPoints(rect) - box = np.int0(box) - w, h = rect[1] - angle = rect[2] - if angle < -45: - angle += 90 - angle = -angle - if w < h: - angle += 90 - - vert = False - if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng: - _w = max(box[:, 0]) - min(box[:, 0]) - _h = max(box[:, 1]) - min(box[:, 1]) - if _h >= _w: - vert = True - angle = 0 - - img = np.zeros((height * scale, width * scale, 3), np.uint8) - img = Image.fromarray(img) - - # infer font size - image4ratio = Image.new("RGB", img.size, "white") - draw = ImageDraw.Draw(image4ratio) - _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font) - text_w = min(w, h) * (_tw / _th) - if text_w <= max(w, h): - # add space - if len(text) > 1 and not vert and add_space: - for i in range(1, 100): - text_space = self.insert_spaces(text, i) - _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font) - if min(w, h) * (_tw2 / _th2) > max(w, h): - break - text = self.insert_spaces(text, i - 1) - font_size = min(w, h) * 0.80 - else: - shrink = 0.75 if vert else 0.85 - font_size = min(w, h) / (text_w / max(w, h)) * shrink - new_font = font.font_variant(size=int(font_size)) - - left, top, right, bottom = new_font.getbbox(text) - text_width = right - left - text_height = bottom - top - - layer = Image.new("RGBA", img.size, (0, 0, 0, 0)) - draw = ImageDraw.Draw(layer) - if not vert: - draw.text( - (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top), - text, - font=new_font, - fill=(255, 255, 255, 255), - ) - else: - x_s = min(box[:, 0]) + _w // 2 - text_height // 2 - y_s = min(box[:, 1]) - for c in text: - draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255)) - _, _t, _, _b = new_font.getbbox(c) - y_s += _b - - rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1])) - - x_offset = int((img.width - rotated_layer.width) / 2) - y_offset = int((img.height - rotated_layer.height) / 2) - img.paste(rotated_layer, (x_offset, y_offset), rotated_layer) - img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64) - return img - - def insert_spaces(self, string, nSpace): - if nSpace == 0: - return string - new_string = "" - for char in string: - new_string += char + " " * nSpace - return new_string[:-nSpace] - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -class AuxiliaryLatentModule(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - vae, - device="cpu", - ): - super().__init__() - - @torch.no_grad() - def forward( - self, - text_info, - mode, - draw_pos, - ori_image, - num_images_per_prompt, - np_hint, - h=512, - w=512, - ): - if mode == "generate": - edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image - elif mode == "edit": - if draw_pos is None or ori_image is None: - raise ValueError("Reference image and position image are needed for text editing!") - if isinstance(ori_image, str): - ori_image = cv2.imread(ori_image)[..., ::-1] - if ori_image is None: - raise ValueError(f"Can't read ori_image image from {ori_image}!") - elif isinstance(ori_image, torch.Tensor): - ori_image = ori_image.cpu().numpy() - else: - if not isinstance(ori_image, np.ndarray): - raise ValueError(f"Unknown format of ori_image: {type(ori_image)}") - edit_image = ori_image.clip(1, 255) # for mask reason - edit_image = self.check_channels(edit_image) - edit_image = self.resize_image( - edit_image, max_length=768 - ) # make w h multiple of 64, resize if w or h > max_length - - # get masked_x - masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) - masked_img = np.transpose(masked_img, (2, 0, 1)) - device = next(self.config.vae.parameters()).device - dtype = next(self.config.vae.parameters()).dtype - masked_img = torch.from_numpy(masked_img.copy()).float().to(device) - if dtype == torch.float16: - masked_img = masked_img.half() - masked_x = ( - retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor - ).detach() - if dtype == torch.float16: - masked_x = masked_x.half() - text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0) - - glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) - positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) - - return glyphs, positions, text_info - - def check_channels(self, image): - channels = image.shape[2] if len(image.shape) == 3 else 1 - if channels == 1: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) - elif channels > 3: - image = image[:, :, :3] - return image - - def resize_image(self, img, max_length=768): - height, width = img.shape[:2] - max_dimension = max(height, width) - - if max_dimension > max_length: - scale_factor = max_length / max_dimension - new_width = int(round(width * scale_factor)) - new_height = int(round(height * scale_factor)) - new_size = (new_width, new_height) - img = cv2.resize(img, new_size) - height, width = img.shape[:2] - img = cv2.resize(img, (width - (width % 64), height - (height % 64))) - return img - - def insert_spaces(self, string, nSpace): - if nSpace == 0: - return string - new_string = "" - for char in string: - new_string += char + " " * nSpace - return new_string[:-nSpace] - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class AnyTextPipeline( - DiffusionPipeline, - StableDiffusionMixin, - TextualInversionLoaderMixin, - StableDiffusionLoraLoaderMixin, - IPAdapterMixin, - FromSingleFileMixin, -): - r""" - Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). - - The pipeline also inherits the following loading methods: - - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. - unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): - Provides additional conditioning to the `unet` during the denoising process. If you set multiple - ControlNets as a list, the outputs from each ControlNet are added together to create one combined - additional conditioning. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details - about a model's potential harms. - feature_extractor ([`~transformers.CLIPImageProcessor`]): - A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. - """ - - model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] - _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - font_path: str = None, - text_embedding_module: Optional[TextEmbeddingModule] = None, - auxiliary_latent_module: Optional[AuxiliaryLatentModule] = None, - trust_remote_code: bool = False, - image_encoder: CLIPVisionModelWithProjection = None, - requires_safety_checker: bool = True, - ): - super().__init__() - if font_path is None: - raise ValueError("font_path is required!") - - text_embedding_module = TextEmbeddingModule(font_path=font_path, use_fp16=unet.dtype == torch.float16) - auxiliary_latent_module = AuxiliaryLatentModule(vae=vae) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - if isinstance(controlnet, (list, tuple)): - controlnet = MultiControlNetModel(controlnet) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - image_encoder=image_encoder, - text_embedding_module=text_embedding_module, - auxiliary_latent_module=auxiliary_latent_module, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def modify_prompt(self, prompt): - prompt = prompt.replace("โ€œ", '"') - prompt = prompt.replace("โ€", '"') - p = '"(.*?)"' - strs = re.findall(p, prompt) - if len(strs) == 0: - strs = [" "] - else: - for s in strs: - prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1) - if self.is_chinese(prompt): - if self.trans_pipe is None: - return None, None - old_prompt = prompt - prompt = self.trans_pipe(input=prompt + " .")["translation"][:-1] - print(f"Translate: {old_prompt} --> {prompt}") - return prompt, strs - - def is_chinese(self, text): - text = checker._clean_text(text) - for char in text: - cp = ord(char) - if checker._is_chinese_char(cp): - return True - return False - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - **kwargs, - ): - deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." - deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) - - prompt_embeds_tuple = self.encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=lora_scale, - **kwargs, - ) - - # concatenate for backwards comp - prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if self.text_encoder is not None: - if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance - ): - image_embeds = [] - if do_classifier_free_guidance: - negative_image_embeds = [] - if ip_adapter_image_embeds is None: - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) - - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - single_ip_adapter_image, device, 1, output_hidden_state - ) - - image_embeds.append(single_image_embeds[None, :]) - if do_classifier_free_guidance: - negative_image_embeds.append(single_negative_image_embeds[None, :]) - else: - for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - negative_image_embeds.append(single_negative_image_embeds) - image_embeds.append(single_image_embeds) - - ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - single_image_embeds = single_image_embeds.to(device=device) - ip_adapter_image_embeds.append(single_image_embeds) - - return ip_adapter_image_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" - deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) - - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - # image, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ip_adapter_image=None, - ip_adapter_image_embeds=None, - controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, - callback_on_step_end_tensor_inputs=None, - ): - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - - # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): - if not isinstance(controlnet_conditioning_scale, float): - print(controlnet_conditioning_scale) - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): - if isinstance(controlnet_conditioning_scale, list): - if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError( - "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " - "The conditioning scale must be fixed across the batch." - ) - elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( - self.controlnet.nets - ): - raise ValueError( - "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" - " the same length as the number of controlnets" - ) - else: - assert False - - if not isinstance(control_guidance_start, (tuple, list)): - control_guidance_start = [control_guidance_start] - - if not isinstance(control_guidance_end, (tuple, list)): - control_guidance_end = [control_guidance_end] - - if len(control_guidance_start) != len(control_guidance_end): - raise ValueError( - f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." - ) - - if isinstance(self.controlnet, MultiControlNetModel): - if len(control_guidance_start) != len(self.controlnet.nets): - raise ValueError( - f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." - ) - - for start, end in zip(control_guidance_start, control_guidance_end): - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - - if ip_adapter_image is not None and ip_adapter_image_embeds is not None: - raise ValueError( - "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." - ) - - if ip_adapter_image_embeds is not None: - if not isinstance(ip_adapter_image_embeds, list): - raise ValueError( - f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" - ) - elif ip_adapter_image_embeds[0].ndim not in [3, 4]: - raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" - ) - - def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): - raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" - ) - - if image_is_pil: - image_batch_size = 1 - else: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding( - self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - w (`torch.Tensor`): - Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. - embedding_dim (`int`, *optional*, defaults to 512): - Dimension of the embeddings to generate. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - Data type of the generated embeddings. - - Returns: - `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def clip_skip(self): - return self._clip_skip - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None - - @property - def cross_attention_kwargs(self): - return self._cross_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - mode: Optional[str] = "generate", - draw_pos: Optional[Union[str, torch.Tensor]] = None, - ori_image: Optional[Union[str, torch.Tensor]] = None, - timesteps: List[int] = None, - sigmas: List[float] = None, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - ip_adapter_image: Optional[PipelineImageInput] = None, - ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - guess_mode: bool = False, - control_guidance_start: Union[float, List[float]] = 0.0, - control_guidance_end: Union[float, List[float]] = 1.0, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - **kwargs, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: - `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): - The ControlNet input condition to provide guidance to the `unet` for generation. If the type is - specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted - as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or - width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, - images must be passed as a list such that each element of the list can be correctly batched for input - to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single - ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple - ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - sigmas (`List[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - guidance_scale (`float`, *optional*, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (ฮท) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. - ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of - IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should - contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not - provided, embeddings are computed from the `ip_adapter_image` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set - the corresponding scale as a list. - guess_mode (`bool`, *optional*, defaults to `False`): - The ControlNet encoder tries to recognize the content of the input image even if you remove all - prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. - control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): - The percentage of total steps at which the ControlNet starts applying. - control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): - The percentage of total steps at which the ControlNet stops applying. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): - A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of - each denoising step during the inference. with the following arguments: `callback_on_step_end(self: - DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a - list of all tensors as specified by `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - - callback = kwargs.pop("callback", None) - callback_steps = kwargs.pop("callback_steps", None) - - if callback is not None: - deprecate( - "callback", - "1.0.0", - "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - if callback_steps is not None: - deprecate( - "callback_steps", - "1.0.0", - "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = ( - mult * [control_guidance_start], - mult * [control_guidance_end], - ) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - # image, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - ) - - self._guidance_scale = guidance_scale - self._clip_skip = clip_skip - self._cross_attention_kwargs = cross_attention_kwargs - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - - global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - guess_mode = guess_mode or global_pool_conditions - - prompt, texts = self.modify_prompt(prompt) - - # 3. Encode input prompt - text_encoder_lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) - draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos - prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module( - prompt, - texts, - negative_prompt, - num_images_per_prompt, - mode, - draw_pos, - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, - ip_adapter_image_embeds, - device, - batch_size * num_images_per_prompt, - self.do_classifier_free_guidance, - ) - - # 3.5 Optionally get Guidance Scale Embedding - timestep_cond = None - if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) - - # 4. Prepare image - if isinstance(controlnet, ControlNetModel): - guided_hint = self.auxiliary_latent_module( - text_info=text_info, - mode=mode, - draw_pos=draw_pos, - ori_image=ori_image, - num_images_per_prompt=num_images_per_prompt, - np_hint=np_hint, - ) - height, width = 512, 512 - else: - assert False - - # 5. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - self._num_timesteps = len(timesteps) - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = ( - {"image_embeds": image_embeds} - if ip_adapter_image is not None or ip_adapter_image_embeds is not None - else None - ) - - # 7.2 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - is_unet_compiled = is_compiled_module(self.unet) - is_controlnet_compiled = is_compiled_module(self.controlnet) - is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # Relevant thread: - # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: - torch._inductor.cudagraph_mark_step_begin() - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # controlnet(s) inference - if guess_mode and self.do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. - control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - else: - control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - controlnet_cond_scale = controlnet_conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] - - down_block_res_samples, mid_block_res_sample = self.controlnet( - control_model_input.to(self.controlnet.dtype), - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=guided_hint, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - return_dict=False, - ) - - if guess_mode and self.do_classifier_free_guidance: - # Inferred ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=self.cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - # If we do sequential model offloading, let's offload unet and controlnet - # manually for max memory savings - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.unet.to("cpu") - self.controlnet.to("cpu") - torch.cuda.empty_cache() - - if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.text_embedding_module.to(*args, **kwargs) - self.auxiliary_latent_module.to(*args, **kwargs) - return self diff --git a/examples/research_projects/anytext/anytext_controlnet.py b/examples/research_projects/anytext/anytext_controlnet.py deleted file mode 100644 index 5965ceed1370..000000000000 --- a/examples/research_projects/anytext/anytext_controlnet.py +++ /dev/null @@ -1,463 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054). -# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie -# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license -# -# Adapted to Diffusers by [M. Tolga Cangรถz](https://github.com/tolgacangoz). - - -from typing import Any, Dict, Optional, Tuple, Union - -import torch -from torch import nn - -from diffusers.configuration_utils import register_to_config -from diffusers.models.controlnets.controlnet import ( - ControlNetModel, - ControlNetOutput, -) -from diffusers.utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class AnyTextControlNetConditioningEmbedding(nn.Module): - """ - Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN - [11] to convert the entire dataset of 512 ร— 512 images into smaller 64 ร— 64 โ€œlatent imagesโ€ for stabilized - training. This requires ControlNets to convert image-based conditions to 64 ร— 64 feature space to match the - convolution size. We use a tiny network E(ยท) of four convolution layers with 4 ร— 4 kernels and 2 ร— 2 strides - (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full - model) to encode image-space conditions ... into feature maps ..." - """ - - def __init__( - self, - conditioning_embedding_channels: int, - glyph_channels=1, - position_channels=1, - ): - super().__init__() - - self.glyph_block = nn.Sequential( - nn.Conv2d(glyph_channels, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 16, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 32, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(32, 32, 3, padding=1), - nn.SiLU(), - nn.Conv2d(32, 96, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(96, 96, 3, padding=1), - nn.SiLU(), - nn.Conv2d(96, 256, 3, padding=1, stride=2), - nn.SiLU(), - ) - - self.position_block = nn.Sequential( - nn.Conv2d(position_channels, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 8, 3, padding=1), - nn.SiLU(), - nn.Conv2d(8, 16, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 32, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(32, 32, 3, padding=1), - nn.SiLU(), - nn.Conv2d(32, 64, 3, padding=1, stride=2), - nn.SiLU(), - ) - - self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1) - - def forward(self, glyphs, positions, text_info): - glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device)) - position_embedding = self.position_block(positions.to(self.position_block[0].weight.device)) - guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1)) - - return guided_hint - - -class AnyTextControlNetModel(ControlNetModel): - """ - A AnyTextControlNetModel model. - - Args: - in_channels (`int`, defaults to 4): - The number of channels in the input sample. - flip_sin_to_cos (`bool`, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, defaults to 0): - The frequency shift to apply to the time embedding. - down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): - block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, defaults to 2): - The number of layers per block. - downsample_padding (`int`, defaults to 1): - The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, defaults to 1): - The scale factor to use for the mid block. - act_fn (`str`, defaults to "silu"): - The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups to use for the normalization. If None, normalization and activation layers is skipped - in post-processing. - norm_eps (`float`, defaults to 1e-5): - The epsilon to use for the normalization. - cross_attention_dim (`int`, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - encoder_hid_dim (`int`, *optional*, defaults to None): - If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` - dimension to `cross_attention_dim`. - encoder_hid_dim_type (`str`, *optional*, defaults to `None`): - If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text - embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. - attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): - The dimension of the attention heads. - use_linear_projection (`bool`, defaults to `False`): - class_embed_type (`str`, *optional*, defaults to `None`): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, - `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. - addition_embed_type (`str`, *optional*, defaults to `None`): - Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or - "text". "text" will use the `TextTimeEmbedding` layer. - num_class_embeds (`int`, *optional*, defaults to 0): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - upcast_attention (`bool`, defaults to `False`): - resnet_time_scale_shift (`str`, defaults to `"default"`): - Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. - projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): - The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when - `class_embed_type="projection"`. - controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): - The channel order of conditional image. Will convert to `rgb` if it's `bgr`. - conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in the `conditioning_embedding` layer. - global_pool_conditions (`bool`, defaults to `False`): - TODO(Patrick) - unused parameter. - addition_embed_type_num_heads (`int`, defaults to 64): - The number of heads to use for the `TextTimeEmbedding` layer. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - in_channels: int = 4, - conditioning_channels: int = 1, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str, ...] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, - encoder_hid_dim: Optional[int] = None, - encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int, ...]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - addition_embed_type: Optional[str] = None, - addition_time_embed_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - projection_class_embeddings_input_dim: Optional[int] = None, - controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - global_pool_conditions: bool = False, - addition_embed_type_num_heads: int = 64, - ): - super().__init__( - in_channels, - conditioning_channels, - flip_sin_to_cos, - freq_shift, - down_block_types, - mid_block_type, - only_cross_attention, - block_out_channels, - layers_per_block, - downsample_padding, - mid_block_scale_factor, - act_fn, - norm_num_groups, - norm_eps, - cross_attention_dim, - transformer_layers_per_block, - encoder_hid_dim, - encoder_hid_dim_type, - attention_head_dim, - num_attention_heads, - use_linear_projection, - class_embed_type, - addition_embed_type, - addition_time_embed_dim, - num_class_embeds, - upcast_attention, - resnet_time_scale_shift, - projection_class_embeddings_input_dim, - controlnet_conditioning_channel_order, - conditioning_embedding_out_channels, - global_pool_conditions, - addition_embed_type_num_heads, - ) - - # control net conditioning embedding - self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - glyph_channels=conditioning_channels, - position_channels=conditioning_channels, - ) - - def forward( - self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: torch.Tensor, - conditioning_scale: float = 1.0, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guess_mode: bool = False, - return_dict: bool = True, - ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: - """ - The [`~PromptDiffusionControlNetModel`] forward method. - - Args: - sample (`torch.Tensor`): - The noisy input tensor. - timestep (`Union[torch.Tensor, float, int]`): - The number of timesteps to denoise an input. - encoder_hidden_states (`torch.Tensor`): - The encoder hidden states. - #controlnet_cond (`torch.Tensor`): - # The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. - conditioning_scale (`float`, defaults to `1.0`): - The scale factor for ControlNet outputs. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): - Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the - timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep - embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - added_cond_kwargs (`dict`): - Additional conditions for the Stable Diffusion XL UNet. - cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): - A kwargs dictionary that if specified is passed along to the `AttnProcessor`. - guess_mode (`bool`, defaults to `False`): - In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if - you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. - return_dict (`bool`, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. - - Returns: - [`~models.controlnet.ControlNetOutput`] **or** `tuple`: - If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is - returned where the first element is the sample tensor. - """ - # check channel order - channel_order = self.config.controlnet_conditioning_channel_order - - if channel_order == "rgb": - # in rgb order by default - ... - # elif channel_order == "bgr": - # controlnet_cond = torch.flip(controlnet_cond, dims=[1]) - else: - raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - aug_emb = None - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - if self.config.addition_embed_type is not None: - if self.config.addition_embed_type == "text": - aug_emb = self.add_embedding(encoder_hidden_states) - - elif self.config.addition_embed_type == "text_time": - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(emb.dtype) - aug_emb = self.add_embedding(add_embeds) - - emb = emb + aug_emb if aug_emb is not None else emb - - # 2. pre-process - sample = self.conv_in(sample) - - controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond) - sample = sample + controlnet_cond - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample = self.mid_block(sample, emb) - - # 5. Control net blocks - controlnet_down_block_res_samples = () - - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = controlnet_down_block_res_samples - - mid_block_res_sample = self.controlnet_mid_block(sample) - - # 6. scaling - if guess_mode and not self.config.global_pool_conditions: - scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 - scales = scales * conditioning_scale - down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] - mid_block_res_sample = mid_block_res_sample * scales[-1] # last one - else: - down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample = mid_block_res_sample * conditioning_scale - - if self.config.global_pool_conditions: - down_block_res_samples = [ - torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples - ] - mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) - - if not return_dict: - return (down_block_res_samples, mid_block_res_sample) - - return ControlNetOutput( - down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample - ) - - -# Copied from diffusers.models.controlnet.zero_module -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module diff --git a/examples/research_projects/anytext/ocr_recog/RNN.py b/examples/research_projects/anytext/ocr_recog/RNN.py deleted file mode 100755 index aec796d987c0..000000000000 --- a/examples/research_projects/anytext/ocr_recog/RNN.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch -from torch import nn - -from .RecSVTR import Block - - -class Swish(nn.Module): - def __int__(self): - super(Swish, self).__int__() - - def forward(self, x): - return x * torch.sigmoid(x) - - -class Im2Im(nn.Module): - def __init__(self, in_channels, **kwargs): - super().__init__() - self.out_channels = in_channels - - def forward(self, x): - return x - - -class Im2Seq(nn.Module): - def __init__(self, in_channels, **kwargs): - super().__init__() - self.out_channels = in_channels - - def forward(self, x): - B, C, H, W = x.shape - # assert H == 1 - x = x.reshape(B, C, H * W) - x = x.permute((0, 2, 1)) - return x - - -class EncoderWithRNN(nn.Module): - def __init__(self, in_channels, **kwargs): - super(EncoderWithRNN, self).__init__() - hidden_size = kwargs.get("hidden_size", 256) - self.out_channels = hidden_size * 2 - self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True) - - def forward(self, x): - self.lstm.flatten_parameters() - x, _ = self.lstm(x) - return x - - -class SequenceEncoder(nn.Module): - def __init__(self, in_channels, encoder_type="rnn", **kwargs): - super(SequenceEncoder, self).__init__() - self.encoder_reshape = Im2Seq(in_channels) - self.out_channels = self.encoder_reshape.out_channels - self.encoder_type = encoder_type - if encoder_type == "reshape": - self.only_reshape = True - else: - support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR} - assert encoder_type in support_encoder_dict, "{} must in {}".format( - encoder_type, support_encoder_dict.keys() - ) - - self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs) - self.out_channels = self.encoder.out_channels - self.only_reshape = False - - def forward(self, x): - if self.encoder_type != "svtr": - x = self.encoder_reshape(x) - if not self.only_reshape: - x = self.encoder(x) - return x - else: - x = self.encoder(x) - x = self.encoder_reshape(x) - return x - - -class ConvBNLayer(nn.Module): - def __init__( - self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU - ): - super().__init__() - self.conv = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - groups=groups, - # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), - bias=bias_attr, - ) - self.norm = nn.BatchNorm2d(out_channels) - self.act = Swish() - - def forward(self, inputs): - out = self.conv(inputs) - out = self.norm(out) - out = self.act(out) - return out - - -class EncoderWithSVTR(nn.Module): - def __init__( - self, - in_channels, - dims=64, # XS - depth=2, - hidden_dims=120, - use_guide=False, - num_heads=8, - qkv_bias=True, - mlp_ratio=2.0, - drop_rate=0.1, - attn_drop_rate=0.1, - drop_path=0.0, - qk_scale=None, - ): - super(EncoderWithSVTR, self).__init__() - self.depth = depth - self.use_guide = use_guide - self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish") - self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish") - - self.svtr_block = nn.ModuleList( - [ - Block( - dim=hidden_dims, - num_heads=num_heads, - mixer="Global", - HW=None, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer="swish", - attn_drop=attn_drop_rate, - drop_path=drop_path, - norm_layer="nn.LayerNorm", - epsilon=1e-05, - prenorm=False, - ) - for i in range(depth) - ] - ) - self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) - self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish") - # last conv-nxn, the input is concat of input tensor and conv3 output tensor - self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish") - - self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish") - self.out_channels = dims - self.apply(self._init_weights) - - def _init_weights(self, m): - # weight initialization - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out") - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.BatchNorm2d): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.ConvTranspose2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out") - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.LayerNorm): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) - - def forward(self, x): - # for use guide - if self.use_guide: - z = x.clone() - z.stop_gradient = True - else: - z = x - # for short cut - h = z - # reduce dim - z = self.conv1(z) - z = self.conv2(z) - # SVTR global block - B, C, H, W = z.shape - z = z.flatten(2).permute(0, 2, 1) - - for blk in self.svtr_block: - z = blk(z) - - z = self.norm(z) - # last stage - z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2) - z = self.conv3(z) - z = torch.cat((h, z), dim=1) - z = self.conv1x1(self.conv4(z)) - - return z - - -if __name__ == "__main__": - svtrRNN = EncoderWithSVTR(56) - print(svtrRNN) diff --git a/examples/research_projects/anytext/ocr_recog/RecCTCHead.py b/examples/research_projects/anytext/ocr_recog/RecCTCHead.py deleted file mode 100755 index c066c6202b19..000000000000 --- a/examples/research_projects/anytext/ocr_recog/RecCTCHead.py +++ /dev/null @@ -1,45 +0,0 @@ -from torch import nn - - -class CTCHead(nn.Module): - def __init__( - self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs - ): - super(CTCHead, self).__init__() - if mid_channels is None: - self.fc = nn.Linear( - in_channels, - out_channels, - bias=True, - ) - else: - self.fc1 = nn.Linear( - in_channels, - mid_channels, - bias=True, - ) - self.fc2 = nn.Linear( - mid_channels, - out_channels, - bias=True, - ) - - self.out_channels = out_channels - self.mid_channels = mid_channels - self.return_feats = return_feats - - def forward(self, x, labels=None): - if self.mid_channels is None: - predicts = self.fc(x) - else: - x = self.fc1(x) - predicts = self.fc2(x) - - if self.return_feats: - result = {} - result["ctc"] = predicts - result["ctc_neck"] = x - else: - result = predicts - - return result diff --git a/examples/research_projects/anytext/ocr_recog/RecModel.py b/examples/research_projects/anytext/ocr_recog/RecModel.py deleted file mode 100755 index 872ccade69e0..000000000000 --- a/examples/research_projects/anytext/ocr_recog/RecModel.py +++ /dev/null @@ -1,49 +0,0 @@ -from torch import nn - -from .RecCTCHead import CTCHead -from .RecMv1_enhance import MobileNetV1Enhance -from .RNN import Im2Im, Im2Seq, SequenceEncoder - - -backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance} -neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im} -head_dict = {"CTCHead": CTCHead} - - -class RecModel(nn.Module): - def __init__(self, config): - super().__init__() - assert "in_channels" in config, "in_channels must in model config" - backbone_type = config["backbone"].pop("type") - assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}" - self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"]) - - neck_type = config["neck"].pop("type") - assert neck_type in neck_dict, f"neck.type must in {neck_dict}" - self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"]) - - head_type = config["head"].pop("type") - assert head_type in head_dict, f"head.type must in {head_dict}" - self.head = head_dict[head_type](self.neck.out_channels, **config["head"]) - - self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}" - - def load_3rd_state_dict(self, _3rd_name, _state): - self.backbone.load_3rd_state_dict(_3rd_name, _state) - self.neck.load_3rd_state_dict(_3rd_name, _state) - self.head.load_3rd_state_dict(_3rd_name, _state) - - def forward(self, x): - import torch - - x = x.to(torch.float32) - x = self.backbone(x) - x = self.neck(x) - x = self.head(x) - return x - - def encode(self, x): - x = self.backbone(x) - x = self.neck(x) - x = self.head.ctc_encoder(x) - return x diff --git a/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py b/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py deleted file mode 100644 index df41519b2713..000000000000 --- a/examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py +++ /dev/null @@ -1,197 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .common import Activation - - -class ConvBNLayer(nn.Module): - def __init__( - self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act="hard_swish" - ): - super(ConvBNLayer, self).__init__() - self.act = act - self._conv = nn.Conv2d( - in_channels=num_channels, - out_channels=num_filters, - kernel_size=filter_size, - stride=stride, - padding=padding, - groups=num_groups, - bias=False, - ) - - self._batch_norm = nn.BatchNorm2d( - num_filters, - ) - if self.act is not None: - self._act = Activation(act_type=act, inplace=True) - - def forward(self, inputs): - y = self._conv(inputs) - y = self._batch_norm(y) - if self.act is not None: - y = self._act(y) - return y - - -class DepthwiseSeparable(nn.Module): - def __init__( - self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False - ): - super(DepthwiseSeparable, self).__init__() - self.use_se = use_se - self._depthwise_conv = ConvBNLayer( - num_channels=num_channels, - num_filters=int(num_filters1 * scale), - filter_size=dw_size, - stride=stride, - padding=padding, - num_groups=int(num_groups * scale), - ) - if use_se: - self._se = SEModule(int(num_filters1 * scale)) - self._pointwise_conv = ConvBNLayer( - num_channels=int(num_filters1 * scale), - filter_size=1, - num_filters=int(num_filters2 * scale), - stride=1, - padding=0, - ) - - def forward(self, inputs): - y = self._depthwise_conv(inputs) - if self.use_se: - y = self._se(y) - y = self._pointwise_conv(y) - return y - - -class MobileNetV1Enhance(nn.Module): - def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type="max", **kwargs): - super().__init__() - self.scale = scale - self.block_list = [] - - self.conv1 = ConvBNLayer( - num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1 - ) - - conv2_1 = DepthwiseSeparable( - num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale - ) - self.block_list.append(conv2_1) - - conv2_2 = DepthwiseSeparable( - num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale - ) - self.block_list.append(conv2_2) - - conv3_1 = DepthwiseSeparable( - num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale - ) - self.block_list.append(conv3_1) - - conv3_2 = DepthwiseSeparable( - num_channels=int(128 * scale), - num_filters1=128, - num_filters2=256, - num_groups=128, - stride=(2, 1), - scale=scale, - ) - self.block_list.append(conv3_2) - - conv4_1 = DepthwiseSeparable( - num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale - ) - self.block_list.append(conv4_1) - - conv4_2 = DepthwiseSeparable( - num_channels=int(256 * scale), - num_filters1=256, - num_filters2=512, - num_groups=256, - stride=(2, 1), - scale=scale, - ) - self.block_list.append(conv4_2) - - for _ in range(5): - conv5 = DepthwiseSeparable( - num_channels=int(512 * scale), - num_filters1=512, - num_filters2=512, - num_groups=512, - stride=1, - dw_size=5, - padding=2, - scale=scale, - use_se=False, - ) - self.block_list.append(conv5) - - conv5_6 = DepthwiseSeparable( - num_channels=int(512 * scale), - num_filters1=512, - num_filters2=1024, - num_groups=512, - stride=(2, 1), - dw_size=5, - padding=2, - scale=scale, - use_se=True, - ) - self.block_list.append(conv5_6) - - conv6 = DepthwiseSeparable( - num_channels=int(1024 * scale), - num_filters1=1024, - num_filters2=1024, - num_groups=1024, - stride=last_conv_stride, - dw_size=5, - padding=2, - use_se=True, - scale=scale, - ) - self.block_list.append(conv6) - - self.block_list = nn.Sequential(*self.block_list) - if last_pool_type == "avg": - self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - else: - self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) - self.out_channels = int(1024 * scale) - - def forward(self, inputs): - y = self.conv1(inputs) - y = self.block_list(y) - y = self.pool(y) - return y - - -def hardsigmoid(x): - return F.relu6(x + 3.0, inplace=True) / 6.0 - - -class SEModule(nn.Module): - def __init__(self, channel, reduction=4): - super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.conv1 = nn.Conv2d( - in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True - ) - self.conv2 = nn.Conv2d( - in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True - ) - - def forward(self, inputs): - outputs = self.avg_pool(inputs) - outputs = self.conv1(outputs) - outputs = F.relu(outputs) - outputs = self.conv2(outputs) - outputs = hardsigmoid(outputs) - x = torch.mul(inputs, outputs) - - return x diff --git a/examples/research_projects/anytext/ocr_recog/RecSVTR.py b/examples/research_projects/anytext/ocr_recog/RecSVTR.py deleted file mode 100644 index 590a96995b26..000000000000 --- a/examples/research_projects/anytext/ocr_recog/RecSVTR.py +++ /dev/null @@ -1,570 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from torch.nn import functional -from torch.nn.init import ones_, trunc_normal_, zeros_ - - -def drop_path(x, drop_prob=0.0, training=False): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... - """ - if drop_prob == 0.0 or not training: - return x - keep_prob = torch.tensor(1 - drop_prob) - shape = (x.size()[0],) + (1,) * (x.ndim - 1) - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype) - random_tensor = torch.floor(random_tensor) # binarize - output = x.divide(keep_prob) * random_tensor - return output - - -class Swish(nn.Module): - def __int__(self): - super(Swish, self).__int__() - - def forward(self, x): - return x * torch.sigmoid(x) - - -class ConvBNLayer(nn.Module): - def __init__( - self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU - ): - super().__init__() - self.conv = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - groups=groups, - # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), - bias=bias_attr, - ) - self.norm = nn.BatchNorm2d(out_channels) - self.act = act() - - def forward(self, inputs): - out = self.conv(inputs) - out = self.norm(out) - out = self.act(out) - return out - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - -class Identity(nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, input): - return input - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - if isinstance(act_layer, str): - self.act = Swish() - else: - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class ConvMixer(nn.Module): - def __init__( - self, - dim, - num_heads=8, - HW=(8, 25), - local_k=(3, 3), - ): - super().__init__() - self.HW = HW - self.dim = dim - self.local_mixer = nn.Conv2d( - dim, - dim, - local_k, - 1, - (local_k[0] // 2, local_k[1] // 2), - groups=num_heads, - # weight_attr=ParamAttr(initializer=KaimingNormal()) - ) - - def forward(self, x): - h = self.HW[0] - w = self.HW[1] - x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w]) - x = self.local_mixer(x) - x = x.flatten(2).transpose([0, 2, 1]) - return x - - -class Attention(nn.Module): - def __init__( - self, - dim, - num_heads=8, - mixer="Global", - HW=(8, 25), - local_k=(7, 11), - qkv_bias=False, - qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - ): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.HW = HW - if HW is not None: - H = HW[0] - W = HW[1] - self.N = H * W - self.C = dim - if mixer == "Local" and HW is not None: - hk = local_k[0] - wk = local_k[1] - mask = torch.ones([H * W, H + hk - 1, W + wk - 1]) - for h in range(0, H): - for w in range(0, W): - mask[h * W + w, h : h + hk, w : w + wk] = 0.0 - mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1) - mask_inf = torch.full([H * W, H * W], fill_value=float("-inf")) - mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf) - self.mask = mask[None, None, :] - # self.mask = mask.unsqueeze([0, 1]) - self.mixer = mixer - - def forward(self, x): - if self.HW is not None: - N = self.N - C = self.C - else: - _, N, C = x.shape - qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4)) - q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - - attn = q.matmul(k.permute((0, 1, 3, 2))) - if self.mixer == "Local": - attn += self.mask - attn = functional.softmax(attn, dim=-1) - attn = self.attn_drop(attn) - - x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C)) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class Block(nn.Module): - def __init__( - self, - dim, - num_heads, - mixer="Global", - local_mixer=(7, 11), - HW=(8, 25), - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer="nn.LayerNorm", - epsilon=1e-6, - prenorm=True, - ): - super().__init__() - if isinstance(norm_layer, str): - self.norm1 = eval(norm_layer)(dim, eps=epsilon) - else: - self.norm1 = norm_layer(dim) - if mixer == "Global" or mixer == "Local": - self.mixer = Attention( - dim, - num_heads=num_heads, - mixer=mixer, - HW=HW, - local_k=local_mixer, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - ) - elif mixer == "Conv": - self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer) - else: - raise TypeError("The mixer must be one of [Global, Local, Conv]") - - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() - if isinstance(norm_layer, str): - self.norm2 = eval(norm_layer)(dim, eps=epsilon) - else: - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp_ratio = mlp_ratio - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - self.prenorm = prenorm - - def forward(self, x): - if self.prenorm: - x = self.norm1(x + self.drop_path(self.mixer(x))) - x = self.norm2(x + self.drop_path(self.mlp(x))) - else: - x = x + self.drop_path(self.mixer(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - -class PatchEmbed(nn.Module): - """Image to Patch Embedding""" - - def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2): - super().__init__() - num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num)) - self.img_size = img_size - self.num_patches = num_patches - self.embed_dim = embed_dim - self.norm = None - if sub_num == 2: - self.proj = nn.Sequential( - ConvBNLayer( - in_channels=in_channels, - out_channels=embed_dim // 2, - kernel_size=3, - stride=2, - padding=1, - act=nn.GELU, - bias_attr=False, - ), - ConvBNLayer( - in_channels=embed_dim // 2, - out_channels=embed_dim, - kernel_size=3, - stride=2, - padding=1, - act=nn.GELU, - bias_attr=False, - ), - ) - if sub_num == 3: - self.proj = nn.Sequential( - ConvBNLayer( - in_channels=in_channels, - out_channels=embed_dim // 4, - kernel_size=3, - stride=2, - padding=1, - act=nn.GELU, - bias_attr=False, - ), - ConvBNLayer( - in_channels=embed_dim // 4, - out_channels=embed_dim // 2, - kernel_size=3, - stride=2, - padding=1, - act=nn.GELU, - bias_attr=False, - ), - ConvBNLayer( - in_channels=embed_dim // 2, - out_channels=embed_dim, - kernel_size=3, - stride=2, - padding=1, - act=nn.GELU, - bias_attr=False, - ), - ) - - def forward(self, x): - B, C, H, W = x.shape - assert ( - H == self.img_size[0] and W == self.img_size[1] - ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).permute(0, 2, 1) - return x - - -class SubSample(nn.Module): - def __init__(self, in_channels, out_channels, types="Pool", stride=(2, 1), sub_norm="nn.LayerNorm", act=None): - super().__init__() - self.types = types - if types == "Pool": - self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2)) - self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2)) - self.proj = nn.Linear(in_channels, out_channels) - else: - self.conv = nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=stride, - padding=1, - # weight_attr=ParamAttr(initializer=KaimingNormal()) - ) - self.norm = eval(sub_norm)(out_channels) - if act is not None: - self.act = act() - else: - self.act = None - - def forward(self, x): - if self.types == "Pool": - x1 = self.avgpool(x) - x2 = self.maxpool(x) - x = (x1 + x2) * 0.5 - out = self.proj(x.flatten(2).permute((0, 2, 1))) - else: - x = self.conv(x) - out = x.flatten(2).permute((0, 2, 1)) - out = self.norm(out) - if self.act is not None: - out = self.act(out) - - return out - - -class SVTRNet(nn.Module): - def __init__( - self, - img_size=[48, 100], - in_channels=3, - embed_dim=[64, 128, 256], - depth=[3, 6, 3], - num_heads=[2, 4, 8], - mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv - local_mixer=[[7, 11], [7, 11], [7, 11]], - patch_merging="Conv", # Conv, Pool, None - mlp_ratio=4, - qkv_bias=True, - qk_scale=None, - drop_rate=0.0, - last_drop=0.1, - attn_drop_rate=0.0, - drop_path_rate=0.1, - norm_layer="nn.LayerNorm", - sub_norm="nn.LayerNorm", - epsilon=1e-6, - out_channels=192, - out_char_num=25, - block_unit="Block", - act="nn.GELU", - last_stage=True, - sub_num=2, - prenorm=True, - use_lenhead=False, - **kwargs, - ): - super().__init__() - self.img_size = img_size - self.embed_dim = embed_dim - self.out_channels = out_channels - self.prenorm = prenorm - patch_merging = None if patch_merging != "Conv" and patch_merging != "Pool" else patch_merging - self.patch_embed = PatchEmbed( - img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num - ) - num_patches = self.patch_embed.num_patches - self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)] - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0])) - # self.pos_embed = self.create_parameter( - # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_) - - # self.add_parameter("pos_embed", self.pos_embed) - - self.pos_drop = nn.Dropout(p=drop_rate) - Block_unit = eval(block_unit) - - dpr = np.linspace(0, drop_path_rate, sum(depth)) - self.blocks1 = nn.ModuleList( - [ - Block_unit( - dim=embed_dim[0], - num_heads=num_heads[0], - mixer=mixer[0 : depth[0]][i], - HW=self.HW, - local_mixer=local_mixer[0], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=eval(act), - attn_drop=attn_drop_rate, - drop_path=dpr[0 : depth[0]][i], - norm_layer=norm_layer, - epsilon=epsilon, - prenorm=prenorm, - ) - for i in range(depth[0]) - ] - ) - if patch_merging is not None: - self.sub_sample1 = SubSample( - embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging - ) - HW = [self.HW[0] // 2, self.HW[1]] - else: - HW = self.HW - self.patch_merging = patch_merging - self.blocks2 = nn.ModuleList( - [ - Block_unit( - dim=embed_dim[1], - num_heads=num_heads[1], - mixer=mixer[depth[0] : depth[0] + depth[1]][i], - HW=HW, - local_mixer=local_mixer[1], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=eval(act), - attn_drop=attn_drop_rate, - drop_path=dpr[depth[0] : depth[0] + depth[1]][i], - norm_layer=norm_layer, - epsilon=epsilon, - prenorm=prenorm, - ) - for i in range(depth[1]) - ] - ) - if patch_merging is not None: - self.sub_sample2 = SubSample( - embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging - ) - HW = [self.HW[0] // 4, self.HW[1]] - else: - HW = self.HW - self.blocks3 = nn.ModuleList( - [ - Block_unit( - dim=embed_dim[2], - num_heads=num_heads[2], - mixer=mixer[depth[0] + depth[1] :][i], - HW=HW, - local_mixer=local_mixer[2], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - act_layer=eval(act), - attn_drop=attn_drop_rate, - drop_path=dpr[depth[0] + depth[1] :][i], - norm_layer=norm_layer, - epsilon=epsilon, - prenorm=prenorm, - ) - for i in range(depth[2]) - ] - ) - self.last_stage = last_stage - if last_stage: - self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num)) - self.last_conv = nn.Conv2d( - in_channels=embed_dim[2], - out_channels=self.out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=False, - ) - self.hardswish = nn.Hardswish() - self.dropout = nn.Dropout(p=last_drop) - if not prenorm: - self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon) - self.use_lenhead = use_lenhead - if use_lenhead: - self.len_conv = nn.Linear(embed_dim[2], self.out_channels) - self.hardswish_len = nn.Hardswish() - self.dropout_len = nn.Dropout(p=last_drop) - - trunc_normal_(self.pos_embed, std=0.02) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - zeros_(m.bias) - elif isinstance(m, nn.LayerNorm): - zeros_(m.bias) - ones_(m.weight) - - def forward_features(self, x): - x = self.patch_embed(x) - x = x + self.pos_embed - x = self.pos_drop(x) - for blk in self.blocks1: - x = blk(x) - if self.patch_merging is not None: - x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]])) - for blk in self.blocks2: - x = blk(x) - if self.patch_merging is not None: - x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]])) - for blk in self.blocks3: - x = blk(x) - if not self.prenorm: - x = self.norm(x) - return x - - def forward(self, x): - x = self.forward_features(x) - if self.use_lenhead: - len_x = self.len_conv(x.mean(1)) - len_x = self.dropout_len(self.hardswish_len(len_x)) - if self.last_stage: - if self.patch_merging is not None: - h = self.HW[0] // 4 - else: - h = self.HW[0] - x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]])) - x = self.last_conv(x) - x = self.hardswish(x) - x = self.dropout(x) - if self.use_lenhead: - return x, len_x - return x - - -if __name__ == "__main__": - a = torch.rand(1, 3, 48, 100) - svtr = SVTRNet() - - out = svtr(a) - print(svtr) - print(out.size()) diff --git a/examples/research_projects/anytext/ocr_recog/common.py b/examples/research_projects/anytext/ocr_recog/common.py deleted file mode 100644 index 207a95b17d0e..000000000000 --- a/examples/research_projects/anytext/ocr_recog/common.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Hswish(nn.Module): - def __init__(self, inplace=True): - super(Hswish, self).__init__() - self.inplace = inplace - - def forward(self, x): - return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 - - -# out = max(0, min(1, slop*x+offset)) -# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None) -class Hsigmoid(nn.Module): - def __init__(self, inplace=True): - super(Hsigmoid, self).__init__() - self.inplace = inplace - - def forward(self, x): - # torch: F.relu6(x + 3., inplace=self.inplace) / 6. - # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6. - return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0 - - -class GELU(nn.Module): - def __init__(self, inplace=True): - super(GELU, self).__init__() - self.inplace = inplace - - def forward(self, x): - return torch.nn.functional.gelu(x) - - -class Swish(nn.Module): - def __init__(self, inplace=True): - super(Swish, self).__init__() - self.inplace = inplace - - def forward(self, x): - if self.inplace: - x.mul_(torch.sigmoid(x)) - return x - else: - return x * torch.sigmoid(x) - - -class Activation(nn.Module): - def __init__(self, act_type, inplace=True): - super(Activation, self).__init__() - act_type = act_type.lower() - if act_type == "relu": - self.act = nn.ReLU(inplace=inplace) - elif act_type == "relu6": - self.act = nn.ReLU6(inplace=inplace) - elif act_type == "sigmoid": - raise NotImplementedError - elif act_type == "hard_sigmoid": - self.act = Hsigmoid(inplace) - elif act_type == "hard_swish": - self.act = Hswish(inplace=inplace) - elif act_type == "leakyrelu": - self.act = nn.LeakyReLU(inplace=inplace) - elif act_type == "gelu": - self.act = GELU(inplace=inplace) - elif act_type == "swish": - self.act = Swish(inplace=inplace) - else: - raise NotImplementedError - - def forward(self, inputs): - return self.act(inputs) diff --git a/examples/research_projects/anytext/ocr_recog/en_dict.txt b/examples/research_projects/anytext/ocr_recog/en_dict.txt deleted file mode 100644 index 7677d31b9d3f..000000000000 --- a/examples/research_projects/anytext/ocr_recog/en_dict.txt +++ /dev/null @@ -1,95 +0,0 @@ -0 -1 -2 -3 -4 -5 -6 -7 -8 -9 -: -; -< -= -> -? -@ -A -B -C -D -E -F -G -H -I -J -K -L -M -N -O -P -Q -R -S -T -U -V -W -X -Y -Z -[ -\ -] -^ -_ -` -a -b -c -d -e -f -g -h -i -j -k -l -m -n -o -p -q -r -s -t -u -v -w -x -y -z -{ -| -} -~ -! -" -# -$ -% -& -' -( -) -* -+ -, -- -. -/ - diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py index 829b0031156e..765bb495062e 100644 --- a/examples/research_projects/controlnet/train_controlnet_webdataset.py +++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py @@ -381,7 +381,9 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images = [np.asarray(validation_image)] + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) for image in images: formatted_images.append(np.asarray(image)) diff --git a/examples/research_projects/pixart/train_pixart_controlnet_hf.py b/examples/research_projects/pixart/train_pixart_controlnet_hf.py index 67ec30da0ece..995a20dfa28e 100644 --- a/examples/research_projects/pixart/train_pixart_controlnet_hf.py +++ b/examples/research_projects/pixart/train_pixart_controlnet_hf.py @@ -164,7 +164,9 @@ def log_validation( validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] - formatted_images = [np.asarray(validation_image)] + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) for image in images: formatted_images.append(np.asarray(image)) diff --git a/examples/research_projects/pytorch_xla/inference/flux/README.md b/examples/research_projects/pytorch_xla/inference/flux/README.md index 9d482e6805a3..dd7e23c57049 100644 --- a/examples/research_projects/pytorch_xla/inference/flux/README.md +++ b/examples/research_projects/pytorch_xla/inference/flux/README.md @@ -1,6 +1,8 @@ # Generating images using Flux and PyTorch/XLA -The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested. +The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation. + +It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested. ## Create TPU @@ -21,23 +23,20 @@ Verify that PyTorch and PyTorch/XLA were installed correctly: python3 -c "import torch; import torch_xla;" ``` -Clone the diffusers repo and install dependencies +Install dependencies ```bash -git clone https://github.com/huggingface/diffusers.git -cd diffusers pip install transformers accelerate sentencepiece structlog +pushd ../../.. pip install . -cd examples/research_projects/pytorch_xla/inference/flux/ +popd ``` ## Run the inference job ### Authenticate -**Gated Model** - -As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youโ€™ve accepted the gate. Use the command below to log in: +Run the following command to authenticate your token in order to download Flux weights. ```bash huggingface-cli login @@ -51,116 +50,51 @@ python flux_inference.py The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest. -On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel): +On a Trillium v6e-4, you should expect ~9 sec / 4 images or 2.25 sec / image (as devices run generation in parallel): ```bash WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. -Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2/2 [00:00<00:00, 7.06it/s] -Loading pipeline components...: 60%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ | 3/5 [00:00<00:00, 6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers -Loading pipeline components...: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 5/5 [00:00<00:00, 6.28it/s] -2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev -2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev -Loading pipeline components...: 0%| | 0/3 [00:00 Dict[str, Any]: state_dict = saved_dict @@ -126,16 +104,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: def convert_transformer( ckpt_path: str, dtype: torch.dtype, - version: str = "0.9.0", ): PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(load_file(ckpt_path)) - config = {} - if version == "0.9.5": - config["_use_causal_rope_fix"] = True with init_empty_weights(): - transformer = LTXVideoTransformer3DModel(**config) + transformer = LTXVideoTransformer3DModel() for key in list(original_state_dict.keys()): new_key = key[:] @@ -187,19 +161,12 @@ def get_vae_config(version: str) -> Dict[str, Any]: "out_channels": 3, "latent_channels": 128, "block_out_channels": (128, 256, 512, 512), - "down_block_types": ( - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - ), "decoder_block_out_channels": (128, 256, 512, 512), "layers_per_block": (4, 3, 3, 3, 4), "decoder_layers_per_block": (4, 3, 3, 3, 4), "spatio_temporal_scaling": (True, True, True, False), "decoder_spatio_temporal_scaling": (True, True, True, False), "decoder_inject_noise": (False, False, False, False, False), - "downsample_type": ("conv", "conv", "conv", "conv"), "upsample_residual": (False, False, False, False), "upsample_factor": (1, 1, 1, 1), "patch_size": 4, @@ -216,19 +183,12 @@ def get_vae_config(version: str) -> Dict[str, Any]: "out_channels": 3, "latent_channels": 128, "block_out_channels": (128, 256, 512, 512), - "down_block_types": ( - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - ), "decoder_block_out_channels": (256, 512, 1024), "layers_per_block": (4, 3, 3, 3, 4), "decoder_layers_per_block": (5, 6, 7, 8), "spatio_temporal_scaling": (True, True, True, False), "decoder_spatio_temporal_scaling": (True, True, True), "decoder_inject_noise": (True, True, True, False), - "downsample_type": ("conv", "conv", "conv", "conv"), "upsample_residual": (True, True, True), "upsample_factor": (2, 2, 2), "timestep_conditioning": True, @@ -240,38 +200,7 @@ def get_vae_config(version: str) -> Dict[str, Any]: "decoder_causal": False, } VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) - elif version == "0.9.5": - config = { - "in_channels": 3, - "out_channels": 3, - "latent_channels": 128, - "block_out_channels": (128, 256, 512, 1024, 2048), - "down_block_types": ( - "LTXVideo095DownBlock3D", - "LTXVideo095DownBlock3D", - "LTXVideo095DownBlock3D", - "LTXVideo095DownBlock3D", - ), - "decoder_block_out_channels": (256, 512, 1024), - "layers_per_block": (4, 6, 6, 2, 2), - "decoder_layers_per_block": (5, 5, 5, 5), - "spatio_temporal_scaling": (True, True, True, True), - "decoder_spatio_temporal_scaling": (True, True, True), - "decoder_inject_noise": (False, False, False, False), - "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), - "upsample_residual": (True, True, True), - "upsample_factor": (2, 2, 2), - "timestep_conditioning": True, - "patch_size": 4, - "patch_size_t": 1, - "resnet_norm_eps": 1e-6, - "scaling_factor": 1.0, - "encoder_causal": True, - "decoder_causal": False, - "spatial_compression_ratio": 32, - "temporal_compression_ratio": 8, - } - VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT) + VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP) return config @@ -294,7 +223,7 @@ def get_args(): parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") parser.add_argument( - "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model" + "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model" ) return parser.parse_args() @@ -348,17 +277,14 @@ def get_args(): for param in text_encoder.parameters(): param.data = param.data.contiguous() - if args.version == "0.9.5": - scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False) - else: - scheduler = FlowMatchEulerDiscreteScheduler( - use_dynamic_shifting=True, - base_shift=0.95, - max_shift=2.05, - base_image_seq_len=1024, - max_image_seq_len=4096, - shift_terminal=0.1, - ) + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, + ) pipe = LTXPipeline( scheduler=scheduler, diff --git a/scripts/convert_lumina_to_diffusers.py b/scripts/convert_lumina_to_diffusers.py index c14aad3c6bf2..a12625d1376f 100644 --- a/scripts/convert_lumina_to_diffusers.py +++ b/scripts/convert_lumina_to_diffusers.py @@ -5,7 +5,7 @@ from safetensors.torch import load_file from transformers import AutoModel, AutoTokenizer -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline def main(args): @@ -115,7 +115,7 @@ def main(args): tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") text_encoder = AutoModel.from_pretrained("google/gemma-2b") - pipeline = LuminaPipeline( + pipeline = LuminaText2ImgPipeline( tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler ) pipeline.save_pretrained(args.dump_path) diff --git a/setup.py b/setup.py index fdc166a81ecf..93945ae040dd 100644 --- a/setup.py +++ b/setup.py @@ -128,10 +128,6 @@ "GitPython<3.1.19", "scipy", "onnx", - "optimum_quanto>=0.2.6", - "gguf>=0.10.0", - "torchao>=0.7.0", - "bitsandbytes>=0.43.3", "regex!=2019.12.17", "requests", "tensorboard", @@ -239,11 +235,6 @@ def run(self): ) extras["torch"] = deps_list("torch", "accelerate") -extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate") -extras["gguf"] = deps_list("gguf", "accelerate") -extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate") -extras["torchao"] = deps_list("torchao", "accelerate") - if os.name == "nt": # windows extras["flax"] = [] # jax is not supported on windows else: diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ad658f1b14ff..d5cfad915e3c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -6,19 +6,14 @@ DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, - is_accelerate_available, - is_bitsandbytes_available, is_flax_available, - is_gguf_available, is_k_diffusion_available, is_librosa_available, is_note_seq_available, is_onnx_available, - is_optimum_quanto_available, is_scipy_available, is_sentencepiece_available, is_torch_available, - is_torchao_available, is_torchsde_available, is_transformers_available, ) @@ -37,7 +32,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers.quantization_config": [], + "quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -59,54 +54,6 @@ ], } -try: - if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils import dummy_bitsandbytes_objects - - _import_structure["utils.dummy_bitsandbytes_objects"] = [ - name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_") - ] -else: - _import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig") - -try: - if not is_torch_available() and not is_accelerate_available() and not is_gguf_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils import dummy_gguf_objects - - _import_structure["utils.dummy_gguf_objects"] = [ - name for name in dir(dummy_gguf_objects) if not name.startswith("_") - ] -else: - _import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig") - -try: - if not is_torch_available() and not is_accelerate_available() and not is_torchao_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils import dummy_torchao_objects - - _import_structure["utils.dummy_torchao_objects"] = [ - name for name in dir(dummy_torchao_objects) if not name.startswith("_") - ] -else: - _import_structure["quantizers.quantization_config"].append("TorchAoConfig") - -try: - if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils import dummy_optimum_quanto_objects - - _import_structure["utils.dummy_optimum_quanto_objects"] = [ - name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_") - ] -else: - _import_structure["quantizers.quantization_config"].append("QuantoConfig") - try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -345,7 +292,6 @@ "CogVideoXPipeline", "CogVideoXVideoToVideoPipeline", "CogView3PlusPipeline", - "CogView4ControlPipeline", "CogView4Pipeline", "ConsisIDPipeline", "CycleDiffusionPipeline", @@ -402,12 +348,9 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", - "LTXConditionPipeline", "LTXImageToVideoPipeline", "LTXPipeline", - "Lumina2Pipeline", "Lumina2Text2ImgPipeline", - "LuminaPipeline", "LuminaText2ImgPipeline", "MarigoldDepthPipeline", "MarigoldIntrinsicsPipeline", @@ -656,38 +599,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin - - try: - if not is_bitsandbytes_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_bitsandbytes_objects import * - else: - from .quantizers.quantization_config import BitsAndBytesConfig - - try: - if not is_gguf_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_gguf_objects import * - else: - from .quantizers.quantization_config import GGUFQuantizationConfig - - try: - if not is_torchao_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_torchao_objects import * - else: - from .quantizers.quantization_config import TorchAoConfig - - try: - if not is_optimum_quanto_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_optimum_quanto_objects import * - else: - from .quantizers.quantization_config import QuantoConfig + from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig try: if not is_onnx_available(): @@ -891,7 +803,6 @@ CogVideoXPipeline, CogVideoXVideoToVideoPipeline, CogView3PlusPipeline, - CogView4ControlPipeline, CogView4Pipeline, ConsisIDPipeline, CycleDiffusionPipeline, @@ -948,12 +859,9 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, - LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline, - Lumina2Pipeline, Lumina2Text2ImgPipeline, - LuminaPipeline, LuminaText2ImgPipeline, MarigoldDepthPipeline, MarigoldIntrinsicsPipeline, diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 8ec95ed6fc8d..17d5da60347d 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -35,10 +35,6 @@ "GitPython": "GitPython<3.1.19", "scipy": "scipy", "onnx": "onnx", - "optimum_quanto": "optimum_quanto>=0.2.6", - "gguf": "gguf>=0.10.0", - "torchao": "torchao>=0.7.0", - "bitsandbytes": "bitsandbytes>=0.43.3", "regex": "regex!=2019.12.17", "requests": "requests", "tensorboard": "tensorboard", diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index e4b9ed9307ea..c389c5dc9826 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -83,10 +83,7 @@ def onload_(self): with context: for group_module in self.modules: - for param in group_module.parameters(): - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) - for buffer in group_module.buffers(): - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + group_module.to(self.onload_device, non_blocking=self.non_blocking) if self.parameters is not None: for param in self.parameters: param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) @@ -101,12 +98,6 @@ def offload_(self): for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] - if self.parameters is not None: - for param in self.parameters: - param.data = self.cpu_param_dict[param] - if self.buffers is not None: - for buffer in self.buffers: - buffer.data = self.cpu_param_dict[buffer] else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=self.non_blocking) @@ -181,13 +172,6 @@ def __init__(self): self._layer_execution_tracker_module_names = set() def initialize_hook(self, module): - def make_execution_order_update_callback(current_name, current_submodule): - def callback(): - logger.debug(f"Adding {current_name} to the execution order") - self.execution_order.append((current_name, current_submodule)) - - return callback - # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the # layers are executed during the forward pass. @@ -199,8 +183,14 @@ def callback(): group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING) if group_offloading_hook is not None: - # For the first forward pass, we have to load in a blocking manner - group_offloading_hook.group.non_blocking = False + + def make_execution_order_update_callback(current_name, current_submodule): + def callback(): + logger.debug(f"Adding {current_name} to the execution order") + self.execution_order.append((current_name, current_submodule)) + + return callback + layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule)) registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER) self._layer_execution_tracker_module_names.add(name) @@ -230,7 +220,6 @@ def post_forward(self, module, output): # Remove the layer execution tracker hooks from the submodules base_module_registry = module._diffusers_hook registries = [submodule._diffusers_hook for _, submodule in self.execution_order] - group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] for i in range(num_executed): registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False) @@ -238,13 +227,8 @@ def post_forward(self, module, output): # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False) - # LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True. - # We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to - # see the benefits of prefetching. - for hook in group_offloading_hooks: - hook.group.non_blocking = True - - # Set required attributes for prefetching + # Apply lazy prefetching by setting required attributes + group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] if num_executed > 0: base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING) base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group @@ -403,7 +387,9 @@ def _apply_group_offloading_block_level( # Create a pinned CPU parameter dict for async data transfer if streams are to be used cpu_param_dict = None if stream is not None: - cpu_param_dict = _get_pinned_cpu_param_dict(module) + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict = {param: param.data for param in module.parameters()} # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() @@ -500,7 +486,9 @@ def _apply_group_offloading_leaf_level( # Create a pinned CPU parameter dict for async data transfer if streams are to be used cpu_param_dict = None if stream is not None: - cpu_param_dict = _get_pinned_cpu_param_dict(module) + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict = {param: param.data for param in module.parameters()} # Create module groups for leaf modules and apply group offloading hooks modules_with_group_offloading = set() @@ -616,17 +604,6 @@ def _apply_lazy_group_offloading_hook( registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) -def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]: - cpu_param_dict = {} - for param in module.parameters(): - param.data = param.data.cpu().pin_memory() - cpu_param_dict[param] = param.data - for buffer in module.buffers(): - buffer.data = buffer.data.cpu().pin_memory() - cpu_param_dict[buffer] = buffer.data - return cpu_param_dict - - def _gather_parameters_with_no_group_offloading_parent( module: torch.nn.Module, modules_with_group_offloading: Set[str] ) -> List[torch.nn.Parameter]: diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 3ba1bfacf3dd..86ffffd7d5df 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -70,7 +70,6 @@ def text_encoder_attn_modules(text_encoder): "LoraLoaderMixin", "FluxLoraLoaderMixin", "CogVideoXLoraLoaderMixin", - "CogView4LoraLoaderMixin", "Mochi1LoraLoaderMixin", "HunyuanVideoLoraLoaderMixin", "SanaLoraLoaderMixin", @@ -104,7 +103,6 @@ def text_encoder_attn_modules(text_encoder): from .lora_pipeline import ( AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, - CogView4LoraLoaderMixin, FluxLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, LoraLoaderMixin, diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 21a1a70ff79b..ac0a3c635332 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -804,7 +804,9 @@ def load_ip_adapter( } self.register_modules( - feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs), + feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to( + self.device, dtype=self.dtype + ), image_encoder=SiglipVisionModel.from_pretrained( image_encoder_subfolder, torch_dtype=self.dtype, **kwargs ).to(self.device), diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 17ed8c5444fc..50b6448ecdca 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -339,97 +339,93 @@ def _load_lora_into_text_encoder( # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as # their prefixes. + keys = list(state_dict.keys()) prefix = text_encoder_name if prefix is None else prefix - # Load the layers corresponding to text encoder and make necessary adjustments. - if prefix is not None: - state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} - - if len(state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - state_dict = convert_state_dict_to_diffusers(state_dict) - - # convert state dict - state_dict = convert_state_dict_to_peft(state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in state_dict: - continue - rank[rank_key] = state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in state_dict: - continue - rank[rank_key] = state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") + # Safe prefix to check with. + if any(text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } - lora_config = LoraConfig(**lora_config_kwargs) + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) + lora_config = LoraConfig(**lora_config_kwargs) - is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=state_dict, - peft_config=lora_config, - **peft_kwargs, - ) + is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - if prefix is not None and not state_dict: - logger.warning( - f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. " - "This is safe to ignore if LoRA state dict didn't originally have any " - f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` " - "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " - "https://github.com/huggingface/diffusers/issues/new" - ) + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + **peft_kwargs, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> def _func_optionally_disable_offloading(_pipeline): diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 20fcb61f3b80..4be6971755d2 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1348,56 +1348,3 @@ def process_block(prefix, index, convert_norm): converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) return converted_state_dict - - -def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): - converted_state_dict = {} - original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} - - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict}) - is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) - - for i in range(num_blocks): - # Self-attention - for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.lora_A.weight" - ) - converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.self_attn.{o}.lora_B.weight" - ) - - # Cross-attention - for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_A.weight" - ) - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_B.weight" - ) - - if is_i2v_lora: - for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_A.weight" - ) - converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.cross_attn.{o}.lora_B.weight" - ) - - # FFN - for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): - converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop( - f"blocks.{i}.{o}.lora_A.weight" - ) - converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop( - f"blocks.{i}.{o}.lora_B.weight" - ) - - if len(original_state_dict) > 0: - raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") - - for key in list(converted_state_dict.keys()): - converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) - - return converted_state_dict diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 160793ba1b58..e48725b01ca2 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -42,7 +42,6 @@ _convert_kohya_flux_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers, - _convert_non_diffusers_wan_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, ) @@ -299,15 +298,19 @@ def load_lora_into_unet( # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. - logger.info(f"Loading {cls.unet_name}.") - unet.load_lora_adapter( - state_dict, - prefix=cls.unet_name, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + keys = list(state_dict.keys()) + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) + if not only_text_encoder: + # Load the layers corresponding to UNet. + logger.info(f"Loading {cls.unet_name}.") + unet.load_lora_adapter( + state_dict, + prefix=cls.unet_name, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def load_lora_into_text_encoder( @@ -452,11 +455,7 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs): @@ -477,7 +476,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): @@ -560,26 +559,31 @@ def load_lora_weights( _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix=self.text_encoder_name, - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder_2, - prefix=f"{self.text_encoder_name}_2", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod @validate_hf_hub_args @@ -734,15 +738,19 @@ def load_lora_into_unet( # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. - logger.info(f"Loading {cls.unet_name}.") - unet.load_lora_adapter( - state_dict, - prefix=cls.unet_name, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + keys = list(state_dict.keys()) + only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) + if not only_text_encoder: + # Load the layers corresponding to UNet. + logger.info(f"Loading {cls.unet_name}.") + unet.load_lora_adapter( + state_dict, + prefix=cls.unet_name, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -896,11 +904,7 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs): @@ -921,7 +925,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_enc Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) class SD3LoraLoaderMixin(LoraBaseMixin): @@ -1081,33 +1085,43 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=None, - text_encoder=self.text_encoder, - prefix=self.text_encoder_name, - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=None, - text_encoder=self.text_encoder_2, - prefix=f"{self.text_encoder_name}_2", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} + if len(transformer_state_dict) > 0: + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=None, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=None, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def load_lora_into_transformer( @@ -1299,11 +1313,7 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer @@ -1325,7 +1335,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) class FluxLoraLoaderMixin(LoraBaseMixin): @@ -1531,23 +1541,18 @@ def load_lora_weights( raise ValueError("Invalid LoRA checkpoint.") transformer_lora_state_dict = { - k: state_dict.get(k) - for k in list(state_dict.keys()) - if k.startswith(f"{self.transformer_name}.") and "lora" in k + k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k } transformer_norm_state_dict = { k: state_dict.pop(k) for k in list(state_dict.keys()) - if k.startswith(f"{self.transformer_name}.") - and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) + if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) } transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - has_param_with_expanded_shape = False - if len(transformer_lora_state_dict) > 0: - has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( - transformer, transformer_lora_state_dict, transformer_norm_state_dict - ) + has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( + transformer, transformer_lora_state_dict, transformer_norm_state_dict + ) if has_param_with_expanded_shape: logger.info( @@ -1555,21 +1560,19 @@ def load_lora_weights( "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " "To get a comprehensive list of parameter names that were modified, enable debug logging." ) + transformer_lora_state_dict = self._maybe_expand_lora_state_dict( + transformer=transformer, lora_state_dict=transformer_lora_state_dict + ) + if len(transformer_lora_state_dict) > 0: - transformer_lora_state_dict = self._maybe_expand_lora_state_dict( - transformer=transformer, lora_state_dict=transformer_lora_state_dict + self.load_lora_into_transformer( + transformer_lora_state_dict, + network_alphas=network_alphas, + transformer=transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) - for k in transformer_lora_state_dict: - state_dict.update({k: transformer_lora_state_dict[k]}) - - self.load_lora_into_transformer( - state_dict, - network_alphas=network_alphas, - transformer=transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) if len(transformer_norm_state_dict) > 0: transformer._transformer_norm_layers = self._load_norm_into_transformer( @@ -1578,16 +1581,18 @@ def load_lora_weights( discard_original_layers=False, ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix=self.text_encoder_name, - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def load_lora_into_transformer( @@ -1620,14 +1625,17 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + keys = list(state_dict.keys()) + transformer_present = any(key.startswith(cls.transformer_name) for key in keys) + if transformer_present: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def _load_norm_into_transformer( @@ -1841,11 +1849,7 @@ def fuse_lora( ) super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): @@ -1866,7 +1870,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) # We override this here account for `_transformer_norm_layers` and `_overwritten_params`. def unload_lora_weights(self, reset_to_overwritten_params=False): @@ -2170,14 +2174,17 @@ def load_lora_into_transformer( ) # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + keys = list(state_dict.keys()) + transformer_present = any(key.startswith(cls.transformer_name) for key in keys) + if transformer_present: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2565,11 +2572,7 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): @@ -2587,7 +2590,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) class Mochi1LoraLoaderMixin(LoraBaseMixin): @@ -2873,11 +2876,7 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora @@ -2896,7 +2895,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) class LTXVideoLoraLoaderMixin(LoraBaseMixin): @@ -3182,11 +3181,7 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora @@ -3205,7 +3200,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) class SanaLoraLoaderMixin(LoraBaseMixin): @@ -3491,11 +3486,7 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora @@ -3514,7 +3505,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): @@ -3803,11 +3794,7 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora @@ -3826,7 +3813,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) class Lumina2LoraLoaderMixin(LoraBaseMixin): @@ -4116,11 +4103,7 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora @@ -4139,7 +4122,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) class WanLoraLoaderMixin(LoraBaseMixin): @@ -4152,6 +4135,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -4238,8 +4222,6 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) - if any(k.startswith("diffusion_model.") for k in state_dict): - state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: @@ -4426,320 +4408,7 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - """ - super().unfuse_lora(components=components, **kwargs) - - -class CogView4LoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See - [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel - def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False - ): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`CogView4Transformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora - def fuse_lora( - self, - components: List[str] = ["transformer"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora @@ -4758,7 +4427,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components, **kwargs) + super().unfuse_lora(components=components) class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 74e51445cc1e..aaa2fd4108b1 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -54,7 +54,6 @@ "SanaTransformer2DModel": lambda model_cls, weights: weights, "Lumina2Transformer2DModel": lambda model_cls, weights: weights, "WanTransformer3DModel": lambda model_cls, weights: weights, - "CogView4Transformer2DModel": lambda model_cls, weights: weights, } @@ -236,7 +235,10 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans raise ValueError("`network_alphas` cannot be None when `prefix` is None.") if prefix is not None: - state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} + keys = list(state_dict.keys()) + model_keys = [k for k in keys if k.startswith(f"{prefix}.")] + if len(model_keys) > 0: + state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} if len(state_dict) > 0: if adapter_name in getattr(self, "peft_config", {}): @@ -353,15 +355,6 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans _pipeline.enable_sequential_cpu_offload() # Unsafe code /> - if prefix is not None and not state_dict: - logger.warning( - f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. " - "This is safe to ignore if LoRA state dict didn't originally have any " - f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` " - "to resolve the warning. Otherwise, open an issue if you think it's unexpected: " - "https://github.com/huggingface/diffusers/issues/new" - ) - def save_lora_adapter( self, save_directory, diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index f72a0dd369f2..b7d61b3e8ff4 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -37,7 +37,6 @@ convert_ltx_vae_checkpoint_to_diffusers, convert_lumina2_to_diffusers, convert_mochi_transformer_checkpoint_to_diffusers, - convert_sana_transformer_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, convert_wan_transformer_to_diffusers, @@ -120,10 +119,6 @@ "checkpoint_mapping_fn": convert_lumina2_to_diffusers, "default_subfolder": "transformer", }, - "SanaTransformer2DModel": { - "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers, - "default_subfolder": "transformer", - }, "WanTransformer3DModel": { "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, "default_subfolder": "transformer", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 42aee4a84822..8ee7e14cb101 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -117,12 +117,6 @@ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], - "sana": [ - "blocks.0.cross_attn.q_linear.weight", - "blocks.0.cross_attn.q_linear.bias", - "blocks.0.cross_attn.kv_linear.weight", - "blocks.0.cross_attn.kv_linear.bias", - ], "wan": ["model.diffusion_model.head.modulation", "head.modulation"], "wan_vae": "decoder.middle.0.residual.0.gamma", } @@ -184,7 +178,6 @@ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"}, "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, - "sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"}, "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, @@ -676,9 +669,6 @@ def infer_diffusers_model_type(checkpoint): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): model_type = "lumina2" - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]): - model_type = "sana" - elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]): if "model.diffusion_model.patch_embedding.weight" in checkpoint: target_key = "model.diffusion_model.patch_embedding.weight" @@ -2907,111 +2897,6 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key): return converted_state_dict -def convert_sana_transformer_to_diffusers(checkpoint, **kwargs): - converted_state_dict = {} - keys = list(checkpoint.keys()) - for k in keys: - if "model.diffusion_model." in k: - checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - - num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401 - - # Positional and patch embeddings. - checkpoint.pop("pos_embed") - converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") - converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") - - # Timestep embeddings. - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") - converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight") - converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias") - - # Caption Projection. - checkpoint.pop("y_embedder.y_embedding") - converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight") - converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias") - converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight") - converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias") - converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight") - - for i in range(num_layers): - converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop( - f"blocks.{i}.scale_shift_table" - ) - - # Self-Attention - sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q]) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k]) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v]) - - # Output Projections - converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop( - f"blocks.{i}.attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop( - f"blocks.{i}.attn.proj.bias" - ) - - # Cross-Attention - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop( - f"blocks.{i}.cross_attn.q_linear.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop( - f"blocks.{i}.cross_attn.q_linear.bias" - ) - - linear_sample_k, linear_sample_v = torch.chunk( - checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0 - ) - linear_sample_k_bias, linear_sample_v_bias = torch.chunk( - checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0 - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k - converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v - converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias - converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias - - # Output Projections - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( - f"blocks.{i}.cross_attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( - f"blocks.{i}.cross_attn.proj.bias" - ) - - # MLP - converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop( - f"blocks.{i}.mlp.inverted_conv.conv.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop( - f"blocks.{i}.mlp.inverted_conv.conv.bias" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop( - f"blocks.{i}.mlp.depth_conv.conv.weight" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop( - f"blocks.{i}.mlp.depth_conv.conv.bias" - ) - converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop( - f"blocks.{i}.mlp.point_conv.conv.weight" - ) - - # Final layer - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") - converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table") - - return converted_state_dict - - def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 21d17d6acdab..b45cb2a7950d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -741,14 +741,10 @@ def prepare_attention_mask( if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave( - head_size, dim=0, output_size=attention_mask.shape[0] * head_size - ) + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) - attention_mask = attention_mask.repeat_interleave( - head_size, dim=1, output_size=attention_mask.shape[1] * head_size - ) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) return attention_mask @@ -2339,9 +2335,7 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -3710,10 +3704,8 @@ def __call__( if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads - key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head) - value = torch.repeat_interleave( - value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head - ) + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) + value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) if attn.norm_q is not None: query = attn.norm_q(query) diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 9146aa5c7c6c..1e6a26dddca8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -190,7 +190,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: x = F.pixel_shuffle(x, self.factor) if self.shortcut: - y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats) + y = hidden_states.repeat_interleave(self.repeats, dim=1) y = F.pixel_shuffle(y, self.factor) hidden_states = x + y else: @@ -361,9 +361,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.in_shortcut: - x = hidden_states.repeat_interleave( - self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats - ) + x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1) hidden_states = self.conv_in(hidden_states) + x else: hidden_states = self.conv_in(hidden_states) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index a76277366c09..f79aabe91dd3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -103,7 +103,7 @@ def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor: if self.down_sample: identity = hidden_states[:, :, ::2] elif self.up_sample: - identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2) + identity = hidden_states.repeat_interleave(2, dim=2) else: identity = hidden_states diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 2b2f77a5509d..75709ca10dfe 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -196,55 +196,6 @@ def forward( return hidden_states -class LTXVideoDownsampler3d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - stride: Union[int, Tuple[int, int, int]] = 1, - is_causal: bool = True, - padding_mode: str = "zeros", - ) -> None: - super().__init__() - - self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) - self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels - - out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) - - self.conv = LTXVideoCausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - stride=1, - is_causal=is_causal, - padding_mode=padding_mode, - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) - - residual = ( - hidden_states.unflatten(4, (-1, self.stride[2])) - .unflatten(3, (-1, self.stride[1])) - .unflatten(2, (-1, self.stride[0])) - ) - residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) - residual = residual.unflatten(1, (-1, self.group_size)) - residual = residual.mean(dim=2) - - hidden_states = self.conv(hidden_states) - hidden_states = ( - hidden_states.unflatten(4, (-1, self.stride[2])) - .unflatten(3, (-1, self.stride[1])) - .unflatten(2, (-1, self.stride[0])) - ) - hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) - hidden_states = hidden_states + residual - - return hidden_states - - class LTXVideoUpsampler3d(nn.Module): def __init__( self, @@ -253,7 +204,6 @@ def __init__( is_causal: bool = True, residual: bool = False, upscale_factor: int = 1, - padding_mode: str = "zeros", ) -> None: super().__init__() @@ -269,7 +219,6 @@ def __init__( kernel_size=3, stride=1, is_causal=is_causal, - padding_mode=padding_mode, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -403,118 +352,6 @@ def forward( return hidden_states -class LTXVideo095DownBlock3D(nn.Module): - r""" - Down block used in the LTXVideo model. - - Args: - in_channels (`int`): - Number of input channels. - out_channels (`int`, *optional*): - Number of output channels. If None, defaults to `in_channels`. - num_layers (`int`, defaults to `1`): - Number of resnet layers. - dropout (`float`, defaults to `0.0`): - Dropout rate. - resnet_eps (`float`, defaults to `1e-6`): - Epsilon value for normalization layers. - resnet_act_fn (`str`, defaults to `"swish"`): - Activation function to use. - spatio_temporal_scale (`bool`, defaults to `True`): - Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. - Whether or not to downsample across temporal dimension. - is_causal (`bool`, defaults to `True`): - Whether this layer behaves causally (future frames depend only on past frames) or not. - """ - - _supports_gradient_checkpointing = True - - def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - resnet_eps: float = 1e-6, - resnet_act_fn: str = "swish", - spatio_temporal_scale: bool = True, - is_causal: bool = True, - downsample_type: str = "conv", - ): - super().__init__() - - out_channels = out_channels or in_channels - - resnets = [] - for _ in range(num_layers): - resnets.append( - LTXVideoResnetBlock3d( - in_channels=in_channels, - out_channels=in_channels, - dropout=dropout, - eps=resnet_eps, - non_linearity=resnet_act_fn, - is_causal=is_causal, - ) - ) - self.resnets = nn.ModuleList(resnets) - - self.downsamplers = None - if spatio_temporal_scale: - self.downsamplers = nn.ModuleList() - - if downsample_type == "conv": - self.downsamplers.append( - LTXVideoCausalConv3d( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=3, - stride=(2, 2, 2), - is_causal=is_causal, - ) - ) - elif downsample_type == "spatial": - self.downsamplers.append( - LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal - ) - ) - elif downsample_type == "temporal": - self.downsamplers.append( - LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal - ) - ) - elif downsample_type == "spatiotemporal": - self.downsamplers.append( - LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal - ) - ) - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - ) -> torch.Tensor: - r"""Forward method of the `LTXDownBlock3D` class.""" - - for i, resnet in enumerate(self.resnets): - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) - else: - hidden_states = resnet(hidden_states, temb, generator) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - # Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d class LTXVideoMidBlock3d(nn.Module): r""" @@ -756,15 +593,8 @@ def __init__( in_channels: int = 3, out_channels: int = 128, block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), - down_block_types: Tuple[str, ...] = ( - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - ), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), - downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"), patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, @@ -787,37 +617,20 @@ def __init__( ) # down blocks - is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D" - num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0) + num_block_out_channels = len(block_out_channels) self.down_blocks = nn.ModuleList([]) for i in range(num_block_out_channels): input_channel = output_channel - if not is_ltx_095: - output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] - else: - output_channel = block_out_channels[i + 1] - - if down_block_types[i] == "LTXVideoDownBlock3D": - down_block = LTXVideoDownBlock3D( - in_channels=input_channel, - out_channels=output_channel, - num_layers=layers_per_block[i], - resnet_eps=resnet_norm_eps, - spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, - ) - elif down_block_types[i] == "LTXVideo095DownBlock3D": - down_block = LTXVideo095DownBlock3D( - in_channels=input_channel, - out_channels=output_channel, - num_layers=layers_per_block[i], - resnet_eps=resnet_norm_eps, - spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, - downsample_type=downsample_type[i], - ) - else: - raise ValueError(f"Unknown down block type: {down_block_types[i]}") + output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] + + down_block = LTXVideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + ) self.down_blocks.append(down_block) @@ -981,9 +794,7 @@ def __init__( # timestep embedding self.time_embedder = None self.scale_shift_table = None - self.timestep_scale_multiplier = None if timestep_conditioning: - self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) @@ -992,9 +803,6 @@ def __init__( def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) - if self.timestep_scale_multiplier is not None: - temb = temb * self.timestep_scale_multiplier - if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) @@ -1083,19 +891,12 @@ def __init__( out_channels: int = 3, latent_channels: int = 128, block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), - down_block_types: Tuple[str, ...] = ( - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - "LTXVideoDownBlock3D", - ), decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False), - downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"), upsample_residual: Tuple[bool, ...] = (False, False, False, False), upsample_factor: Tuple[int, ...] = (1, 1, 1, 1), timestep_conditioning: bool = False, @@ -1105,8 +906,6 @@ def __init__( scaling_factor: float = 1.0, encoder_causal: bool = True, decoder_causal: bool = False, - spatial_compression_ratio: int = None, - temporal_compression_ratio: int = None, ) -> None: super().__init__() @@ -1114,10 +913,8 @@ def __init__( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, - down_block_types=down_block_types, spatio_temporal_scaling=spatio_temporal_scaling, layers_per_block=layers_per_block, - downsample_type=downsample_type, patch_size=patch_size, patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, @@ -1144,16 +941,8 @@ def __init__( self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True) - self.spatial_compression_ratio = ( - patch_size * 2 ** sum(spatio_temporal_scaling) - if spatial_compression_ratio is None - else spatial_compression_ratio - ) - self.temporal_compression_ratio = ( - patch_size_t * 2 ** sum(spatio_temporal_scaling) - if temporal_compression_ratio is None - else temporal_compression_ratio - ) + self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) + self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # to perform decoding of a single video latent at a time. diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index d69ec6252b00..cd3eff73ed64 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -426,9 +426,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1] # Interleaved repeat of input channels to match w - h = inputs.repeat_interleave( - num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs - ) # [B, C * num_freqs, T, H, W] + h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W] # Scale channels by frequency. h = w * h diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py index 25348ce606d6..4edc91cacaa7 100644 --- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py @@ -687,7 +687,7 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) - emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames) + emb = emb.repeat_interleave(sample_num_frames, dim=0) # 2. pre-process batch_size, channels, num_frames, height, width = sample.shape diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 006ea8b4013f..04a0b273f1fa 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -139,9 +139,7 @@ def get_3d_sincos_pos_embed( # 3. Concat pos_embed_spatial = pos_embed_spatial[None, :, :] - pos_embed_spatial = pos_embed_spatial.repeat_interleave( - temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size - ) # [T, H*W, D // 4 * 3] + pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3] pos_embed_temporal = pos_embed_temporal[:, None, :] pos_embed_temporal = pos_embed_temporal.repeat_interleave( @@ -1154,13 +1152,10 @@ def get_1d_rotary_pos_embed( / linear_factor ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] - is_npu = freqs.device.type == "npu" - if is_npu: - freqs = freqs.float() if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox - freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: # stable audio, allegro diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 741f7075d76d..f019a3cc67a6 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -245,9 +245,6 @@ def load_model_dict_into_meta( ): param = param.to(torch.float32) set_module_kwargs["dtype"] = torch.float32 - # For quantizers have save weights using torch.float8_e4m3fn - elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None): - pass else: param = param.to(dtype) set_module_kwargs["dtype"] = dtype @@ -295,9 +292,7 @@ def load_model_dict_into_meta( elif is_quantized and ( hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device) ): - hf_quantizer.create_quantized_param( - model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype - ) + hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys) else: set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 260b4b8929b0..00b55cd9c9d6 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -366,7 +366,7 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor.contiguous()) + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 4b359021f29d..4fe1d99cb6ee 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -227,17 +227,13 @@ def forward( # Prepare text embeddings for spatial block # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 - encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave( - num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame - ).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]) + encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view( + -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] + ) # Prepare timesteps for spatial and temporal block - timestep_spatial = timestep.repeat_interleave( - num_frame, dim=0, output_size=timestep.shape[0] * num_frame - ).view(-1, timestep.shape[-1]) - timestep_temp = timestep.repeat_interleave( - num_patches, dim=0, output_size=timestep.shape[0] * num_patches - ).view(-1, timestep.shape[-1]) + timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1]) + timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1]) # Spatial and temporal transformer blocks for i, (spatial_block, temp_block) in enumerate( @@ -303,9 +299,7 @@ def forward( ).permute(0, 2, 1, 3) hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) - embedded_timestep = embedded_timestep.repeat_interleave( - num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame - ).view(-1, embedded_timestep.shape[-1]) + embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1]) shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py index 24d4e4d3d76f..fdb67384ff5e 100644 --- a/src/diffusers/models/transformers/prior_transformer.py +++ b/src/diffusers/models/transformers/prior_transformer.py @@ -353,11 +353,7 @@ def forward( attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) - attention_mask = attention_mask.repeat_interleave( - self.config.num_attention_heads, - dim=0, - output_size=attention_mask.shape[0] * self.config.num_attention_heads, - ) + attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) if self.norm_in is not None: hidden_states = self.norm_in(hidden_states) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index b8cc96d3532c..cface676b409 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -18,7 +18,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention_processor import ( Attention, @@ -195,7 +195,7 @@ def forward( return hidden_states -class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 41c4cbbf97c7..db261ca1ea4b 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import FeedForward -from ..attention_processor import Attention -from ..cache_utils import CacheMixin +from ...models.attention import FeedForward +from ...models.attention_processor import Attention +from ...models.modeling_utils import ModelMixin +from ...models.normalization import AdaLayerNormContinuous +from ...utils import logging from ..embeddings import CogView3CombinedTimestepSizeEmbeddings from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -127,8 +125,7 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape - batch_size, image_seq_length, embed_dim = hidden_states.shape + text_seq_length = encoder_hidden_states.size(1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) # 1. QKV projections @@ -158,15 +155,6 @@ def __call__( ) # 4. Attention - if attention_mask is not None: - text_attention_mask = attention_mask.float().to(query.device) - actual_text_seq_length = text_attention_mask.size(1) - new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device) - new_attention_mask[:, :actual_text_seq_length] = text_attention_mask - new_attention_mask = new_attention_mask.unsqueeze(2) - attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2) - attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype) - hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -214,8 +202,6 @@ def forward( encoder_hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, ) -> torch.Tensor: # 1. Timestep conditioning ( @@ -236,8 +222,6 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, - attention_mask=attention_mask, - **kwargs, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) @@ -304,7 +288,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens return (freqs.cos(), freqs.sin()) -class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): +class CogView4Transformer2DModel(ModelMixin, ConfigMixin): r""" Args: patch_size (`int`, defaults to `2`): @@ -399,26 +383,8 @@ def forward( original_size: torch.Tensor, target_size: torch.Tensor, crop_coords: torch.Tensor, - attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, ) -> Union[torch.Tensor, Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, height, width = hidden_states.shape # 1. RoPE @@ -438,11 +404,11 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb ) else: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs + hidden_states, encoder_hidden_states, temb, image_rotary_emb ) # 4. Output norm & projection @@ -453,10 +419,6 @@ def forward( hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index c1f2df587927..f5dc63f49562 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -14,7 +14,7 @@ # limitations under the License. import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn @@ -113,19 +113,20 @@ def __init__( self.patch_size_t = patch_size_t self.theta = theta - def _prepare_video_coords( + def forward( self, - batch_size: int, + hidden_states: torch.Tensor, num_frames: int, height: int, width: int, - rope_interpolation_scale: Tuple[torch.Tensor, float, float], - device: torch.device, - ) -> torch.Tensor: + rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.size(0) + # Always compute rope in fp32 - grid_h = torch.arange(height, dtype=torch.float32, device=device) - grid_w = torch.arange(width, dtype=torch.float32, device=device) - grid_f = torch.arange(num_frames, dtype=torch.float32, device=device) + grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device) + grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device) + grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device) grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") grid = torch.stack(grid, dim=0) grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) @@ -137,38 +138,6 @@ def _prepare_video_coords( grid = grid.flatten(2, 4).transpose(1, 2) - return grid - - def forward( - self, - hidden_states: torch.Tensor, - num_frames: Optional[int] = None, - height: Optional[int] = None, - width: Optional[int] = None, - rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, - video_coords: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size = hidden_states.size(0) - - if video_coords is None: - grid = self._prepare_video_coords( - batch_size, - num_frames, - height, - width, - rope_interpolation_scale=rope_interpolation_scale, - device=hidden_states.device, - ) - else: - grid = torch.stack( - [ - video_coords[:, 0] / self.base_num_frames, - video_coords[:, 1] / self.base_height, - video_coords[:, 2] / self.base_width, - ], - dim=-1, - ) - start = 1.0 end = self.theta freqs = self.theta ** torch.linspace( @@ -398,11 +367,10 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, - num_frames: Optional[int] = None, - height: Optional[int] = None, - width: Optional[int] = None, - rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None, - video_coords: Optional[torch.Tensor] = None, + num_frames: int, + height: int, + width: int, + rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: @@ -421,7 +389,7 @@ def forward( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) - image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords) + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 4eb4add37601..66cdda388c06 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -441,14 +441,6 @@ def forward( # 5. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) - - # Move the shift and scale tensors to the same device as hidden_states. - # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. - shift = shift.to(hidden_states.device) - scale = scale.to(hidden_states.device) - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index a148cf6cbe06..845d93b9db09 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -638,10 +638,8 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) - emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) - encoder_hidden_states = encoder_hidden_states.repeat_interleave( - num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames - ) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) # 2. pre-process sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py index c275e16744f4..f0eca75de169 100644 --- a/src/diffusers/models/unets/unet_i2vgen_xl.py +++ b/src/diffusers/models/unets/unet_i2vgen_xl.py @@ -592,7 +592,7 @@ def forward( # 3. time + FPS embeddings. emb = t_emb + fps_emb - emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) # 4. context embeddings. # The context embeddings consist of both text embeddings from the input prompt @@ -620,7 +620,7 @@ def forward( image_emb = self.context_embedding(image_embeddings) image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim) context_emb = torch.cat([context_emb, image_emb], dim=1) - context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames) + context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape( image_latents.shape[0] * image_latents.shape[2], diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index bd83024c9b7c..21e4db23a166 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2059,7 +2059,7 @@ def forward( aug_emb = self.add_embedding(add_embeds) emb = emb if aug_emb is None else emb + aug_emb - emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: @@ -2068,10 +2068,7 @@ def forward( ) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds) - image_embeds = [ - image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames) - for image_embed in image_embeds - ] + image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds] encoder_hidden_states = (encoder_hidden_states, image_embeds) # 2. pre-process diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 059a6e807c8e..db4ace9656a3 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -431,11 +431,9 @@ def forward( sample = sample.flatten(0, 1) # Repeat the embeddings num_video_frames times # emb: [batch, channels] -> [batch * frames, channels] - emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) + emb = emb.repeat_interleave(num_frames, dim=0) # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] - encoder_hidden_states = encoder_hidden_states.repeat_interleave( - num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames - ) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 6b714d31c0e3..8b76e109e754 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -154,7 +154,7 @@ "CogVideoXFunControlPipeline", ] _import_structure["cogview3"] = ["CogView3PlusPipeline"] - _import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"] + _import_structure["cogview4"] = ["CogView4Pipeline"] _import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["controlnet"].extend( [ @@ -264,9 +264,9 @@ ] ) _import_structure["latte"] = ["LattePipeline"] - _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"] - _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] - _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] + _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"] + _import_structure["lumina"] = ["LuminaText2ImgPipeline"] + _import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"] _import_structure["marigold"].extend( [ "MarigoldDepthPipeline", @@ -511,7 +511,7 @@ CogVideoXVideoToVideoPipeline, ) from .cogview3 import CogView3PlusPipeline - from .cogview4 import CogView4ControlPipeline, CogView4Pipeline + from .cogview4 import CogView4Pipeline from .consisid import ConsisIDPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, @@ -618,9 +618,9 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) - from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline - from .lumina import LuminaPipeline, LuminaText2ImgPipeline - from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline + from .ltx import LTXImageToVideoPipeline, LTXPipeline + from .lumina import LuminaText2ImgPipeline + from .lumina2 import Lumina2Text2ImgPipeline from .marigold import ( MarigoldDepthPipeline, MarigoldIntrinsicsPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 6a5f6098b6fb..4f760ee09add 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -22,7 +22,7 @@ from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline from .cogview3 import CogView3PlusPipeline -from .cogview4 import CogView4ControlPipeline, CogView4Pipeline +from .cogview4 import CogView4Pipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, @@ -69,8 +69,8 @@ ) from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline -from .lumina import LuminaPipeline -from .lumina2 import Lumina2Pipeline +from .lumina import LuminaText2ImgPipeline +from .lumina2 import Lumina2Text2ImgPipeline from .pag import ( HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, @@ -141,11 +141,10 @@ ("flux", FluxPipeline), ("flux-control", FluxControlPipeline), ("flux-controlnet", FluxControlNetPipeline), - ("lumina", LuminaPipeline), - ("lumina2", Lumina2Pipeline), + ("lumina", LuminaText2ImgPipeline), + ("lumina2", Lumina2Text2ImgPipeline), ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), - ("cogview4-control", CogView4ControlPipeline), ] ) diff --git a/src/diffusers/pipelines/cogview4/__init__.py b/src/diffusers/pipelines/cogview4/__init__.py index 6a365e17fee7..5a535b3feb4b 100644 --- a/src/diffusers/pipelines/cogview4/__init__.py +++ b/src/diffusers/pipelines/cogview4/__init__.py @@ -23,7 +23,6 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_cogview4"] = ["CogView4Pipeline"] - _import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -32,7 +31,6 @@ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_cogview4 import CogView4Pipeline - from .pipeline_cogview4_control import CogView4ControlPipeline else: import sys diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index c27a1a19774d..6005c419b5c2 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -14,7 +14,7 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -22,7 +22,6 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor -from ...loaders import CogView4LoraLoaderMixin from ...models import AutoencoderKL, CogView4Transformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -134,7 +133,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): +class CogView4Pipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using CogView4. @@ -389,14 +388,6 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps - @property - def attention_kwargs(self): - return self._attention_kwargs - - @property - def current_timestep(self): - return self._current_timestep - @property def interrupt(self): return self._interrupt @@ -422,7 +413,6 @@ def __call__( crops_coords_top_left: Tuple[int, int] = (0, 0), output_type: str = "pil", return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -536,8 +526,6 @@ def __call__( negative_prompt_embeds, ) self._guidance_scale = guidance_scale - self._attention_kwargs = attention_kwargs - self._current_timestep = None self._interrupt = False # Default call parameters @@ -615,7 +603,6 @@ def __call__( if self.interrupt: continue - self._current_timestep = t latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -628,7 +615,6 @@ def __call__( original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -641,7 +627,6 @@ def __call__( original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -667,8 +652,6 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - self._current_timestep = None - if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False, generator=generator)[0] diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py deleted file mode 100644 index b22705ed05c9..000000000000 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ /dev/null @@ -1,727 +0,0 @@ -# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from transformers import AutoTokenizer, GlmModel - -from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...models import AutoencoderKL, CogView4Transformer2DModel -from ...pipelines.pipeline_utils import DiffusionPipeline -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor -from .pipeline_output import CogView4PipelineOutput - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```python - >>> import torch - >>> from diffusers import CogView4ControlPipeline - - >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16) - >>> control_image = load_image( - ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ... ) - >>> prompt = "A bird in space" - >>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0] - >>> image.save("cogview4-control.png") - ``` -""" - - -# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - base_shift: float = 0.25, - max_shift: float = 0.75, -) -> float: - m = (image_seq_len / base_seq_len) ** 0.5 - mu = m * max_shift + base_shift - return mu - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class CogView4ControlPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using CogView4. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`GLMModel`]): - Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf). - tokenizer (`PreTrainedTokenizer`): - Tokenizer of class - [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). - transformer ([`CogView4Transformer2DModel`]): - A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - """ - - _optional_components = [] - model_cpu_offload_seq = "text_encoder->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - - def __init__( - self, - tokenizer: AutoTokenizer, - text_encoder: GlmModel, - vae: AutoencoderKL, - transformer: CogView4Transformer2DModel, - scheduler: FlowMatchEulerDiscreteScheduler, - ): - super().__init__() - - self.register_modules( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - - # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds - def _get_glm_embeds( - self, - prompt: Union[str, List[str]] = None, - max_sequence_length: int = 1024, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - - text_inputs = self.tokenizer( - prompt, - padding="longest", # not use max length - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - current_length = text_input_ids.shape[1] - pad_length = (16 - (current_length % 16)) % 16 - if pad_length > 0: - pad_ids = torch.full( - (text_input_ids.shape[0], pad_length), - fill_value=self.tokenizer.pad_token_id, - dtype=text_input_ids.dtype, - device=text_input_ids.device, - ) - text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) - prompt_embeds = self.text_encoder( - text_input_ids.to(self.text_encoder.device), output_hidden_states=True - ).hidden_states[-2] - - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - return prompt_embeds - - # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - do_classifier_free_guidance: bool = True, - num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - max_sequence_length: int = 1024, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. - num_images_per_prompt (`int`, *optional*, defaults to 1): - Number of images that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - max_sequence_length (`int`, defaults to `1024`): - Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. - """ - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype) - - seq_len = prompt_embeds.size(1) - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype) - - seq_len = negative_prompt_embeds.size(1) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds, negative_prompt_embeds - - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - if latents is not None: - return latents.to(device) - - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - return latents - - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - if isinstance(image, torch.Tensor): - pass - else: - image = self.image_processor.preprocess(image, height=height, width=width) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - def check_inputs( - self, - prompt, - height, - width, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 16 != 0 or width % 16 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - @property - def guidance_scale(self): - return self._guidance_scale - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def attention_kwargs(self): - return self._attention_kwargs - - @property - def current_timestep(self): - return self._current_timestep - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, - control_image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - guidance_scale: float = 5.0, - num_images_per_prompt: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - original_size: Optional[Tuple[int, int]] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - output_type: str = "pil", - return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 1024, - ) -> Union[CogView4PipelineOutput, Tuple]: - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. If not provided, it is set to 1024. - width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. If not provided it is set to 1024. - num_inference_steps (`int`, *optional*, defaults to `50`): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - sigmas (`List[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - guidance_scale (`float`, *optional*, defaults to `5.0`): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to `1`): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. - `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as - explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): - `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position - `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting - `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. - attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int`, defaults to `224`): - Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. - - Examples: - - Returns: - [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: - [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is a list with the generated images. - """ - - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - - height = height or self.transformer.config.sample_size * self.vae_scale_factor - width = width or self.transformer.config.sample_size * self.vae_scale_factor - - original_size = original_size or (height, width) - target_size = (height, width) - - # Check inputs. Raise error if not correct - self.check_inputs( - prompt, - height, - width, - negative_prompt, - callback_on_step_end_tensor_inputs, - prompt_embeds, - negative_prompt_embeds, - ) - self._guidance_scale = guidance_scale - self._attention_kwargs = attention_kwargs - self._current_timestep = None - self._interrupt = False - - # Default call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - # Encode input prompt - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - negative_prompt, - self.do_classifier_free_guidance, - num_images_per_prompt=num_images_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=max_sequence_length, - device=device, - ) - - # Prepare latents - latent_channels = self.transformer.config.in_channels // 2 - - control_image = self.prepare_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.vae.dtype, - ) - height, width = control_image.shape[-2:] - - vae_shift_factor = 0 - - control_image = self.vae.encode(control_image).latent_dist.sample() - control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor - - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - latent_channels, - height, - width, - torch.float32, - device, - generator, - latents, - ) - - # Prepare additional timestep conditions - original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) - target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) - crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) - - original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) - target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) - crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) - - # Prepare timesteps - image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( - self.transformer.config.patch_size**2 - ) - - timesteps = ( - np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps) - if timesteps is None - else np.array(timesteps) - ) - timesteps = timesteps.astype(np.int64).astype(np.float32) - sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas - mu = calculate_shift( - image_seq_len, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("base_shift", 0.25), - self.scheduler.config.get("max_shift", 0.75), - ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu - ) - self._num_timesteps = len(timesteps) - # Denoising loop - transformer_dtype = self.transformer.dtype - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]) - - noise_pred_cond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=negative_prompt_embeds, - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) - else: - noise_pred = noise_pred_cond - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - # call the callback, if provided - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - self._current_timestep = None - - if not output_type == "latent": - latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor - image = self.vae.decode(latents, return_dict=False, generator=generator)[0] - else: - image = latents - - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return CogView4PipelineOutput(images=image) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index f3f1d90204d6..eee41b9af4d1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -63,7 +63,6 @@ >>> from diffusers import FluxControlNetPipeline >>> from diffusers import FluxControlNetModel - >>> base_model = "black-forest-labs/FLUX.1-dev" >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny" >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) >>> pipe = FluxControlNetPipeline.from_pretrained( diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py index 199e730d9b4d..20cc1c216522 100644 --- a/src/diffusers/pipelines/ltx/__init__.py +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -23,7 +23,6 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_ltx"] = ["LTXPipeline"] - _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"] _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -35,7 +34,6 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_ltx import LTXPipeline - from .pipeline_ltx_condition import LTXConditionPipeline from .pipeline_ltx_image2video import LTXImageToVideoPipeline else: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index f7b0811d1a22..866be61810a9 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -694,8 +694,9 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions + latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio rope_interpolation_scale = ( - self.vae_temporal_compression_ratio / frame_rate, + 1 / latent_frame_rate, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio, ) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py deleted file mode 100644 index e7f3666cb2c7..000000000000 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ /dev/null @@ -1,1174 +0,0 @@ -# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union - -import PIL.Image -import torch -from transformers import T5EncoderModel, T5TokenizerFast - -from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import PipelineImageInput -from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin -from ...models.autoencoders import AutoencoderKLLTXVideo -from ...models.transformers import LTXVideoTransformer3DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor -from ...video_processor import VideoProcessor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import LTXPipelineOutput - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition - >>> from diffusers.utils import export_to_video, load_video, load_image - - >>> pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") - - >>> # Load input image and video - >>> video = load_video( - ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" - ... ) - >>> image = load_image( - ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg" - ... ) - - >>> # Create conditioning objects - >>> condition1 = LTXVideoCondition( - ... image=image, - ... frame_index=0, - ... ) - >>> condition2 = LTXVideoCondition( - ... video=video, - ... frame_index=80, - ... ) - - >>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." - >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" - - >>> # Generate video - >>> generator = torch.Generator("cuda").manual_seed(0) - >>> video = pipe( - ... conditions=[condition1, condition2], - ... prompt=prompt, - ... negative_prompt=negative_prompt, - ... width=768, - ... height=512, - ... num_frames=161, - ... num_inference_steps=40, - ... generator=generator, - ... ).frames[0] - - >>> export_to_video(video, "output.mp4", fps=24) - ``` -""" - - -@dataclass -class LTXVideoCondition: - """ - Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames. - - Attributes: - image (`PIL.Image.Image`): - The image to condition the video on. - video (`List[PIL.Image.Image]`): - The video to condition the video on. - frame_index (`int`): - The frame index at which the image or video will conditionally effect the video generation. - strength (`float`, defaults to `1.0`): - The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied. - """ - - image: Optional[PIL.Image.Image] = None - video: Optional[List[PIL.Image.Image]] = None - frame_index: int = 0 - strength: float = 1.0 - - -# from LTX-Video/ltx_video/schedulers/rf.py -def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None): - if linear_steps is None: - linear_steps = num_steps // 2 - if num_steps < 2: - return torch.tensor([1.0]) - linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] - threshold_noise_step_diff = linear_steps - threshold_noise * num_steps - quadratic_steps = num_steps - linear_steps - quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) - linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) - const = quadratic_coef * (linear_steps**2) - quadratic_sigma_schedule = [ - quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) - ] - sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] - sigma_schedule = [1.0 - x for x in sigma_schedule] - return torch.tensor(sigma_schedule[:-1]) - - -# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.15, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): - r""" - Pipeline for image-to-video generation. - - Reference: https://github.com/Lightricks/LTX-Video - - Args: - transformer ([`LTXVideoTransformer3DModel`]): - Conditional Transformer architecture to denoise the encoded video latents. - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKLLTXVideo`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer (`T5TokenizerFast`): - Second Tokenizer of class - [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). - """ - - model_cpu_offload_seq = "text_encoder->transformer->vae" - _optional_components = [] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - - def __init__( - self, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKLLTXVideo, - text_encoder: T5EncoderModel, - tokenizer: T5TokenizerFast, - transformer: LTXVideoTransformer3DModel, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - transformer=transformer, - scheduler=scheduler, - ) - - self.vae_spatial_compression_ratio = ( - self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 - ) - self.vae_temporal_compression_ratio = ( - self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 - ) - self.transformer_spatial_patch_size = ( - self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 - ) - self.transformer_temporal_patch_size = ( - self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 - ) - - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128 - ) - - self.default_height = 512 - self.default_width = 704 - self.default_frames = 121 - - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 256, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.bool().to(device) - - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) - prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) - - return prompt_embeds, prompt_attention_mask - - # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt - def encode_prompt( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - do_classifier_free_guidance: bool = True, - num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - max_sequence_length: int = 256, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - """ - device = device or self._execution_device - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask - - def check_inputs( - self, - prompt, - conditions, - image, - video, - frame_index, - strength, - height, - width, - callback_on_step_end_tensor_inputs=None, - prompt_embeds=None, - negative_prompt_embeds=None, - prompt_attention_mask=None, - negative_prompt_attention_mask=None, - ): - if height % 32 != 0 or width % 32 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if prompt_embeds is not None and prompt_attention_mask is None: - raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") - - if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: - raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: - raise ValueError( - "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" - f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" - f" {negative_prompt_attention_mask.shape}." - ) - - if conditions is not None and (image is not None or video is not None): - raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.") - - if conditions is None and (image is None and video is None): - raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.") - - if conditions is None: - if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index): - raise ValueError( - "If `conditions` is not provided, `image` and `frame_index` must be of the same length." - ) - elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength): - raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.") - elif isinstance(video, list) and isinstance(frame_index, list) and len(video) != len(frame_index): - raise ValueError( - "If `conditions` is not provided, `video` and `frame_index` must be of the same length." - ) - elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength): - raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.") - - @staticmethod - def _prepare_video_ids( - batch_size: int, - num_frames: int, - height: int, - width: int, - patch_size: int = 1, - patch_size_t: int = 1, - device: torch.device = None, - ) -> torch.Tensor: - latent_sample_coords = torch.meshgrid( - torch.arange(0, num_frames, patch_size_t, device=device), - torch.arange(0, height, patch_size, device=device), - torch.arange(0, width, patch_size, device=device), - indexing="ij", - ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) - - return latent_coords - - @staticmethod - def _scale_video_ids( - video_ids: torch.Tensor, - scale_factor: int = 32, - scale_factor_t: int = 8, - frame_index: int = 0, - device: torch.device = None, - ) -> torch.Tensor: - scaled_latent_coords = ( - video_ids - * torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None] - ) - scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0) - scaled_latent_coords[:, 0] += frame_index - - return scaled_latent_coords - - @staticmethod - # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents - def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: - # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. - # The patch dimensions are then permuted and collapsed into the channel dimension of shape: - # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). - # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features - batch_size, num_channels, num_frames, height, width = latents.shape - post_patch_num_frames = num_frames // patch_size_t - post_patch_height = height // patch_size - post_patch_width = width // patch_size - latents = latents.reshape( - batch_size, - -1, - post_patch_num_frames, - patch_size_t, - post_patch_height, - patch_size, - post_patch_width, - patch_size, - ) - latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) - return latents - - @staticmethod - # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents - def _unpack_latents( - latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 - ) -> torch.Tensor: - # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) - # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of - # what happens in the `_pack_latents` method. - batch_size = latents.size(0) - latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) - latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) - return latents - - @staticmethod - # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents - def _normalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 - ) -> torch.Tensor: - # Normalize latents across the channel dimension [B, C, F, H, W] - latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents = (latents - latents_mean) * scaling_factor / latents_std - return latents - - @staticmethod - # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents - def _denormalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 - ) -> torch.Tensor: - # Denormalize latents across the channel dimension [B, C, F, H, W] - latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents = latents * latents_std / scaling_factor + latents_mean - return latents - - def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int): - """ - Trim a conditioning sequence to the allowed number of frames. - - Args: - start_frame (int): The target frame number of the first frame in the sequence. - sequence_num_frames (int): The number of frames in the sequence. - target_num_frames (int): The target number of frames in the generated video. - Returns: - int: updated sequence length - """ - scale_factor = self.vae_temporal_compression_ratio - num_frames = min(sequence_num_frames, target_num_frames - start_frame) - # Trim down to a multiple of temporal_scale_factor frames plus 1 - num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 - return num_frames - - @staticmethod - def add_noise_to_image_conditioning_latents( - t: float, - init_latents: torch.Tensor, - latents: torch.Tensor, - noise_scale: float, - conditioning_mask: torch.Tensor, - generator, - eps=1e-6, - ): - """ - Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially - when conditioned on a single frame. - """ - noise = randn_tensor( - latents.shape, - generator=generator, - device=latents.device, - dtype=latents.dtype, - ) - # Add noise only to hard-conditioning latents (conditioning_mask = 1.0) - need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1) - noised_latents = init_latents + noise_scale * noise * (t**2) - latents = torch.where(need_to_noise, noised_latents, latents) - return latents - - def prepare_latents( - self, - conditions: List[torch.Tensor], - condition_strength: List[float], - condition_frame_index: List[int], - batch_size: int = 1, - num_channels_latents: int = 128, - height: int = 512, - width: int = 704, - num_frames: int = 161, - num_prefix_latent_frames: int = 2, - generator: Optional[torch.Generator] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - - shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) - - extra_conditioning_latents = [] - extra_conditioning_video_ids = [] - extra_conditioning_mask = [] - extra_conditioning_num_latents = 0 - for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index): - condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) - condition_latents = self._normalize_latents( - condition_latents, self.vae.latents_mean, self.vae.latents_std - ).to(device, dtype=dtype) - - num_data_frames = data.size(2) - num_cond_frames = condition_latents.size(2) - - if frame_index == 0: - latents[:, :, :num_cond_frames] = torch.lerp( - latents[:, :, :num_cond_frames], condition_latents, strength - ) - condition_latent_frames_mask[:, :num_cond_frames] = strength - - else: - if num_data_frames > 1: - if num_cond_frames < num_prefix_latent_frames: - raise ValueError( - f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}." - ) - - if num_cond_frames > num_prefix_latent_frames: - start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames - end_frame = start_frame + num_cond_frames - num_prefix_latent_frames - latents[:, :, start_frame:end_frame] = torch.lerp( - latents[:, :, start_frame:end_frame], - condition_latents[:, :, num_prefix_latent_frames:], - strength, - ) - condition_latent_frames_mask[:, start_frame:end_frame] = strength - condition_latents = condition_latents[:, :, :num_prefix_latent_frames] - - noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) - condition_latents = torch.lerp(noise, condition_latents, strength) - - condition_video_ids = self._prepare_video_ids( - batch_size, - condition_latents.size(2), - latent_height, - latent_width, - patch_size=self.transformer_spatial_patch_size, - patch_size_t=self.transformer_temporal_patch_size, - device=device, - ) - condition_video_ids = self._scale_video_ids( - condition_video_ids, - scale_factor=self.vae_spatial_compression_ratio, - scale_factor_t=self.vae_temporal_compression_ratio, - frame_index=frame_index, - device=device, - ) - condition_latents = self._pack_latents( - condition_latents, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - condition_conditioning_mask = torch.full( - condition_latents.shape[:2], strength, device=device, dtype=dtype - ) - - extra_conditioning_latents.append(condition_latents) - extra_conditioning_video_ids.append(condition_video_ids) - extra_conditioning_mask.append(condition_conditioning_mask) - extra_conditioning_num_latents += condition_latents.size(1) - - video_ids = self._prepare_video_ids( - batch_size, - num_latent_frames, - latent_height, - latent_width, - patch_size_t=self.transformer_temporal_patch_size, - patch_size=self.transformer_spatial_patch_size, - device=device, - ) - conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) - video_ids = self._scale_video_ids( - video_ids, - scale_factor=self.vae_spatial_compression_ratio, - scale_factor_t=self.vae_temporal_compression_ratio, - frame_index=0, - device=device, - ) - latents = self._pack_latents( - latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) - - if len(extra_conditioning_latents) > 0: - latents = torch.cat([*extra_conditioning_latents, latents], dim=1) - video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2) - conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1) - - return latents, conditioning_mask, video_ids, extra_conditioning_num_latents - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1.0 - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def attention_kwargs(self): - return self._attention_kwargs - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - conditions: Union[LTXVideoCondition, List[LTXVideoCondition]] = None, - image: Union[PipelineImageInput, List[PipelineImageInput]] = None, - video: List[PipelineImageInput] = None, - frame_index: Union[int, List[int]] = 0, - strength: Union[float, List[float]] = 1.0, - prompt: Union[str, List[str]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 512, - width: int = 704, - num_frames: int = 161, - frame_rate: int = 25, - num_inference_steps: int = 50, - timesteps: List[int] = None, - guidance_scale: float = 3, - image_cond_noise_scale: float = 0.15, - num_videos_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - decode_timestep: Union[float, List[float]] = 0.0, - decode_noise_scale: Optional[Union[float, List[float]]] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 256, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - conditions (`List[LTXVideoCondition], *optional*`): - The list of frame-conditioning items for the video generation.If not provided, conditions will be - created using `image`, `video`, `frame_index` and `strength`. - image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*): - The image or images to condition the video generation. If not provided, one has to pass `video` or - `conditions`. - video (`List[PipelineImageInput]`, *optional*): - The video to condition the video generation. If not provided, one has to pass `image` or `conditions`. - frame_index (`int` or `List[int]`, *optional*): - The frame index or frame indices at which the image or video will conditionally effect the video - generation. If not provided, one has to pass `conditions`. - strength (`float` or `List[float]`, *optional*): - The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`. - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`int`, defaults to `512`): - The height in pixels of the generated image. This is set to 480 by default for the best results. - width (`int`, defaults to `704`): - The width in pixels of the generated image. This is set to 848 by default for the best results. - num_frames (`int`, defaults to `161`): - The number of video frames to generate - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - guidance_scale (`float`, defaults to `3 `): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - prompt_attention_mask (`torch.Tensor`, *optional*): - Pre-generated attention mask for text embeddings. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not - provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): - Pre-generated attention mask for negative text embeddings. - decode_timestep (`float`, defaults to `0.0`): - The timestep at which generated video is decoded. - decode_noise_scale (`float`, defaults to `None`): - The interpolation factor between random noise and denoised latents at the decode timestep. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. - attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to `128 `): - Maximum sequence length to use with the `prompt`. - - Examples: - - Returns: - [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is - returned where the first element is a list with the generated images. - """ - - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - if latents is not None: - raise ValueError("Passing latents is not yet supported.") - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt=prompt, - conditions=conditions, - image=image, - video=video, - frame_index=frame_index, - strength=strength, - height=height, - width=width, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - ) - - self._guidance_scale = guidance_scale - self._attention_kwargs = attention_kwargs - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if conditions is not None: - if not isinstance(conditions, list): - conditions = [conditions] - - strength = [condition.strength for condition in conditions] - frame_index = [condition.frame_index for condition in conditions] - image = [condition.image for condition in conditions] - video = [condition.video for condition in conditions] - else: - if not isinstance(image, list): - image = [image] - num_conditions = 1 - elif isinstance(image, list): - num_conditions = len(image) - if not isinstance(video, list): - video = [video] - num_conditions = 1 - elif isinstance(video, list): - num_conditions = len(video) - - if not isinstance(frame_index, list): - frame_index = [frame_index] * num_conditions - if not isinstance(strength, list): - strength = [strength] * num_conditions - - device = self._execution_device - - # 3. Prepare text embeddings - ( - prompt_embeds, - prompt_attention_mask, - negative_prompt_embeds, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - max_sequence_length=max_sequence_length, - device=device, - ) - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) - - vae_dtype = self.vae.dtype - - conditioning_tensors = [] - for condition_image, condition_video, condition_frame_index, condition_strength in zip( - image, video, frame_index, strength - ): - if condition_image is not None: - condition_tensor = ( - self.video_processor.preprocess(condition_image, height, width) - .unsqueeze(2) - .to(device, dtype=vae_dtype) - ) - elif condition_video is not None: - condition_tensor = self.video_processor.preprocess_video(condition_video, height, width) - num_frames_input = condition_tensor.size(2) - num_frames_output = self.trim_conditioning_sequence( - condition_frame_index, num_frames_input, num_frames - ) - condition_tensor = condition_tensor[:, :, :num_frames_output] - condition_tensor = condition_tensor.to(device, dtype=vae_dtype) - else: - raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") - - if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1: - raise ValueError( - f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " - f"but got {condition_tensor.size(2)} frames." - ) - conditioning_tensors.append(condition_tensor) - - # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents( - conditioning_tensors, - strength, - frame_index, - batch_size=batch_size * num_videos_per_prompt, - num_channels_latents=num_channels_latents, - height=height, - width=width, - num_frames=num_frames, - generator=generator, - device=device, - dtype=torch.float32, - ) - - video_coords = video_coords.float() - video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate) - - init_latents = latents.clone() - - if self.do_classifier_free_guidance: - video_coords = torch.cat([video_coords, video_coords], dim=0) - - # 5. Prepare timesteps - latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 - latent_height = height // self.vae_spatial_compression_ratio - latent_width = width // self.vae_spatial_compression_ratio - sigmas = linear_quadratic_schedule(num_inference_steps) - timesteps = sigmas * 1000 - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - timesteps=timesteps, - ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - if image_cond_noise_scale > 0: - # Add timestep-dependent noise to the hard-conditioning latents - # This helps with motion continuity, especially when conditioned on a single frame - latents = self.add_noise_to_image_conditioning_latents( - t / 1000.0, - init_latents, - latents, - image_cond_noise_scale, - conditioning_mask, - generator, - ) - - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - conditioning_mask_model_input = ( - torch.cat([conditioning_mask, conditioning_mask]) - if self.do_classifier_free_guidance - else conditioning_mask - ) - latent_model_input = latent_model_input.to(prompt_embeds.dtype) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() - timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) - - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - video_coords=video_coords, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - timestep, _ = timestep.chunk(2) - - denoised_latents = self.scheduler.step( - -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False - )[0] - tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) - latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - latents = latents[:, extra_conditioning_num_latents:] - latents = self._unpack_latents( - latents, - latent_num_frames, - latent_height, - latent_width, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - - if output_type == "latent": - video = latents - else: - latents = self._denormalize_latents( - latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor - ) - latents = latents.to(prompt_embeds.dtype) - - if not self.vae.config.timestep_conditioning: - timestep = None - else: - noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) - if not isinstance(decode_timestep, list): - decode_timestep = [decode_timestep] * batch_size - if decode_noise_scale is None: - decode_noise_scale = decode_timestep - elif not isinstance(decode_noise_scale, list): - decode_noise_scale = [decode_noise_scale] * batch_size - - timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) - decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ - :, None, None, None, None - ] - latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise - - video = self.vae.decode(latents, timestep, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (video,) - - return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 6c4214fe1b26..0577a56ec13d 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -764,8 +764,9 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Prepare micro-conditions + latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio rope_interpolation_scale = ( - self.vae_temporal_compression_ratio / frame_rate, + 1 / latent_frame_rate, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio, ) diff --git a/src/diffusers/pipelines/lumina/__init__.py b/src/diffusers/pipelines/lumina/__init__.py index a19dc7e94641..ca1396359721 100644 --- a/src/diffusers/pipelines/lumina/__init__.py +++ b/src/diffusers/pipelines/lumina/__init__.py @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] + _import_structure["pipeline_lumina"] = ["LuminaText2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -32,7 +32,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_lumina import LuminaPipeline, LuminaText2ImgPipeline + from .pipeline_lumina import LuminaText2ImgPipeline else: import sys diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 816213f105cb..b50079532f94 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -30,7 +30,6 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( BACKENDS_MAPPING, - deprecate, is_bs4_available, is_ftfy_available, is_torch_xla_available, @@ -61,9 +60,11 @@ Examples: ```py >>> import torch - >>> from diffusers import LuminaPipeline + >>> from diffusers import LuminaText2ImgPipeline - >>> pipe = LuminaPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16) + >>> pipe = LuminaText2ImgPipeline.from_pretrained( + ... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 + ... ) >>> # Enable memory optimizations. >>> pipe.enable_model_cpu_offload() @@ -133,7 +134,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class LuminaPipeline(DiffusionPipeline): +class LuminaText2ImgPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Lumina-T2I. @@ -931,23 +932,3 @@ def __call__( return (image,) return ImagePipelineOutput(images=image) - - -class LuminaText2ImgPipeline(LuminaPipeline): - def __init__( - self, - transformer: LuminaNextDiT2DModel, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - text_encoder: GemmaPreTrainedModel, - tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], - ): - deprecation_message = "`LuminaText2ImgPipeline` has been renamed to `LuminaPipeline` and will be removed in a future version. Please use `LuminaPipeline` instead." - deprecate("diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline", "0.34", deprecation_message) - super().__init__( - transformer=transformer, - scheduler=scheduler, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - ) diff --git a/src/diffusers/pipelines/lumina2/__init__.py b/src/diffusers/pipelines/lumina2/__init__.py index b1d6bfeb0d58..0e51a768a785 100644 --- a/src/diffusers/pipelines/lumina2/__init__.py +++ b/src/diffusers/pipelines/lumina2/__init__.py @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] + _import_structure["pipeline_lumina2"] = ["Lumina2Text2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -32,7 +32,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: - from .pipeline_lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline + from .pipeline_lumina2 import Lumina2Text2ImgPipeline else: import sys diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index e0905a2f131f..514192cb70c7 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -25,7 +25,6 @@ from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - deprecate, is_torch_xla_available, logging, replace_example_docstring, @@ -48,9 +47,9 @@ Examples: ```py >>> import torch - >>> from diffusers import Lumina2Pipeline + >>> from diffusers import Lumina2Text2ImgPipeline - >>> pipe = Lumina2Pipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16) + >>> pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16) >>> # Enable memory optimizations. >>> pipe.enable_model_cpu_offload() @@ -134,7 +133,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin): +class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin): r""" Pipeline for text-to-image generation using Lumina-T2I. @@ -768,23 +767,3 @@ def __call__( return (image,) return ImagePipelineOutput(images=image) - - -class Lumina2Text2ImgPipeline(Lumina2Pipeline): - def __init__( - self, - transformer: Lumina2Transformer2DModel, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - text_encoder: Gemma2PreTrainedModel, - tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], - ): - deprecation_message = "`Lumina2Text2ImgPipeline` has been renamed to `Lumina2Pipeline` and will be removed in a future version. Please use `Lumina2Pipeline` instead." - deprecate("diffusers.pipelines.lumina2.pipeline_lumina2.Lumina2Text2ImgPipeline", "0.34", deprecation_message) - super().__init__( - transformer=transformer, - scheduler=scheduler, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - ) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index e80325ed42b0..07da8b5e2e2e 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -667,12 +667,9 @@ def load_sub_model( use_safetensors: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, - quantization_config: Optional[Any] = None, ): """Helper method to load the module `name` from `library_name` and `class_name`""" - from ..quantizers import PipelineQuantizationConfig - # retrieve class candidates class_obj, class_candidates = get_class_obj_and_candidates( @@ -764,17 +761,6 @@ def load_sub_model( else: loading_kwargs["low_cpu_mem_usage"] = False - if ( - quantization_config is not None - and isinstance(quantization_config, PipelineQuantizationConfig) - and issubclass(class_obj, torch.nn.Module) - ): - model_quant_config = quantization_config._resolve_quant_config( - is_diffusers=is_diffusers_model, module_name=name - ) - if model_quant_config is not None: - loading_kwargs["quantization_config"] = model_quant_config - # check if the module is in a subdirectory if dduf_entries: loading_kwargs["dduf_entries"] = dduf_entries diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 040eb8e8c74f..cb60350be1b0 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -702,7 +702,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_safetensors = kwargs.pop("use_safetensors", None) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) - quantization_config = kwargs.pop("quantization_config", None) if not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 @@ -875,9 +874,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P } init_kwargs = {**init_kwargs, **passed_pipe_kwargs} - # TODO: add checking for quantization_config `mapping` i.e., if the modules specified there actually exist. - ######################### - # remove `null` components def load_module(name, value): if value[0] is None: @@ -977,7 +973,6 @@ def load_module(name, value): use_safetensors=use_safetensors, dduf_entries=dduf_entries, provider_options=provider_options, - quantization_config=quantization_config, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." @@ -1615,7 +1610,7 @@ def _get_signature_keys(cls, obj): expected_modules.add(name) optional_parameters.remove(name) - return sorted(expected_modules), sorted(optional_parameters) + return expected_modules, optional_parameters @classmethod def _get_signature_types(cls): @@ -1657,12 +1652,10 @@ def components(self) -> Dict[str, Any]: k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters } - actual = sorted(set(components.keys())) - expected = sorted(expected_modules) - if actual != expected: + if set(components.keys()) != expected_modules: raise ValueError( f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" - f" {expected} to be defined, but {actual} are defined." + f" {expected_modules} to be defined, but {components.keys()} are defined." ) return components diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index e5699718ea71..863178e7c434 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -108,7 +108,6 @@ def prompt_clean(text): return text -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): @@ -386,6 +385,13 @@ def prepare_latents( ) video_condition = video_condition.to(device=device, dtype=dtype) + if isinstance(generator, list): + latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator] + latents = latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), generator) + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) @@ -395,15 +401,6 @@ def prepare_latents( latents.device, latents.dtype ) - if isinstance(generator, list): - latent_condition = [ - retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator - ] - latent_condition = torch.cat(latent_condition) - else: - latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") - latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - latent_condition = (latent_condition - latents_mean) * latents_std mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 975bf00afac2..4c8483a3d6ee 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -12,163 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Dict, List, Optional - -from ..utils import is_transformers_available, logging from .auto import DiffusersAutoQuantizer from .base import DiffusersQuantizer - - -logger = logging.get_logger(__name__) - - -class PipelineQuantizationConfig: - """TODO""" - - def __init__( - self, - quant_backend: str = None, - quant_kwargs: Dict[str, str] = None, - modules_to_quantize: Optional[List[str]] = None, - quant_mapping: Dict[str,] = None, - ): - self.quant_backend = quant_backend - # Initialize kwargs to be {} to set to the defaults. - self.quant_kwargs = quant_kwargs or {} - self.modules_to_quantize = modules_to_quantize - self.quant_mapping = quant_mapping - - self.post_init() - - def post_init(self): - quant_mapping = self.quant_mapping - self.is_granular = True if quant_mapping is not None else False - - self._validate_init_args() - - def _validate_init_args(self): - if self.quant_backend and self.quant_mapping: - raise ValueError("Both `quant_backend` and `quant_mapping` cannot be set.") - - if not self.quant_mapping and not self.quant_backend: - raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.") - - if not self.quant_kwargs and not self.quant_mapping: - raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.") - - if self.quant_backend is not None: - self._validate_init_kwargs_in_backends() - - if self.quant_mapping is not None: - self._validate_quant_mapping_args() - - def _validate_init_kwargs_in_backends(self): - quant_backend = self.quant_backend - - self._check_backend_availability(quant_backend) - - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() - - if quant_config_mapping_transformers is not None: - if quant_backend not in quant_config_mapping_transformers: - raise ValueError( - f"`{quant_backend=}` is not available in `transformers`, available ones are: {list(quant_config_mapping_transformers.keys())}." - ) - init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__) - init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"} - else: - init_kwargs_transformers = None - - if quant_backend not in quant_config_mapping_diffusers: - raise ValueError( - f"`{quant_backend=}` is not available in `diffusers`, available ones are: {list(quant_config_mapping_diffusers.keys())}." - ) - init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__) - init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"} - - if init_kwargs_transformers != init_kwargs_diffusers: - raise ValueError( - "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " - f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class." - ) - - def _validate_quant_mapping_args(self): - quant_mapping = self.quant_mapping - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() - - available_configs_transformers = ( - list(quant_config_mapping_transformers.values()) if quant_config_mapping_transformers else None - ) - available_configs_diffusers = list(quant_config_mapping_diffusers.values()) - - for module_name, config in quant_mapping.items(): - if config not in available_configs_diffusers or ( - available_configs_transformers and config not in available_configs_transformers - ): - msg = f"Provided config for {module_name=} could not be found. Available ones for `diffusers` are: {available_configs_diffusers}.)" - if available_configs_transformers is not None: - msg += f" Available ones for `diffusers` are: {available_configs_transformers}." - raise ValueError(msg) - - def _check_backend_availability(self, quant_backend: str): - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() - - available_backends_transformers = ( - list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None - ) - available_backends_diffusers = list(quant_config_mapping_diffusers.keys()) - - if ( - available_backends_transformers and quant_backend not in available_backends_transformers - ) or quant_backend not in quant_config_mapping_diffusers: - error_message = f"Provided quant_backend={quant_backend} was not found." - if available_backends_transformers: - error_message += f"\nAvailable ones (transformers): {available_backends_transformers}." - error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}." - raise ValueError(error_message) - - def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None): - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() - - quant_mapping = self.quant_mapping - modules_to_quantize = self.modules_to_quantize - - # Granular case - if self.is_granular and module_name in quant_mapping: - logger.debug(f"Initializing quantization config class for {module_name}.") - config = quant_mapping[module_name] - return config - - # Global config case - else: - should_quantize = False - # Only quantize the modules requested for. - if modules_to_quantize and module_name in modules_to_quantize: - should_quantize = True - # No specification for `modules_to_quantize` means all modules should be quantized. - elif not self.is_granular and not modules_to_quantize: - should_quantize = True - - if should_quantize: - logger.debug(f"Initializing quantization config class for {module_name}.") - mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers - quant_config_cls = mapping_to_use[self.quant_backend] - # If `quant_kwargs` is None we default to initializing with the defaults of `quant_config_cls`. - quant_kwargs = self.quant_kwargs or {} - return quant_config_cls(**quant_kwargs) - - # Fallback: no applicable configuration found. - return None - - def _get_quant_config_list(self): - if is_transformers_available(): - from transformers.quantizers.auto import ( - AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers, - ) - else: - quant_config_mapping_transformers = None - - from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers - - return quant_config_mapping_transformers, quant_config_mapping_diffusers diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index ce214ae7bc17..d9874cc282ae 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -26,10 +26,8 @@ GGUFQuantizationConfig, QuantizationConfigMixin, QuantizationMethod, - QuantoConfig, TorchAoConfig, ) -from .quanto import QuantoQuantizer from .torchao import TorchAoHfQuantizer @@ -37,7 +35,6 @@ "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, "gguf": GGUFQuantizer, - "quanto": QuantoQuantizer, "torchao": TorchAoHfQuantizer, } @@ -45,7 +42,6 @@ "bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig, "gguf": GGUFQuantizationConfig, - "quanto": QuantoConfig, "torchao": TorchAoConfig, } diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index f4aa1504534c..ada75588a42a 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -135,7 +135,6 @@ def create_quantized_param( target_device: "torch.device", state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, - **kwargs, ): import bitsandbytes as bnb @@ -446,7 +445,6 @@ def create_quantized_param( target_device: "torch.device", state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, - **kwargs, ): import bitsandbytes as bnb diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py index 6da69c7bd60c..0c760e277ce4 100644 --- a/src/diffusers/quantizers/gguf/gguf_quantizer.py +++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py @@ -108,7 +108,6 @@ def create_quantized_param( target_device: "torch.device", state_dict: Optional[Dict[str, Any]] = None, unexpected_keys: Optional[List[str]] = None, - **kwargs, ): module, tensor_name = get_module_from_name(model, param_name) if tensor_name not in module._parameters and tensor_name not in module._buffers: diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 0bc433be0ff3..4fac8dd3829f 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -45,7 +45,6 @@ class QuantizationMethod(str, Enum): BITS_AND_BYTES = "bitsandbytes" GGUF = "gguf" TORCHAO = "torchao" - QUANTO = "quanto" if is_torchao_available(): @@ -687,38 +686,3 @@ def __repr__(self): return ( f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n" ) - - -@dataclass -class QuantoConfig(QuantizationConfigMixin): - """ - This is a wrapper class about all possible attributes and features that you can play with a model that has been - loaded using `quanto`. - - Args: - weights_dtype (`str`, *optional*, defaults to `"int8"`): - The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2") - modules_to_not_convert (`list`, *optional*, default to `None`): - The list of modules to not quantize, useful for quantizing models that explicitly require to have some - modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). - """ - - def __init__( - self, - weights_dtype: str = "int8", - modules_to_not_convert: Optional[List[str]] = None, - **kwargs, - ): - self.quant_method = QuantizationMethod.QUANTO - self.weights_dtype = weights_dtype - self.modules_to_not_convert = modules_to_not_convert - - self.post_init() - - def post_init(self): - r""" - Safety checker that arguments are correct - """ - accepted_weights = ["float8", "int8", "int4", "int2"] - if self.weights_dtype not in accepted_weights: - raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") diff --git a/src/diffusers/quantizers/quanto/__init__.py b/src/diffusers/quantizers/quanto/__init__.py deleted file mode 100644 index a4e8a1f41a1e..000000000000 --- a/src/diffusers/quantizers/quanto/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .quanto_quantizer import QuantoQuantizer diff --git a/src/diffusers/quantizers/quanto/quanto_quantizer.py b/src/diffusers/quantizers/quanto/quanto_quantizer.py deleted file mode 100644 index 0120163804c9..000000000000 --- a/src/diffusers/quantizers/quanto/quanto_quantizer.py +++ /dev/null @@ -1,177 +0,0 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Union - -from diffusers.utils.import_utils import is_optimum_quanto_version - -from ...utils import ( - get_module_from_name, - is_accelerate_available, - is_accelerate_version, - is_optimum_quanto_available, - is_torch_available, - logging, -) -from ..base import DiffusersQuantizer - - -if TYPE_CHECKING: - from ...models.modeling_utils import ModelMixin - - -if is_torch_available(): - import torch - -if is_accelerate_available(): - from accelerate.utils import CustomDtype, set_module_tensor_to_device - -if is_optimum_quanto_available(): - from .utils import _replace_with_quanto_layers - -logger = logging.get_logger(__name__) - - -class QuantoQuantizer(DiffusersQuantizer): - r""" - Diffusers Quantizer for Optimum Quanto - """ - - use_keep_in_fp32_modules = True - requires_calibration = False - required_packages = ["quanto", "accelerate"] - - def __init__(self, quantization_config, **kwargs): - super().__init__(quantization_config, **kwargs) - - def validate_environment(self, *args, **kwargs): - if not is_optimum_quanto_available(): - raise ImportError( - "Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)" - ) - if not is_optimum_quanto_version(">=", "0.2.6"): - raise ImportError( - "Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. " - "Please upgrade your installation with `pip install --upgrade optimum-quanto" - ) - - if not is_accelerate_available(): - raise ImportError( - "Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)" - ) - - device_map = kwargs.get("device_map", None) - if isinstance(device_map, dict) and len(device_map.keys()) > 1: - raise ValueError( - "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend" - ) - - def check_if_quantized_param( - self, - model: "ModelMixin", - param_value: "torch.Tensor", - param_name: str, - state_dict: Dict[str, Any], - **kwargs, - ): - # Quanto imports diffusers internally. This is here to prevent circular imports - from optimum.quanto import QModuleMixin, QTensor - from optimum.quanto.tensor.packed import PackedTensor - - module, tensor_name = get_module_from_name(model, param_name) - if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]): - return True - elif isinstance(module, QModuleMixin) and "weight" in tensor_name: - return not module.frozen - - return False - - def create_quantized_param( - self, - model: "ModelMixin", - param_value: "torch.Tensor", - param_name: str, - target_device: "torch.device", - *args, - **kwargs, - ): - """ - Create the quantized parameter by calling .freeze() after setting it to the module. - """ - - dtype = kwargs.get("dtype", torch.float32) - module, tensor_name = get_module_from_name(model, param_name) - if self.pre_quantized: - setattr(module, tensor_name, param_value) - else: - set_module_tensor_to_device(model, param_name, target_device, param_value, dtype) - module.freeze() - module.weight.requires_grad = False - - def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: - max_memory = {key: val * 0.90 for key, val in max_memory.items()} - return max_memory - - def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": - if is_accelerate_version(">=", "0.27.0"): - mapping = { - "int8": torch.int8, - "float8": CustomDtype.FP8, - "int4": CustomDtype.INT4, - "int2": CustomDtype.INT2, - } - target_dtype = mapping[self.quantization_config.weights_dtype] - - return target_dtype - - def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype": - if torch_dtype is None: - logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.") - torch_dtype = torch.float32 - return torch_dtype - - def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: - # Quanto imports diffusers internally. This is here to prevent circular imports - from optimum.quanto import QModuleMixin - - not_missing_keys = [] - for name, module in model.named_modules(): - if isinstance(module, QModuleMixin): - for missing in missing_keys: - if ( - (name in missing or name in f"{prefix}.{missing}") - and not missing.endswith(".weight") - and not missing.endswith(".bias") - ): - not_missing_keys.append(missing) - return [k for k in missing_keys if k not in not_missing_keys] - - def _process_model_before_weight_loading( - self, - model: "ModelMixin", - device_map, - keep_in_fp32_modules: List[str] = [], - **kwargs, - ): - self.modules_to_not_convert = self.quantization_config.modules_to_not_convert - - if not isinstance(self.modules_to_not_convert, list): - self.modules_to_not_convert = [self.modules_to_not_convert] - - self.modules_to_not_convert.extend(keep_in_fp32_modules) - - model = _replace_with_quanto_layers( - model, - modules_to_not_convert=self.modules_to_not_convert, - quantization_config=self.quantization_config, - pre_quantized=self.pre_quantized, - ) - model.config.quantization_config = self.quantization_config - - def _process_model_after_weight_loading(self, model, **kwargs): - return model - - @property - def is_trainable(self): - return True - - @property - def is_serializable(self): - return True diff --git a/src/diffusers/quantizers/quanto/utils.py b/src/diffusers/quantizers/quanto/utils.py deleted file mode 100644 index 6f41fd36b43a..000000000000 --- a/src/diffusers/quantizers/quanto/utils.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch.nn as nn - -from ...utils import is_accelerate_available, logging - - -logger = logging.get_logger(__name__) - -if is_accelerate_available(): - from accelerate import init_empty_weights - - -def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False): - # Quanto imports diffusers internally. These are placed here to avoid circular imports - from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8 - - def _get_weight_type(dtype: str): - return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype] - - def _replace_layers(model, quantization_config, modules_to_not_convert): - has_children = list(model.children()) - if not has_children: - return model - - for name, module in model.named_children(): - _replace_layers(module, quantization_config, modules_to_not_convert) - - if name in modules_to_not_convert: - continue - - if isinstance(module, nn.Linear): - with init_empty_weights(): - qlinear = QLinear( - in_features=module.in_features, - out_features=module.out_features, - bias=module.bias is not None, - dtype=module.weight.dtype, - weights=_get_weight_type(quantization_config.weights_dtype), - ) - model._modules[name] = qlinear - model._modules[name].source_cls = type(module) - model._modules[name].requires_grad_(False) - - return model - - model = _replace_layers(model, quantization_config, modules_to_not_convert) - has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules()) - - if not has_been_replaced: - logger.warning( - f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied." - " Please check your model architecture, or submit an issue on Github if you think this is a bug." - " https://github.com/huggingface/diffusers/issues/new" - ) - - # We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict - # to match when trying to load weights with load_model_dict_into_meta - if pre_quantized: - freeze(model) - - return model diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index f9fb217ed6bd..e86ce2f64278 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -23,14 +23,7 @@ from packaging import version -from ...utils import ( - get_module_from_name, - is_torch_available, - is_torch_version, - is_torchao_available, - is_torchao_version, - logging, -) +from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging from ..base import DiffusersQuantizer @@ -69,43 +62,6 @@ from torchao.quantization import quantize_ -def _update_torch_safe_globals(): - safe_globals = [ - (torch.uint1, "torch.uint1"), - (torch.uint2, "torch.uint2"), - (torch.uint3, "torch.uint3"), - (torch.uint4, "torch.uint4"), - (torch.uint5, "torch.uint5"), - (torch.uint6, "torch.uint6"), - (torch.uint7, "torch.uint7"), - ] - try: - from torchao.dtypes import NF4Tensor - from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl - from torchao.dtypes.uintx.uint4_layout import UInt4Tensor - from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor - - safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) - - except (ImportError, ModuleNotFoundError) as e: - logger.warning( - "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" - ) - logger.debug(e) - - finally: - torch.serialization.add_safe_globals(safe_globals=safe_globals) - - -if ( - is_torch_available() - and is_torch_version(">=", "2.6.0") - and is_torchao_available() - and is_torchao_version(">=", "0.7.0") -): - _update_torch_safe_globals() - - logger = logging.get_logger(__name__) @@ -259,7 +215,6 @@ def create_quantized_param( target_device: "torch.device", state_dict: Dict[str, Any], unexpected_keys: List[str], - **kwargs, ): r""" Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor, diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index cbb27e5fad63..e3bff7582cd9 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -377,7 +377,6 @@ def step( s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, - per_token_timesteps: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: """ @@ -398,8 +397,6 @@ def step( Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. - per_token_timesteps (`torch.Tensor`, *optional*): - The timesteps for each token in the sample. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. @@ -430,26 +427,16 @@ def step( # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) - if per_token_timesteps is not None: - per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] - sigmas = self.sigmas[:, None, None] - lower_mask = sigmas < per_token_sigmas[None] - 1e-6 - lower_sigmas = lower_mask * sigmas - lower_sigmas, _ = lower_sigmas.max(dim=0) - dt = (per_token_sigmas - lower_sigmas)[..., None] - else: - sigma = self.sigmas[self.step_index] - sigma_next = self.sigmas[self.step_index + 1] - dt = sigma_next - sigma + prev_sample = sample + (sigma_next - sigma) * model_output - prev_sample = sample + dt * model_output + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 - if per_token_timesteps is None: - # Cast sample back to model compatible dtype - prev_sample = prev_sample.to(model_output.dtype) if not return_dict: return (prev_sample,) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 50a470772772..6702ea2efbc8 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -79,8 +79,6 @@ is_matplotlib_available, is_note_seq_available, is_onnx_available, - is_optimum_quanto_available, - is_optimum_quanto_version, is_peft_available, is_peft_version, is_safetensors_available, @@ -94,7 +92,6 @@ is_torch_xla_available, is_torch_xla_version, is_torchao_available, - is_torchao_version, is_torchsde_available, is_torchvision_available, is_transformers_available, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index fa12318f4714..3f88f347710f 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -56,14 +56,3 @@ if USE_PEFT_BACKEND and _CHECK_PEFT: dep_version_check("peft") - - -DECODE_ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" -DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" -DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" -DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" - - -ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" -ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" -ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/" diff --git a/src/diffusers/utils/dummy_bitsandbytes_objects.py b/src/diffusers/utils/dummy_bitsandbytes_objects.py deleted file mode 100644 index 2dc589428de9..000000000000 --- a/src/diffusers/utils/dummy_bitsandbytes_objects.py +++ /dev/null @@ -1,17 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class BitsAndBytesConfig(metaclass=DummyObject): - _backends = ["bitsandbytes"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["bitsandbytes"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["bitsandbytes"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["bitsandbytes"]) diff --git a/src/diffusers/utils/dummy_gguf_objects.py b/src/diffusers/utils/dummy_gguf_objects.py deleted file mode 100644 index 4a6d9a060a13..000000000000 --- a/src/diffusers/utils/dummy_gguf_objects.py +++ /dev/null @@ -1,17 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class GGUFQuantizationConfig(metaclass=DummyObject): - _backends = ["gguf"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["gguf"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["gguf"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["gguf"]) diff --git a/src/diffusers/utils/dummy_optimum_quanto_objects.py b/src/diffusers/utils/dummy_optimum_quanto_objects.py deleted file mode 100644 index 44f8eaffc246..000000000000 --- a/src/diffusers/utils/dummy_optimum_quanto_objects.py +++ /dev/null @@ -1,17 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class QuantoConfig(metaclass=DummyObject): - _backends = ["optimum_quanto"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["optimum_quanto"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["optimum_quanto"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["optimum_quanto"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0c916bbbc1bc..ded30d16cf93 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -362,21 +362,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class CogView4ControlPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class CogView4Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1217,21 +1202,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class LTXConditionPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class LTXImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1262,21 +1232,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class Lumina2Pipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class Lumina2Text2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1292,21 +1247,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class LuminaPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class LuminaText2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/dummy_torchao_objects.py b/src/diffusers/utils/dummy_torchao_objects.py deleted file mode 100644 index 16f0f6a55f64..000000000000 --- a/src/diffusers/utils/dummy_torchao_objects.py +++ /dev/null @@ -1,17 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class TorchAoConfig(metaclass=DummyObject): - _backends = ["torchao"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torchao"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torchao"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torchao"]) diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py index 30d2c8bebd8e..00805433ceba 100644 --- a/src/diffusers/utils/export_utils.py +++ b/src/diffusers/utils/export_utils.py @@ -3,7 +3,7 @@ import struct import tempfile from contextlib import contextmanager -from typing import List, Optional, Union +from typing import List, Union import numpy as np import PIL.Image @@ -139,31 +139,8 @@ def _legacy_export_to_video( def export_to_video( - video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], - output_video_path: str = None, - fps: int = 10, - quality: float = 5.0, - bitrate: Optional[int] = None, - macro_block_size: Optional[int] = 16, + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10 ) -> str: - """ - quality: - Video output quality. Default is 5. Uses variable bit rate. Highest quality is 10, lowest is 0. Set to None to - prevent variable bitrate flags to FFMPEG so you can manually specify them using output_params instead. - Specifying a fixed bitrate using `bitrate` disables this parameter. - - bitrate: - Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead. - Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter - rather than specifiying a fixed bitrate with this parameter. - - macro_block_size: - Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number - imageio will tell ffmpeg to scale the image up to the next closest size divisible by this number. Most codecs - are compatible with a macroblock size of 16 (default), some can go smaller (4, 8). To disable this automatic - feature set it to None or 1, however be warned many players can't decode videos that are odd in size and some - codecs will produce poor results or fail. See https://en.wikipedia.org/wiki/Macroblock. - """ # TODO: Dhruv. Remove by Diffusers release 0.33.0 # Added to prevent breaking existing code if not is_imageio_available(): @@ -200,9 +177,7 @@ def export_to_video( elif isinstance(video_frames[0], PIL.Image.Image): video_frames = [np.array(frame) for frame in video_frames] - with imageio.get_writer( - output_video_path, fps=fps, quality=quality, bitrate=bitrate, macro_block_size=macro_block_size - ) as writer: + with imageio.get_writer(output_video_path, fps=fps) as writer: for frame in video_frames: writer.append_data(frame) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 98b9c75451c8..ae1b9cae6edc 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -25,6 +25,7 @@ from typing import Any, Union from huggingface_hub.utils import is_jinja_available # noqa: F401 +from packaging import version from packaging.version import Version, parse from . import logging @@ -51,30 +52,36 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} -_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ) - - -def _is_package_available(pkg_name: str): - pkg_exists = importlib.util.find_spec(pkg_name) is not None - pkg_version = "N/A" - - if pkg_exists: - try: - pkg_version = importlib_metadata.version(pkg_name) - logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") - except (ImportError, importlib_metadata.PackageNotFoundError): - pkg_exists = False - - return pkg_exists, pkg_version - - +_torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: - _torch_available, _torch_version = _is_package_available("torch") - + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version("torch") + logger.info(f"PyTorch version {_torch_version} available.") + except importlib_metadata.PackageNotFoundError: + _torch_available = False else: logger.info("Disabling PyTorch because USE_TORCH is set") _torch_available = False +_torch_xla_available = importlib.util.find_spec("torch_xla") is not None +if _torch_xla_available: + try: + _torch_xla_version = importlib_metadata.version("torch_xla") + logger.info(f"PyTorch XLA version {_torch_xla_version} available.") + except ImportError: + _torch_xla_available = False + +# check whether torch_npu is available +_torch_npu_available = importlib.util.find_spec("torch_npu") is not None +if _torch_npu_available: + try: + _torch_npu_version = importlib_metadata.version("torch_npu") + logger.info(f"torch_npu version {_torch_npu_version} available.") + except ImportError: + _torch_npu_available = False + _jax_version = "N/A" _flax_version = "N/A" if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: @@ -90,12 +97,47 @@ def _is_package_available(pkg_name: str): _flax_available = False if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: - _safetensors_available, _safetensors_version = _is_package_available("safetensors") - + _safetensors_available = importlib.util.find_spec("safetensors") is not None + if _safetensors_available: + try: + _safetensors_version = importlib_metadata.version("safetensors") + logger.info(f"Safetensors version {_safetensors_version} available.") + except importlib_metadata.PackageNotFoundError: + _safetensors_available = False else: logger.info("Disabling Safetensors because USE_TF is set") _safetensors_available = False +_transformers_available = importlib.util.find_spec("transformers") is not None +try: + _transformers_version = importlib_metadata.version("transformers") + logger.debug(f"Successfully imported transformers version {_transformers_version}") +except importlib_metadata.PackageNotFoundError: + _transformers_available = False + +_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None +try: + _hf_hub_version = importlib_metadata.version("huggingface_hub") + logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}") +except importlib_metadata.PackageNotFoundError: + _hf_hub_available = False + + +_inflect_available = importlib.util.find_spec("inflect") is not None +try: + _inflect_version = importlib_metadata.version("inflect") + logger.debug(f"Successfully imported inflect version {_inflect_version}") +except importlib_metadata.PackageNotFoundError: + _inflect_available = False + + +_unidecode_available = importlib.util.find_spec("unidecode") is not None +try: + _unidecode_version = importlib_metadata.version("unidecode") + logger.debug(f"Successfully imported unidecode version {_unidecode_version}") +except importlib_metadata.PackageNotFoundError: + _unidecode_available = False + _onnxruntime_version = "N/A" _onnx_available = importlib.util.find_spec("onnxruntime") is not None if _onnx_available: @@ -144,6 +186,85 @@ def _is_package_available(pkg_name: str): except importlib_metadata.PackageNotFoundError: _opencv_available = False +_scipy_available = importlib.util.find_spec("scipy") is not None +try: + _scipy_version = importlib_metadata.version("scipy") + logger.debug(f"Successfully imported scipy version {_scipy_version}") +except importlib_metadata.PackageNotFoundError: + _scipy_available = False + +_librosa_available = importlib.util.find_spec("librosa") is not None +try: + _librosa_version = importlib_metadata.version("librosa") + logger.debug(f"Successfully imported librosa version {_librosa_version}") +except importlib_metadata.PackageNotFoundError: + _librosa_available = False + +_accelerate_available = importlib.util.find_spec("accelerate") is not None +try: + _accelerate_version = importlib_metadata.version("accelerate") + logger.debug(f"Successfully imported accelerate version {_accelerate_version}") +except importlib_metadata.PackageNotFoundError: + _accelerate_available = False + +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + if _torch_available: + _torch_version = importlib_metadata.version("torch") + if version.Version(_torch_version) < version.Version("1.12"): + raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12") + + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + +_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None +try: + _k_diffusion_version = importlib_metadata.version("k_diffusion") + logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") +except importlib_metadata.PackageNotFoundError: + _k_diffusion_available = False + +_note_seq_available = importlib.util.find_spec("note_seq") is not None +try: + _note_seq_version = importlib_metadata.version("note_seq") + logger.debug(f"Successfully imported note-seq version {_note_seq_version}") +except importlib_metadata.PackageNotFoundError: + _note_seq_available = False + +_wandb_available = importlib.util.find_spec("wandb") is not None +try: + _wandb_version = importlib_metadata.version("wandb") + logger.debug(f"Successfully imported wandb version {_wandb_version }") +except importlib_metadata.PackageNotFoundError: + _wandb_available = False + + +_tensorboard_available = importlib.util.find_spec("tensorboard") +try: + _tensorboard_version = importlib_metadata.version("tensorboard") + logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") +except importlib_metadata.PackageNotFoundError: + _tensorboard_available = False + + +_compel_available = importlib.util.find_spec("compel") +try: + _compel_version = importlib_metadata.version("compel") + logger.debug(f"Successfully imported compel version {_compel_version}") +except importlib_metadata.PackageNotFoundError: + _compel_available = False + + +_ftfy_available = importlib.util.find_spec("ftfy") is not None +try: + _ftfy_version = importlib_metadata.version("ftfy") + logger.debug(f"Successfully imported ftfy version {_ftfy_version}") +except importlib_metadata.PackageNotFoundError: + _ftfy_available = False + + _bs4_available = importlib.util.find_spec("bs4") is not None try: # importlib metadata under different name @@ -152,6 +273,13 @@ def _is_package_available(pkg_name: str): except importlib_metadata.PackageNotFoundError: _bs4_available = False +_torchsde_available = importlib.util.find_spec("torchsde") is not None +try: + _torchsde_version = importlib_metadata.version("torchsde") + logger.debug(f"Successfully imported torchsde version {_torchsde_version}") +except importlib_metadata.PackageNotFoundError: + _torchsde_available = False + _invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None try: _invisible_watermark_version = importlib_metadata.version("invisible-watermark") @@ -159,42 +287,82 @@ def _is_package_available(pkg_name: str): except importlib_metadata.PackageNotFoundError: _invisible_watermark_available = False -_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") -_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") -_transformers_available, _transformers_version = _is_package_available("transformers") -_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") -_inflect_available, _inflect_version = _is_package_available("inflect") -_unidecode_available, _unidecode_version = _is_package_available("unidecode") -_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion") -_note_seq_available, _note_seq_version = _is_package_available("note_seq") -_wandb_available, _wandb_version = _is_package_available("wandb") -_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard") -_compel_available, _compel_version = _is_package_available("compel") -_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece") -_torchsde_available, _torchsde_version = _is_package_available("torchsde") -_peft_available, _peft_version = _is_package_available("peft") -_torchvision_available, _torchvision_version = _is_package_available("torchvision") -_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib") -_timm_available, _timm_version = _is_package_available("timm") -_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") -_imageio_available, _imageio_version = _is_package_available("imageio") -_ftfy_available, _ftfy_version = _is_package_available("ftfy") -_scipy_available, _scipy_version = _is_package_available("scipy") -_librosa_available, _librosa_version = _is_package_available("librosa") -_accelerate_available, _accelerate_version = _is_package_available("accelerate") -_xformers_available, _xformers_version = _is_package_available("xformers") -_gguf_available, _gguf_version = _is_package_available("gguf") -_torchao_available, _torchao_version = _is_package_available("torchao") -_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") -_torchao_available, _torchao_version = _is_package_available("torchao") - -_optimum_quanto_available = importlib.util.find_spec("optimum") is not None -if _optimum_quanto_available: + +_peft_available = importlib.util.find_spec("peft") is not None +try: + _peft_version = importlib_metadata.version("peft") + logger.debug(f"Successfully imported peft version {_peft_version}") +except importlib_metadata.PackageNotFoundError: + _peft_available = False + +_torchvision_available = importlib.util.find_spec("torchvision") is not None +try: + _torchvision_version = importlib_metadata.version("torchvision") + logger.debug(f"Successfully imported torchvision version {_torchvision_version}") +except importlib_metadata.PackageNotFoundError: + _torchvision_available = False + +_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None +try: + _sentencepiece_version = importlib_metadata.version("sentencepiece") + logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}") +except importlib_metadata.PackageNotFoundError: + _sentencepiece_available = False + +_matplotlib_available = importlib.util.find_spec("matplotlib") is not None +try: + _matplotlib_version = importlib_metadata.version("matplotlib") + logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}") +except importlib_metadata.PackageNotFoundError: + _matplotlib_available = False + +_timm_available = importlib.util.find_spec("timm") is not None +if _timm_available: + try: + _timm_version = importlib_metadata.version("timm") + logger.info(f"Timm version {_timm_version} available.") + except importlib_metadata.PackageNotFoundError: + _timm_available = False + + +def is_timm_available(): + return _timm_available + + +_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None +try: + _bitsandbytes_version = importlib_metadata.version("bitsandbytes") + logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}") +except importlib_metadata.PackageNotFoundError: + _bitsandbytes_available = False + +_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ) + +_imageio_available = importlib.util.find_spec("imageio") is not None +if _imageio_available: + try: + _imageio_version = importlib_metadata.version("imageio") + logger.debug(f"Successfully imported imageio version {_imageio_version}") + + except importlib_metadata.PackageNotFoundError: + _imageio_available = False + +_is_gguf_available = importlib.util.find_spec("gguf") is not None +if _is_gguf_available: try: - _optimum_quanto_version = importlib_metadata.version("optimum_quanto") - logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}") + _gguf_version = importlib_metadata.version("gguf") + logger.debug(f"Successfully import gguf version {_gguf_version}") except importlib_metadata.PackageNotFoundError: - _optimum_quanto_available = False + _is_gguf_available = False + + +_is_torchao_available = importlib.util.find_spec("torchao") is not None +if _is_torchao_available: + try: + _torchao_version = importlib_metadata.version("torchao") + logger.debug(f"Successfully import torchao version {_torchao_version}") + except importlib_metadata.PackageNotFoundError: + _is_torchao_available = False def is_torch_available(): @@ -318,19 +486,11 @@ def is_imageio_available(): def is_gguf_available(): - return _gguf_available + return _is_gguf_available def is_torchao_available(): - return _torchao_available - - -def is_optimum_quanto_available(): - return _optimum_quanto_available - - -def is_timm_available(): - return _timm_available + return _is_torchao_available # docstyle-ignore @@ -476,11 +636,6 @@ def is_timm_available(): torchao` """ -QUANTO_IMPORT_ERROR = """ -{0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip -install optimum-quanto` -""" - BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -508,7 +663,6 @@ def is_timm_available(): ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), ("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)), ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)), - ("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)), ] ) @@ -690,26 +844,11 @@ def is_gguf_version(operation: str, version: str): version (`str`): A version string """ - if not _gguf_available: + if not _is_gguf_available: return False return compare_versions(parse(_gguf_version), operation, version) -def is_torchao_version(operation: str, version: str): - """ - Compares the current torchao version to a given reference with an operation. - - Args: - operation (`str`): - A string representation of an operator, such as `">"` or `"<="` - version (`str`): - A version string - """ - if not _torchao_available: - return False - return compare_versions(parse(_torchao_version), operation, version) - - def is_k_diffusion_version(operation: str, version: str): """ Compares the current k-diffusion version to a given reference with an operation. @@ -725,21 +864,6 @@ def is_k_diffusion_version(operation: str, version: str): return compare_versions(parse(_k_diffusion_version), operation, version) -def is_optimum_quanto_version(operation: str, version: str): - """ - Compares the current Accelerate version to a given reference with an operation. - - Args: - operation (`str`): - A string representation of an operator, such as `">"` or `"<="` - version (`str`): - A version string - """ - if not _optimum_quanto_available: - return False - return compare_versions(parse(_optimum_quanto_version), operation, version) - - def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py index fbce33d97f54..12bcc94af74f 100644 --- a/src/diffusers/utils/remote_utils.py +++ b/src/diffusers/utils/remote_utils.py @@ -55,7 +55,7 @@ def detect_image_type(data: bytes) -> str: return "unknown" -def check_inputs_decode( +def check_inputs( endpoint: str, tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, @@ -89,7 +89,7 @@ def check_inputs_decode( ) -def postprocess_decode( +def postprocess( response: requests.Response, processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, output_type: Literal["mp4", "pil", "pt"] = "pil", @@ -142,7 +142,7 @@ def postprocess_decode( return output -def prepare_decode( +def prepare( tensor: "torch.Tensor", processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None, do_scaling: bool = True, @@ -293,7 +293,7 @@ def remote_decode( standard_warn=False, ) output_tensor_type = "binary" - check_inputs_decode( + check_inputs( endpoint, tensor, processor, @@ -309,7 +309,7 @@ def remote_decode( height, width, ) - kwargs = prepare_decode( + kwargs = prepare( tensor=tensor, processor=processor, do_scaling=do_scaling, @@ -324,7 +324,7 @@ def remote_decode( response = requests.post(endpoint, **kwargs) if not response.ok: raise RuntimeError(response.json()) - output = postprocess_decode( + output = postprocess( response=response, processor=processor, output_type=output_type, @@ -332,94 +332,3 @@ def remote_decode( partial_postprocess=partial_postprocess, ) return output - - -def check_inputs_encode( - endpoint: str, - image: Union["torch.Tensor", Image.Image], - scaling_factor: Optional[float] = None, - shift_factor: Optional[float] = None, -): - pass - - -def postprocess_encode( - response: requests.Response, -): - output_tensor = response.content - parameters = response.headers - shape = json.loads(parameters["shape"]) - dtype = parameters["dtype"] - torch_dtype = DTYPE_MAP[dtype] - output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape) - return output_tensor - - -def prepare_encode( - image: Union["torch.Tensor", Image.Image], - scaling_factor: Optional[float] = None, - shift_factor: Optional[float] = None, -): - headers = {} - parameters = {} - if scaling_factor is not None: - parameters["scaling_factor"] = scaling_factor - if shift_factor is not None: - parameters["shift_factor"] = shift_factor - if isinstance(image, torch.Tensor): - data = safetensors.torch._tobytes(image, "tensor") - parameters["shape"] = list(image.shape) - parameters["dtype"] = str(image.dtype).split(".")[-1] - else: - buffer = io.BytesIO() - image.save(buffer, format="PNG") - data = buffer.getvalue() - return {"data": data, "params": parameters, "headers": headers} - - -def remote_encode( - endpoint: str, - image: Union["torch.Tensor", Image.Image], - scaling_factor: Optional[float] = None, - shift_factor: Optional[float] = None, -) -> "torch.Tensor": - """ - Hugging Face Hybrid Inference that allow running VAE encode remotely. - - Args: - endpoint (`str`): - Endpoint for Remote Decode. - image (`torch.Tensor` or `PIL.Image.Image`): - Image to be encoded. - scaling_factor (`float`, *optional*): - Scaling is applied when passed e.g. [`latents * self.vae.config.scaling_factor`]. - - SD v1: 0.18215 - - SD XL: 0.13025 - - Flux: 0.3611 - If `None`, input must be passed with scaling applied. - shift_factor (`float`, *optional*): - Shift is applied when passed e.g. `latents - self.vae.config.shift_factor`. - - Flux: 0.1159 - If `None`, input must be passed with scaling applied. - - Returns: - output (`torch.Tensor`). - """ - check_inputs_encode( - endpoint, - image, - scaling_factor, - shift_factor, - ) - kwargs = prepare_encode( - image=image, - scaling_factor=scaling_factor, - shift_factor=shift_factor, - ) - response = requests.post(endpoint, **kwargs) - if not response.ok: - raise RuntimeError(response.json()) - output = postprocess_encode( - response=response, - ) - return output diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 2a3feae967d7..7eda13716025 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -101,8 +101,6 @@ mps_backend_registered = hasattr(torch.backends, "mps") torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device - from .torch_utils import get_torch_cuda_device_capability - def torch_all_close(a, b, *args, **kwargs): if not is_torch_available(): @@ -284,20 +282,6 @@ def require_torch_gpu(test_case): ) -def require_torch_cuda_compatibility(expected_compute_capability): - def decorator(test_case): - if not torch.cuda.is_available(): - return unittest.skip(test_case) - else: - current_compute_capability = get_torch_cuda_device_capability() - return unittest.skipUnless( - float(current_compute_capability) == float(expected_compute_capability), - "Test not supported for this compute capability.", - ) - - return decorator - - # These decorators are for accelerator-specific behaviours that are not GPU-specific def require_torch_accelerator(test_case): """Decorator marking a test that requires an accelerator backend and PyTorch.""" diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py deleted file mode 100644 index 178de2069b7e..000000000000 --- a/tests/lora/test_lora_layers_cogview4.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2024 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import tempfile -import unittest - -import numpy as np -import torch -from transformers import AutoTokenizer, GlmModel - -from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler -from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device - - -sys.path.append(".") - -from utils import PeftLoraLoaderMixinTests # noqa: E402 - - -class TokenizerWrapper: - @staticmethod - def from_pretrained(*args, **kwargs): - return AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True - ) - - -@require_peft_backend -@skip_mps -class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): - pipeline_class = CogView4Pipeline - scheduler_cls = FlowMatchEulerDiscreteScheduler - scheduler_classes = [FlowMatchEulerDiscreteScheduler] - scheduler_kwargs = {} - - transformer_kwargs = { - "patch_size": 2, - "in_channels": 4, - "num_layers": 2, - "attention_head_dim": 4, - "num_attention_heads": 4, - "out_channels": 4, - "text_embed_dim": 32, - "time_embed_dim": 8, - "condition_dim": 4, - } - transformer_cls = CogView4Transformer2DModel - vae_kwargs = { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 4, - "sample_size": 128, - } - vae_cls = AutoencoderKL - tokenizer_cls, tokenizer_id, tokenizer_subfolder = ( - TokenizerWrapper, - "hf-internal-testing/tiny-random-cogview4", - "tokenizer", - ) - text_encoder_cls, text_encoder_id, text_encoder_subfolder = ( - GlmModel, - "hf-internal-testing/tiny-random-cogview4", - "text_encoder", - ) - - @property - def output_shape(self): - return (1, 32, 32, 3) - - def get_dummy_inputs(self, with_generator=True): - batch_size = 1 - sequence_length = 16 - num_channels = 4 - sizes = (4, 4) - - generator = torch.manual_seed(0) - noise = floats_tensor((batch_size, num_channels) + sizes) - input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) - - pipeline_inputs = { - "prompt": "", - "num_inference_steps": 1, - "guidance_scale": 6.0, - "height": 32, - "width": 32, - "max_sequence_length": sequence_length, - "output_type": "np", - } - if with_generator: - pipeline_inputs.update({"generator": generator}) - - return noise, input_ids, pipeline_inputs - - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): - super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) - - def test_simple_inference_with_text_denoiser_lora_unfused(self): - super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) - - def test_simple_inference_save_pretrained(self): - """ - Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained - """ - for scheduler_cls in self.scheduler_classes: - components, _, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - _, _, inputs = self.get_dummy_inputs(with_generator=False) - - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(output_no_lora.shape == self.output_shape) - - images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - - with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) - - pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) - pipe_from_pretrained.to(torch_device) - - images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] - - self.assertTrue( - np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), - "Loading from saved checkpoints should give same results.", - ) - - @unittest.skip("Not supported in CogView4.") - def test_simple_inference_with_text_denoiser_block_scale(self): - pass - - @unittest.skip("Not supported in CogView4.") - def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): - pass - - @unittest.skip("Not supported in CogView4.") - def test_modify_padding_mode(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 860aa6511689..06bbcc62a0d5 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -371,8 +371,9 @@ def test_with_norm_in_state_dict(self): lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - "The provided state dict contains normalization layers in addition to LoRA layers" - in cap_logger.out + cap_logger.out.startswith( + "The provided state dict contains normalization layers in addition to LoRA layers" + ) ) self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0) @@ -391,7 +392,7 @@ def test_with_norm_in_state_dict(self): pipe.load_lora_weights(norm_state_dict) self.assertTrue( - "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out + cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers") ) def test_lora_parameter_expanded_shapes(self): diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 8cdb43c9d085..17f6c9ccdf98 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1948,50 +1948,6 @@ def set_pad_mode(network, mode="circular"): _, _, inputs = self.get_dummy_inputs() _ = pipe(**inputs)[0] - def test_logs_info_when_no_lora_keys_found(self): - scheduler_cls = self.scheduler_classes[0] - # Skip text encoder check for now as that is handled with `transformers`. - components, _, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs(with_generator=False) - original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] - - no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} - logger = logging.get_logger("diffusers.loaders.peft") - logger.setLevel(logging.WARNING) - - with CaptureLogger(logger) as cap_logger: - pipe.load_lora_weights(no_op_state_dict) - out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0] - - denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer") - self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}")) - self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5)) - - # test only for text encoder - for lora_module in self.pipeline_class._lora_loadable_modules: - if "text_encoder" in lora_module: - text_encoder = getattr(pipe, lora_module) - if lora_module == "text_encoder": - prefix = "text_encoder" - elif lora_module == "text_encoder_2": - prefix = "text_encoder_2" - - logger = logging.get_logger("diffusers.loaders.lora_base") - logger.setLevel(logging.WARNING) - - with CaptureLogger(logger) as cap_logger: - self.pipeline_class.load_lora_into_text_encoder( - no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix - ) - - self.assertTrue( - cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}") - ) - def test_set_adapters_match_attention_kwargs(self): """Test to check if outputs after `set_adapters()` and attention kwargs match.""" call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() diff --git a/tests/pipelines/ltx/test_ltx_condition.py b/tests/pipelines/ltx/test_ltx_condition.py deleted file mode 100644 index dbb9a740b433..000000000000 --- a/tests/pipelines/ltx/test_ltx_condition.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright 2024 The HuggingFace Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import unittest - -import numpy as np -import torch -from transformers import AutoTokenizer, T5EncoderModel - -from diffusers import ( - AutoencoderKLLTXVideo, - FlowMatchEulerDiscreteScheduler, - LTXConditionPipeline, - LTXVideoTransformer3DModel, -) -from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition -from diffusers.utils.testing_utils import enable_full_determinism, torch_device - -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np - - -enable_full_determinism() - - -class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): - pipeline_class = LTXConditionPipeline - params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} - batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) - image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS - required_optional_params = frozenset( - [ - "num_inference_steps", - "generator", - "latents", - "return_dict", - "callback_on_step_end", - "callback_on_step_end_tensor_inputs", - ] - ) - test_xformers_attention = False - - def get_dummy_components(self): - torch.manual_seed(0) - transformer = LTXVideoTransformer3DModel( - in_channels=8, - out_channels=8, - patch_size=1, - patch_size_t=1, - num_attention_heads=4, - attention_head_dim=8, - cross_attention_dim=32, - num_layers=1, - caption_channels=32, - ) - - torch.manual_seed(0) - vae = AutoencoderKLLTXVideo( - in_channels=3, - out_channels=3, - latent_channels=8, - block_out_channels=(8, 8, 8, 8), - decoder_block_out_channels=(8, 8, 8, 8), - layers_per_block=(1, 1, 1, 1, 1), - decoder_layers_per_block=(1, 1, 1, 1, 1), - spatio_temporal_scaling=(True, True, False, False), - decoder_spatio_temporal_scaling=(True, True, False, False), - decoder_inject_noise=(False, False, False, False, False), - upsample_residual=(False, False, False, False), - upsample_factor=(1, 1, 1, 1), - timestep_conditioning=False, - patch_size=1, - patch_size_t=1, - encoder_causal=True, - decoder_causal=False, - ) - vae.use_framewise_encoding = False - vae.use_framewise_decoding = False - - torch.manual_seed(0) - scheduler = FlowMatchEulerDiscreteScheduler() - text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") - - components = { - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - } - return components - - def get_dummy_inputs(self, device, seed=0, use_conditions=False): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device=device).manual_seed(seed) - - image = torch.randn((1, 3, 32, 32), generator=generator, device=device) - if use_conditions: - conditions = LTXVideoCondition( - image=image, - ) - else: - conditions = None - - inputs = { - "conditions": conditions, - "image": None if use_conditions else image, - "prompt": "dance monkey", - "negative_prompt": "", - "generator": generator, - "num_inference_steps": 2, - "guidance_scale": 3.0, - "height": 32, - "width": 32, - # 8 * k + 1 is the recommendation - "num_frames": 9, - "max_sequence_length": 16, - "output_type": "pt", - } - - return inputs - - def test_inference(self): - device = "cpu" - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - inputs2 = self.get_dummy_inputs(device, use_conditions=True) - video = pipe(**inputs).frames - generated_video = video[0] - video2 = pipe(**inputs2).frames - generated_video2 = video2[0] - - self.assertEqual(generated_video.shape, (9, 3, 32, 32)) - - max_diff = np.abs(generated_video - generated_video2).max() - self.assertLessEqual(max_diff, 1e-3) - - def test_callback_inputs(self): - sig = inspect.signature(self.pipeline_class.__call__) - has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters - has_callback_step_end = "callback_on_step_end" in sig.parameters - - if not (has_callback_tensor_inputs and has_callback_step_end): - return - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - self.assertTrue( - hasattr(pipe, "_callback_tensor_inputs"), - f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", - ) - - def callback_inputs_subset(pipe, i, t, callback_kwargs): - # iterate over callback args - for tensor_name, tensor_value in callback_kwargs.items(): - # check that we're only passing in allowed tensor inputs - assert tensor_name in pipe._callback_tensor_inputs - - return callback_kwargs - - def callback_inputs_all(pipe, i, t, callback_kwargs): - for tensor_name in pipe._callback_tensor_inputs: - assert tensor_name in callback_kwargs - - # iterate over callback args - for tensor_name, tensor_value in callback_kwargs.items(): - # check that we're only passing in allowed tensor inputs - assert tensor_name in pipe._callback_tensor_inputs - - return callback_kwargs - - inputs = self.get_dummy_inputs(torch_device) - - # Test passing in a subset - inputs["callback_on_step_end"] = callback_inputs_subset - inputs["callback_on_step_end_tensor_inputs"] = ["latents"] - output = pipe(**inputs)[0] - - # Test passing in a everything - inputs["callback_on_step_end"] = callback_inputs_all - inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs - output = pipe(**inputs)[0] - - def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): - is_last = i == (pipe.num_timesteps - 1) - if is_last: - callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) - return callback_kwargs - - inputs["callback_on_step_end"] = callback_inputs_change_tensor - inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs - output = pipe(**inputs)[0] - assert output.abs().sum() < 1e10 - - def test_inference_batch_single_identical(self): - self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) - - def test_attention_slicing_forward_pass( - self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 - ): - if not self.test_attention_slicing: - return - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - for component in pipe.components.values(): - if hasattr(component, "set_default_attn_processor"): - component.set_default_attn_processor() - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - generator_device = "cpu" - inputs = self.get_dummy_inputs(generator_device) - output_without_slicing = pipe(**inputs)[0] - - pipe.enable_attention_slicing(slice_size=1) - inputs = self.get_dummy_inputs(generator_device) - output_with_slicing1 = pipe(**inputs)[0] - - pipe.enable_attention_slicing(slice_size=2) - inputs = self.get_dummy_inputs(generator_device) - output_with_slicing2 = pipe(**inputs)[0] - - if test_max_difference: - max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() - max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() - self.assertLess( - max(max_diff1, max_diff2), - expected_max_diff, - "Attention slicing should not affect the inference results", - ) - - def test_vae_tiling(self, expected_diff_max: float = 0.2): - generator_device = "cpu" - components = self.get_dummy_components() - - pipe = self.pipeline_class(**components) - pipe.to("cpu") - pipe.set_progress_bar_config(disable=None) - - # Without tiling - inputs = self.get_dummy_inputs(generator_device) - inputs["height"] = inputs["width"] = 128 - output_without_tiling = pipe(**inputs)[0] - - # With tiling - pipe.vae.enable_tiling( - tile_sample_min_height=96, - tile_sample_min_width=96, - tile_sample_stride_height=64, - tile_sample_stride_width=64, - ) - inputs = self.get_dummy_inputs(generator_device) - inputs["height"] = inputs["width"] = 128 - output_with_tiling = pipe(**inputs)[0] - - self.assertLess( - (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), - expected_diff_max, - "VAE tiling should not affect the inference results", - ) diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index 0c1fe8eb2fcd..034a0185d338 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -5,13 +5,7 @@ import torch from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM -from diffusers import ( - AutoencoderKL, - FlowMatchEulerDiscreteScheduler, - LuminaNextDiT2DModel, - LuminaPipeline, - LuminaText2ImgPipeline, -) +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline from diffusers.utils.testing_utils import ( backend_empty_cache, numpy_cosine_similarity_distance, @@ -23,8 +17,8 @@ from ..test_pipelines_common import PipelineTesterMixin -class LuminaPipelineFastTests(unittest.TestCase, PipelineTesterMixin): - pipeline_class = LuminaPipeline +class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = LuminaText2ImgPipeline params = frozenset( [ "prompt", @@ -105,17 +99,11 @@ def get_dummy_inputs(self, device, seed=0): def test_xformers_attention_forwardGenerator_pass(self): pass - def test_deprecation_raises_warning(self): - with self.assertWarns(FutureWarning) as warning: - _ = LuminaText2ImgPipeline(**self.get_dummy_components()).to(torch_device) - warning_message = str(warning.warnings[0].message) - assert "renamed to `LuminaPipeline`" in warning_message - @slow @require_torch_accelerator -class LuminaPipelineSlowTests(unittest.TestCase): - pipeline_class = LuminaPipeline +class LuminaText2ImgPipelineSlowTests(unittest.TestCase): + pipeline_class = LuminaText2ImgPipeline repo_id = "Alpha-VLLM/Lumina-Next-SFT-diffusers" def setUp(self): diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py index 33fc870bcd34..aa0571559b45 100644 --- a/tests/pipelines/lumina2/test_pipeline_lumina2.py +++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py @@ -6,17 +6,15 @@ from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, - Lumina2Pipeline, Lumina2Text2ImgPipeline, Lumina2Transformer2DModel, ) -from diffusers.utils.testing_utils import torch_device from ..test_pipelines_common import PipelineTesterMixin -class Lumina2PipelineFastTests(unittest.TestCase, PipelineTesterMixin): - pipeline_class = Lumina2Pipeline +class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = Lumina2Text2ImgPipeline params = frozenset( [ "prompt", @@ -117,9 +115,3 @@ def get_dummy_inputs(self, device, seed=0): "output_type": "np", } return inputs - - def test_deprecation_raises_warning(self): - with self.assertWarns(FutureWarning) as warning: - _ = Lumina2Text2ImgPipeline(**self.get_dummy_components()).to(torch_device) - warning_message = str(warning.warnings[0].message) - assert "renamed to `Lumina2Pipeline`" in warning_message diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 423c2b8ab146..964b55fde651 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -19,7 +19,7 @@ UNet2DConditionModel, ) from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings -from diffusers.utils.testing_utils import require_torch_gpu, torch_device +from diffusers.utils.testing_utils import torch_device class IsSafetensorsCompatibleTests(unittest.TestCase): @@ -826,104 +826,3 @@ def test_video_to_video(self): with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): _ = pipe(**inputs) self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") - - -@require_torch_gpu -class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase): - expected_pipe_device = torch.device("cuda:0") - expected_pipe_dtype = torch.float64 - - def get_dummy_components_image_generation(self): - cross_attention_dim = 8 - - torch.manual_seed(0) - unet = UNet2DConditionModel( - block_out_channels=(4, 8), - layers_per_block=1, - sample_size=32, - in_channels=4, - out_channels=4, - down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), - up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), - cross_attention_dim=cross_attention_dim, - norm_num_groups=2, - ) - scheduler = DDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - ) - torch.manual_seed(0) - vae = AutoencoderKL( - block_out_channels=[4, 8], - in_channels=3, - out_channels=3, - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], - latent_channels=4, - norm_num_groups=2, - ) - torch.manual_seed(0) - text_encoder_config = CLIPTextConfig( - bos_token_id=0, - eos_token_id=2, - hidden_size=cross_attention_dim, - intermediate_size=16, - layer_norm_eps=1e-05, - num_attention_heads=2, - num_hidden_layers=2, - pad_token_id=1, - vocab_size=1000, - ) - text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - - components = { - "unet": unet, - "scheduler": scheduler, - "vae": vae, - "text_encoder": text_encoder, - "tokenizer": tokenizer, - "safety_checker": None, - "feature_extractor": None, - "image_encoder": None, - } - return components - - def test_deterministic_device(self): - components = self.get_dummy_components_image_generation() - - pipe = StableDiffusionPipeline(**components) - pipe.to(device=torch_device, dtype=torch.float32) - - pipe.unet.to(device="cpu") - pipe.vae.to(device="cuda") - pipe.text_encoder.to(device="cuda:0") - - pipe_device = pipe.device - - self.assertEqual( - self.expected_pipe_device, - pipe_device, - f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.", - ) - - def test_deterministic_dtype(self): - components = self.get_dummy_components_image_generation() - - pipe = StableDiffusionPipeline(**components) - pipe.to(device=torch_device, dtype=torch.float32) - - pipe.unet.to(dtype=torch.float16) - pipe.vae.to(dtype=torch.float32) - pipe.text_encoder.to(dtype=torch.float64) - - pipe_dtype = pipe.dtype - - self.assertEqual( - self.expected_pipe_dtype, - pipe_dtype, - f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.", - ) diff --git a/tests/quantization/__init__.py b/tests/quantization/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index a80286fbb8dd..6f85e6f38955 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -33,7 +33,6 @@ numpy_cosine_similarity_distance, require_accelerate, require_bitsandbytes_version_greater, - require_peft_backend, require_torch, require_torch_gpu, require_transformers_version_greater, @@ -55,8 +54,29 @@ def get_some_linear_layer(model): if is_torch_available(): import torch + import torch.nn as nn - from ..utils import LoRALayer, get_memory_consumption_stat + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) if is_bitsandbytes_available(): @@ -76,8 +96,6 @@ class Base4bitTests(unittest.TestCase): # This was obtained on audace so the number might slightly change expected_rel_difference = 3.69 - expected_memory_saving_ratio = 0.8 - prompt = "a beautiful sunset amidst the mountains." num_inference_steps = 10 seed = 0 @@ -122,10 +140,8 @@ def setUp(self): ) def tearDown(self): - if hasattr(self, "model_fp16"): - del self.model_fp16 - if hasattr(self, "model_4bit"): - del self.model_4bit + del self.model_fp16 + del self.model_4bit gc.collect() torch.cuda.empty_cache() @@ -164,32 +180,6 @@ def test_memory_footprint(self): linear = get_some_linear_layer(self.model_4bit) self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) - def test_model_memory_usage(self): - # Delete to not let anything interfere. - del self.model_4bit, self.model_fp16 - - # Re-instantiate. - inputs = self.get_dummy_inputs() - inputs = { - k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool) - } - model_fp16 = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", torch_dtype=torch.float16 - ).to(torch_device) - unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs) - del model_fp16 - - nf4_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.float16, - ) - model_4bit = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.float16 - ) - quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs) - assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio - def test_original_dtype(self): r""" A simple test to check if the model succesfully stores the original dtype @@ -669,7 +659,6 @@ def test_quality(self): max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-3) - @require_peft_backend def test_lora_loading(self): self.pipeline_4bit.load_lora_weights( hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd" diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 4964f8c9af07..4be420e7dffa 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -60,8 +60,29 @@ def get_some_linear_layer(model): if is_torch_available(): import torch + import torch.nn as nn - from ..utils import LoRALayer, get_memory_consumption_stat + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) if is_bitsandbytes_available(): @@ -81,8 +102,6 @@ class Base8bitTests(unittest.TestCase): # This was obtained on audace so the number might slightly change expected_rel_difference = 1.94 - expected_memory_saving_ratio = 0.7 - prompt = "a beautiful sunset amidst the mountains." num_inference_steps = 10 seed = 0 @@ -123,10 +142,8 @@ def setUp(self): ) def tearDown(self): - if hasattr(self, "model_fp16"): - del self.model_fp16 - if hasattr(self, "model_8bit"): - del self.model_8bit + del self.model_fp16 + del self.model_8bit gc.collect() torch.cuda.empty_cache() @@ -165,28 +182,6 @@ def test_memory_footprint(self): linear = get_some_linear_layer(self.model_8bit) self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) - def test_model_memory_usage(self): - # Delete to not let anything interfere. - del self.model_8bit, self.model_fp16 - - # Re-instantiate. - inputs = self.get_dummy_inputs() - inputs = { - k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool) - } - model_fp16 = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", torch_dtype=torch.float16 - ).to(torch_device) - unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs) - del model_fp16 - - config = BitsAndBytesConfig(load_in_8bit=True) - model_8bit = SD3Transformer2DModel.from_pretrained( - self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16 - ) - quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs) - assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio - def test_original_dtype(self): r""" A simple test to check if the model succesfully stores the original dtype @@ -253,7 +248,7 @@ def test_llm_skip(self): self.assertTrue(linear.weight.dtype == torch.int8) self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt)) - self.assertTrue(isinstance(model_8bit.proj_out, torch.nn.Linear)) + self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear)) self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8) def test_config_from_pretrained(self): diff --git a/tests/quantization/quanto/__init__.py b/tests/quantization/quanto/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/quantization/quanto/test_quanto.py b/tests/quantization/quanto/test_quanto.py deleted file mode 100644 index 9eb6958d2183..000000000000 --- a/tests/quantization/quanto/test_quanto.py +++ /dev/null @@ -1,328 +0,0 @@ -import gc -import tempfile -import unittest - -from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig -from diffusers.models.attention_processor import Attention -from diffusers.utils import is_optimum_quanto_available, is_torch_available -from diffusers.utils.testing_utils import ( - nightly, - numpy_cosine_similarity_distance, - require_accelerate, - require_big_gpu_with_torch_cuda, - require_torch_cuda_compatibility, - torch_device, -) - - -if is_optimum_quanto_available(): - from optimum.quanto import QLinear - -if is_torch_available(): - import torch - - from ..utils import LoRALayer, get_memory_consumption_stat - - -@nightly -@require_big_gpu_with_torch_cuda -@require_accelerate -class QuantoBaseTesterMixin: - model_id = None - pipeline_model_id = None - model_cls = None - torch_dtype = torch.bfloat16 - # the expected reduction in peak memory used compared to an unquantized model expressed as a percentage - expected_memory_reduction = 0.0 - keep_in_fp32_module = "" - modules_to_not_convert = "" - _test_torch_compile = False - - def setUp(self): - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() - gc.collect() - - def tearDown(self): - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() - gc.collect() - - def get_dummy_init_kwargs(self): - return {"weights_dtype": "float8"} - - def get_dummy_model_init_kwargs(self): - return { - "pretrained_model_name_or_path": self.model_id, - "torch_dtype": self.torch_dtype, - "quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()), - } - - def test_quanto_layers(self): - model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) - for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear): - assert isinstance(module, QLinear) - - def test_quanto_memory_usage(self): - inputs = self.get_dummy_inputs() - inputs = { - k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool) - } - - unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype) - unquantized_model.to(torch_device) - unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs) - - quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) - quantized_model.to(torch_device) - quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs) - - assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction - - def test_keep_modules_in_fp32(self): - r""" - A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. - Also ensures if inference works. - """ - _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules - self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module - - model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) - model.to("cuda") - - for name, module in model.named_modules(): - if isinstance(module, torch.nn.Linear): - if name in model._keep_in_fp32_modules: - assert module.weight.dtype == torch.float32 - self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules - - def test_modules_to_not_convert(self): - init_kwargs = self.get_dummy_model_init_kwargs() - - quantization_config_kwargs = self.get_dummy_init_kwargs() - quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert}) - quantization_config = QuantoConfig(**quantization_config_kwargs) - - init_kwargs.update({"quantization_config": quantization_config}) - - model = self.model_cls.from_pretrained(**init_kwargs) - model.to("cuda") - - for name, module in model.named_modules(): - if name in self.modules_to_not_convert: - assert not isinstance(module, QLinear) - - def test_dtype_assignment(self): - model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) - - with self.assertRaises(ValueError): - # Tries with a `dtype` - model.to(torch.float16) - - with self.assertRaises(ValueError): - # Tries with a `device` and `dtype` - model.to(device="cuda:0", dtype=torch.float16) - - with self.assertRaises(ValueError): - # Tries with a cast - model.float() - - with self.assertRaises(ValueError): - # Tries with a cast - model.half() - - # This should work - model.to("cuda") - - def test_serialization(self): - model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) - inputs = self.get_dummy_inputs() - - model.to(torch_device) - with torch.no_grad(): - model_output = model(**inputs) - - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir) - saved_model = self.model_cls.from_pretrained( - tmp_dir, - torch_dtype=torch.bfloat16, - ) - - saved_model.to(torch_device) - with torch.no_grad(): - saved_model_output = saved_model(**inputs) - - assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5) - - def test_torch_compile(self): - if not self._test_torch_compile: - return - - model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) - compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False) - - model.to(torch_device) - with torch.no_grad(): - model_output = model(**self.get_dummy_inputs()).sample - - compiled_model.to(torch_device) - with torch.no_grad(): - compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample - - model_output = model_output.detach().float().cpu().numpy() - compiled_model_output = compiled_model_output.detach().float().cpu().numpy() - - max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten()) - assert max_diff < 1e-3 - - def test_device_map_error(self): - with self.assertRaises(ValueError): - _ = self.model_cls.from_pretrained( - **self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"} - ) - - -class FluxTransformerQuantoMixin(QuantoBaseTesterMixin): - model_id = "hf-internal-testing/tiny-flux-transformer" - model_cls = FluxTransformer2DModel - pipeline_cls = FluxPipeline - torch_dtype = torch.bfloat16 - keep_in_fp32_module = "proj_out" - modules_to_not_convert = ["proj_out"] - _test_torch_compile = False - - def get_dummy_inputs(self): - return { - "hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to( - torch_device, self.torch_dtype - ), - "encoder_hidden_states": torch.randn( - (1, 512, 4096), - generator=torch.Generator("cpu").manual_seed(0), - ).to(torch_device, self.torch_dtype), - "pooled_projections": torch.randn( - (1, 768), - generator=torch.Generator("cpu").manual_seed(0), - ).to(torch_device, self.torch_dtype), - "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), - "img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to( - torch_device, self.torch_dtype - ), - "txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to( - torch_device, self.torch_dtype - ), - "guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype), - } - - def get_dummy_training_inputs(self, device=None, seed: int = 0): - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - height = width = 4 - sequence_length = 48 - embedding_dim = 32 - - torch.manual_seed(seed) - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) - - torch.manual_seed(seed) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( - device, dtype=torch.bfloat16 - ) - - torch.manual_seed(seed) - pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) - - torch.manual_seed(seed) - text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) - - torch.manual_seed(seed) - image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) - - timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) - - return { - "hidden_states": hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "pooled_projections": pooled_prompt_embeds, - "txt_ids": text_ids, - "img_ids": image_ids, - "timestep": timestep, - } - - def test_model_cpu_offload(self): - init_kwargs = self.get_dummy_init_kwargs() - transformer = self.model_cls.from_pretrained( - "hf-internal-testing/tiny-flux-pipe", - quantization_config=QuantoConfig(**init_kwargs), - subfolder="transformer", - torch_dtype=torch.bfloat16, - ) - pipe = self.pipeline_cls.from_pretrained( - "hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16 - ) - pipe.enable_model_cpu_offload(device=torch_device) - _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) - - def test_training(self): - quantization_config = QuantoConfig(**self.get_dummy_init_kwargs()) - quantized_model = self.model_cls.from_pretrained( - "hf-internal-testing/tiny-flux-pipe", - subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=torch.bfloat16, - ).to(torch_device) - - for param in quantized_model.parameters(): - # freeze the model as only adapter layers will be trained - param.requires_grad = False - if param.ndim == 1: - param.data = param.data.to(torch.float32) - - for _, module in quantized_model.named_modules(): - if isinstance(module, Attention): - module.to_q = LoRALayer(module.to_q, rank=4) - module.to_k = LoRALayer(module.to_k, rank=4) - module.to_v = LoRALayer(module.to_v, rank=4) - - with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): - inputs = self.get_dummy_training_inputs(torch_device) - output = quantized_model(**inputs)[0] - output.norm().backward() - - for module in quantized_model.modules(): - if isinstance(module, LoRALayer): - self.assertTrue(module.adapter[1].weight.grad is not None) - - -class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): - expected_memory_reduction = 0.6 - - def get_dummy_init_kwargs(self): - return {"weights_dtype": "float8"} - - -class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): - expected_memory_reduction = 0.6 - _test_torch_compile = True - - def get_dummy_init_kwargs(self): - return {"weights_dtype": "int8"} - - -@require_torch_cuda_compatibility(8.0) -class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): - expected_memory_reduction = 0.55 - - def get_dummy_init_kwargs(self): - return {"weights_dtype": "int4"} - - -@require_torch_cuda_compatibility(8.0) -class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): - expected_memory_reduction = 0.65 - - def get_dummy_init_kwargs(self): - return {"weights_dtype": "int2"} diff --git a/tests/quantization/torchao/__init__.py b/tests/quantization/torchao/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 0e671307dd18..e14a1cc0369e 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -50,7 +50,27 @@ import torch import torch.nn as nn - from ..utils import LoRALayer, get_memory_consumption_stat + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) if is_torchao_available(): @@ -483,22 +503,6 @@ def test_memory_footprint(self): # there is additional overhead of scales and zero points self.assertTrue(total_bf16 < total_int4wo) - def test_model_memory_usage(self): - model_id = "hf-internal-testing/tiny-flux-pipe" - expected_memory_saving_ratio = 2.0 - - inputs = self.get_dummy_tensor_inputs(device=torch_device) - - transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] - transformer_bf16.to(torch_device) - unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs) - del transformer_bf16 - - transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"] - transformer_int8wo.to(torch_device) - quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs) - assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio - def test_wrong_config(self): with self.assertRaises(ValueError): self.get_dummy_components(TorchAoConfig("int42")) diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py deleted file mode 100644 index 04ebf9e159f4..000000000000 --- a/tests/quantization/utils.py +++ /dev/null @@ -1,38 +0,0 @@ -from diffusers.utils import is_torch_available - - -if is_torch_available(): - import torch - import torch.nn as nn - - class LoRALayer(nn.Module): - """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only - - Taken from - https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 - """ - - def __init__(self, module: nn.Module, rank: int): - super().__init__() - self.module = module - self.adapter = nn.Sequential( - nn.Linear(module.in_features, rank, bias=False), - nn.Linear(rank, module.out_features, bias=False), - ) - small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 - nn.init.normal_(self.adapter[0].weight, std=small_std) - nn.init.zeros_(self.adapter[1].weight) - self.adapter.to(module.weight.device) - - def forward(self, input, *args, **kwargs): - return self.module(input, *args, **kwargs) + self.adapter(input) - - @torch.no_grad() - @torch.inference_mode() - def get_memory_consumption_stat(model, inputs): - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() - - model(**inputs) - max_memory_mem_allocated = torch.cuda.max_memory_allocated() - return max_memory_mem_allocated diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py index cec96e729a48..11f9c24d16f6 100644 --- a/tests/remote/test_remote_decode.py +++ b/tests/remote/test_remote_decode.py @@ -21,15 +21,7 @@ import torch from diffusers.image_processor import VaeImageProcessor -from diffusers.utils.constants import ( - DECODE_ENDPOINT_FLUX, - DECODE_ENDPOINT_HUNYUAN_VIDEO, - DECODE_ENDPOINT_SD_V1, - DECODE_ENDPOINT_SD_XL, -) -from diffusers.utils.remote_utils import ( - remote_decode, -) +from diffusers.utils.remote_utils import remote_decode from diffusers.utils.testing_utils import ( enable_full_determinism, slow, @@ -41,6 +33,11 @@ enable_full_determinism() +ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" +ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" +ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" +ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" + class RemoteAutoencoderKLMixin: shape: Tuple[int, ...] = None @@ -353,7 +350,7 @@ class RemoteAutoencoderKLSDv1Tests( 512, 512, ) - endpoint = DECODE_ENDPOINT_SD_V1 + endpoint = ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None @@ -377,7 +374,7 @@ class RemoteAutoencoderKLSDXLTests( 1024, 1024, ) - endpoint = DECODE_ENDPOINT_SD_XL + endpoint = ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None @@ -401,7 +398,7 @@ class RemoteAutoencoderKLFluxTests( 1024, 1024, ) - endpoint = DECODE_ENDPOINT_FLUX + endpoint = ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 @@ -428,7 +425,7 @@ class RemoteAutoencoderKLFluxPackedTests( ) height = 1024 width = 1024 - endpoint = DECODE_ENDPOINT_FLUX + endpoint = ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 @@ -456,7 +453,7 @@ class RemoteAutoencoderKLHunyuanVideoTests( 320, 512, ) - endpoint = DECODE_ENDPOINT_HUNYUAN_VIDEO + endpoint = ENDPOINT_HUNYUAN_VIDEO dtype = torch.float16 scaling_factor = 0.476986 processor_cls = VideoProcessor @@ -507,7 +504,7 @@ class RemoteAutoencoderKLSDv1SlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): - endpoint = DECODE_ENDPOINT_SD_V1 + endpoint = ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None @@ -518,7 +515,7 @@ class RemoteAutoencoderKLSDXLSlowTests( RemoteAutoencoderKLSlowTestMixin, unittest.TestCase, ): - endpoint = DECODE_ENDPOINT_SD_XL + endpoint = ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None @@ -530,7 +527,7 @@ class RemoteAutoencoderKLFluxSlowTests( unittest.TestCase, ): channels = 16 - endpoint = DECODE_ENDPOINT_FLUX + endpoint = ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 diff --git a/tests/remote/test_remote_encode.py b/tests/remote/test_remote_encode.py deleted file mode 100644 index 62ed97ee8f49..000000000000 --- a/tests/remote/test_remote_encode.py +++ /dev/null @@ -1,224 +0,0 @@ -# coding=utf-8 -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import PIL.Image -import torch - -from diffusers.utils import load_image -from diffusers.utils.constants import ( - DECODE_ENDPOINT_FLUX, - DECODE_ENDPOINT_SD_V1, - DECODE_ENDPOINT_SD_XL, - ENCODE_ENDPOINT_FLUX, - ENCODE_ENDPOINT_SD_V1, - ENCODE_ENDPOINT_SD_XL, -) -from diffusers.utils.remote_utils import ( - remote_decode, - remote_encode, -) -from diffusers.utils.testing_utils import ( - enable_full_determinism, - slow, -) - - -enable_full_determinism() - -IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true" - - -class RemoteAutoencoderKLEncodeMixin: - channels: int = None - endpoint: str = None - decode_endpoint: str = None - dtype: torch.dtype = None - scaling_factor: float = None - shift_factor: float = None - image: PIL.Image.Image = None - - def get_dummy_inputs(self): - if self.image is None: - self.image = load_image(IMAGE) - inputs = { - "endpoint": self.endpoint, - "image": self.image, - "scaling_factor": self.scaling_factor, - "shift_factor": self.shift_factor, - } - return inputs - - def test_image_input(self): - inputs = self.get_dummy_inputs() - height, width = inputs["image"].height, inputs["image"].width - output = remote_encode(**inputs) - self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) - decoded = remote_decode( - tensor=output, - endpoint=self.decode_endpoint, - scaling_factor=self.scaling_factor, - shift_factor=self.shift_factor, - image_format="png", - ) - self.assertEqual(decoded.height, height) - self.assertEqual(decoded.width, width) - # image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten()) - # decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten()) - # TODO: how to test this? encode->decode is lossy. expected slice of encoded latent? - - -class RemoteAutoencoderKLSDv1Tests( - RemoteAutoencoderKLEncodeMixin, - unittest.TestCase, -): - channels = 4 - endpoint = ENCODE_ENDPOINT_SD_V1 - decode_endpoint = DECODE_ENDPOINT_SD_V1 - dtype = torch.float16 - scaling_factor = 0.18215 - shift_factor = None - - -class RemoteAutoencoderKLSDXLTests( - RemoteAutoencoderKLEncodeMixin, - unittest.TestCase, -): - channels = 4 - endpoint = ENCODE_ENDPOINT_SD_XL - decode_endpoint = DECODE_ENDPOINT_SD_XL - dtype = torch.float16 - scaling_factor = 0.13025 - shift_factor = None - - -class RemoteAutoencoderKLFluxTests( - RemoteAutoencoderKLEncodeMixin, - unittest.TestCase, -): - channels = 16 - endpoint = ENCODE_ENDPOINT_FLUX - decode_endpoint = DECODE_ENDPOINT_FLUX - dtype = torch.bfloat16 - scaling_factor = 0.3611 - shift_factor = 0.1159 - - -class RemoteAutoencoderKLEncodeSlowTestMixin: - channels: int = 4 - endpoint: str = None - decode_endpoint: str = None - dtype: torch.dtype = None - scaling_factor: float = None - shift_factor: float = None - image: PIL.Image.Image = None - - def get_dummy_inputs(self): - if self.image is None: - self.image = load_image(IMAGE) - inputs = { - "endpoint": self.endpoint, - "image": self.image, - "scaling_factor": self.scaling_factor, - "shift_factor": self.shift_factor, - } - return inputs - - def test_multi_res(self): - inputs = self.get_dummy_inputs() - for height in { - 320, - 512, - 640, - 704, - 896, - 1024, - 1208, - 1384, - 1536, - 1608, - 1864, - 2048, - }: - for width in { - 320, - 512, - 640, - 704, - 896, - 1024, - 1208, - 1384, - 1536, - 1608, - 1864, - 2048, - }: - inputs["image"] = inputs["image"].resize( - ( - width, - height, - ) - ) - output = remote_encode(**inputs) - self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8]) - decoded = remote_decode( - tensor=output, - endpoint=self.decode_endpoint, - scaling_factor=self.scaling_factor, - shift_factor=self.shift_factor, - image_format="png", - ) - self.assertEqual(decoded.height, height) - self.assertEqual(decoded.width, width) - decoded.save(f"test_multi_res_{height}_{width}.png") - - -@slow -class RemoteAutoencoderKLSDv1SlowTests( - RemoteAutoencoderKLEncodeSlowTestMixin, - unittest.TestCase, -): - endpoint = ENCODE_ENDPOINT_SD_V1 - decode_endpoint = DECODE_ENDPOINT_SD_V1 - dtype = torch.float16 - scaling_factor = 0.18215 - shift_factor = None - - -@slow -class RemoteAutoencoderKLSDXLSlowTests( - RemoteAutoencoderKLEncodeSlowTestMixin, - unittest.TestCase, -): - endpoint = ENCODE_ENDPOINT_SD_XL - decode_endpoint = DECODE_ENDPOINT_SD_XL - dtype = torch.float16 - scaling_factor = 0.13025 - shift_factor = None - - -@slow -class RemoteAutoencoderKLFluxSlowTests( - RemoteAutoencoderKLEncodeSlowTestMixin, - unittest.TestCase, -): - channels = 16 - endpoint = ENCODE_ENDPOINT_FLUX - decode_endpoint = DECODE_ENDPOINT_FLUX - dtype = torch.bfloat16 - scaling_factor = 0.3611 - shift_factor = 0.1159 diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py deleted file mode 100644 index 7695e1577711..000000000000 --- a/tests/single_file/test_sana_transformer.py +++ /dev/null @@ -1,61 +0,0 @@ -import gc -import unittest - -import torch - -from diffusers import ( - SanaTransformer2DModel, -) -from diffusers.utils.testing_utils import ( - backend_empty_cache, - enable_full_determinism, - require_torch_accelerator, - torch_device, -) - - -enable_full_determinism() - - -@require_torch_accelerator -class SanaTransformer2DModelSingleFileTests(unittest.TestCase): - model_class = SanaTransformer2DModel - ckpt_path = ( - "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" - ) - alternate_keys_ckpt_paths = [ - "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth" - ] - - repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers" - - def setUp(self): - super().setUp() - gc.collect() - backend_empty_cache(torch_device) - - def tearDown(self): - super().tearDown() - gc.collect() - backend_empty_cache(torch_device) - - def test_single_file_components(self): - model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") - model_single_file = self.model_class.from_single_file(self.ckpt_path) - - PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] - for param_name, param_value in model_single_file.config.items(): - if param_name in PARAMS_TO_IGNORE: - continue - assert ( - model.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" - - def test_checkpoint_loading(self): - for ckpt_path in self.alternate_keys_ckpt_paths: - torch.cuda.empty_cache() - model = self.model_class.from_single_file(ckpt_path) - - del model - gc.collect() - torch.cuda.empty_cache() From 4d3dedee070840329fd3baf22cf06f48b484303d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 21 Mar 2025 07:47:04 +0530 Subject: [PATCH 03/19] feat: implement pipeline-level quantization config Co-authored-by: SunMarc --- .../pipelines/pipeline_loading_utils.py | 13 ++ src/diffusers/pipelines/pipeline_utils.py | 2 + src/diffusers/quantizers/__init__.py | 162 ++++++++++++++++++ 3 files changed, 177 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 07da8b5e2e2e..e3d7202f0852 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -667,8 +667,10 @@ def load_sub_model( use_safetensors: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, + quantization_config: Optional[Any] = None, ): """Helper method to load the module `name` from `library_name` and `class_name`""" + from ..quantizers import PipelineQuantizationConfig # retrieve class candidates @@ -761,6 +763,17 @@ def load_sub_model( else: loading_kwargs["low_cpu_mem_usage"] = False + if ( + quantization_config is not None + and isinstance(quantization_config, PipelineQuantizationConfig) + and issubclass(class_obj, torch.nn.Module) + ): + model_quant_config = quantization_config._resolve_quant_config( + is_diffusers=is_diffusers_model, module_name=name + ) + if model_quant_config is not None: + loading_kwargs["quantization_config"] = model_quant_config + # check if the module is in a subdirectory if dduf_entries: loading_kwargs["dduf_entries"] = dduf_entries diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0896a14d64af..d74574d30c26 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -702,6 +702,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_safetensors = kwargs.pop("use_safetensors", None) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + quantization_config = kwargs.pop("quantization_config", None) if not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 @@ -973,6 +974,7 @@ def load_module(name, value): use_safetensors=use_safetensors, dduf_entries=dduf_entries, provider_options=provider_options, + quantization_config=quantization_config, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 4c8483a3d6ee..e6554f2e1f1d 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -12,5 +12,167 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from ..utils import is_transformers_available, logging from .auto import DiffusersAutoQuantizer from .base import DiffusersQuantizer + + +if TYPE_CHECKING: + from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin + + try: + from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin + except ImportError: + + class TransformersQuantConfigMixin: + pass + + +logger = logging.get_logger(__name__) + + +class PipelineQuantizationConfig: + """TODO""" + + def __init__( + self, + quant_backend: str = None, + quant_kwargs: Dict[str, Union[str, float, int]] = None, + modules_to_quantize: Optional[List[str]] = None, + quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, + ): + self.quant_backend = quant_backend + # Initialize kwargs to be {} to set to the defaults. + self.quant_kwargs = quant_kwargs or {} + self.modules_to_quantize = modules_to_quantize + self.quant_mapping = quant_mapping + + self.post_init() + + def post_init(self): + quant_mapping = self.quant_mapping + self.is_granular = True if quant_mapping is not None else False + + self._validate_init_args() + + def _validate_init_args(self): + if self.quant_backend and self.quant_mapping: + raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.") + + if not self.quant_mapping and not self.quant_backend: + raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.") + + if not self.quant_kwargs and not self.quant_mapping: + raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.") + + if self.quant_backend is not None: + self._validate_init_kwargs_in_backends() + + if self.quant_mapping is not None: + self._validate_quant_mapping_args() + + def _validate_init_kwargs_in_backends(self): + quant_backend = self.quant_backend + + self._check_backend_availability(quant_backend) + + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + if quant_config_mapping_transformers is not None: + init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__) + init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"} + else: + init_kwargs_transformers = None + + init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__) + init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"} + + if init_kwargs_transformers != init_kwargs_diffusers: + raise ValueError( + "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " + f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class." + ) + + def _validate_quant_mapping_args(self): + quant_mapping = self.quant_mapping + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + available_configs_transformers = ( + list(quant_config_mapping_transformers.values()) if quant_config_mapping_transformers else None + ) + available_configs_diffusers = list(quant_config_mapping_diffusers.values()) + + for module_name, config in quant_mapping.items(): + msg = "" + if not (any(isinstance(config, available) for available in available_configs_diffusers)): + msg = f"Provided config for {module_name=} could not be found. Available ones for `diffusers` are: {available_configs_diffusers}.)" + elif available_configs_transformers is not None and not ( + any(isinstance(config, available) for available in available_configs_transformers) + ): + msg = f"Provided config for {module_name=} could not be found. Available ones for `transformers` are: {available_configs_transformers}.)" + if msg: + raise ValueError(msg) + + def _check_backend_availability(self, quant_backend: str): + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + available_backends_transformers = ( + list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None + ) + available_backends_diffusers = list(quant_config_mapping_diffusers.keys()) + + if ( + available_backends_transformers and quant_backend not in available_backends_transformers + ) or quant_backend not in quant_config_mapping_diffusers: + error_message = f"Provided quant_backend={quant_backend} was not found." + if available_backends_transformers: + error_message += f"\nAvailable ones (transformers): {available_backends_transformers}." + error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}." + raise ValueError(error_message) + + def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None): + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + quant_mapping = self.quant_mapping + modules_to_quantize = self.modules_to_quantize + + # Granular case + if self.is_granular and module_name in quant_mapping: + logger.debug(f"Initializing quantization config class for {module_name}.") + config = quant_mapping[module_name] + return config + + # Global config case + else: + should_quantize = False + # Only quantize the modules requested for. + if modules_to_quantize and module_name in modules_to_quantize: + should_quantize = True + # No specification for `modules_to_quantize` means all modules should be quantized. + elif not self.is_granular and not modules_to_quantize: + should_quantize = True + + if should_quantize: + logger.debug(f"Initializing quantization config class for {module_name}.") + mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers + quant_config_cls = mapping_to_use[self.quant_backend] + quant_kwargs = self.quant_kwargs + return quant_config_cls(**quant_kwargs) + + # Fallback: no applicable configuration found. + return None + + def _get_quant_config_list(self): + if is_transformers_available(): + from transformers.quantizers.auto import ( + AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers, + ) + else: + quant_config_mapping_transformers = None + + from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers + + return quant_config_mapping_transformers, quant_config_mapping_diffusers From dc79f32bede0dfe4c581a1fcf4b603a54847ba06 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 21 Mar 2025 07:59:25 +0530 Subject: [PATCH 04/19] update --- src/diffusers/quantizers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index e6554f2e1f1d..280d88128816 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -40,7 +40,7 @@ class PipelineQuantizationConfig: def __init__( self, quant_backend: str = None, - quant_kwargs: Dict[str, Union[str, float, int]] = None, + quant_kwargs: Dict[str, Union[str, float, int, dict]] = None, modules_to_quantize: Optional[List[str]] = None, quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, ): From df749e4758a09bdb7968635e15b1309248ef3aaf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 21 Mar 2025 08:38:23 +0530 Subject: [PATCH 05/19] fixes --- src/diffusers/quantizers/__init__.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 280d88128816..cb8f35f6af3a 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -13,22 +13,20 @@ # limitations under the License. import inspect -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union from ..utils import is_transformers_available, logging from .auto import DiffusersAutoQuantizer from .base import DiffusersQuantizer +from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin -if TYPE_CHECKING: - from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin +try: + from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin +except ImportError: - try: - from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin - except ImportError: - - class TransformersQuantConfigMixin: - pass + class TransformersQuantConfigMixin: + pass logger = logging.get_logger(__name__) From 82bcce077edab7cd374c0930e645ce1520e01fb4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 29 Apr 2025 09:45:21 +0800 Subject: [PATCH 06/19] fix validation. --- src/diffusers/quantizers/__init__.py | 34 ++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index cb8f35f6af3a..4195a4653295 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -96,23 +96,29 @@ def _validate_init_kwargs_in_backends(self): def _validate_quant_mapping_args(self): quant_mapping = self.quant_mapping - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + transformers_map, diffusers_map = self._get_quant_config_list() - available_configs_transformers = ( - list(quant_config_mapping_transformers.values()) if quant_config_mapping_transformers else None - ) - available_configs_diffusers = list(quant_config_mapping_diffusers.values()) + available_transformers = list(transformers_map.values()) if transformers_map else None + available_diffusers = list(diffusers_map.values()) for module_name, config in quant_mapping.items(): - msg = "" - if not (any(isinstance(config, available) for available in available_configs_diffusers)): - msg = f"Provided config for {module_name=} could not be found. Available ones for `diffusers` are: {available_configs_diffusers}.)" - elif available_configs_transformers is not None and not ( - any(isinstance(config, available) for available in available_configs_transformers) - ): - msg = f"Provided config for {module_name=} could not be found. Available ones for `transformers` are: {available_configs_transformers}.)" - if msg: - raise ValueError(msg) + if any(isinstance(config, cfg) for cfg in available_diffusers): + continue + + if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers): + continue + + if available_transformers: + raise ValueError( + f"Provided config for module_name={module_name} could not be found. " + f"Available diffusers configs: {available_diffusers}; " + f"Available transformers configs: {available_transformers}." + ) + else: + raise ValueError( + f"Provided config for module_name={module_name} could not be found. " + f"Available diffusers configs: {available_diffusers}." + ) def _check_backend_availability(self, quant_backend: str): quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() From 78f134ba2f7dedb925fc76781e87efaad28d03d5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 29 Apr 2025 10:53:11 +0800 Subject: [PATCH 07/19] add tests and other improvements. --- src/diffusers/quantizers/__init__.py | 17 ++++++++++------- src/diffusers/utils/testing_utils.py | 8 ++++++++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 4195a4653295..4a5b97a53a01 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -39,13 +39,13 @@ def __init__( self, quant_backend: str = None, quant_kwargs: Dict[str, Union[str, float, int, dict]] = None, - modules_to_quantize: Optional[List[str]] = None, + components_to_quantize: Optional[List[str]] = None, quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, ): self.quant_backend = quant_backend # Initialize kwargs to be {} to set to the defaults. self.quant_kwargs = quant_kwargs or {} - self.modules_to_quantize = modules_to_quantize + self.components_to_quantize = components_to_quantize self.quant_mapping = quant_mapping self.post_init() @@ -91,7 +91,8 @@ def _validate_init_kwargs_in_backends(self): if init_kwargs_transformers != init_kwargs_diffusers: raise ValueError( "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " - f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class." + f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to the docs to learn more about how " + "this mapping would look like: TODO." ) def _validate_quant_mapping_args(self): @@ -100,6 +101,8 @@ def _validate_quant_mapping_args(self): available_transformers = list(transformers_map.values()) if transformers_map else None available_diffusers = list(diffusers_map.values()) + print(f"{quant_mapping=}") + print(f"{available_diffusers=}") for module_name, config in quant_mapping.items(): if any(isinstance(config, cfg) for cfg in available_diffusers): @@ -141,7 +144,7 @@ def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = No quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() quant_mapping = self.quant_mapping - modules_to_quantize = self.modules_to_quantize + components_to_quantize = self.components_to_quantize # Granular case if self.is_granular and module_name in quant_mapping: @@ -153,10 +156,10 @@ def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = No else: should_quantize = False # Only quantize the modules requested for. - if modules_to_quantize and module_name in modules_to_quantize: + if components_to_quantize and module_name in components_to_quantize: should_quantize = True - # No specification for `modules_to_quantize` means all modules should be quantized. - elif not self.is_granular and not modules_to_quantize: + # No specification for `components_to_quantize` means all modules should be quantized. + elif not self.is_granular and not components_to_quantize: should_quantize = True if should_quantize: diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7a524e76f16e..00aad9d71a61 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -38,6 +38,7 @@ is_note_seq_available, is_onnx_available, is_opencv_available, + is_optimum_quanto_available, is_peft_available, is_timm_available, is_torch_available, @@ -486,6 +487,13 @@ def require_bitsandbytes(test_case): return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) +def require_quanto(test_case): + """ + Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed. + """ + return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case) + + def require_accelerate(test_case): """ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. From 3b76e0a6965397cc9db2356cc5d948e02b1a4688 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 29 Apr 2025 11:01:43 +0800 Subject: [PATCH 08/19] add tests --- .../test_pipeline_level_quantization.py | 178 ++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 tests/quantization/test_pipeline_level_quantization.py diff --git a/tests/quantization/test_pipeline_level_quantization.py b/tests/quantization/test_pipeline_level_quantization.py new file mode 100644 index 000000000000..337a093d4611 --- /dev/null +++ b/tests/quantization/test_pipeline_level_quantization.py @@ -0,0 +1,178 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch + +from diffusers import ( + DiffusionPipeline, + QuantoConfig, +) +from diffusers.quantizers import PipelineQuantizationConfig +from diffusers.utils.testing_utils import ( + is_transformers_available, + require_accelerate, + require_bitsandbytes_version_greater, + require_quanto, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + + +if is_transformers_available(): + from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig +else: + TranBitsAndBytesConfig = None + + +@require_bitsandbytes_version_greater("0.43.2") +@require_quanto +@require_accelerate +@require_torch +@require_torch_accelerator +@slow +class PipelineQuantizationTests(unittest.TestCase): + model_name = "hf-internal-testing/tiny-flux-pipe" + prompt = "a beautiful sunset amidst the mountains." + num_inference_steps = 10 + seed = 0 + + def test_quant_config_set_correctly_through_kwargs(self): + components_to_quantize = ["transformer", "text_encoder_2"] + quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={ + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.bfloat16, + }, + components_to_quantize=components_to_quantize, + ) + pipe = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + for name, component in pipe.components.items(): + if name in components_to_quantize: + self.assertTrue(getattr(component.config, "quantization_config", None) is not None) + quantization_config = component.config.quantization_config + self.assertTrue(quantization_config.load_in_4bit) + self.assertTrue(quantization_config.quant_method == "bitsandbytes") + + _ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) + + def test_quant_config_set_correctly_granular(self): + quant_config = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + ) + components_to_quantize = list(quant_config.quant_mapping.keys()) + pipe = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + for name, component in pipe.components.items(): + if name in components_to_quantize: + self.assertTrue(getattr(component.config, "quantization_config", None) is not None) + quantization_config = component.config.quantization_config + + if name == "text_encoder_2": + self.assertTrue(quantization_config.load_in_4bit) + self.assertTrue(quantization_config.quant_method == "bitsandbytes") + else: + self.assertTrue(quantization_config.quant_method == "quanto") + + _ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) + + def test_raises_error_for_invalid_config(self): + with self.assertRaises(ValueError) as err_context: + _ = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + }, + quant_backend="bitsandbytes_4bit", + ) + + self.assertTrue( + str(err_context.exception) + == "Both `quant_backend` and `quant_mapping` cannot be specified at the same time." + ) + + def test_validation_for_kwargs(self): + components_to_quantize = ["transformer", "text_encoder_2"] + with self.assertRaises(ValueError) as err_context: + _ = PipelineQuantizationConfig( + quant_backend="quanto", + quant_kwargs={"weights_dtype": "int8"}, + components_to_quantize=components_to_quantize, + ) + + self.assertTrue( + "The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception) + ) + + def test_validation_for_mapping(self): + with self.assertRaises(ValueError) as err_context: + _ = PipelineQuantizationConfig( + quant_mapping={ + "transformer": DiffusionPipeline(), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + ) + + self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception)) + + def test_saving_loading(self): + quant_config = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + ) + components_to_quantize = list(quant_config.quant_mapping.keys()) + pipe = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + + pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"} + output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device) + for name, component in loaded_pipe.components.items(): + if name in components_to_quantize: + self.assertTrue(getattr(component.config, "quantization_config", None) is not None) + quantization_config = component.config.quantization_config + + if name == "text_encoder_2": + self.assertTrue(quantization_config.load_in_4bit) + self.assertTrue(quantization_config.quant_method == "bitsandbytes") + else: + self.assertTrue(quantization_config.quant_method == "quanto") + + output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images + + self.assertTrue(torch.allclose(output_1, output_2)) From 695061b4ad450f1295e4de53508f83b9a02dbcb1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 29 Apr 2025 11:02:54 +0800 Subject: [PATCH 09/19] import quality --- tests/quantization/test_pipeline_level_quantization.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/quantization/test_pipeline_level_quantization.py b/tests/quantization/test_pipeline_level_quantization.py index 337a093d4611..77df8be29a6d 100644 --- a/tests/quantization/test_pipeline_level_quantization.py +++ b/tests/quantization/test_pipeline_level_quantization.py @@ -17,10 +17,7 @@ import torch -from diffusers import ( - DiffusionPipeline, - QuantoConfig, -) +from diffusers import DiffusionPipeline, QuantoConfig from diffusers.quantizers import PipelineQuantizationConfig from diffusers.utils.testing_utils import ( is_transformers_available, From 969325119da83504e6ea078fb23951988ebcb5a7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 29 Apr 2025 22:30:12 +0800 Subject: [PATCH 10/19] remove prints. --- src/diffusers/quantizers/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 4a5b97a53a01..990641eeaeda 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -101,8 +101,6 @@ def _validate_quant_mapping_args(self): available_transformers = list(transformers_map.values()) if transformers_map else None available_diffusers = list(diffusers_map.values()) - print(f"{quant_mapping=}") - print(f"{available_diffusers=}") for module_name, config in quant_mapping.items(): if any(isinstance(config, cfg) for cfg in available_diffusers): From 872c91ef66d3bc55f3af3cdc508fc0fb72619b39 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 11:21:19 +0530 Subject: [PATCH 11/19] add docs. --- docs/source/en/api/quantization.md | 3 + docs/source/en/quantization/overview.md | 59 +++++++++++++++++++ src/diffusers/quantizers/__init__.py | 44 +++++++++++++- .../quantizers/quantization_config.py | 2 +- 4 files changed, 106 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 2c728cff3c07..8779adf8daf9 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -23,6 +23,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui +## PipelineQuantizationConfig + +[[autodoc]] PipelineQuantizationConfig ## BitsAndBytesConfig diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 93323f86c7fc..d783b2cf90ec 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -39,3 +39,62 @@ Diffusers currently supports the following quantization methods. - [Quanto](./quanto.md) [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. + +## Pipeline-level quantization + +Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply +quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can +do this with [`PipelineQuantizationConfig`]. + +Start by defining a `PipelineQuantizationConfig`: + +```py +import torch +from diffusers import DiffusionPipeline +from diffusers.quantizers.quantization_config import QuantoConfig +from diffusers.quantizers import PipelineQuantizationConfig +from transformers import BitsAndBytesConfig + +pipeline_quant_config = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": BitsAndBytesConfig( + load_in_4bit=True, compute_dtype=torch.bfloat16 + ), + } +) +``` + +Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference: + +```py +pipe = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + quantization_config=pipeline_quant_config, + torch_dtype=torch.bfloat16, +).to("cuda") + +image = pipe("photo of a cute dog").images[0] +``` + +This method allows for more granular control over the quantization specifications of individual +model-level components of a pipeline. It also allows for different quantization backends for +different components. In the above example, you used a combination of Quanto and BitsandBytes. + +The other method is simpler in terms of experience but is +less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way: + +```py +pipeline_quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16}, + components_to_quantize=["transformer", "text_encoder_2"], +) +``` + +This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example. + +In this case, `quant_kwargs` will be used to initialize the quantization specifications +of the respective quantization configuration class of `quant_backend`. `components_to_quantize` +is used to denote the components that will be quantized. For most pipelines, you would want to +keep `transformer` in the list as that is often the most compute and memory intensive. \ No newline at end of file diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 990641eeaeda..e527267f6400 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -33,7 +33,49 @@ class TransformersQuantConfigMixin: class PipelineQuantizationConfig: - """TODO""" + """ + Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`]. + + Args: + quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend + is available to both `diffusers` and `transformers`. + quant_kwargs (`dict`): Params to initialize the quantization backend class. + components_to_quantize (`list`): Components of a pipeline to be quantized. + quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline + components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`, + and `components_to_quantize`. + + Examples: + + When using with `quant_backend`: + + >>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.quantizers import + PipelineQuantizationConfig + + >>> pipeline_quant_config = PipelineQuantizationConfig( ... quant_backend="bitsandbytes_4bit", ... quant_kwargs={ + ... "load_in_4bit": True, ... "bnb_4bit_quant_type": "nf4", ... "bnb_4bit_compute_dtype": torch.bfloat16, ... }, + ... components_to_quantize=["transformer", "text_encoder_2"], ... ) + + >>> pipe = DiffusionPipeline.from_pretrained( ... "black-forest-labs/FLUX.1-dev", ... + quantization_config=pipeline_quant_config, ... torch_dtype=torch.bfloat16, ... ).to("cuda") + + >>> image = pipe("photo of a cute dog").images[0] + + When using with `quant_mapping`: + + >>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.quantizers.quantization_config + import QuantoConfig >>> from diffusers.quantizers import PipelineQuantizationConfig >>> from transformers import + BitsAndBytesConfig + + >>> pipeline_quant_config = PipelineQuantizationConfig( ... quant_mapping={ ... "transformer": + QuantoConfig(weights_dtype="int8"), ... "text_encoder_2": BitsAndBytesConfig( ... load_in_4bit=True, + compute_dtype=torch.bfloat16 ... ), ... } ... ) + + >>> pipe = DiffusionPipeline.from_pretrained( ... "black-forest-labs/FLUX.1-dev", ... + quantization_config=pipeline_quant_config, ... torch_dtype=torch.bfloat16, ... ).to("cuda") + + >>> image = pipe("photo of a cute dog").images[0] + """ def __init__( self, diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 0bc433be0ff3..52cfbb980792 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -75,7 +75,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): Args: config_dict (`Dict[str, Any]`): Dictionary that will be used to instantiate the configuration object. - return_unused_kwargs (`bool`,*optional*, defaults to `False`): + return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in `PreTrainedModel`. kwargs (`Dict[str, Any]`): From fbdf4c6422940d36a412b55605510be234d9dda1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 11:36:16 +0530 Subject: [PATCH 12/19] fixes to docs. --- docs/source/en/api/quantization.md | 6 ++---- src/diffusers/quantizers/__init__.py | 31 ---------------------------- 2 files changed, 2 insertions(+), 35 deletions(-) diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 8779adf8daf9..3192f1f78380 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -13,9 +13,7 @@ specific language governing permissions and limitations under the License. # Quantization -Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index). - -Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class. +Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. @@ -25,7 +23,7 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui ## PipelineQuantizationConfig -[[autodoc]] PipelineQuantizationConfig +[[autodoc]] quantizers.__init__.PipelineQuantizationConfig ## BitsAndBytesConfig diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index e527267f6400..1d683d07dcce 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -44,37 +44,6 @@ class PipelineQuantizationConfig: quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`, and `components_to_quantize`. - - Examples: - - When using with `quant_backend`: - - >>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.quantizers import - PipelineQuantizationConfig - - >>> pipeline_quant_config = PipelineQuantizationConfig( ... quant_backend="bitsandbytes_4bit", ... quant_kwargs={ - ... "load_in_4bit": True, ... "bnb_4bit_quant_type": "nf4", ... "bnb_4bit_compute_dtype": torch.bfloat16, ... }, - ... components_to_quantize=["transformer", "text_encoder_2"], ... ) - - >>> pipe = DiffusionPipeline.from_pretrained( ... "black-forest-labs/FLUX.1-dev", ... - quantization_config=pipeline_quant_config, ... torch_dtype=torch.bfloat16, ... ).to("cuda") - - >>> image = pipe("photo of a cute dog").images[0] - - When using with `quant_mapping`: - - >>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.quantizers.quantization_config - import QuantoConfig >>> from diffusers.quantizers import PipelineQuantizationConfig >>> from transformers import - BitsAndBytesConfig - - >>> pipeline_quant_config = PipelineQuantizationConfig( ... quant_mapping={ ... "transformer": - QuantoConfig(weights_dtype="int8"), ... "text_encoder_2": BitsAndBytesConfig( ... load_in_4bit=True, - compute_dtype=torch.bfloat16 ... ), ... } ... ) - - >>> pipe = DiffusionPipeline.from_pretrained( ... "black-forest-labs/FLUX.1-dev", ... - quantization_config=pipeline_quant_config, ... torch_dtype=torch.bfloat16, ... ).to("cuda") - - >>> image = pipe("photo of a cute dog").images[0] """ def __init__( From da6df86be610ff41cdb79a12e4ada2523a1798e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 11:49:27 +0530 Subject: [PATCH 13/19] doc fixes. --- docs/source/en/api/quantization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 3192f1f78380..e2ca990190e6 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -23,7 +23,7 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui ## PipelineQuantizationConfig -[[autodoc]] quantizers.__init__.PipelineQuantizationConfig +[[autodoc]] quantizers.PipelineQuantizationConfig ## BitsAndBytesConfig From 9a418a9eaca3532fcb0f0c87a69c70dfcde08bdb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 2 May 2025 12:21:15 +0530 Subject: [PATCH 14/19] doc fixes. --- docs/source/en/quantization/overview.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index d783b2cf90ec..7cebd72e626a 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -42,9 +42,9 @@ Diffusers currently supports the following quantization methods. ## Pipeline-level quantization -Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply +Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models ([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can -do this with [`PipelineQuantizationConfig`]. +do this with [`~quantizers.PipelineQuantizationConfig`]. Start by defining a `PipelineQuantizationConfig`: From 478a353e832d6f3335f5f3d5f9648fe72f011dde Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 6 May 2025 08:22:21 +0530 Subject: [PATCH 15/19] add validation to the input quantization_config. --- src/diffusers/pipelines/pipeline_utils.py | 4 ++++ .../test_pipeline_level_quantization.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c5496f76d20d..7cb2a12d3c94 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -47,6 +47,7 @@ from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin +from ..quantizers import PipelineQuantizationConfig from ..quantizers.bitsandbytes.utils import _check_bnb_status from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( @@ -742,6 +743,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " install accelerate\n```\n." ) + if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig): + raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.") + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" diff --git a/tests/quantization/test_pipeline_level_quantization.py b/tests/quantization/test_pipeline_level_quantization.py index 77df8be29a6d..b82b2889d72d 100644 --- a/tests/quantization/test_pipeline_level_quantization.py +++ b/tests/quantization/test_pipeline_level_quantization.py @@ -74,7 +74,7 @@ def test_quant_config_set_correctly_through_kwargs(self): _ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) - def test_quant_config_set_correctly_granular(self): + def test_quant_config_set_correctly_through_granular(self): quant_config = PipelineQuantizationConfig( quant_mapping={ "transformer": QuantoConfig(weights_dtype="int8"), @@ -128,6 +128,21 @@ def test_validation_for_kwargs(self): "The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception) ) + def test_raises_error_for_wrong_config_class(self): + quant_config = { + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + with self.assertRaises(ValueError) as err_context: + _ = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ) + self.assertTrue( + str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`." + ) + def test_validation_for_mapping(self): with self.assertRaises(ValueError) as err_context: _ = PipelineQuantizationConfig( From d6b48ea6e76cd6ffc7178dcb92d8ce1dd6114933 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 May 2025 08:44:26 +0530 Subject: [PATCH 16/19] clarify recommendations. --- docs/source/en/quantization/overview.md | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 7cebd72e626a..d1887419c3ff 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -97,4 +97,21 @@ This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pret In this case, `quant_kwargs` will be used to initialize the quantization specifications of the respective quantization configuration class of `quant_backend`. `components_to_quantize` is used to denote the components that will be quantized. For most pipelines, you would want to -keep `transformer` in the list as that is often the most compute and memory intensive. \ No newline at end of file +keep `transformer` in the list as that is often the most compute and memory intensive. + +The config below will work for most diffusion pipelines that have a `transformer` component present. +In most case, you will want to quantize the `transformer` component as that is often the most compute- +intensive part of a diffusion pipeline. + +```py +pipeline_quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16}, + components_to_quantize=["transformer"], +) +``` + +Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's +recommended to quantize the text encoders that are memory-intensive. Some examples include T5, +Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through +`text_encoder_2` while keeping the CLIP model intact (accessible through `text_encoder`). \ No newline at end of file From ffb974f973097cbdb90102a98973d2ea7b22b5e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 May 2025 15:22:11 +0530 Subject: [PATCH 17/19] docs --- docs/source/en/quantization/overview.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index d1887419c3ff..68b99f524ec0 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -79,7 +79,9 @@ image = pipe("photo of a cute dog").images[0] This method allows for more granular control over the quantization specifications of individual model-level components of a pipeline. It also allows for different quantization backends for -different components. In the above example, you used a combination of Quanto and BitsandBytes. +different components. In the above example, you used a combination of Quanto and BitsandBytes. However, +one caveat of this method is that users need to know which components come from `transformers` to be able +to import the right quantization config class. The other method is simpler in terms of experience but is less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way: @@ -111,6 +113,15 @@ pipeline_quant_config = PipelineQuantizationConfig( ) ``` +Below is a list of the supported quantization backends available in both `diffusers` and `transformers`: + +* `bitsandbytes_4bit` +* `bitsandbytes_8bit` +* `gguf` +* `quanto` +* `torchao` + + Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's recommended to quantize the text encoders that are memory-intensive. Some examples include T5, Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through From 86ee7735a7fab1f2a24351c876fc55fc19a1094c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 May 2025 15:33:23 +0530 Subject: [PATCH 18/19] add to ci. --- .github/workflows/nightly_tests.yml | 54 +++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 7696852ecd44..9af0bb6d2280 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -525,6 +525,60 @@ jobs: pip install slack_sdk tabulate python utils/log_reports.py >> $GITHUB_STEP_SUMMARY + run_nightly_pipeline_level_quantization_tests: + name: Torch quantization nightly tests + strategy: + fail-fast: false + max-parallel: 2 + runs-on: + group: aws-g6e-xlarge-plus + container: + image: diffusers/diffusers-pytorch-cuda + options: --shm-size "20gb" --ipc host --gpus 0 + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: NVIDIA-SMI + run: nvidia-smi + - name: Install dependencies + run: | + python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python -m uv pip install -e [quality,test] + python -m uv pip install -U bitsandbytes optimum_quanto + python -m uv pip install pytest-reportlog + - name: Environment + run: | + python utils/print_env.py + - name: Pipeline-level quantization tests on GPU + env: + HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} + # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms + CUBLAS_WORKSPACE_CONFIG: :16:8 + BIG_GPU_MEMORY: 40 + run: | + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ + --make-reports=tests_pipeline_level_quant_torch_cuda \ + --report-log=tests_pipeline_level_quant_torch_cuda.log \ + tests/quantization/test_pipeline_level_quantization.py + - name: Failure short reports + if: ${{ failure() }} + run: | + cat reports/tests_pipeline_level_quant_torch_cuda_stats.txt + cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v4 + with: + name: torch_cuda_pipeline_level_quant_reports + path: reports + - name: Generate Report and Notify Channel + if: always() + run: | + pip install slack_sdk tabulate + python utils/log_reports.py >> $GITHUB_STEP_SUMMARY + # M1 runner currently not well supported # TODO: (Dhruv) add these back when we setup better testing for Apple Silicon # run_nightly_tests_apple_m1: From 7b8a73d78f0e4cddcd679441ea460e745612803a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 9 May 2025 09:21:58 +0530 Subject: [PATCH 19/19] todo. --- src/diffusers/quantizers/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 1d683d07dcce..bd9e2303c93b 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -102,8 +102,8 @@ def _validate_init_kwargs_in_backends(self): if init_kwargs_transformers != init_kwargs_diffusers: raise ValueError( "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " - f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to the docs to learn more about how " - "this mapping would look like: TODO." + f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how " + "this mapping would look like." ) def _validate_quant_mapping_args(self):