|
20 | 20 | "import torch\n",
|
21 | 21 | "from monai.losses import DiceLoss\n",
|
22 | 22 | "from monai.metrics import DiceMetric\n",
|
23 |
| - "from monai.transforms import EnsureChannelFirst, AsDiscrete, Compose\n", |
| 23 | + "from monai.transforms import AsDiscrete, Compose\n", |
24 | 24 | "\n",
|
25 | 25 | "\n",
|
26 | 26 | "def print_tensor(name, t):\n",
|
|
108 | 108 | "outputs": [],
|
109 | 109 | "source": [
|
110 | 110 | "# make one hot and add batch dimension\n",
|
111 |
| - "make_2_class = Compose([AsDiscrete(to_onehot=2), EnsureChannelFirst()])\n", |
| 111 | + "make_2_class = Compose([AsDiscrete(to_onehot=2), lambda x: x[None]])\n", |
112 | 112 | "\n",
|
113 | 113 | "grnd2 = make_2_class(grnd)\n",
|
114 | 114 | "pred2 = make_2_class(pred)"
|
|
305 | 305 | "outputs": [],
|
306 | 306 | "source": [
|
307 | 307 | "# make one hot and add batch dimension\n",
|
308 |
| - "make_3_class = Compose([AsDiscrete(to_onehot=3), EnsureChannelFirst()])\n", |
| 308 | + "make_3_class = Compose([AsDiscrete(to_onehot=3), lambda x: x[None]])\n", |
309 | 309 | "\n",
|
310 | 310 | "mgrnd2 = make_3_class(mgrnd)\n",
|
311 | 311 | "mpred2 = make_3_class(mpred)"
|
|
412 | 412 | "name": "python",
|
413 | 413 | "nbconvert_exporter": "python",
|
414 | 414 | "pygments_lexer": "ipython3",
|
415 |
| - "version": "3.8.13" |
| 415 | + "version": "3.8.12" |
416 | 416 | }
|
417 | 417 | },
|
418 | 418 | "nbformat": 4,
|
|
0 commit comments