|
69 | 69 | "outputs": [],
|
70 | 70 | "source": [
|
71 | 71 | "import os\n",
|
| 72 | + "\n", |
72 | 73 | "datadir = \"./monai_data\"\n",
|
73 | 74 | "if not os.path.exists(datadir):\n",
|
74 |
| - " os.makedirs(datadir)\n" |
| 75 | + " os.makedirs(datadir)" |
75 | 76 | ]
|
76 | 77 | },
|
77 | 78 | {
|
|
133 | 134 | "from monai.networks.nets import Transchex\n",
|
134 | 135 | "from monai.config import print_config\n",
|
135 | 136 | "from monai.utils import set_determinism\n",
|
| 137 | + "\n", |
136 | 138 | "torch.backends.cudnn.benchmark = True\n",
|
137 | 139 | "\n",
|
138 | 140 | "print_config()"
|
|
176 | 178 | " self.img_name = self.data.id\n",
|
177 | 179 | " self.targets = self.data.list\n",
|
178 | 180 | "\n",
|
179 |
| - " self.preprocess = transforms.Compose([\n", |
180 |
| - " transforms.Resize(256),\n", |
181 |
| - " transforms.ToTensor(),\n", |
182 |
| - " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n", |
183 |
| - " ])\n", |
| 181 | + " self.preprocess = transforms.Compose(\n", |
| 182 | + " [\n", |
| 183 | + " transforms.Resize(256),\n", |
| 184 | + " transforms.ToTensor(),\n", |
| 185 | + " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),\n", |
| 186 | + " ]\n", |
| 187 | + " )\n", |
184 | 188 | " self.parent_dir = parent_dir\n",
|
185 | 189 | "\n",
|
186 | 190 | " def __len__(self):\n",
|
|
189 | 193 | " def encode_features(self, sent, max_seq_length, tokenizer):\n",
|
190 | 194 | " tokens = tokenizer.tokenize(sent.strip())\n",
|
191 | 195 | " if len(tokens) > max_seq_length - 2:\n",
|
192 |
| - " tokens = tokens[:(max_seq_length - 2)]\n", |
193 |
| - " tokens = ['[CLS]'] + tokens + ['[SEP]']\n", |
| 196 | + " tokens = tokens[: (max_seq_length - 2)]\n", |
| 197 | + " tokens = [\"[CLS]\"] + tokens + [\"[SEP]\"]\n", |
194 | 198 | " input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
195 | 199 | " segment_ids = [0] * len(input_ids)\n",
|
196 | 200 | " while len(input_ids) < max_seq_length:\n",
|
|
201 | 205 | " return input_ids, segment_ids\n",
|
202 | 206 | "\n",
|
203 | 207 | " def __getitem__(self, index):\n",
|
204 |
| - " name = self.img_name[index].split('.')[0]\n", |
| 208 | + " name = self.img_name[index].split(\".\")[0]\n", |
205 | 209 | " img_address = os.path.join(self.parent_dir, self.img_name[index])\n",
|
206 | 210 | " image = Image.open(img_address)\n",
|
207 | 211 | " images = self.preprocess(image)\n",
|
208 | 212 | " report = str(self.report_summary[index])\n",
|
209 | 213 | " report = \" \".join(report.split())\n",
|
210 |
| - " input_ids, segment_ids = self.encode_features(report, self.max_seq_length, self.tokenizer)\n", |
| 214 | + " input_ids, segment_ids = self.encode_features(\n", |
| 215 | + " report, self.max_seq_length, self.tokenizer\n", |
| 216 | + " )\n", |
211 | 217 | " input_ids = torch.tensor(input_ids, dtype=torch.long)\n",
|
212 | 218 | " segment_ids = torch.tensor(segment_ids, dtype=torch.long)\n",
|
213 | 219 | " targets = torch.tensor(self.targets[index], dtype=torch.float)\n",
|
214 | 220 | " return {\n",
|
215 |
| - " 'ids': input_ids,\n", |
216 |
| - " 'segment_ids': segment_ids,\n", |
217 |
| - " 'name': name,\n", |
218 |
| - " 'targets': targets,\n", |
219 |
| - " 'images': images,\n", |
| 221 | + " \"ids\": input_ids,\n", |
| 222 | + " \"segment_ids\": segment_ids,\n", |
| 223 | + " \"name\": name,\n", |
| 224 | + " \"targets\": targets,\n", |
| 225 | + " \"images\": images,\n", |
220 | 226 | " }"
|
221 | 227 | ]
|
222 | 228 | },
|
|
320 | 326 | "def save_ckp(state, checkpoint_dir):\n",
|
321 | 327 | " torch.save(state, checkpoint_dir)\n",
|
322 | 328 | "\n",
|
| 329 | + "\n", |
323 | 330 | "def compute_AUCs(gt, pred, num_classes=14):\n",
|
324 | 331 | " with torch.no_grad():\n",
|
325 | 332 | " AUROCs = []\n",
|
|
329 | 336 | " AUROCs.append(roc_auc_score(gt_np[:, i].tolist(), pred_np[:, i].tolist()))\n",
|
330 | 337 | " return AUROCs\n",
|
331 | 338 | "\n",
|
| 339 | + "\n", |
332 | 340 | "def train(epoch):\n",
|
333 | 341 | " model.train()\n",
|
334 | 342 | " for i, data in enumerate(training_loader, 0):\n",
|
335 |
| - " input_ids = data['ids'].cuda()\n", |
336 |
| - " segment_ids = data['segment_ids'].cuda()\n", |
337 |
| - " img = data['images'].cuda()\n", |
338 |
| - " targets = data['targets'].cuda()\n", |
339 |
| - " logits_lang = model(input_ids=input_ids,vision_feats=img,token_type_ids=segment_ids)\n", |
| 343 | + " input_ids = data[\"ids\"].cuda()\n", |
| 344 | + " segment_ids = data[\"segment_ids\"].cuda()\n", |
| 345 | + " img = data[\"images\"].cuda()\n", |
| 346 | + " targets = data[\"targets\"].cuda()\n", |
| 347 | + " logits_lang = model(\n", |
| 348 | + " input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids\n", |
| 349 | + " )\n", |
340 | 350 | " loss = loss_bce(torch.sigmoid(logits_lang), targets)\n",
|
341 | 351 | " optimizer.zero_grad()\n",
|
342 | 352 | " loss.backward()\n",
|
343 | 353 | " optimizer.step()\n",
|
344 |
| - " print(f'Epoch: {epoch}, Iteration: {i}, Loss_Tot: {loss}')\n", |
| 354 | + " print(f\"Epoch: {epoch}, Iteration: {i}, Loss_Tot: {loss}\")\n", |
| 355 | + "\n", |
345 | 356 | "\n",
|
346 | 357 | "def validation(testing_loader):\n",
|
347 | 358 | " model.eval()\n",
|
|
350 | 361 | " val_loss = []\n",
|
351 | 362 | " with torch.no_grad():\n",
|
352 | 363 | " for _, data in enumerate(testing_loader, 0):\n",
|
353 |
| - " input_ids = data['ids'].cuda()\n", |
354 |
| - " segment_ids = data['segment_ids'].cuda()\n", |
355 |
| - " img = data['images'].cuda()\n", |
356 |
| - " targets = data['targets'].cuda()\n", |
357 |
| - " logits_lang = model(input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids)\n", |
| 364 | + " input_ids = data[\"ids\"].cuda()\n", |
| 365 | + " segment_ids = data[\"segment_ids\"].cuda()\n", |
| 366 | + " img = data[\"images\"].cuda()\n", |
| 367 | + " targets = data[\"targets\"].cuda()\n", |
| 368 | + " logits_lang = model(\n", |
| 369 | + " input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids\n", |
| 370 | + " )\n", |
358 | 371 | " prob = torch.sigmoid(logits_lang)\n",
|
359 | 372 | " loss = loss_bce(prob, targets).item()\n",
|
360 | 373 | " targets_in[_, :] = targets.detach().cpu().numpy()\n",
|
|
363 | 376 | " auc = compute_AUCs(targets_in, preds_cls, 14)\n",
|
364 | 377 | " mean_auc = np.mean(auc)\n",
|
365 | 378 | " mean_loss = np.mean(val_loss)\n",
|
366 |
| - " print('Evaluation Statistics: Mean AUC : {}, Mean Loss : {}'.format(mean_auc, mean_loss))\n", |
| 379 | + " print(\n", |
| 380 | + " \"Evaluation Statistics: Mean AUC : {}, Mean Loss : {}\".format(\n", |
| 381 | + " mean_auc, mean_loss\n", |
| 382 | + " )\n", |
| 383 | + " )\n", |
367 | 384 | " return mean_auc, mean_loss, auc\n",
|
368 | 385 | "\n",
|
| 386 | + "\n", |
369 | 387 | "auc_val_best = 0.0\n",
|
370 | 388 | "epoch_loss_values = []\n",
|
371 | 389 | "metric_values = []\n",
|
|
375 | 393 | " epoch_loss_values.append(loss_val)\n",
|
376 | 394 | " metric_values.append(auc_val)\n",
|
377 | 395 | " if auc_val > auc_val_best:\n",
|
378 |
| - " checkpoint = {'epoch': epoch,\n", |
379 |
| - " 'state_dict': model.state_dict(),\n", |
380 |
| - " 'optimizer': optimizer.state_dict()}\n", |
381 |
| - " save_ckp(checkpoint, logdir+'/transchex.pt')\n", |
| 396 | + " checkpoint = {\n", |
| 397 | + " \"epoch\": epoch,\n", |
| 398 | + " \"state_dict\": model.state_dict(),\n", |
| 399 | + " \"optimizer\": optimizer.state_dict(),\n", |
| 400 | + " }\n", |
| 401 | + " save_ckp(checkpoint, logdir + \"/transchex.pt\")\n", |
382 | 402 | " auc_val_best = auc_val\n",
|
383 |
| - " print('Model Was Saved ! Current Best Validation AUC: {} Current AUC: {}'.format(auc_val_best, auc_val))\n", |
| 403 | + " print(\n", |
| 404 | + " \"Model Was Saved ! Current Best Validation AUC: {} Current AUC: {}\".format(\n", |
| 405 | + " auc_val_best, auc_val\n", |
| 406 | + " )\n", |
| 407 | + " )\n", |
384 | 408 | " else:\n",
|
385 |
| - " print('Model Was NOT Saved ! Current Best Validation AUC: {} Current AUC: {}'.format(auc_val_best, auc_val))\n", |
| 409 | + " print(\n", |
| 410 | + " \"Model Was NOT Saved ! Current Best Validation AUC: {} Current AUC: {}\".format(\n", |
| 411 | + " auc_val_best, auc_val\n", |
| 412 | + " )\n", |
| 413 | + " )\n", |
386 | 414 | " scheduler.step()"
|
387 | 415 | ]
|
388 | 416 | },
|
|
400 | 428 | }
|
401 | 429 | ],
|
402 | 430 | "source": [
|
403 |
| - "print(\n", |
404 |
| - " f\"Training Finished ! Best Validation AUC: {auc_val_best:.4f} \"\n", |
405 |
| - ")" |
| 431 | + "print(f\"Training Finished ! Best Validation AUC: {auc_val_best:.4f} \")" |
406 | 432 | ]
|
407 | 433 | },
|
408 | 434 | {
|
|
503 | 529 | "\n",
|
504 | 530 | "print(\n",
|
505 | 531 | " \"\\nMean test AUC for each class in 14 disease categories:\\n\\nAtelectasis: {}\\nCardiomegaly: {}\\nConsolidation: {}\\nEdema: {}\\nEnlarged-Cardiomediastinum: {}\\nFracture: {}\\nLung-Lesion: {}\\nLung-Opacity: {}\\nNo-Finding: {}\\nPleural-Effusion: {}\\nPleural_Other: {}\\nPneumonia: {}\\nPneumothorax: {}\\nSupport-Devices: {}\".format(\n",
|
506 |
| - " auc[0], auc[1], auc[2], auc[3], auc[4], auc[5], auc[6], auc[7], auc[8], auc[9], auc[10], auc[11], auc[12], auc[13]\n", |
| 532 | + " auc[0],\n", |
| 533 | + " auc[1],\n", |
| 534 | + " auc[2],\n", |
| 535 | + " auc[3],\n", |
| 536 | + " auc[4],\n", |
| 537 | + " auc[5],\n", |
| 538 | + " auc[6],\n", |
| 539 | + " auc[7],\n", |
| 540 | + " auc[8],\n", |
| 541 | + " auc[9],\n", |
| 542 | + " auc[10],\n", |
| 543 | + " auc[11],\n", |
| 544 | + " auc[12],\n", |
| 545 | + " auc[13],\n", |
507 | 546 | " )\n",
|
508 | 547 | ")"
|
509 | 548 | ]
|
|
0 commit comments