diff --git a/modules/dice_loss_metric_notes.ipynb b/modules/dice_loss_metric_notes.ipynb index 3d626a6adc..c9c4b2e33e 100644 --- a/modules/dice_loss_metric_notes.ipynb +++ b/modules/dice_loss_metric_notes.ipynb @@ -20,7 +20,7 @@ "import torch\n", "from monai.losses import DiceLoss\n", "from monai.metrics import DiceMetric\n", - "from monai.transforms import EnsureChannelFirst, AsDiscrete, Compose\n", + "from monai.transforms import AsDiscrete, Compose\n", "\n", "\n", "def print_tensor(name, t):\n", @@ -108,7 +108,7 @@ "outputs": [], "source": [ "# make one hot and add batch dimension\n", - "make_2_class = Compose([AsDiscrete(to_onehot=2), EnsureChannelFirst()])\n", + "make_2_class = Compose([AsDiscrete(to_onehot=2), lambda x: x[None]])\n", "\n", "grnd2 = make_2_class(grnd)\n", "pred2 = make_2_class(pred)" @@ -305,7 +305,7 @@ "outputs": [], "source": [ "# make one hot and add batch dimension\n", - "make_3_class = Compose([AsDiscrete(to_onehot=3), EnsureChannelFirst()])\n", + "make_3_class = Compose([AsDiscrete(to_onehot=3), lambda x: x[None]])\n", "\n", "mgrnd2 = make_3_class(mgrnd)\n", "mpred2 = make_3_class(mpred)" @@ -412,7 +412,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.12" } }, "nbformat": 4,