22
22
import torch
23
23
from monai .data import DataLoader , PatchDataset , create_test_image_3d , list_data_collate
24
24
from monai .inferers import SliceInferer
25
+ from monai .metrics import DiceMetric
25
26
from monai .transforms import (
26
27
AsChannelFirstd ,
27
28
Compose ,
39
40
40
41
def main (tempdir ):
41
42
monai .config .print_config ()
43
+ monai .utils .set_determinism (0 )
42
44
logging .basicConfig (stream = sys .stdout , level = logging .INFO )
43
45
44
46
# -----
@@ -55,7 +57,6 @@ def main(tempdir):
55
57
56
58
n = nib .Nifti1Image (im , np .eye (4 ))
57
59
nib .save (n , os .path .join (tempdir , f"img{ i :d} .nii.gz" ))
58
-
59
60
n = nib .Nifti1Image (seg , np .eye (4 ))
60
61
nib .save (n , os .path .join (tempdir , f"seg{ i :d} .nii.gz" ))
61
62
@@ -79,7 +80,7 @@ def main(tempdir):
79
80
# 3D dataset with preprocessing transforms
80
81
volume_ds = monai .data .Dataset (data = train_files , transform = train_transforms )
81
82
# use batch_size=1 to check the volumes because the input volumes have different shapes
82
- check_loader = DataLoader (volume_ds , batch_size = 1 , collate_fn = list_data_collate )
83
+ check_loader = DataLoader (volume_ds , batch_size = 1 )
83
84
check_data = monai .utils .misc .first (check_loader )
84
85
print ("first volume's shape: " , check_data ["img" ].shape , check_data ["seg" ].shape )
85
86
@@ -172,7 +173,6 @@ def main(tempdir):
172
173
# -----
173
174
# inference with a SliceInferer
174
175
# -----
175
- model .eval ()
176
176
val_transform = Compose (
177
177
[
178
178
LoadImaged (keys = ["img" , "seg" ]),
@@ -183,8 +183,9 @@ def main(tempdir):
183
183
)
184
184
val_files = [{"img" : img , "seg" : seg } for img , seg in zip (images [- 3 :], segs [- 3 :])]
185
185
val_ds = monai .data .Dataset (data = val_files , transform = val_transform )
186
- data_loader = DataLoader (val_ds , pin_memory = torch .cuda .is_available ())
187
-
186
+ data_loader = DataLoader (val_ds , num_workers = 1 , pin_memory = torch .cuda .is_available ())
187
+ dice_metric = DiceMetric (include_background = True , reduction = "mean" , get_not_nans = False )
188
+ model .eval ()
188
189
with torch .no_grad ():
189
190
for val_data in data_loader :
190
191
val_images = val_data ["img" ].to (device )
@@ -193,14 +194,17 @@ def main(tempdir):
193
194
slice_inferer = SliceInferer (
194
195
roi_size = roi_size ,
195
196
sw_batch_size = sw_batch_size ,
196
- spatial_dim = 2 , # Spatial dim to slice along is defined here
197
+ spatial_dim = 1 , # Spatial dim to slice along is defined here
197
198
device = torch .device ("cpu" ),
198
199
padding_mode = "replicate" ,
199
200
)
200
201
val_output = slice_inferer (val_images , model )
202
+ dice_metric (y_pred = val_output > 0.5 , y = val_data ["seg" ])
203
+ print (dice_metric .get_buffer ()[0 ][- 1 ])
201
204
matshow3d (val_output [0 ] > 0.5 )
202
205
matshow3d (val_images [0 ])
203
206
plt .show ()
207
+ print (f"Dice score: { dice_metric .aggregate ().item ()} " )
204
208
205
209
206
210
if __name__ == "__main__" :
0 commit comments