Skip to content

Commit b1310a5

Browse files
committed
Add TransCheX Tutorial
Signed-off-by: ahatamizadeh <ahatamizadeh@nvidia.com>
1 parent c93248c commit b1310a5

File tree

1 file changed

+76
-37
lines changed

1 file changed

+76
-37
lines changed

multimodal/openi_multilabel_classification_transchex/Transchex_OpenI_multilabel_classification.ipynb

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@
6969
"outputs": [],
7070
"source": [
7171
"import os\n",
72+
"\n",
7273
"datadir = \"./monai_data\"\n",
7374
"if not os.path.exists(datadir):\n",
74-
" os.makedirs(datadir)\n"
75+
" os.makedirs(datadir)"
7576
]
7677
},
7778
{
@@ -133,6 +134,7 @@
133134
"from monai.networks.nets import Transchex\n",
134135
"from monai.config import print_config\n",
135136
"from monai.utils import set_determinism\n",
137+
"\n",
136138
"torch.backends.cudnn.benchmark = True\n",
137139
"\n",
138140
"print_config()"
@@ -176,11 +178,13 @@
176178
" self.img_name = self.data.id\n",
177179
" self.targets = self.data.list\n",
178180
"\n",
179-
" self.preprocess = transforms.Compose([\n",
180-
" transforms.Resize(256),\n",
181-
" transforms.ToTensor(),\n",
182-
" transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
183-
" ])\n",
181+
" self.preprocess = transforms.Compose(\n",
182+
" [\n",
183+
" transforms.Resize(256),\n",
184+
" transforms.ToTensor(),\n",
185+
" transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),\n",
186+
" ]\n",
187+
" )\n",
184188
" self.parent_dir = parent_dir\n",
185189
"\n",
186190
" def __len__(self):\n",
@@ -189,8 +193,8 @@
189193
" def encode_features(self, sent, max_seq_length, tokenizer):\n",
190194
" tokens = tokenizer.tokenize(sent.strip())\n",
191195
" if len(tokens) > max_seq_length - 2:\n",
192-
" tokens = tokens[:(max_seq_length - 2)]\n",
193-
" tokens = ['[CLS]'] + tokens + ['[SEP]']\n",
196+
" tokens = tokens[: (max_seq_length - 2)]\n",
197+
" tokens = [\"[CLS]\"] + tokens + [\"[SEP]\"]\n",
194198
" input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
195199
" segment_ids = [0] * len(input_ids)\n",
196200
" while len(input_ids) < max_seq_length:\n",
@@ -201,22 +205,24 @@
201205
" return input_ids, segment_ids\n",
202206
"\n",
203207
" def __getitem__(self, index):\n",
204-
" name = self.img_name[index].split('.')[0]\n",
208+
" name = self.img_name[index].split(\".\")[0]\n",
205209
" img_address = os.path.join(self.parent_dir, self.img_name[index])\n",
206210
" image = Image.open(img_address)\n",
207211
" images = self.preprocess(image)\n",
208212
" report = str(self.report_summary[index])\n",
209213
" report = \" \".join(report.split())\n",
210-
" input_ids, segment_ids = self.encode_features(report, self.max_seq_length, self.tokenizer)\n",
214+
" input_ids, segment_ids = self.encode_features(\n",
215+
" report, self.max_seq_length, self.tokenizer\n",
216+
" )\n",
211217
" input_ids = torch.tensor(input_ids, dtype=torch.long)\n",
212218
" segment_ids = torch.tensor(segment_ids, dtype=torch.long)\n",
213219
" targets = torch.tensor(self.targets[index], dtype=torch.float)\n",
214220
" return {\n",
215-
" 'ids': input_ids,\n",
216-
" 'segment_ids': segment_ids,\n",
217-
" 'name': name,\n",
218-
" 'targets': targets,\n",
219-
" 'images': images,\n",
221+
" \"ids\": input_ids,\n",
222+
" \"segment_ids\": segment_ids,\n",
223+
" \"name\": name,\n",
224+
" \"targets\": targets,\n",
225+
" \"images\": images,\n",
220226
" }"
221227
]
222228
},
@@ -320,6 +326,7 @@
320326
"def save_ckp(state, checkpoint_dir):\n",
321327
" torch.save(state, checkpoint_dir)\n",
322328
"\n",
329+
"\n",
323330
"def compute_AUCs(gt, pred, num_classes=14):\n",
324331
" with torch.no_grad():\n",
325332
" AUROCs = []\n",
@@ -329,19 +336,23 @@
329336
" AUROCs.append(roc_auc_score(gt_np[:, i].tolist(), pred_np[:, i].tolist()))\n",
330337
" return AUROCs\n",
331338
"\n",
339+
"\n",
332340
"def train(epoch):\n",
333341
" model.train()\n",
334342
" for i, data in enumerate(training_loader, 0):\n",
335-
" input_ids = data['ids'].cuda()\n",
336-
" segment_ids = data['segment_ids'].cuda()\n",
337-
" img = data['images'].cuda()\n",
338-
" targets = data['targets'].cuda()\n",
339-
" logits_lang = model(input_ids=input_ids,vision_feats=img,token_type_ids=segment_ids)\n",
343+
" input_ids = data[\"ids\"].cuda()\n",
344+
" segment_ids = data[\"segment_ids\"].cuda()\n",
345+
" img = data[\"images\"].cuda()\n",
346+
" targets = data[\"targets\"].cuda()\n",
347+
" logits_lang = model(\n",
348+
" input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids\n",
349+
" )\n",
340350
" loss = loss_bce(torch.sigmoid(logits_lang), targets)\n",
341351
" optimizer.zero_grad()\n",
342352
" loss.backward()\n",
343353
" optimizer.step()\n",
344-
" print(f'Epoch: {epoch}, Iteration: {i}, Loss_Tot: {loss}')\n",
354+
" print(f\"Epoch: {epoch}, Iteration: {i}, Loss_Tot: {loss}\")\n",
355+
"\n",
345356
"\n",
346357
"def validation(testing_loader):\n",
347358
" model.eval()\n",
@@ -350,11 +361,13 @@
350361
" val_loss = []\n",
351362
" with torch.no_grad():\n",
352363
" for _, data in enumerate(testing_loader, 0):\n",
353-
" input_ids = data['ids'].cuda()\n",
354-
" segment_ids = data['segment_ids'].cuda()\n",
355-
" img = data['images'].cuda()\n",
356-
" targets = data['targets'].cuda()\n",
357-
" logits_lang = model(input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids)\n",
364+
" input_ids = data[\"ids\"].cuda()\n",
365+
" segment_ids = data[\"segment_ids\"].cuda()\n",
366+
" img = data[\"images\"].cuda()\n",
367+
" targets = data[\"targets\"].cuda()\n",
368+
" logits_lang = model(\n",
369+
" input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids\n",
370+
" )\n",
358371
" prob = torch.sigmoid(logits_lang)\n",
359372
" loss = loss_bce(prob, targets).item()\n",
360373
" targets_in[_, :] = targets.detach().cpu().numpy()\n",
@@ -363,9 +376,14 @@
363376
" auc = compute_AUCs(targets_in, preds_cls, 14)\n",
364377
" mean_auc = np.mean(auc)\n",
365378
" mean_loss = np.mean(val_loss)\n",
366-
" print('Evaluation Statistics: Mean AUC : {}, Mean Loss : {}'.format(mean_auc, mean_loss))\n",
379+
" print(\n",
380+
" \"Evaluation Statistics: Mean AUC : {}, Mean Loss : {}\".format(\n",
381+
" mean_auc, mean_loss\n",
382+
" )\n",
383+
" )\n",
367384
" return mean_auc, mean_loss, auc\n",
368385
"\n",
386+
"\n",
369387
"auc_val_best = 0.0\n",
370388
"epoch_loss_values = []\n",
371389
"metric_values = []\n",
@@ -375,14 +393,24 @@
375393
" epoch_loss_values.append(loss_val)\n",
376394
" metric_values.append(auc_val)\n",
377395
" if auc_val > auc_val_best:\n",
378-
" checkpoint = {'epoch': epoch,\n",
379-
" 'state_dict': model.state_dict(),\n",
380-
" 'optimizer': optimizer.state_dict()}\n",
381-
" save_ckp(checkpoint, logdir+'/transchex.pt')\n",
396+
" checkpoint = {\n",
397+
" \"epoch\": epoch,\n",
398+
" \"state_dict\": model.state_dict(),\n",
399+
" \"optimizer\": optimizer.state_dict(),\n",
400+
" }\n",
401+
" save_ckp(checkpoint, logdir + \"/transchex.pt\")\n",
382402
" auc_val_best = auc_val\n",
383-
" print('Model Was Saved ! Current Best Validation AUC: {} Current AUC: {}'.format(auc_val_best, auc_val))\n",
403+
" print(\n",
404+
" \"Model Was Saved ! Current Best Validation AUC: {} Current AUC: {}\".format(\n",
405+
" auc_val_best, auc_val\n",
406+
" )\n",
407+
" )\n",
384408
" else:\n",
385-
" print('Model Was NOT Saved ! Current Best Validation AUC: {} Current AUC: {}'.format(auc_val_best, auc_val))\n",
409+
" print(\n",
410+
" \"Model Was NOT Saved ! Current Best Validation AUC: {} Current AUC: {}\".format(\n",
411+
" auc_val_best, auc_val\n",
412+
" )\n",
413+
" )\n",
386414
" scheduler.step()"
387415
]
388416
},
@@ -400,9 +428,7 @@
400428
}
401429
],
402430
"source": [
403-
"print(\n",
404-
" f\"Training Finished ! Best Validation AUC: {auc_val_best:.4f} \"\n",
405-
")"
431+
"print(f\"Training Finished ! Best Validation AUC: {auc_val_best:.4f} \")"
406432
]
407433
},
408434
{
@@ -503,7 +529,20 @@
503529
"\n",
504530
"print(\n",
505531
" \"\\nMean test AUC for each class in 14 disease categories:\\n\\nAtelectasis: {}\\nCardiomegaly: {}\\nConsolidation: {}\\nEdema: {}\\nEnlarged-Cardiomediastinum: {}\\nFracture: {}\\nLung-Lesion: {}\\nLung-Opacity: {}\\nNo-Finding: {}\\nPleural-Effusion: {}\\nPleural_Other: {}\\nPneumonia: {}\\nPneumothorax: {}\\nSupport-Devices: {}\".format(\n",
506-
" auc[0], auc[1], auc[2], auc[3], auc[4], auc[5], auc[6], auc[7], auc[8], auc[9], auc[10], auc[11], auc[12], auc[13]\n",
532+
" auc[0],\n",
533+
" auc[1],\n",
534+
" auc[2],\n",
535+
" auc[3],\n",
536+
" auc[4],\n",
537+
" auc[5],\n",
538+
" auc[6],\n",
539+
" auc[7],\n",
540+
" auc[8],\n",
541+
" auc[9],\n",
542+
" auc[10],\n",
543+
" auc[11],\n",
544+
" auc[12],\n",
545+
" auc[13],\n",
507546
" )\n",
508547
")"
509548
]

0 commit comments

Comments
 (0)