Skip to content

Commit 5185031

Browse files
ankitademalfetSvetlana Karslioglu
authored
Add TorchMultimodal tutorial for FLAVA finetuning (#2054)
* Add a TorchMultimodal tutorial for FLAVA finetuning Co-authored-by: Nikita Shulga <nshulga@fb.com> Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent adda5fe commit 5185031

File tree

4 files changed

+216
-0
lines changed

4 files changed

+216
-0
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ download:
102102
wget -nv -N https://download.pytorch.org/models/resnet18-5c106cde.pth -P $(DATADIR)
103103
cp $(DATADIR)/resnet18-5c106cde.pth prototype_source/data/resnet18_pretrained_float.pth
104104

105+
# Download vocab for beginner_source/flava_finetuning_tutorial.py
106+
wget -nv -N http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz -P $(DATADIR)
107+
tar $(TAROPTS) -xzf $(DATADIR)/vocab.tar.gz -C ./beginner_source/data/
108+
105109

106110
docs:
107111
make download
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
TorchMultimodal Tutorial: Finetuning FLAVA
4+
============================================
5+
"""
6+
7+
######################################################################
8+
# Multimodal AI has recently become very popular owing to its ubiquitous
9+
# nature, from use cases like image captioning and visual search to more
10+
# recent applications like image generation from text. **TorchMultimodal
11+
# is a library powered by Pytorch consisting of building blocks and end to
12+
# end examples, aiming to enable and accelerate research in
13+
# multimodality**.
14+
#
15+
# In this tutorial, we will demonstrate how to use a **pretrained SoTA
16+
# model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from
17+
# TorchMultimodal library to finetune on a multimodal task i.e. visual
18+
# question answering** (VQA). The model consists of two unimodal transformer
19+
# based encoders for text and image and a multimodal encoder to combine
20+
# the two embeddings. It is pretrained using contrastive, image text matching and
21+
# text, image and multimodal masking losses.
22+
23+
24+
######################################################################
25+
# Installation
26+
# -----------------
27+
# We will use TextVQA dataset and bert tokenizer from HuggingFace for this
28+
# tutorial. So you need to install datasets and transformers in addition to TorchMultimodal.
29+
#
30+
# .. note::
31+
#
32+
# When running this tutorial in Google Colab, install the required packages by
33+
# creating a new cell and running the following commands:
34+
#
35+
# .. code-block::
36+
#
37+
# !pip install torchmultimodal-nightly
38+
# !pip install datasets
39+
# !pip install transformers
40+
#
41+
42+
######################################################################
43+
# Steps
44+
# -----
45+
#
46+
# 1. Download the HuggingFace dataset to a directory on your computer by running the following command:
47+
#
48+
# .. code-block::
49+
#
50+
# wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz
51+
# tar xf vocab.tar.gz
52+
#
53+
# .. note::
54+
# If you are running this tutorial in Google Colab, run these commands
55+
# in a new cell and prepend these commands with an exclamation mark (!)
56+
#
57+
#
58+
# 2. For this tutorial, we treat VQA as a classification task where
59+
# the inputs are images and question (text) and the output is an answer class.
60+
# So we need to download the vocab file with answer classes and create the answer to
61+
# label mapping.
62+
#
63+
# We also load the `textvqa
64+
# dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ containing 34602 training samples
65+
# (images,questions and answers) from HuggingFace
66+
#
67+
# We see there are 3997 answer classes including a class representing
68+
# unknown answers.
69+
#
70+
71+
with open("data/vocabs/answers_textvqa_more_than_1.txt") as f:
72+
vocab = f.readlines()
73+
74+
answer_to_idx = {}
75+
for idx, entry in enumerate(vocab):
76+
answer_to_idx[entry.strip("\n")] = idx
77+
print(len(vocab))
78+
print(vocab[:5])
79+
80+
from datasets import load_dataset
81+
dataset = load_dataset("textvqa")
82+
83+
######################################################################
84+
# Lets display a sample entry from the dataset:
85+
#
86+
87+
import matplotlib.pyplot as plt
88+
import numpy as np
89+
idx = 5
90+
print("Question: ", dataset["train"][idx]["question"])
91+
print("Answers: " ,dataset["train"][idx]["answers"])
92+
im = np.asarray(dataset["train"][idx]["image"].resize((500,500)))
93+
plt.imshow(im)
94+
plt.show()
95+
96+
97+
######################################################################
98+
# 3. Next, we write the transform function to convert the image and text into
99+
# Tensors consumable by our model - For images, we use the transforms from
100+
# torchvision to convert to Tensor and resize to uniform sizes - For text,
101+
# we tokenize (and pad) them using the BertTokenizer from HuggingFace -
102+
# For answers (i.e. labels), we take the most frequently occuring answer
103+
# as the label to train with:
104+
#
105+
106+
import torch
107+
from torchvision import transforms
108+
from collections import defaultdict
109+
from transformers import BertTokenizer
110+
from functools import partial
111+
112+
def transform(tokenizer, input):
113+
batch = {}
114+
image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])
115+
image = image_transform(input["image"][0].convert("RGB"))
116+
batch["image"] = [image]
117+
118+
tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512)
119+
batch.update(tokenized)
120+
121+
122+
ans_to_count = defaultdict(int)
123+
for ans in input["answers"][0]:
124+
ans_to_count[ans] += 1
125+
max_value = max(ans_to_count, key=ans_to_count.get)
126+
ans_idx = answer_to_idx.get(max_value,0)
127+
batch["answers"] = torch.as_tensor([ans_idx])
128+
return batch
129+
130+
tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512)
131+
transform=partial(transform,tokenizer)
132+
dataset.set_transform(transform)
133+
134+
135+
######################################################################
136+
# 4. Finally, we import the flava_model_for_classification from
137+
# torchmultimodal. It loads the pretrained flava checkpoint by default and
138+
# includes a classification head.
139+
#
140+
# The model forward function passes the image through the visual encoder
141+
# and the question through the text encoder. The image and question
142+
# embeddings are then passed through the multimodal encoder. The final
143+
# embedding corresponding to the CLS token is passed through a MLP head
144+
# which finally gives the probability distribution over each possible
145+
# answers.
146+
#
147+
148+
from torchmultimodal.models.flava.model import flava_model_for_classification
149+
model = flava_model_for_classification(num_classes=len(vocab))
150+
151+
152+
######################################################################
153+
# 5. We put together the dataset and model in a toy training loop to
154+
# demonstrate how to train the model for 3 iterations:
155+
#
156+
157+
from torch import nn
158+
BATCH_SIZE = 2
159+
MAX_STEPS = 3
160+
from torch.utils.data import DataLoader
161+
162+
train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE)
163+
optimizer = torch.optim.AdamW(model.parameters())
164+
165+
166+
epochs = 1
167+
for _ in range(epochs):
168+
for idx, batch in enumerate(train_dataloader):
169+
optimizer.zero_grad()
170+
out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"])
171+
loss = out.loss
172+
loss.backward()
173+
optimizer.step()
174+
print(f"Loss at step {idx} = {loss}")
175+
if idx > MAX_STEPS-1:
176+
break
177+
178+
179+
######################################################################
180+
# Conclusion
181+
# -------------------
182+
#
183+
# This tutorial introduced the basics around how to finetune on a
184+
# multimodal task using FLAVA from TorchMultimodal. Please also check out
185+
# other examples from the library like
186+
# `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__
187+
# which is a multimodal model for object detection and
188+
# `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
189+
# which is multitask model spanning image, video and 3d classification.
190+
#

index.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,15 @@ What's new in PyTorch tutorials?
658658
:link: advanced/sharding.html
659659
:tags: TorchRec,Recommender
660660

661+
.. Multimodality
662+
663+
.. customcarditem::
664+
:header: Introduction to TorchMultimodal
665+
:card_description: TorchMultimodal is a library that provides models, primitives and examples for training multimodal tasks
666+
:image: _static/img/thumbnails/torchrec.png
667+
:link: beginner/flava_finetuning_tutorial.html
668+
:tags: TorchMultimodal
669+
661670

662671
.. End of tutorial card section
663672
@@ -934,3 +943,11 @@ Additional Resources
934943

935944
intermediate/torchrec_tutorial
936945
advanced/sharding
946+
947+
.. toctree::
948+
:maxdepth: 2
949+
:includehidden:
950+
:hidden:
951+
:caption: Multimodality
952+
953+
beginner/flava_finetuning_tutorial

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ pytorch-lightning
2727
torchx
2828
ax-platform
2929
nbformat>=4.2.0
30+
datasets
31+
transformers
32+
torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
3033
deep_phonemizer==0.0.17
3134

3235
# the following is necessary due to https://github.com/python/importlib_metadata/issues/411
@@ -50,4 +53,6 @@ wget
5053
gym==0.25.1
5154
gym-super-mario-bros==7.4.0
5255
timm
56+
iopath
5357
pygame==2.1.2
58+

0 commit comments

Comments
 (0)