Skip to content

Commit bbc440f

Browse files
remove find unused param (#477)
1 parent d8b5a30 commit bbc440f

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

modules/dynunet_pipeline/train.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ def validation(args):
5252
net = net.to(device)
5353

5454
if multi_gpu_flag:
55-
net = DistributedDataParallel(
56-
module=net, device_ids=[device], find_unused_parameters=True
57-
)
55+
net = DistributedDataParallel(module=net, device_ids=[device])
5856

5957
num_classes = len(properties["labels"])
6058

@@ -139,9 +137,7 @@ def train(args):
139137
net = net.to(device)
140138

141139
if multi_gpu_flag:
142-
net = DistributedDataParallel(
143-
module=net, device_ids=[device], find_unused_parameters=True
144-
)
140+
net = DistributedDataParallel(module=net, device_ids=[device])
145141

146142
optimizer = torch.optim.SGD(
147143
net.parameters(),
@@ -193,7 +189,9 @@ def train(args):
193189

194190
train_handlers += [
195191
ValidationHandler(validator=evaluator, interval=interval, epoch_level=True),
196-
StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
192+
StatsHandler(
193+
tag_name="train_loss", output_transform=from_engine(["loss"], first=True)
194+
),
197195
]
198196

199197
trainer = DynUNetTrainer(

0 commit comments

Comments
 (0)