|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +TorchMultimodal Tutorial: Finetuning FLAVA |
| 4 | +============================================ |
| 5 | +""" |
| 6 | + |
| 7 | +###################################################################### |
| 8 | +# Multimodal AI has recently become very popular owing to its ubiquitous |
| 9 | +# nature, from use cases like image captioning and visual search to more |
| 10 | +# recent applications like image generation from text. **TorchMultimodal |
| 11 | +# is a library powered by Pytorch consisting of building blocks and end to |
| 12 | +# end examples, aiming to enable and accelerate research in |
| 13 | +# multimodality**. |
| 14 | +# |
| 15 | +# In this tutorial, we will demonstrate how to use a **pretrained SoTA |
| 16 | +# model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from |
| 17 | +# TorchMultimodal library to finetune on a multimodal task i.e. visual |
| 18 | +# question answering** (VQA). The model consists of two unimodal transformer |
| 19 | +# based encoders for text and image and a multimodal encoder to combine |
| 20 | +# the two embeddings. It is pretrained using contrastive, image text matching and |
| 21 | +# text, image and multimodal masking losses. |
| 22 | + |
| 23 | + |
| 24 | +###################################################################### |
| 25 | +# Installation |
| 26 | +# ----------------- |
| 27 | +# We will use TextVQA dataset and bert tokenizer from HuggingFace for this |
| 28 | +# tutorial. So you need to install datasets and transformers in addition to TorchMultimodal. |
| 29 | +# |
| 30 | +# .. note:: |
| 31 | +# |
| 32 | +# When running this tutorial in Google Colab, install the required packages by |
| 33 | +# creating a new cell and running the following commands: |
| 34 | +# |
| 35 | +# .. code-block:: |
| 36 | +# |
| 37 | +# !pip install torchmultimodal-nightly |
| 38 | +# !pip install datasets |
| 39 | +# !pip install transformers |
| 40 | +# |
| 41 | + |
| 42 | +###################################################################### |
| 43 | +# Steps |
| 44 | +# ----- |
| 45 | +# |
| 46 | +# 1. Download the HuggingFace dataset to a directory on your computer by running the following command: |
| 47 | +# |
| 48 | +# .. code-block:: |
| 49 | +# |
| 50 | +# wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz |
| 51 | +# tar xf vocab.tar.gz |
| 52 | +# |
| 53 | +# .. note:: |
| 54 | +# If you are running this tutorial in Google Colab, run these commands |
| 55 | +# in a new cell and prepend these commands with an exclamation mark (!) |
| 56 | +# |
| 57 | +# |
| 58 | +# 2. For this tutorial, we treat VQA as a classification task where |
| 59 | +# the inputs are images and question (text) and the output is an answer class. |
| 60 | +# So we need to download the vocab file with answer classes and create the answer to |
| 61 | +# label mapping. |
| 62 | +# |
| 63 | +# We also load the `textvqa |
| 64 | +# dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ containing 34602 training samples |
| 65 | +# (images,questions and answers) from HuggingFace |
| 66 | +# |
| 67 | +# We see there are 3997 answer classes including a class representing |
| 68 | +# unknown answers. |
| 69 | +# |
| 70 | + |
| 71 | +with open("data/vocabs/answers_textvqa_more_than_1.txt") as f: |
| 72 | + vocab = f.readlines() |
| 73 | + |
| 74 | +answer_to_idx = {} |
| 75 | +for idx, entry in enumerate(vocab): |
| 76 | + answer_to_idx[entry.strip("\n")] = idx |
| 77 | +print(len(vocab)) |
| 78 | +print(vocab[:5]) |
| 79 | + |
| 80 | +from datasets import load_dataset |
| 81 | +dataset = load_dataset("textvqa") |
| 82 | + |
| 83 | +###################################################################### |
| 84 | +# Lets display a sample entry from the dataset: |
| 85 | +# |
| 86 | + |
| 87 | +import matplotlib.pyplot as plt |
| 88 | +import numpy as np |
| 89 | +idx = 5 |
| 90 | +print("Question: ", dataset["train"][idx]["question"]) |
| 91 | +print("Answers: " ,dataset["train"][idx]["answers"]) |
| 92 | +im = np.asarray(dataset["train"][idx]["image"].resize((500,500))) |
| 93 | +plt.imshow(im) |
| 94 | +plt.show() |
| 95 | + |
| 96 | + |
| 97 | +###################################################################### |
| 98 | +# 3. Next, we write the transform function to convert the image and text into |
| 99 | +# Tensors consumable by our model - For images, we use the transforms from |
| 100 | +# torchvision to convert to Tensor and resize to uniform sizes - For text, |
| 101 | +# we tokenize (and pad) them using the BertTokenizer from HuggingFace - |
| 102 | +# For answers (i.e. labels), we take the most frequently occuring answer |
| 103 | +# as the label to train with: |
| 104 | +# |
| 105 | + |
| 106 | +import torch |
| 107 | +from torchvision import transforms |
| 108 | +from collections import defaultdict |
| 109 | +from transformers import BertTokenizer |
| 110 | +from functools import partial |
| 111 | + |
| 112 | +def transform(tokenizer, input): |
| 113 | + batch = {} |
| 114 | + image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])]) |
| 115 | + image = image_transform(input["image"][0].convert("RGB")) |
| 116 | + batch["image"] = [image] |
| 117 | + |
| 118 | + tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512) |
| 119 | + batch.update(tokenized) |
| 120 | + |
| 121 | + |
| 122 | + ans_to_count = defaultdict(int) |
| 123 | + for ans in input["answers"][0]: |
| 124 | + ans_to_count[ans] += 1 |
| 125 | + max_value = max(ans_to_count, key=ans_to_count.get) |
| 126 | + ans_idx = answer_to_idx.get(max_value,0) |
| 127 | + batch["answers"] = torch.as_tensor([ans_idx]) |
| 128 | + return batch |
| 129 | + |
| 130 | +tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512) |
| 131 | +transform=partial(transform,tokenizer) |
| 132 | +dataset.set_transform(transform) |
| 133 | + |
| 134 | + |
| 135 | +###################################################################### |
| 136 | +# 4. Finally, we import the flava_model_for_classification from |
| 137 | +# torchmultimodal. It loads the pretrained flava checkpoint by default and |
| 138 | +# includes a classification head. |
| 139 | +# |
| 140 | +# The model forward function passes the image through the visual encoder |
| 141 | +# and the question through the text encoder. The image and question |
| 142 | +# embeddings are then passed through the multimodal encoder. The final |
| 143 | +# embedding corresponding to the CLS token is passed through a MLP head |
| 144 | +# which finally gives the probability distribution over each possible |
| 145 | +# answers. |
| 146 | +# |
| 147 | + |
| 148 | +from torchmultimodal.models.flava.model import flava_model_for_classification |
| 149 | +model = flava_model_for_classification(num_classes=len(vocab)) |
| 150 | + |
| 151 | + |
| 152 | +###################################################################### |
| 153 | +# 5. We put together the dataset and model in a toy training loop to |
| 154 | +# demonstrate how to train the model for 3 iterations: |
| 155 | +# |
| 156 | + |
| 157 | +from torch import nn |
| 158 | +BATCH_SIZE = 2 |
| 159 | +MAX_STEPS = 3 |
| 160 | +from torch.utils.data import DataLoader |
| 161 | + |
| 162 | +train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE) |
| 163 | +optimizer = torch.optim.AdamW(model.parameters()) |
| 164 | + |
| 165 | + |
| 166 | +epochs = 1 |
| 167 | +for _ in range(epochs): |
| 168 | + for idx, batch in enumerate(train_dataloader): |
| 169 | + optimizer.zero_grad() |
| 170 | + out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"]) |
| 171 | + loss = out.loss |
| 172 | + loss.backward() |
| 173 | + optimizer.step() |
| 174 | + print(f"Loss at step {idx} = {loss}") |
| 175 | + if idx > MAX_STEPS-1: |
| 176 | + break |
| 177 | + |
| 178 | + |
| 179 | +###################################################################### |
| 180 | +# Conclusion |
| 181 | +# ------------------- |
| 182 | +# |
| 183 | +# This tutorial introduced the basics around how to finetune on a |
| 184 | +# multimodal task using FLAVA from TorchMultimodal. Please also check out |
| 185 | +# other examples from the library like |
| 186 | +# `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__ |
| 187 | +# which is a multimodal model for object detection and |
| 188 | +# `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__ |
| 189 | +# which is multitask model spanning image, video and 3d classification. |
| 190 | +# |
0 commit comments