Skip to content

Commit 2f83bc1

Browse files
authored
Merge pull request #167 from Project-MONAI/force_cpu_torchscript
Force to remap to CPU for TorchScript model if GPUdoesn't exist
2 parents 0456678 + 14c1aae commit 2f83bc1

File tree

5 files changed

+15
-23
lines changed

5 files changed

+15
-23
lines changed

examples/apps/mednist_classifier_monaideploy/mednist_classifier_monaideploy.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,10 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
6969
image_tensor = self.transform(img) # (1, 64, 64), torch.float64
7070
image_tensor = image_tensor[None].float() # (1, 1, 64, 64), torch.float32
7171

72-
# Comment below line if you want to do CPU inference
73-
image_tensor = image_tensor.cuda()
72+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73+
image_tensor = image_tensor.to(device)
7474

7575
model = context.models.get() # get a TorchScriptModel object
76-
# Uncomment the following line if you want to do CPU inference
77-
# model.predictor = torch.jit.load(model.path, map_location="cpu").eval()
7876

7977
with torch.no_grad():
8078
outputs = model(image_tensor)

monai/deploy/core/models/torch_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def predictor(self) -> "torch.nn.Module": # type: ignore
4747
torch.nn.Module: the model's predictor
4848
"""
4949
if self._predictor is None:
50-
self._predictor = torch.jit.load(self.path).eval()
50+
# Use a device to dynamically remap, depending on the GPU availability.
51+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52+
self._predictor = torch.jit.load(self.path, map_location=device).eval()
5153
return self._predictor
5254

5355
@predictor.setter

monai/deploy/operators/monai_seg_inference_operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,18 +195,18 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
195195
pre_transforms = self._pre_transform if self._pre_transform else self.pre_process(self._reader)
196196
post_transforms = self._post_transforms if self._post_transforms else self.post_process(pre_transforms)
197197

198+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
198199
model = None
199200
if context.models:
200201
# `context.models.get(model_name)` returns a model instance if exists.
201202
# If model_name is not specified and only one model exists, it returns that model.
202203
model = context.models.get()
203204
else:
204205
print(f"Loading TorchScript model from: {MonaiSegInferenceOperator.MODEL_LOCAL_PATH}")
205-
model = torch.jit.load(MonaiSegInferenceOperator.MODEL_LOCAL_PATH)
206+
model = torch.jit.load(MonaiSegInferenceOperator.MODEL_LOCAL_PATH, map_location=device)
206207

207208
dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms)
208209
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
209-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
210210

211211
with torch.no_grad():
212212
for d in dataloader:

notebooks/tutorials/02_mednist_app-prebuilt.ipynb

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -432,12 +432,10 @@
432432
" image_tensor = self.transform(img) # (1, 64, 64), torch.float64\n",
433433
" image_tensor = image_tensor[None].float() # (1, 1, 64, 64), torch.float32\n",
434434
"\n",
435-
" # Comment below line if you want to do CPU inference\n",
436-
" image_tensor = image_tensor.cuda()\n",
435+
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
436+
" image_tensor = image_tensor.to(device)\n",
437437
"\n",
438438
" model = context.models.get() # get a TorchScriptModel object\n",
439-
" # Uncomment the following line if you want to do CPU inference\n",
440-
" # model.predictor = torch.jit.load(model.path, map_location=\"cpu\").eval()\n",
441439
"\n",
442440
" with torch.no_grad():\n",
443441
" outputs = model(image_tensor)\n",
@@ -652,12 +650,10 @@
652650
" image_tensor = self.transform(img) # (1, 64, 64), torch.float64\n",
653651
" image_tensor = image_tensor[None].float() # (1, 1, 64, 64), torch.float32\n",
654652
"\n",
655-
" # Comment below line if you want to do CPU inference\n",
656-
" image_tensor = image_tensor.cuda()\n",
653+
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
654+
" image_tensor = image_tensor.to(device)\n",
657655
"\n",
658656
" model = context.models.get() # get a TorchScriptModel object\n",
659-
" # Uncomment the following line if you want to do CPU inference\n",
660-
" # model.predictor = torch.jit.load(model.path, map_location=\"cpu\").eval()\n",
661657
"\n",
662658
" with torch.no_grad():\n",
663659
" outputs = model(image_tensor)\n",

notebooks/tutorials/02_mednist_app.ipynb

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -510,12 +510,10 @@
510510
" image_tensor = self.transform(img) # (1, 64, 64), torch.float64\n",
511511
" image_tensor = image_tensor[None].float() # (1, 1, 64, 64), torch.float32\n",
512512
"\n",
513-
" # Comment below line if you want to do CPU inference\n",
514-
" image_tensor = image_tensor.cuda()\n",
513+
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
514+
" image_tensor = image_tensor.to(device)\n",
515515
"\n",
516516
" model = context.models.get() # get a TorchScriptModel object\n",
517-
" # Uncomment the following line if you want to do CPU inference\n",
518-
" # model.predictor = torch.jit.load(model.path, map_location=\"cpu\").eval()\n",
519517
"\n",
520518
" with torch.no_grad():\n",
521519
" outputs = model(image_tensor)\n",
@@ -748,12 +746,10 @@
748746
" image_tensor = self.transform(img) # (1, 64, 64), torch.float64\n",
749747
" image_tensor = image_tensor[None].float() # (1, 1, 64, 64), torch.float32\n",
750748
"\n",
751-
" # Comment below line if you want to do CPU inference\n",
752-
" image_tensor = image_tensor.cuda()\n",
749+
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
750+
" image_tensor = image_tensor.to(device)\n",
753751
"\n",
754752
" model = context.models.get() # get a TorchScriptModel object\n",
755-
" # Uncomment the following line if you want to do CPU inference\n",
756-
" # model.predictor = torch.jit.load(model.path, map_location=\"cpu\").eval()\n",
757753
"\n",
758754
" with torch.no_grad():\n",
759755
" outputs = model(image_tensor)\n",

0 commit comments

Comments
 (0)