Open
Description
Describe the bug
I am attempting to use Zero-3 for Flux Controlnet training on 8 GPUs following the guidance of README. The error below occured:
[rank0]: RuntimeError: 'weight' must be 2-D
Reproduction
accelerate config:
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 8
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
training command:
accelerate launch --config_file "./accelerate_config_zero3.yaml" train_controlnet_flux_zero3.py --pretrained_model_name_or_path=/srv/mindone/wty/flux.1-dev/ --jsonl_for_train=/srv/mindone/wty/diffusers/examples/controlnet/train_1000.jsonl --conditioning_image_column=conditioning_image --image_column=image --caption_column=text --output_dir=/srv/mindone/wty/diffusers/examples/controlnet/single_layer --mixed_precision="bf16" --resolution=512 --learning_rate=1e-5 --max_train_steps=100 --train_batch_size=1 --gradient_accumulation_steps=8 --num_double_layers=4 --num_single_layers=0 --seed=42 --gradient_checkpointing --cache_dir=/srv/mindone/wty/diffusers/examples/controlnet/cache --dataloader_num_workers=8 --resume_from_checkpoint="latest"
Logs
Map: 0%| | 0/1000 [00:00<?, ? examples/s]
[rank0]: Traceback (most recent call last):
[rank0]: File "/srv/mindone/wty/diffusers/examples/controlnet/train_controlnet_flux_zero3.py", line 1481, in <module>
[rank0]: main(args)
[rank0]: File "/srv/mindone/wty/diffusers/examples/controlnet/train_controlnet_flux_zero3.py", line 1182, in main
[rank0]: train_dataset = train_dataset.map(
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 562, in wrapper
[rank0]: out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3079, in map
[rank0]: for rank, done, content in Dataset._map_single(**dataset_kwargs):
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3519, in _map_single
[rank0]: for i, batch in iter_outputs(shard_iterable):
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3469, in iter_outputs
[rank0]: yield i, apply_function(example, i, offset=offset)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3392, in apply_function
[rank0]: processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
[rank0]: File "/srv/mindone/wty/diffusers/examples/controlnet/train_controlnet_flux_zero3.py", line 1094, in compute_embeddings
[rank0]: prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt(
[rank0]: File "/srv/mindone/wty/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 396, in encode_prompt
[rank0]: pooled_prompt_embeds = self._get_clip_prompt_embeds(
[rank0]: File "/srv/mindone/wty/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 328, in _get_clip_prompt_embeds
[rank0]: prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 1056, in forward
[rank0]: return self.text_model(
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 947, in forward
[rank0]: hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 292, in forward
[rank0]: inputs_embeds = self.token_embedding(input_ids)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 190, in forward
[rank0]: return F.embedding(
[rank0]: File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/functional.py", line 2551, in embedding
[rank0]: return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank0]: RuntimeError: 'weight' must be 2-D
System Info
- 🤗 Diffusers version: 0.33.0.dev0(HEAD on Fix redundant prev_output_channel assignment in UNet2DModel #10945)
- Platform: Linux-4.15.0-156-generic-x86_64-with-glibc2.27
- Running on Google Colab?: No
- Python version: 3.9.21
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.29.1
- Transformers version: 4.49.0
- Accelerate version: 1.4.0
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB