|
| 1 | +###################################################################### |
| 2 | +# TorchMultimodal Tutorial: FLAVA finetuning |
| 3 | +# -------------------------------------------- |
| 4 | +# |
| 5 | + |
| 6 | +###################################################################### |
| 7 | +# Multimodal AI has recently become very popular owing to its ubiquitous |
| 8 | +# nature, from use cases like image captioning and visual search to more |
| 9 | +# recent applications like image generation from text. **TorchMultimodal |
| 10 | +# is a library powered by Pytorch consisting of building blocks and end to |
| 11 | +# end examples, aiming to enable and accelerate research in |
| 12 | +# multimodality**. |
| 13 | +# |
| 14 | +# In this tutorial, we will demonstrate how to use a **pretrained SoTA |
| 15 | +# model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from |
| 16 | +# TorchMultimodal library to finetune on a multimodal task i.e. visual |
| 17 | +# question answering** (VQA). |
| 18 | +# |
| 19 | + |
| 20 | + |
| 21 | +###################################################################### |
| 22 | +# Installations |
| 23 | +# |
| 24 | +# We will use TextVQA dataset from HuggingFace for this |
| 25 | +# tutorial. So we install datasets in addition to TorchMultimodal |
| 26 | +# |
| 27 | + |
| 28 | +# TODO: replace with install from pip when binary is ready |
| 29 | +!git clone https://github.com/facebookresearch/multimodal.git |
| 30 | +!pip install -r multimodal/requirements.txt |
| 31 | +import os |
| 32 | +import sys |
| 33 | +sys.path.append(os.path.join(os.getcwd(),"multimodal")) |
| 34 | +sys.path.append(os.getcwd()) |
| 35 | +!pip install datasets |
| 36 | +!pip install transformers |
| 37 | + |
| 38 | + |
| 39 | +###################################################################### |
| 40 | +# For this tutorial, we treat VQA as a classification task. So we need to |
| 41 | +# download the vocab file with answer classes and create the answer to |
| 42 | +# label mapping. |
| 43 | +# |
| 44 | +# We also load the `textvqa |
| 45 | +# dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ from HuggingFace |
| 46 | +# |
| 47 | + |
| 48 | +!wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz |
| 49 | +!tar xf vocab.tar.gz |
| 50 | + |
| 51 | + |
| 52 | +with open("vocabs/answers_textvqa_more_than_1.txt") as f: |
| 53 | + vocab = f.readlines() |
| 54 | + |
| 55 | +answer_to_idx = {} |
| 56 | +for idx, entry in enumerate(vocab): |
| 57 | + answer_to_idx[entry.strip("\n")] = idx |
| 58 | + |
| 59 | + |
| 60 | +###################################################################### |
| 61 | +# We see there are 3997 answer classes including a class representing |
| 62 | +# unknown answers |
| 63 | +# |
| 64 | + |
| 65 | +print(len(vocab)) |
| 66 | +print(vocab[:5]) |
| 67 | + |
| 68 | +from datasets import load_dataset |
| 69 | +dataset = load_dataset("textvqa") |
| 70 | + |
| 71 | +from IPython.display import display, Image |
| 72 | +idx = 5 |
| 73 | +print("Question: ", dataset["train"][idx]["question"]) |
| 74 | +print("Answers: " ,dataset["train"][idx]["answers"]) |
| 75 | +display(dataset["train"][idx]["image"].resize((500,500))) |
| 76 | + |
| 77 | + |
| 78 | +###################################################################### |
| 79 | +# Next we write the transform function to convert the image and text into |
| 80 | +# Tensors consumable by our model - For images, we use the transforms from |
| 81 | +# torchvision to convert to Tensor and resize to uniform sizes - For text, |
| 82 | +# we tokenize (and pad) them using the BertTokenizer from HuggingFace - |
| 83 | +# For answers (i.e. labels), we take the most frequently occuring answer |
| 84 | +# as the label to train with |
| 85 | +# |
| 86 | + |
| 87 | +import torch |
| 88 | +from torchvision import transforms |
| 89 | +from collections import defaultdict |
| 90 | +from transformers import BertTokenizer |
| 91 | +from functools import partial |
| 92 | + |
| 93 | +def transform(tokenizer, input): |
| 94 | + batch = {} |
| 95 | + image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])]) |
| 96 | + image = image_transform(input["image"][0].convert("RGB")) |
| 97 | + batch["image"] = [image] |
| 98 | + |
| 99 | + tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512) |
| 100 | + batch.update(tokenized) |
| 101 | + |
| 102 | + |
| 103 | + ans_to_count = defaultdict(int) |
| 104 | + for ans in input["answers"][0]: |
| 105 | + ans_to_count[ans] += 1 |
| 106 | + max_value = max(ans_to_count, key=ans_to_count.get) |
| 107 | + ans_idx = answer_to_idx.get(max_value,0) |
| 108 | + batch["answers"] = torch.as_tensor([ans_idx]) |
| 109 | + |
| 110 | + return batch |
| 111 | + |
| 112 | +tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512) |
| 113 | +transform=partial(transform,tokenizer) |
| 114 | +dataset.set_transform(transform) |
| 115 | + |
| 116 | + |
| 117 | +###################################################################### |
| 118 | +# Finally, we import the flava_model_for_classification from |
| 119 | +# torchmultimodal. It loads the pretrained flava checkpoint by default and |
| 120 | +# includes a classification head. |
| 121 | +# |
| 122 | +# The model forward function passes the image through the visual encoder |
| 123 | +# and the question through the text encoder. The image and question |
| 124 | +# embeddings are then passed through the multimodal encoder. The final |
| 125 | +# embedding corresponding to the CLS token is passed through a MLP head |
| 126 | +# which finally gives the probability distribution over each possible |
| 127 | +# answers. |
| 128 | +# |
| 129 | + |
| 130 | +from torchmultimodal.models.flava.model import flava_model_for_classification |
| 131 | +model = flava_model_for_classification(num_classes=len(vocab)) |
| 132 | + |
| 133 | + |
| 134 | +###################################################################### |
| 135 | +# We put together the dataset and model in a toy training loop to |
| 136 | +# demonstrate how to train the model for 3 iterations. |
| 137 | +# |
| 138 | + |
| 139 | +from torch import nn |
| 140 | +BATCH_SIZE = 2 |
| 141 | +MAX_STEPS = 3 |
| 142 | +from torch.utils.data import DataLoader |
| 143 | + |
| 144 | +train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE) |
| 145 | +optimizer = torch.optim.AdamW(model.parameters()) |
| 146 | + |
| 147 | + |
| 148 | +epochs = 1 |
| 149 | +for _ in range(epochs): |
| 150 | + for idx, batch in enumerate(train_dataloader): |
| 151 | + optimizer.zero_grad() |
| 152 | + out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"], required_embedding="mm") |
| 153 | + loss = out.loss |
| 154 | + loss.backward() |
| 155 | + optimizer.step() |
| 156 | + print(f"Loss at step {idx} = {loss}") |
| 157 | + if idx > MAX_STEPS-1: |
| 158 | + break |
| 159 | + |
| 160 | + |
| 161 | +###################################################################### |
| 162 | +# Conclusion |
| 163 | +# |
| 164 | +# This tutorial introduced the basics around how to finetune on a |
| 165 | +# multimodal task using FLAVA from TorchMultimodal. Please also check out |
| 166 | +# other examples from the library like |
| 167 | +# `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__ |
| 168 | +# which is a multimodal model for object detection and |
| 169 | +# `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__ |
| 170 | +# which is multitask model spanning image, video and 3d classification. |
| 171 | +# |
| 172 | + |
0 commit comments