File tree Expand file tree Collapse file tree 2 files changed +14
-2
lines changed
examples/research_projects/diffusion_dpo Expand file tree Collapse file tree 2 files changed +14
-2
lines changed Original file line number Diff line number Diff line change @@ -414,6 +414,12 @@ def parse_args(input_args=None):
414
414
default = 4 ,
415
415
help = ("The dimension of the LoRA update matrices." ),
416
416
)
417
+ parser .add_argument (
418
+ "--tracker_name" ,
419
+ type = str ,
420
+ default = "diffusion-dpo-lora" ,
421
+ help = ("The name of the tracker to report results to." ),
422
+ )
417
423
418
424
if input_args is not None :
419
425
args = parser .parse_args (input_args )
@@ -726,7 +732,7 @@ def collate_fn(examples):
726
732
# We need to initialize the trackers we use, and also store our configuration.
727
733
# The trackers initializes automatically on the main process.
728
734
if accelerator .is_main_process :
729
- accelerator .init_trackers ("diffusion-dpo-lora" , config = vars (args ))
735
+ accelerator .init_trackers (args . tracker_name , config = vars (args ))
730
736
731
737
# Train!
732
738
total_batch_size = args .train_batch_size * accelerator .num_processes * args .gradient_accumulation_steps
Original file line number Diff line number Diff line change @@ -429,6 +429,12 @@ def parse_args(input_args=None):
429
429
default = 4 ,
430
430
help = ("The dimension of the LoRA update matrices." ),
431
431
)
432
+ parser .add_argument (
433
+ "--tracker_name" ,
434
+ type = str ,
435
+ default = "diffusion-dpo-lora-sdxl" ,
436
+ help = ("The name of the tracker to report results to." ),
437
+ )
432
438
433
439
if input_args is not None :
434
440
args = parser .parse_args (input_args )
@@ -821,7 +827,7 @@ def collate_fn(examples):
821
827
# We need to initialize the trackers we use, and also store our configuration.
822
828
# The trackers initializes automatically on the main process.
823
829
if accelerator .is_main_process :
824
- accelerator .init_trackers ("diffusion-dpo-lora-sdxl" , config = vars (args ))
830
+ accelerator .init_trackers (args . tracker_name , config = vars (args ))
825
831
826
832
# Train!
827
833
total_batch_size = args .train_batch_size * accelerator .num_processes * args .gradient_accumulation_steps
You can’t perform that action at this time.
0 commit comments