Skip to content

Commit d991b9b

Browse files
author
Ankita De
committed
[WIP] Add torchmultimodal tutorial for flava finetuning
ghstack-source-id: e043284 Pull Request resolved: #2055
1 parent b202420 commit d991b9b

File tree

4 files changed

+196
-0
lines changed

4 files changed

+196
-0
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ download:
102102
wget -nv -N https://download.pytorch.org/models/resnet18-5c106cde.pth -P $(DATADIR)
103103
cp $(DATADIR)/resnet18-5c106cde.pth prototype_source/data/resnet18_pretrained_float.pth
104104

105+
# Download vocab for beginner_source/flava_finetuning_tutorial.py
106+
wget -nv -N http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz -P $(DATADIR)
107+
tar $(TAROPTS) -xzf $(DATADIR)/vocab.tar.gz -C ./beginner_source/data/
108+
105109

106110
docs:
107111
make download
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
######################################################################
2+
# TorchMultimodal Tutorial: FLAVA finetuning
3+
# --------------------------------------------
4+
#
5+
6+
######################################################################
7+
# Multimodal AI has recently become very popular owing to its ubiquitous
8+
# nature, from use cases like image captioning and visual search to more
9+
# recent applications like image generation from text. **TorchMultimodal
10+
# is a library powered by Pytorch consisting of building blocks and end to
11+
# end examples, aiming to enable and accelerate research in
12+
# multimodality**.
13+
#
14+
# In this tutorial, we will demonstrate how to use a **pretrained SoTA
15+
# model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from
16+
# TorchMultimodal library to finetune on a multimodal task i.e. visual
17+
# question answering** (VQA).
18+
#
19+
20+
21+
######################################################################
22+
# Installations
23+
#
24+
# We will use TextVQA dataset from HuggingFace for this
25+
# tutorial. So we install datasets in addition to TorchMultimodal
26+
#
27+
28+
# TODO: replace with install from pip when binary is ready
29+
!git clone https://github.com/facebookresearch/multimodal.git
30+
!pip install -r multimodal/requirements.txt
31+
import os
32+
import sys
33+
sys.path.append(os.path.join(os.getcwd(),"multimodal"))
34+
sys.path.append(os.getcwd())
35+
!pip install datasets
36+
!pip install transformers
37+
38+
39+
######################################################################
40+
# For this tutorial, we treat VQA as a classification task. So we need to
41+
# download the vocab file with answer classes and create the answer to
42+
# label mapping.
43+
#
44+
# We also load the `textvqa
45+
# dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ from HuggingFace
46+
#
47+
48+
!wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz
49+
!tar xf vocab.tar.gz
50+
51+
52+
with open("vocabs/answers_textvqa_more_than_1.txt") as f:
53+
vocab = f.readlines()
54+
55+
answer_to_idx = {}
56+
for idx, entry in enumerate(vocab):
57+
answer_to_idx[entry.strip("\n")] = idx
58+
59+
60+
######################################################################
61+
# We see there are 3997 answer classes including a class representing
62+
# unknown answers
63+
#
64+
65+
print(len(vocab))
66+
print(vocab[:5])
67+
68+
from datasets import load_dataset
69+
dataset = load_dataset("textvqa")
70+
71+
from IPython.display import display, Image
72+
idx = 5
73+
print("Question: ", dataset["train"][idx]["question"])
74+
print("Answers: " ,dataset["train"][idx]["answers"])
75+
display(dataset["train"][idx]["image"].resize((500,500)))
76+
77+
78+
######################################################################
79+
# Next we write the transform function to convert the image and text into
80+
# Tensors consumable by our model - For images, we use the transforms from
81+
# torchvision to convert to Tensor and resize to uniform sizes - For text,
82+
# we tokenize (and pad) them using the BertTokenizer from HuggingFace -
83+
# For answers (i.e. labels), we take the most frequently occuring answer
84+
# as the label to train with
85+
#
86+
87+
import torch
88+
from torchvision import transforms
89+
from collections import defaultdict
90+
from transformers import BertTokenizer
91+
from functools import partial
92+
93+
def transform(tokenizer, input):
94+
batch = {}
95+
image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])
96+
image = image_transform(input["image"][0].convert("RGB"))
97+
batch["image"] = [image]
98+
99+
tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512)
100+
batch.update(tokenized)
101+
102+
103+
ans_to_count = defaultdict(int)
104+
for ans in input["answers"][0]:
105+
ans_to_count[ans] += 1
106+
max_value = max(ans_to_count, key=ans_to_count.get)
107+
ans_idx = answer_to_idx.get(max_value,0)
108+
batch["answers"] = torch.as_tensor([ans_idx])
109+
110+
return batch
111+
112+
tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512)
113+
transform=partial(transform,tokenizer)
114+
dataset.set_transform(transform)
115+
116+
117+
######################################################################
118+
# Finally, we import the flava_model_for_classification from
119+
# torchmultimodal. It loads the pretrained flava checkpoint by default and
120+
# includes a classification head.
121+
#
122+
# The model forward function passes the image through the visual encoder
123+
# and the question through the text encoder. The image and question
124+
# embeddings are then passed through the multimodal encoder. The final
125+
# embedding corresponding to the CLS token is passed through a MLP head
126+
# which finally gives the probability distribution over each possible
127+
# answers.
128+
#
129+
130+
from torchmultimodal.models.flava.model import flava_model_for_classification
131+
model = flava_model_for_classification(num_classes=len(vocab))
132+
133+
134+
######################################################################
135+
# We put together the dataset and model in a toy training loop to
136+
# demonstrate how to train the model for 3 iterations.
137+
#
138+
139+
from torch import nn
140+
BATCH_SIZE = 2
141+
MAX_STEPS = 3
142+
from torch.utils.data import DataLoader
143+
144+
train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE)
145+
optimizer = torch.optim.AdamW(model.parameters())
146+
147+
148+
epochs = 1
149+
for _ in range(epochs):
150+
for idx, batch in enumerate(train_dataloader):
151+
optimizer.zero_grad()
152+
out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"], required_embedding="mm")
153+
loss = out.loss
154+
loss.backward()
155+
optimizer.step()
156+
print(f"Loss at step {idx} = {loss}")
157+
if idx > MAX_STEPS-1:
158+
break
159+
160+
161+
######################################################################
162+
# Conclusion
163+
#
164+
# This tutorial introduced the basics around how to finetune on a
165+
# multimodal task using FLAVA from TorchMultimodal. Please also check out
166+
# other examples from the library like
167+
# `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__
168+
# which is a multimodal model for object detection and
169+
# `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
170+
# which is multitask model spanning image, video and 3d classification.
171+
#
172+

index.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,15 @@ What's new in PyTorch tutorials?
646646
:link: advanced/sharding.html
647647
:tags: TorchRec,Recommender
648648

649+
.. Multimodality
650+
651+
.. customcarditem::
652+
:header: Introduction to TorchMultimodal
653+
:card_description: TorchMultimodal is a library that provides models, primitives and examples for training multimodal tasks
654+
:image: _static/img/thumbnails/torchrec.png
655+
:link: beginner/flava_finetuning_tutorial.html
656+
:tags: TorchMultimodal
657+
649658

650659
.. End of tutorial card section
651660
@@ -919,3 +928,11 @@ Additional Resources
919928

920929
intermediate/torchrec_tutorial
921930
advanced/sharding
931+
932+
.. toctree::
933+
:maxdepth: 2
934+
:includehidden:
935+
:hidden:
936+
:caption: Multimodality
937+
938+
beginner/flava_finetuning_tutorial

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ pytorch-lightning
2727
torchx
2828
ax-platform
2929
nbformat>=4.2.0
30+
datasets
31+
transformers
32+
torchmultimodal-nightly
3033

3134
# PyTorch Theme
3235
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme

0 commit comments

Comments
 (0)