Skip to content

Commit a29ea36

Browse files
Update train_unconditional.py (#3899)
increase the time of timeout when using big dataset or high resolution
1 parent af48bf2 commit a29ea36

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import math
55
import os
66
import shutil
7+
from datetime import timedelta
78
from pathlib import Path
89
from typing import Optional
910

1011
import accelerate
1112
import datasets
1213
import torch
1314
import torch.nn.functional as F
14-
from accelerate import Accelerator
15+
from accelerate import Accelerator, InitProcessGroupKwargs
1516
from accelerate.logging import get_logger
1617
from accelerate.utils import ProjectConfiguration
1718
from datasets import load_dataset
@@ -286,11 +287,13 @@ def main(args):
286287
logging_dir = os.path.join(args.output_dir, args.logging_dir)
287288
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
288289

290+
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))#a big number for high resolution or big dataset
289291
accelerator = Accelerator(
290292
gradient_accumulation_steps=args.gradient_accumulation_steps,
291293
mixed_precision=args.mixed_precision,
292294
log_with=args.logger,
293295
project_config=accelerator_project_config,
296+
kwargs_handlers=[kwargs],
294297
)
295298

296299
if args.logger == "tensorboard":

0 commit comments

Comments
 (0)