diff --git a/modules/dynunet_pipeline/train.py b/modules/dynunet_pipeline/train.py index e5d88fa924..b954fae69a 100644 --- a/modules/dynunet_pipeline/train.py +++ b/modules/dynunet_pipeline/train.py @@ -52,9 +52,7 @@ def validation(args): net = net.to(device) if multi_gpu_flag: - net = DistributedDataParallel( - module=net, device_ids=[device], find_unused_parameters=True - ) + net = DistributedDataParallel(module=net, device_ids=[device]) num_classes = len(properties["labels"]) @@ -139,9 +137,7 @@ def train(args): net = net.to(device) if multi_gpu_flag: - net = DistributedDataParallel( - module=net, device_ids=[device], find_unused_parameters=True - ) + net = DistributedDataParallel(module=net, device_ids=[device]) optimizer = torch.optim.SGD( net.parameters(), @@ -193,7 +189,9 @@ def train(args): train_handlers += [ ValidationHandler(validator=evaluator, interval=interval, epoch_level=True), - StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)), + StatsHandler( + tag_name="train_loss", output_transform=from_engine(["loss"], first=True) + ), ] trainer = DynUNetTrainer(