Skip to content

Commit d8c74ba

Browse files
committed
adds 2d/3d example
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent ffa6c1d commit d8c74ba

File tree

1 file changed

+208
-0
lines changed

1 file changed

+208
-0
lines changed

modules/training_with_2d_slices.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
import sys
15+
import tempfile
16+
from glob import glob
17+
18+
import matplotlib.pyplot as plt
19+
import monai
20+
import nibabel as nib
21+
import numpy as np
22+
import torch
23+
from monai.data import DataLoader, PatchDataset, create_test_image_3d, list_data_collate
24+
from monai.inferers import SliceInferer
25+
from monai.transforms import (
26+
AsChannelFirstd,
27+
Compose,
28+
EnsureTyped,
29+
LoadImaged,
30+
RandCropByPosNegLabeld,
31+
RandRotate90d,
32+
Resized,
33+
ResizeWithPadOrCropd,
34+
ScaleIntensityd,
35+
SqueezeDimd,
36+
)
37+
from monai.visualize import matshow3d
38+
39+
40+
def main(tempdir):
41+
monai.config.print_config()
42+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
43+
44+
# -----
45+
# make demo data
46+
# -----
47+
# create a temporary directory and 40 random image, mask pairs
48+
print(f"generating synthetic data to {tempdir} (this may take a while)")
49+
for i in range(40):
50+
# make the input volumes different spatial shapes for demo purposes
51+
H, W, D = 30 + i, 40 + i, 50 + i
52+
im, seg = create_test_image_3d(
53+
H, W, D, num_seg_classes=1, channel_dim=-1, rad_max=10
54+
)
55+
56+
n = nib.Nifti1Image(im, np.eye(4))
57+
nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
58+
59+
n = nib.Nifti1Image(seg, np.eye(4))
60+
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
61+
62+
images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
63+
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
64+
train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:35], segs[:35])]
65+
66+
# -----
67+
# volume-level preprocessing
68+
# -----
69+
# volume-level transforms for both image and segmentation
70+
train_transforms = Compose(
71+
[
72+
LoadImaged(keys=["img", "seg"]),
73+
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
74+
ScaleIntensityd(keys="img"),
75+
RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
76+
EnsureTyped(keys=["img", "seg"]),
77+
]
78+
)
79+
# 3D dataset with preprocessing transforms
80+
volume_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
81+
# use batch_size=1 to check the volumes because the input volumes have different shapes
82+
check_loader = DataLoader(volume_ds, batch_size=1, collate_fn=list_data_collate)
83+
check_data = monai.utils.misc.first(check_loader)
84+
print("first volume's shape: ", check_data["img"].shape, check_data["seg"].shape)
85+
86+
# -----
87+
# volume to patch sampling
88+
# -----
89+
# define the sampling transforms, could also be other spatial cropping transforms
90+
# https://docs.monai.io/en/stable/transforms.html#crop-and-pad-dict
91+
num_samples = 4
92+
patch_func = RandCropByPosNegLabeld(
93+
keys=["img", "seg"],
94+
label_key="seg",
95+
spatial_size=[-1, -1, 1], # dynamic spatial_size for the first two dimensions
96+
pos=1,
97+
neg=1,
98+
num_samples=num_samples,
99+
)
100+
101+
# -----
102+
# patch-level preprocessing
103+
# -----
104+
# resize the sampled slices to a consistent size so that we can batch
105+
# the last spatial dim is always 1, so we squeeze dim.
106+
patch_transform = Compose(
107+
[
108+
SqueezeDimd(keys=["img", "seg"], dim=-1), # squeeze the last dim
109+
Resized(keys=["img", "seg"], spatial_size=[48, 48]),
110+
# ResizeWithPadOrCropd(keys=["img", "seg"], spatial_size=[48, 48], mode="replicate"),
111+
]
112+
)
113+
patch_ds = PatchDataset(
114+
volume_ds,
115+
transform=patch_transform,
116+
patch_func=patch_func,
117+
samples_per_image=num_samples,
118+
)
119+
train_loader = DataLoader(
120+
patch_ds,
121+
batch_size=3,
122+
shuffle=True, # this shuffles slices from different volumes
123+
num_workers=2,
124+
pin_memory=torch.cuda.is_available(),
125+
)
126+
check_data = monai.utils.misc.first(train_loader)
127+
print("first patch's shape: ", check_data["img"].shape, check_data["seg"].shape)
128+
129+
# -----
130+
# network defined for 2d inputs
131+
# -----
132+
# create UNet, DiceLoss and Adam optimizer
133+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134+
model = monai.networks.nets.UNet(
135+
spatial_dims=2,
136+
in_channels=1,
137+
out_channels=1,
138+
channels=(16, 32, 64, 128),
139+
strides=(2, 2, 2),
140+
num_res_units=2,
141+
).to(device)
142+
loss_function = monai.losses.DiceLoss(sigmoid=True)
143+
optimizer = torch.optim.Adam(model.parameters(), 5e-3)
144+
145+
# -----
146+
# training
147+
# -----
148+
epoch_loss_values = []
149+
num_epochs = 5
150+
for epoch in range(num_epochs):
151+
print("-" * 10)
152+
print(f"epoch {epoch + 1}/{num_epochs}")
153+
model.train()
154+
epoch_loss, step = 0, 0
155+
for batch_data in train_loader:
156+
step += 1
157+
inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
158+
optimizer.zero_grad()
159+
outputs = model(inputs)
160+
loss = loss_function(outputs, labels)
161+
loss.backward()
162+
optimizer.step()
163+
epoch_loss += loss.item()
164+
epoch_len = len(patch_ds) // train_loader.batch_size
165+
if step % 25 == 0:
166+
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
167+
epoch_loss /= step
168+
epoch_loss_values.append(epoch_loss)
169+
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
170+
print("train completed")
171+
172+
# -----
173+
# inference with a SliceInferer
174+
# -----
175+
model.eval()
176+
val_transform = Compose(
177+
[
178+
LoadImaged(keys=["img", "seg"]),
179+
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
180+
ScaleIntensityd(keys="img"),
181+
EnsureTyped(keys=["img", "seg"]),
182+
]
183+
)
184+
val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-3:], segs[-3:])]
185+
val_ds = monai.data.Dataset(data=val_files, transform=val_transform)
186+
data_loader = DataLoader(val_ds, pin_memory=torch.cuda.is_available())
187+
188+
with torch.no_grad():
189+
for val_data in data_loader:
190+
val_images = val_data["img"].to(device)
191+
roi_size = (48, 48)
192+
sw_batch_size = 3
193+
slice_inferer = SliceInferer(
194+
roi_size=roi_size,
195+
sw_batch_size=sw_batch_size,
196+
spatial_dim=2, # Spatial dim to slice along is defined here
197+
device=torch.device("cpu"),
198+
padding_mode="replicate",
199+
)
200+
val_output = slice_inferer(val_images, model)
201+
matshow3d(val_output[0] > 0.5)
202+
matshow3d(val_images[0])
203+
plt.show()
204+
205+
206+
if __name__ == "__main__":
207+
with tempfile.TemporaryDirectory() as tempdir:
208+
main(tempdir)

0 commit comments

Comments
 (0)