Skip to content

Commit db8d2aa

Browse files
authored
fixes EnsureChannel dice_loss_metric_notes (#964)
Signed-off-by: Wenqi Li <wenqil@nvidia.com> Fixes #963 ### Description be specific with the channel dim ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Notebook runs automatically `./runner [-p <regex_pattern>]` Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 297d424 commit db8d2aa

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

modules/dice_loss_metric_notes.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"import torch\n",
2121
"from monai.losses import DiceLoss\n",
2222
"from monai.metrics import DiceMetric\n",
23-
"from monai.transforms import EnsureChannelFirst, AsDiscrete, Compose\n",
23+
"from monai.transforms import AsDiscrete, Compose\n",
2424
"\n",
2525
"\n",
2626
"def print_tensor(name, t):\n",
@@ -108,7 +108,7 @@
108108
"outputs": [],
109109
"source": [
110110
"# make one hot and add batch dimension\n",
111-
"make_2_class = Compose([AsDiscrete(to_onehot=2), EnsureChannelFirst()])\n",
111+
"make_2_class = Compose([AsDiscrete(to_onehot=2), lambda x: x[None]])\n",
112112
"\n",
113113
"grnd2 = make_2_class(grnd)\n",
114114
"pred2 = make_2_class(pred)"
@@ -305,7 +305,7 @@
305305
"outputs": [],
306306
"source": [
307307
"# make one hot and add batch dimension\n",
308-
"make_3_class = Compose([AsDiscrete(to_onehot=3), EnsureChannelFirst()])\n",
308+
"make_3_class = Compose([AsDiscrete(to_onehot=3), lambda x: x[None]])\n",
309309
"\n",
310310
"mgrnd2 = make_3_class(mgrnd)\n",
311311
"mpred2 = make_3_class(mpred)"
@@ -412,7 +412,7 @@
412412
"name": "python",
413413
"nbconvert_exporter": "python",
414414
"pygments_lexer": "ipython3",
415-
"version": "3.8.13"
415+
"version": "3.8.12"
416416
}
417417
},
418418
"nbformat": 4,

0 commit comments

Comments
 (0)