Skip to content

Commit bbc4e18

Browse files
authored
1428 update to the latest usage SaveImage (#1429)
Fixes #1428 ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Avoid including large-size files in the PR. - [x] Clean up long text outputs from code cells in the notebook. - [x] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [x] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [x] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 5de3e5f commit bbc4e18

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

3d_segmentation/challenge_baseline/run_net.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import shutil
1717
import sys
1818

19-
import numpy as np
2019
import torch
2120
import torch.nn as nn
2221
from ignite.contrib.handlers import ProgressBar
@@ -42,7 +41,7 @@ def get_xforms(mode="train", keys=("image", "label")):
4241
"""returns a composed transform for train/val/infer."""
4342

4443
xforms = [
45-
LoadImaged(keys, ensure_channel_first=True),
44+
LoadImaged(keys, ensure_channel_first=True, image_only=True),
4645
Orientationd(keys, axcodes="LPS"),
4746
Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]),
4847
ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
@@ -239,7 +238,7 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
239238
)
240239

241240
inferer = get_inferer()
242-
saver = monai.data.NiftiSaver(output_dir=prediction_folder, mode="nearest")
241+
saver = monai.transforms.SaveImage(output_dir=prediction_folder, mode="nearest", resample=True)
243242
with torch.no_grad():
244243
for infer_data in infer_loader:
245244
logging.info(f"segmenting {infer_data['image'].meta['filename_or_obj']}")
@@ -258,7 +257,8 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
258257
n = n + 1.0
259258
preds = preds / n
260259
preds = (preds.argmax(dim=1, keepdims=True)).float()
261-
saver.save_batch(preds, infer_data["image"].meta)
260+
for p in preds: # save each image+metadata in the batch respectively
261+
saver(p)
262262

263263
# copy the saved segmentations into the required folder structure for submission
264264
submission_dir = os.path.join(prediction_folder, "to_submit")

0 commit comments

Comments
 (0)