From 0d13813c6a0db225df759babd2d543a98d2c0a06 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 18 Jun 2023 11:36:37 -0400 Subject: [PATCH] fixes #1428 Signed-off-by: Wenqi Li --- 3d_segmentation/challenge_baseline/run_net.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/3d_segmentation/challenge_baseline/run_net.py b/3d_segmentation/challenge_baseline/run_net.py index 2431333110..db14dffcb9 100644 --- a/3d_segmentation/challenge_baseline/run_net.py +++ b/3d_segmentation/challenge_baseline/run_net.py @@ -16,7 +16,6 @@ import shutil import sys -import numpy as np import torch import torch.nn as nn from ignite.contrib.handlers import ProgressBar @@ -42,7 +41,7 @@ def get_xforms(mode="train", keys=("image", "label")): """returns a composed transform for train/val/infer.""" xforms = [ - LoadImaged(keys, ensure_channel_first=True), + LoadImaged(keys, ensure_channel_first=True, image_only=True), Orientationd(keys, axcodes="LPS"), Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]), ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True), @@ -239,7 +238,7 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"): ) inferer = get_inferer() - saver = monai.data.NiftiSaver(output_dir=prediction_folder, mode="nearest") + saver = monai.transforms.SaveImage(output_dir=prediction_folder, mode="nearest", resample=True) with torch.no_grad(): for infer_data in infer_loader: logging.info(f"segmenting {infer_data['image'].meta['filename_or_obj']}") @@ -258,7 +257,8 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"): n = n + 1.0 preds = preds / n preds = (preds.argmax(dim=1, keepdims=True)).float() - saver.save_batch(preds, infer_data["image"].meta) + for p in preds: # save each image+metadata in the batch respectively + saver(p) # copy the saved segmentations into the required folder structure for submission submission_dir = os.path.join(prediction_folder, "to_submit")