Skip to content

Commit 1b01dad

Browse files
committed
[DLMED] add more features
Signed-off-by: Nic Ma <nma@nvidia.com>
1 parent 88b13ec commit 1b01dad

File tree

9 files changed

+203
-19
lines changed

9 files changed

+203
-19
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
python ../scripts/inference.py
2+
--base_config ../configs/inference.json
3+
--config ../configs/inference_v2.json
4+
--meta ../configs/metadata.json
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
python ../scripts/inference.py
2+
--config ../configs/trtinfer.json
3+
--meta ../configs/metadata.json

modules/model_package/spleen_segmentation/configs/inference.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"multi_gpu": false,
33
"amp": true,
4-
"model": "monai.data.load_net_with_metadata('../models/model.ts')[0]",
4+
"model": "#monai.data.load_net_with_metadata('../models/model.ts')[0]",
55
"network": {
66
"name": "UNet",
77
"args": {
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"amp": false,
3+
"network": {
4+
"name": "UNet",
5+
"args": {
6+
"spatial_dims": 3,
7+
"in_channels": 1,
8+
"out_channels": 2,
9+
"channels": [32, 64, 128, 256, 512],
10+
"strides": [2, 2, 2, 2],
11+
"num_res_units": 2,
12+
"norm": "group"
13+
}
14+
},
15+
"inferer": {
16+
"name": "SlidingWindowInferer",
17+
"args": {
18+
"roi_size": [96, 96, 96],
19+
"sw_batch_size": 4,
20+
"overlap": 0.6
21+
}
22+
}
23+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"preprocessing": {
3+
"ref": {
4+
"path": "../inference.json/preprocessing"
5+
}
6+
},
7+
"dataset": {
8+
"ref": {
9+
"path": "../inference.json/dataset"
10+
}
11+
},
12+
"model": "#load_trt_model(...)",
13+
"dataloader": {
14+
"name": "DALIpipeline"
15+
},
16+
"inferer": {
17+
"name": "TensorRTInferer"
18+
}
19+
}

modules/model_package/spleen_segmentation/scripts/export.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515

1616
import torch
17+
from monai.apps import ConfigParser
1718
from ignite.handlers import Checkpoint
1819
from monai.data import save_net_with_metadata
1920
from monai.networks import convert_to_torchscript
@@ -28,15 +29,15 @@ def main():
2829

2930
# load config file
3031
with open(args.config, "r") as f:
31-
cofnig_dict = json.load(f)
32+
config_dict = json.load(f)
3233
# load meta data
3334
with open(args.meta, "r") as f:
3435
meta_dict = json.load(f)
3536

3637
net: torch.nn.Module = None
3738
# TODO: parse network definiftion from config file and construct network instance
38-
# config_parser = ConfigParser(config_dict, meta_dict)
39-
# net = config_parser.get_component("network")
39+
config_parser = ConfigParser(config_dict)
40+
net = config_parser.get_instance("network")
4041

4142
checkpoint = torch.load(args.weights)
4243
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
@@ -51,7 +52,7 @@ def main():
5152
include_config_vals=False,
5253
append_timestamp=False,
5354
meta_values=meta_dict,
54-
more_extra_files={args.config: json.dumps(cofnig_dict).encode()},
55+
more_extra_files={args.config: json.dumps(config_dict).encode()},
5556
)
5657

5758

modules/model_package/spleen_segmentation/scripts/inference.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515

1616
import torch
17+
from monai.apps import ConfigParser
1718
from monai.data import decollate_batch
1819
from monai.inferers import Inferer
1920
from monai.transforms import Transform
@@ -22,38 +23,45 @@
2223

2324
def main():
2425
parser = argparse.ArgumentParser()
25-
parser.add_argument('--config', '-c', type=str, help='file path of config file that defines network', required=True)
26+
parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True)
2627
parser.add_argument('--meta', '-e', type=str, help='file path of the meta data')
2728
args = parser.parse_args()
2829

2930
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31+
configs = {}
3032

31-
# load config file
32-
with open(args.config, "r") as f:
33-
cofnig_dict = json.load(f)
3433
# load meta data
3534
with open(args.meta, "r") as f:
36-
meta_dict = json.load(f)
35+
configs.update(json.load(f))
36+
# load config file, can override meta data in config
37+
with open(args.config, "r") as f:
38+
configs.update(json.load(f))
3739

38-
net: torch.nn.Module = None
40+
model: torch.nn.Module = None
3941
dataloader: torch.utils.data.DataLoader = None
4042
inferer: Inferer = None
4143
postprocessing: Transform = None
4244
# TODO: parse inference config file and construct instances
43-
# config_parser = ConfigParser(config_dict, meta_dict)
44-
# net = config_parser.get_component("model").to(device)
45-
# dataloader = config_parser.get_component("dataloader")
46-
# inferer = config_parser.get_component("inferer")
47-
# postprocessing = config_parser.get_component("postprocessing")
45+
config_parser = ConfigParser(configs)
46+
47+
# change JSON config content in python code, lazy instantiation
48+
model_conf = config_parser.get_config("model")
49+
model_conf["disabled"] = False
50+
model = config_parser.build(model_conf).to(device)
51+
52+
# instantialize the components immediately
53+
dataloader = config_parser.get_instance("dataloader")
54+
inferer = config_parser.get_instance("inferer")
55+
postprocessing = config_parser.get_instance("postprocessing")
4856

49-
net.eval()
57+
model.eval()
5058
with torch.no_grad():
5159
for d in dataloader:
5260
images = d[CommonKeys.IMAGE].to(device)
5361
# define sliding window size and batch size for windows inference
54-
d[CommonKeys.PRED] = inferer(inputs=images, predictor=net)
62+
d[CommonKeys.PRED] = inferer(inputs=images, predictor=model)
5563
# decollate the batch data into a list of dictionaries, then execute postprocessing transforms
56-
d = [postprocessing(i) for i in decollate_batch(d)]
64+
[postprocessing(i) for i in decollate_batch(d)]
5765

5866

5967
if __name__ == '__main__':
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
2+
# Copyright 2020 - 2021 MONAI Consortium
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
import argparse
14+
import json
15+
16+
import torch
17+
from monai.apps import ConfigParser
18+
from monai.data import decollate_batch
19+
from monai.inferers import Inferer
20+
from monai.transforms import Transform
21+
from monai.utils.enums import CommonKeys
22+
23+
24+
def main():
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument('--base_config', '-c', type=str, help='file path of base config', required=False)
27+
parser.add_argument('--config', '-c', type=str, help='config file to override base config', required=True)
28+
parser.add_argument('--meta', '-e', type=str, help='file path of the meta data')
29+
args = parser.parse_args()
30+
31+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32+
configs = {}
33+
34+
# load meta data
35+
with open(args.meta, "r") as f:
36+
configs.update(json.load(f))
37+
# load base config file, can override meta data in config
38+
with open(args.base_config, "r") as f:
39+
configs.update(json.load(f))
40+
# load config file, add or override the content of base config
41+
with open(args.config, "r") as f:
42+
configs.update(json.load(f))
43+
44+
model: torch.nn.Module = None
45+
dataloader: torch.utils.data.DataLoader = None
46+
inferer: Inferer = None
47+
postprocessing: Transform = None
48+
# TODO: parse inference config file and construct instances
49+
config_parser = ConfigParser(configs)
50+
# instantialize the components immediately
51+
model = config_parser.get_instance("model").to(device)
52+
dataloader = config_parser.get_instance("dataloader")
53+
inferer = config_parser.get_instance("inferer")
54+
postprocessing = config_parser.get_instance("postprocessing")
55+
56+
model.eval()
57+
with torch.no_grad():
58+
for d in dataloader:
59+
images = d[CommonKeys.IMAGE].to(device)
60+
# define sliding window size and batch size for windows inference
61+
d[CommonKeys.PRED] = inferer(inputs=images, predictor=model)
62+
# decollate the batch data into a list of dictionaries, then execute postprocessing transforms
63+
[postprocessing(i) for i in decollate_batch(d)]
64+
65+
66+
if __name__ == '__main__':
67+
main()
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
2+
# Copyright 2020 - 2021 MONAI Consortium
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
import argparse
14+
import json
15+
16+
import torch
17+
from monai.apps import ConfigParser
18+
from monai.data import decollate_batch
19+
from monai.transforms import Transform
20+
21+
22+
def main():
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True)
25+
parser.add_argument('--meta', '-e', type=str, help='file path of the meta data')
26+
args = parser.parse_args()
27+
28+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29+
configs = {}
30+
31+
# load meta data
32+
with open(args.meta, "r") as f:
33+
configs.update(json.load(f))
34+
# load config file, can override meta data in config
35+
with open(args.config, "r") as f:
36+
configs.update(json.load(f))
37+
38+
# fake code to simulate TensorRT and DALI logic
39+
model: TRTModel = None
40+
dataloader: DALIpipeline = None
41+
inferer: TRTInfer = None
42+
postprocessing: Transform = None
43+
# TODO: parse inference config file and construct instances
44+
config_parser = ConfigParser(configs)
45+
46+
# instantialize the components immediately
47+
model = config_parser.get_instance("model").to(device)
48+
dataloader = config_parser.get_instance("dataloader")
49+
inferer = config_parser.get_instance("inferer")
50+
postprocessing = config_parser.get_instance("postprocessing")
51+
52+
# simuluate TensorRT and DALI logic
53+
for d in dataloader:
54+
r = inferer(inputs=d, predictor=model)
55+
[postprocessing(i) for i in decollate_batch(r)]
56+
57+
58+
if __name__ == '__main__':
59+
main()

0 commit comments

Comments
 (0)