Skip to content

Commit 2b320e8

Browse files
committed
adds eval
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent d8c74ba commit 2b320e8

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

modules/training_with_2d_slices.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323
from monai.data import DataLoader, PatchDataset, create_test_image_3d, list_data_collate
2424
from monai.inferers import SliceInferer
25+
from monai.metrics import DiceMetric
2526
from monai.transforms import (
2627
AsChannelFirstd,
2728
Compose,
@@ -39,6 +40,7 @@
3940

4041
def main(tempdir):
4142
monai.config.print_config()
43+
monai.utils.set_determinism(0)
4244
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
4345

4446
# -----
@@ -55,7 +57,6 @@ def main(tempdir):
5557

5658
n = nib.Nifti1Image(im, np.eye(4))
5759
nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
58-
5960
n = nib.Nifti1Image(seg, np.eye(4))
6061
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
6162

@@ -79,7 +80,7 @@ def main(tempdir):
7980
# 3D dataset with preprocessing transforms
8081
volume_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
8182
# use batch_size=1 to check the volumes because the input volumes have different shapes
82-
check_loader = DataLoader(volume_ds, batch_size=1, collate_fn=list_data_collate)
83+
check_loader = DataLoader(volume_ds, batch_size=1)
8384
check_data = monai.utils.misc.first(check_loader)
8485
print("first volume's shape: ", check_data["img"].shape, check_data["seg"].shape)
8586

@@ -172,7 +173,6 @@ def main(tempdir):
172173
# -----
173174
# inference with a SliceInferer
174175
# -----
175-
model.eval()
176176
val_transform = Compose(
177177
[
178178
LoadImaged(keys=["img", "seg"]),
@@ -183,8 +183,9 @@ def main(tempdir):
183183
)
184184
val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-3:], segs[-3:])]
185185
val_ds = monai.data.Dataset(data=val_files, transform=val_transform)
186-
data_loader = DataLoader(val_ds, pin_memory=torch.cuda.is_available())
187-
186+
data_loader = DataLoader(val_ds, num_workers=1, pin_memory=torch.cuda.is_available())
187+
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
188+
model.eval()
188189
with torch.no_grad():
189190
for val_data in data_loader:
190191
val_images = val_data["img"].to(device)
@@ -193,14 +194,17 @@ def main(tempdir):
193194
slice_inferer = SliceInferer(
194195
roi_size=roi_size,
195196
sw_batch_size=sw_batch_size,
196-
spatial_dim=2, # Spatial dim to slice along is defined here
197+
spatial_dim=1, # Spatial dim to slice along is defined here
197198
device=torch.device("cpu"),
198199
padding_mode="replicate",
199200
)
200201
val_output = slice_inferer(val_images, model)
202+
dice_metric(y_pred=val_output > 0.5, y=val_data["seg"])
203+
print("Dice: ", dice_metric.get_buffer()[-1][0])
201204
matshow3d(val_output[0] > 0.5)
202205
matshow3d(val_images[0])
203206
plt.show()
207+
print(f"Avg Dice: {dice_metric.aggregate().item()}")
204208

205209

206210
if __name__ == "__main__":

0 commit comments

Comments
 (0)