diff --git a/multimodal/openi_multilabel_classification_transchex/README.md b/multimodal/openi_multilabel_classification_transchex/README.md new file mode 100644 index 0000000000..970e3a5969 --- /dev/null +++ b/multimodal/openi_multilabel_classification_transchex/README.md @@ -0,0 +1,23 @@ +# Preprocessing Open-I Dataset + +The Open-I dataset provides a collection of 3,996 radiology reports +with 8,121 associated images in PA, AP and lateral views. In this tutorial, we utilize the images from fronal view with their corresponding reports for training and +evaluation of the TransChex model. The chest x-ray images and reports are originally from the Indiana University hospital (see the licencing information below). +The 14 finding categories in this work include Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged-Cardiomediastinum, Fracture, Lung-Lesion, Lung-Opacity, No-Finding, Pleural-Effusion, Pleural-Other, Pneumonia, Pneumothorax and Support-Devices. More information can be found in the following link: +https://openi.nlm.nih.gov/faq + +License: Attribution-NonCommercial-NoDerivatives 4.0 International (CC BY-NC-ND 4.0) + +In this section, we provide the steps that are needed for preprocessing the Open-I dataset for +the multi-label disease classification tutorial using TransCheX model. As a result, once the following steps are +completed, the dataset can be readily used for the tutorial. + +### Preprocessing Steps +1) Create a new folder named 'monai_data' for downloading the raw data and preprocessing. +2) Download the chest X-ray images in PNG format from this [link](https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz). Copy the downloaded file (NLMCXR_png.tgz) +to 'monai_data' directory and extract it. +3) Download the reports in XML format from this [link](https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz). Copy the downloaded file (NLMCXR_reports.tgz) +to 'monai_data' directory and extract it. +4) Download the splits of train, validation and test datasets from this [link](https://drive.google.com/u/1/uc?id=1jvT0jVl9mgtWy4cS7LYbF43bQE4mrXAY&export=download). Copy the downloaded file (TransChex_openi.zip) +to 'monai_data' directory and extract it. +5) Run 'preprocess_openi.py' to process the images and reports. diff --git a/multimodal/openi_multilabel_classification_transchex/preprocess_openi.py b/multimodal/openi_multilabel_classification_transchex/preprocess_openi.py new file mode 100644 index 0000000000..dcba4b0d05 --- /dev/null +++ b/multimodal/openi_multilabel_classification_transchex/preprocess_openi.py @@ -0,0 +1,119 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from os import listdir +from os.path import isfile, join +import numpy as np +from xml.dom import minidom +from PIL import Image +import pandas as pd +import xml.etree.ElementTree as ET + +def create_report(img_names_list_, report_list_, gt_list_, save_add): + pd.DataFrame({'id': img_names_list_, 'report': report_list_, 'Atelectasis': gt_list_[:, 0], + 'Cardiomegaly': gt_list_[:, 1], 'Consolidation': gt_list_[:, 2],'Edema': gt_list_[:, 3], + 'Enlarged-Cardiomediastinum': gt_list_[:, 4], 'Fracture': gt_list_[:, 5], 'Lung-Lesion': gt_list_[:, 6], + 'Lung-Opacity': gt_list_[:, 7], 'No-Finding': gt_list_[:, 8], 'Pleural-Effusion': gt_list_[:, 9], + 'Pleural_Other': gt_list_[:, 10], 'Pneumonia': gt_list_[:, 11], 'Pneumothorax': gt_list_[:, 12], + 'Support-Devices': gt_list_[:, 13]}).to_csv(save_add, index=False) + +report_file_add= './monai_data/dataset_orig/NLMCXR_reports/ecgen-radiology' +img_file_add= './monai_data/dataset_orig/NLMCXR_png' +img_save_add = './monai_data/dataset_proc/images' +report_train_save_add = './monai_data/dataset_proc/train.csv' +report_val_save_add = './monai_data/dataset_proc/validation.csv' +report_test_save_add = './monai_data/dataset_proc/test.csv' + +if not os.path.isdir(img_save_add): + os.makedirs(img_save_add) +report_files = [f for f in listdir(report_file_add) if isfile(join(report_file_add, f))] + +train_data = np.load('./train.npy', allow_pickle=True).item() +train_data_id = train_data['id_GT'] +train_data_gt = train_data['label_GT'] + +val_data = np.load('./validation.npy', allow_pickle=True).item() +val_data_id = val_data['id_GT'] +val_data_gt = val_data['label_GT'] + +test_data = np.load('./test.npy', allow_pickle=True).item() +test_data_id = test_data['id_GT'] +test_data_gt = test_data['label_GT'] + +all_cases = np.union1d(np.union1d(train_data_id, val_data_id), test_data_id) + +img_names_list_train = [] +img_names_list_val = [] +img_names_list_test = [] + +report_list_train = [] +report_list_val = [] +report_list_test = [] + +gt_list_train = [] +gt_list_val = [] +gt_list_test = [] + +for file in report_files: + print('Processing {}'.format(file)) + add_xml = os.path.join(report_file_add, file) + docs = minidom.parse(add_xml) + tree = ET.parse(add_xml) + for node in tree.iter('AbstractText'): + i = 0 + for elem in node.iter(): + if elem.attrib['Label'] == "FINDINGS": + if elem.text == None: + report = "FINDINGS : " + else: + report = "FINDINGS : " + elem.text + elif elem.attrib['Label'] == "IMPRESSION": + if elem.text == None: + report = report + " IMPRESSION : " + else: + report = report + " IMPRESSION : " + elem.text + images = docs.getElementsByTagName("parentImage") + for i in images: + img_name = i.getAttribute("id") + '.png' + if img_name in all_cases: + Image.open(os.path.join(img_file_add, img_name)).resize((512, 512)).save( + os.path.join(img_save_add, img_name)) + if img_name in train_data_id: + img_names_list_train.append(img_name) + report_list_train.append(report) + gt_list_train.append(train_data_gt[np.where(train_data_id==img_name)[0][0]]) + elif img_name in val_data_id: + img_names_list_val.append(img_name) + report_list_val.append(report) + gt_list_val.append(val_data_gt[np.where(val_data_id == img_name)[0][0]]) + elif img_name in test_data_id: + img_names_list_test.append(img_name) + report_list_test.append(report) + gt_list_test.append(test_data_gt[np.where(test_data_id == img_name)[0][0]]) + +datasets = [{"save_add": report_train_save_add, + "img_name": np.array(img_names_list_train), + "report": np.array(report_list_train), + "gt": np.array(gt_list_train)}, + {"save_add": report_val_save_add, + "img_name": np.array(img_names_list_val), + "report": np.array(report_list_val), + "gt": np.array(gt_list_val)}, + {"save_add": report_test_save_add, + "img_name": np.array(img_names_list_test), + "report": np.array(report_list_test), + "gt": np.array(gt_list_test)} + ] +for dataset in datasets: + create_report(dataset["img_name"], dataset["report"], dataset["gt"], dataset["save_add"]) + +print('Processed Dataset Files Are Saved !') diff --git a/multimodal/openi_multilabel_classification_transchex/transchex_openi_multilabel_classification.ipynb b/multimodal/openi_multilabel_classification_transchex/transchex_openi_multilabel_classification.ipynb new file mode 100644 index 0000000000..3a68226137 --- /dev/null +++ b/multimodal/openi_multilabel_classification_transchex/transchex_openi_multilabel_classification.ipynb @@ -0,0 +1,593 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Chest X-ray Multi-Label Disease Classification With TransCheX \n", + "\n", + "This tutorial demonstrates how to construct a training workflow of TransCheX model [1] for chest X-ray multi-label disease classification using Open-I dataset. The TransCheX is multi-modal transformer-based model consisting of vision, language and mixed modality encoder that is designed for chest X-ray image classification.\n", + "\n", + "The Open-I dataset provides a collection of 3,996 radiology reports with 8,121 associated images in PA, AP and lateral views. In this tutorial, we utilize the images from fronal view with their corresponding reports for training and evaluation of the TransChex model.The 14 finding categories in this work include Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged-Cardiomediastinum, Fracture, Lung-Lesion, Lung-Opacity, No-Finding, Pleural-Effusion, Pleural-Other, Pneumonia, Pneumothorax and Support-Devices. More information can be found in the following link: \n", + "https://openi.nlm.nih.gov/faq\n", + "\n", + "License: Attribution-NonCommercial-NoDerivatives 4.0 International (CC BY-NC-ND 4.0)\n", + "\n", + "An example of images and corresponding reports in Open-I dataset is presented as follows [2]:\n", + "![image](https://lh3.googleusercontent.com/bmqrTg0oKbfuwced1SiNdZruqbex_srBbJ9p4nmceAfZaf0FFkl9pzc9UUFP3-6AxxxWDaWbLXfmev5E6_0RmEzd0rLQ1NciF7PTzOkbUcRTJIUKgcpxKZsYnw3L17ATvIFBD47xSIWWiCD28vWBVN1k72P2UPorK1GQJUFEbmDAfGn0XRM2rzwB29SXB2hEtQmbWbe4u4msvcX4spx2rEH-6Qrd-iQRMyDAhq0lstRYBvxtu7ZLRrwtj_P5FQRKeW0hEFqTCQZvKmC75FKoUiltHDfsAl2mig2nsUH0KDBc3atPn9lSBGBFOXsHZdsqw4Q86sXz0roz1vKQWJWcSG7l5YqmPoz5KGrspIs5OJ7QxVvVSmmbe8ctk-T7eBoz3juZ3ux5QhYT2C1BYxGVutLh017FAskyZ1on4BkDTlkLrKSUpbU5la9IrugKM_lAso_cM2ALWb07n-yjsYUJL55oyJBMLCRXyIIutrQSGJW0RwM5LBIgwyklV9P_bRF3_w36hoqtHFNbzN5zrW-RAeJS2nCTYOElmRhzbdl4CwbgVUuStEm66vfUhwtWBMgybyQKb3WVTx69FcgnNC7tuDiPHpU3UuDlNXjKkuh35kxNcbJGYh8ZTY3jmoiVd_nrN9Yh5scCaxxdMtNRgxMWaGFoj7Dl3enBM2wR2FNotZ10smre6F7acOfKSYceAvQXWCzSnZ_C5PJ1szrEFa6v3wn4=w805-h556-no?authuser=0)\n", + "\n", + "In this tutorial, we use the TransCheX model with 2 layers for each of vision, language mixed modality encoders respectively. As an input to the TransCheX, we use the patient **report** and corresponding **chest X-ray image**. The image itself will be divided into non-overlapping patches with a specified patch resolution and projected into an embedding space. Similarly the reports are tokenized and projected into their respective embedding space. The language and vision encoders seperately encode their respective features from the projected embeddings in each modality. Furthmore, the output of vision and language encoders are fed into a mixed modality encoder which extraxts mutual information. The output of the mixed modality encoder is then utilized for the classification application. \n", + "\n", + "[1] : \"Hatamizadeh et al.,TransCheX: Self-Supervised Pretraining of Vision-Language Transformers for Chest X-ray Analysis\"\n", + "\n", + "[2] : \"Shin et al.,Learning to Read Chest X-Rays: Recurrent Neural Cascade Model for Automated Image Annotation\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q \"monai[transformers, pandas]\"\n", + "!pip install -q scikit-learn==0.20.3\n", + "!python -c \"import matplotlib\" || pip install -q matplotlib\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image\n", + "from torchvision import transforms\n", + "from sklearn.metrics.ranking import roc_auc_score\n", + "from monai.optimizers.lr_scheduler import WarmupCosineSchedule\n", + "from monai.networks.nets import Transchex\n", + "from monai.config import print_config\n", + "from monai.utils import set_determinism\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from transformers import BertTokenizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download and pre-process the dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Download the Open-I dataset from the following link, for both the chest X-ray images and corresponding reports, and pre-process the dataset using the provided script:\n", + "\n", + "https://openi.nlm.nih.gov/faq\n", + "\n", + "Please refer to the pre-processing guide for more details. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "datadir = \"./monai_data\"\n", + "if not os.path.exists(datadir):\n", + " os.makedirs(datadir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Print Configurations " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MONAI version: 0.8.0\n", + "Numpy version: 1.21.0\n", + "Pytorch version: 1.6.0\n", + "MONAI flags: HAS_EXT = False, USE_COMPILED = False\n", + "MONAI rev id: 714d00dffe6653e21260160666c4c201ab66511b\n", + "\n", + "Optional dependencies:\n", + "Pytorch Ignite version: 0.4.4\n", + "Nibabel version: 3.1.1\n", + "scikit-image version: 0.14.2\n", + "Pillow version: 8.3.1\n", + "Tensorboard version: 2.2.0\n", + "gdown version: 3.13.0\n", + "TorchVision version: 0.7.0\n", + "tqdm version: 4.59.0\n", + "lmdb version: 1.2.1\n", + "psutil version: 5.6.1\n", + "pandas version: 0.24.2\n", + "einops version: 0.3.0\n", + "transformers version: 4.10.2\n", + "mlflow version: NOT INSTALLED or UNKNOWN VERSION.\n", + "\n", + "For details about installing the optional dependencies, please visit:\n", + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", + "\n" + ] + } + ], + "source": [ + "torch.backends.cudnn.benchmark = True\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set deterministic training for reproducibility\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "set_determinism(seed=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup dataloaders and transforms for training/validation/testomg" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class MultiModalDataset(Dataset):\n", + " def __init__(self, dataframe, tokenizer, parent_dir, max_seq_length=512):\n", + " self.max_seq_length = max_seq_length\n", + " self.tokenizer = tokenizer\n", + " self.data = dataframe\n", + " self.report_summary = self.data.report\n", + " self.img_name = self.data.id\n", + " self.targets = self.data.list\n", + "\n", + " self.preprocess = transforms.Compose(\n", + " [\n", + " transforms.Resize(256),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),\n", + " ]\n", + " )\n", + " self.parent_dir = parent_dir\n", + "\n", + " def __len__(self):\n", + " return len(self.report_summary)\n", + "\n", + " def encode_features(self, sent, max_seq_length, tokenizer):\n", + " tokens = tokenizer.tokenize(sent.strip())\n", + " if len(tokens) > max_seq_length - 2:\n", + " tokens = tokens[: (max_seq_length - 2)]\n", + " tokens = [\"[CLS]\"] + tokens + [\"[SEP]\"]\n", + " input_ids = tokenizer.convert_tokens_to_ids(tokens)\n", + " segment_ids = [0] * len(input_ids)\n", + " while len(input_ids) < max_seq_length:\n", + " input_ids.append(0)\n", + " segment_ids.append(0)\n", + " assert len(input_ids) == max_seq_length\n", + " assert len(segment_ids) == max_seq_length\n", + " return input_ids, segment_ids\n", + "\n", + " def __getitem__(self, index):\n", + " name = self.img_name[index].split(\".\")[0]\n", + " img_address = os.path.join(self.parent_dir, self.img_name[index])\n", + " image = Image.open(img_address)\n", + " images = self.preprocess(image)\n", + " report = str(self.report_summary[index])\n", + " report = \" \".join(report.split())\n", + " input_ids, segment_ids = self.encode_features(\n", + " report, self.max_seq_length, self.tokenizer\n", + " )\n", + " input_ids = torch.tensor(input_ids, dtype=torch.long)\n", + " segment_ids = torch.tensor(segment_ids, dtype=torch.long)\n", + " targets = torch.tensor(self.targets[index], dtype=torch.float)\n", + " return {\n", + " \"ids\": input_ids,\n", + " \"segment_ids\": segment_ids,\n", + " \"name\": name,\n", + " \"targets\": targets,\n", + " \"images\": images,\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Setup the model directory, tokenizer and dataloaders\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def load_txt_gt(add):\n", + " txt_gt = pd.read_csv(add)\n", + " txt_gt[\"list\"] = txt_gt[txt_gt.columns[2:]].values.tolist()\n", + " txt_gt = txt_gt[[\"id\", \"report\", \"list\"]].copy()\n", + " return txt_gt\n", + "\n", + "\n", + "logdir = \"./logdir\"\n", + "if not os.path.exists(logdir):\n", + " os.makedirs(logdir)\n", + "\n", + "parent_dir = \"./monai_data/dataset_proc/images/\"\n", + "train_txt_gt = load_txt_gt(\"./monai_data/dataset_proc/train.csv\")\n", + "val_txt_gt = load_txt_gt(\"./monai_data/dataset_proc/validation.csv\")\n", + "test_txt_gt = load_txt_gt(\"./monai_data/dataset_proc/test.csv\")\n", + "batch_size = 32\n", + "num_workers = 8\n", + "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\", do_lower_case=False)\n", + "training_set = MultiModalDataset(train_txt_gt, tokenizer, parent_dir)\n", + "train_params = {\n", + " \"batch_size\": batch_size,\n", + " \"shuffle\": True,\n", + " \"num_workers\": num_workers,\n", + " \"pin_memory\": True,\n", + "}\n", + "training_loader = DataLoader(training_set, **train_params)\n", + "valid_set = MultiModalDataset(val_txt_gt, tokenizer, parent_dir)\n", + "test_set = MultiModalDataset(test_txt_gt, tokenizer, parent_dir)\n", + "valid_params = {\"batch_size\": 1, \"shuffle\": False, \"num_workers\": 1, \"pin_memory\": True}\n", + "val_loader = DataLoader(valid_set, **valid_params)\n", + "test_loader = DataLoader(test_set, **valid_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create Model, Loss, Optimizer\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "total_epochs = 15\n", + "eval_num = 1\n", + "lr = 1e-4\n", + "weight_decay = 1e-5\n", + "\n", + "model = Transchex(\n", + " in_channels=3,\n", + " img_size=(256, 256),\n", + " num_classes=14,\n", + " patch_size=(32, 32),\n", + " num_language_layers=2,\n", + " num_vision_layers=2,\n", + " num_mixed_layers=2,\n", + ").to(device)\n", + "\n", + "loss_bce = torch.nn.BCELoss().cuda()\n", + "optimizer = torch.optim.Adam(\n", + " params=model.parameters(), lr=lr, weight_decay=weight_decay\n", + ")\n", + "scheduler = WarmupCosineSchedule(optimizer, warmup_steps=5, t_total=total_epochs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Execute a typical PyTorch training process" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "def save_ckp(state, checkpoint_dir):\n", + " torch.save(state, checkpoint_dir)\n", + "\n", + "\n", + "def compute_AUCs(gt, pred, num_classes=14):\n", + " with torch.no_grad():\n", + " AUROCs = []\n", + " gt_np = gt\n", + " pred_np = pred\n", + " for i in range(num_classes):\n", + " AUROCs.append(roc_auc_score(gt_np[:, i].tolist(), pred_np[:, i].tolist()))\n", + " return AUROCs\n", + "\n", + "\n", + "def train(epoch):\n", + " model.train()\n", + " for i, data in enumerate(training_loader, 0):\n", + " input_ids = data[\"ids\"].cuda()\n", + " segment_ids = data[\"segment_ids\"].cuda()\n", + " img = data[\"images\"].cuda()\n", + " targets = data[\"targets\"].cuda()\n", + " logits_lang = model(\n", + " input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids\n", + " )\n", + " loss = loss_bce(torch.sigmoid(logits_lang), targets)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " print(f\"Epoch: {epoch}, Iteration: {i}, Loss_Tot: {loss}\")\n", + "\n", + "\n", + "def validation(testing_loader):\n", + " model.eval()\n", + " targets_in = np.zeros((len(testing_loader), 14))\n", + " preds_cls = np.zeros((len(testing_loader), 14))\n", + " val_loss = []\n", + " with torch.no_grad():\n", + " for _, data in enumerate(testing_loader, 0):\n", + " input_ids = data[\"ids\"].cuda()\n", + " segment_ids = data[\"segment_ids\"].cuda()\n", + " img = data[\"images\"].cuda()\n", + " targets = data[\"targets\"].cuda()\n", + " logits_lang = model(\n", + " input_ids=input_ids, vision_feats=img, token_type_ids=segment_ids\n", + " )\n", + " prob = torch.sigmoid(logits_lang)\n", + " loss = loss_bce(prob, targets).item()\n", + " targets_in[_, :] = targets.detach().cpu().numpy()\n", + " preds_cls[_, :] = prob.detach().cpu().numpy()\n", + " val_loss.append(loss)\n", + " auc = compute_AUCs(targets_in, preds_cls, 14)\n", + " mean_auc = np.mean(auc)\n", + " mean_loss = np.mean(val_loss)\n", + " print(\n", + " \"Evaluation Statistics: Mean AUC : {}, Mean Loss : {}\".format(\n", + " mean_auc, mean_loss\n", + " )\n", + " )\n", + " return mean_auc, mean_loss, auc\n", + "\n", + "\n", + "auc_val_best = 0.0\n", + "epoch_loss_values = []\n", + "metric_values = []\n", + "for epoch in range(total_epochs):\n", + " train(epoch)\n", + " auc_val, loss_val, _ = validation(val_loader)\n", + " epoch_loss_values.append(loss_val)\n", + " metric_values.append(auc_val)\n", + " if auc_val > auc_val_best:\n", + " checkpoint = {\n", + " \"epoch\": epoch,\n", + " \"state_dict\": model.state_dict(),\n", + " \"optimizer\": optimizer.state_dict(),\n", + " }\n", + " save_ckp(checkpoint, logdir + \"/transchex.pt\")\n", + " auc_val_best = auc_val\n", + " print(\n", + " \"Model Was Saved ! Current Best Validation AUC: {} Current AUC: {}\".format(\n", + " auc_val_best, auc_val\n", + " )\n", + " )\n", + " else:\n", + " print(\n", + " \"Model Was NOT Saved ! Current Best Validation AUC: {} Current AUC: {}\".format(\n", + " auc_val_best, auc_val\n", + " )\n", + " )\n", + " scheduler.step()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training Finished ! Best Validation AUC: 0.9533 \n" + ] + } + ], + "source": [ + "print(f\"Training Finished ! Best Validation AUC: {auc_val_best:.4f} \")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot the loss and metric" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(\"train\", (12, 6))\n", + "plt.subplot(1, 2, 1)\n", + "plt.title(\"Average Loss\")\n", + "x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]\n", + "y = epoch_loss_values\n", + "plt.xlabel(\"Epoch\")\n", + "plt.plot(x, y)\n", + "plt.subplot(1, 2, 2)\n", + "plt.title(\"Val Mean AUC\")\n", + "x = [eval_num * (i + 1) for i in range(len(metric_values))]\n", + "y = metric_values\n", + "plt.xlabel(\"Epoch\")\n", + "plt.plot(x, y)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Check best model output with the input image and label" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After training is completed, we use the best validation checkpoint to test the model performance on the Open-I testing set. " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluation Statistics: Mean AUC : 0.9629915902329793, Mean Loss : 0.06937971694082447\n", + "\n", + "Mean test AUC for each class in 14 disease categories :\n", + "\n", + "Atelectasis: 0.9933158010081088\n", + "Cardiomegaly: 0.974534284581847\n", + "Consolidation: 0.9532794249775381\n", + "Edema: 0.9901960784313726\n", + "Enlarged-Cardiomediastinum: 0.9449934738019765\n", + "Fracture: 0.9911196911196911\n", + "Lung-Lesion: 0.9471389645776568\n", + "Lung-Opacity: 0.986452330401375\n", + "No-Finding: 0.9574158854734394\n", + "Pleural-Effusion: 0.8975490196078432\n", + "Pleural_Other: 0.9973118279569892\n", + "Pneumonia: 0.9714795008912656\n", + "Pneumothorax: 0.9787234042553191\n", + "Support-Devices: 0.8983725761772853\n" + ] + } + ], + "source": [ + "model.load_state_dict(torch.load(os.path.join(logdir, \"transchex.pt\"))[\"state_dict\"])\n", + "model.eval()\n", + "with torch.no_grad():\n", + " auc_val, loss_val, auc = validation(test_loader)\n", + "\n", + "print(\n", + " \"\\nMean test AUC for each class in 14 disease categories\\\n", + " :\\n\\nAtelectasis: {}\\nCardiomegaly: {}\\nConsolidation: {}\\nEdema: \\\n", + " {}\\nEnlarged-Cardiomediastinum: {}\\nFracture: {}\\nLung-Lesion: {}\\nLung-Opacity: \\\n", + " {}\\nNo-Finding: {}\\nPleural-Effusion: {}\\nPleural_Other: {}\\nPneumonia: \\\n", + " {}\\nPneumothorax: {}\\nSupport-Devices: {}\".format(\n", + " auc[0],\n", + " auc[1],\n", + " auc[2],\n", + " auc[3],\n", + " auc[4],\n", + " auc[5],\n", + " auc[6],\n", + " auc[7],\n", + " auc[8],\n", + " auc[9],\n", + " auc[10],\n", + " auc[11],\n", + " auc[12],\n", + " auc[13],\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This tutorial demonstrates the effectiveness of TransCheX model for multi-modal training using chest X-ray images and corrersponding reports. By using the Open-I dataset, we demonstrate how TransCheX model can be leveraged for multi-label classification problems involing 2 different modalities of data. \n", + "\n", + "As seen above, the mean AUC for the test dataset is 0.9629 which is 1.007% better than the best validation mean AUC." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/runner.sh b/runner.sh index 46981768aa..b021cc4099 100755 --- a/runner.sh +++ b/runner.sh @@ -59,6 +59,7 @@ doRun=true autofix=false failfast=false pattern="-and -name '*' -and ! -wholename '*federated_learning*'\ + -and ! -wholename '*transchex_openi*'\ -and ! -wholename '*unetr_btcv*'\ -and ! -wholename '*profiling_camelyon*'\ -and ! -wholename '*profiling_train_base_nvtx*'\