Skip to content

Commit 1dd0ac9

Browse files
authored
[DPO Training] pass tracker name as argument (#6542)
pass tracker name as argumentw
1 parent c6b0458 commit 1dd0ac9

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

examples/research_projects/diffusion_dpo/train_diffusion_dpo.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,12 @@ def parse_args(input_args=None):
414414
default=4,
415415
help=("The dimension of the LoRA update matrices."),
416416
)
417+
parser.add_argument(
418+
"--tracker_name",
419+
type=str,
420+
default="diffusion-dpo-lora",
421+
help=("The name of the tracker to report results to."),
422+
)
417423

418424
if input_args is not None:
419425
args = parser.parse_args(input_args)
@@ -726,7 +732,7 @@ def collate_fn(examples):
726732
# We need to initialize the trackers we use, and also store our configuration.
727733
# The trackers initializes automatically on the main process.
728734
if accelerator.is_main_process:
729-
accelerator.init_trackers("diffusion-dpo-lora", config=vars(args))
735+
accelerator.init_trackers(args.tracker_name, config=vars(args))
730736

731737
# Train!
732738
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,12 @@ def parse_args(input_args=None):
429429
default=4,
430430
help=("The dimension of the LoRA update matrices."),
431431
)
432+
parser.add_argument(
433+
"--tracker_name",
434+
type=str,
435+
default="diffusion-dpo-lora-sdxl",
436+
help=("The name of the tracker to report results to."),
437+
)
432438

433439
if input_args is not None:
434440
args = parser.parse_args(input_args)
@@ -821,7 +827,7 @@ def collate_fn(examples):
821827
# We need to initialize the trackers we use, and also store our configuration.
822828
# The trackers initializes automatically on the main process.
823829
if accelerator.is_main_process:
824-
accelerator.init_trackers("diffusion-dpo-lora-sdxl", config=vars(args))
830+
accelerator.init_trackers(args.tracker_name, config=vars(args))
825831

826832
# Train!
827833
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

0 commit comments

Comments
 (0)