Skip to content

[WIP] Add torchmultimodal tutorial for flava finetuning #2055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ download:
wget -nv -N https://download.pytorch.org/models/resnet18-5c106cde.pth -P $(DATADIR)
cp $(DATADIR)/resnet18-5c106cde.pth prototype_source/data/resnet18_pretrained_float.pth

# Download vocab for beginner_source/flava_finetuning_tutorial.py
wget -nv -N http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz -P $(DATADIR)
tar $(TAROPTS) -xzf $(DATADIR)/vocab.tar.gz -C ./beginner_source/data/


docs:
make download
Expand Down
172 changes: 172 additions & 0 deletions beginner_source/flava_finetuning_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
######################################################################
# TorchMultimodal Tutorial: FLAVA finetuning
# --------------------------------------------
#

######################################################################
# Multimodal AI has recently become very popular owing to its ubiquitous
# nature, from use cases like image captioning and visual search to more
# recent applications like image generation from text. **TorchMultimodal
# is a library powered by Pytorch consisting of building blocks and end to
# end examples, aiming to enable and accelerate research in
# multimodality**.
#
# In this tutorial, we will demonstrate how to use a **pretrained SoTA
# model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from
# TorchMultimodal library to finetune on a multimodal task i.e. visual
# question answering** (VQA).
#


######################################################################
# Installations
#
# We will use TextVQA dataset from HuggingFace for this
# tutorial. So we install datasets in addition to TorchMultimodal
#

# TODO: replace with install from pip when binary is ready
!git clone https://github.com/facebookresearch/multimodal.git
!pip install -r multimodal/requirements.txt
import os
import sys
sys.path.append(os.path.join(os.getcwd(),"multimodal"))
sys.path.append(os.getcwd())
!pip install datasets
!pip install transformers


######################################################################
# For this tutorial, we treat VQA as a classification task. So we need to
# download the vocab file with answer classes and create the answer to
# label mapping.
#
# We also load the `textvqa
# dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ from HuggingFace
#

!wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz
!tar xf vocab.tar.gz


with open("vocabs/answers_textvqa_more_than_1.txt") as f:
vocab = f.readlines()

answer_to_idx = {}
for idx, entry in enumerate(vocab):
answer_to_idx[entry.strip("\n")] = idx


######################################################################
# We see there are 3997 answer classes including a class representing
# unknown answers
#

print(len(vocab))
print(vocab[:5])

from datasets import load_dataset
dataset = load_dataset("textvqa")

from IPython.display import display, Image
idx = 5
print("Question: ", dataset["train"][idx]["question"])
print("Answers: " ,dataset["train"][idx]["answers"])
display(dataset["train"][idx]["image"].resize((500,500)))


######################################################################
# Next we write the transform function to convert the image and text into
# Tensors consumable by our model - For images, we use the transforms from
# torchvision to convert to Tensor and resize to uniform sizes - For text,
# we tokenize (and pad) them using the BertTokenizer from HuggingFace -
# For answers (i.e. labels), we take the most frequently occuring answer
# as the label to train with
#

import torch
from torchvision import transforms
from collections import defaultdict
from transformers import BertTokenizer
from functools import partial

def transform(tokenizer, input):
batch = {}
image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])
image = image_transform(input["image"][0].convert("RGB"))
batch["image"] = [image]

tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512)
batch.update(tokenized)


ans_to_count = defaultdict(int)
for ans in input["answers"][0]:
ans_to_count[ans] += 1
max_value = max(ans_to_count, key=ans_to_count.get)
ans_idx = answer_to_idx.get(max_value,0)
batch["answers"] = torch.as_tensor([ans_idx])

return batch

tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512)
transform=partial(transform,tokenizer)
dataset.set_transform(transform)


######################################################################
# Finally, we import the flava_model_for_classification from
# torchmultimodal. It loads the pretrained flava checkpoint by default and
# includes a classification head.
#
# The model forward function passes the image through the visual encoder
# and the question through the text encoder. The image and question
# embeddings are then passed through the multimodal encoder. The final
# embedding corresponding to the CLS token is passed through a MLP head
# which finally gives the probability distribution over each possible
# answers.
#

from torchmultimodal.models.flava.model import flava_model_for_classification
model = flava_model_for_classification(num_classes=len(vocab))


######################################################################
# We put together the dataset and model in a toy training loop to
# demonstrate how to train the model for 3 iterations.
#

from torch import nn
BATCH_SIZE = 2
MAX_STEPS = 3
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE)
optimizer = torch.optim.AdamW(model.parameters())


epochs = 1
for _ in range(epochs):
for idx, batch in enumerate(train_dataloader):
optimizer.zero_grad()
out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"], required_embedding="mm")
loss = out.loss
loss.backward()
optimizer.step()
print(f"Loss at step {idx} = {loss}")
if idx > MAX_STEPS-1:
break


######################################################################
# Conclusion
#
# This tutorial introduced the basics around how to finetune on a
# multimodal task using FLAVA from TorchMultimodal. Please also check out
# other examples from the library like
# `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__
# which is a multimodal model for object detection and
# `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
# which is multitask model spanning image, video and 3d classification.
#

17 changes: 17 additions & 0 deletions index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,15 @@ What's new in PyTorch tutorials?
:link: advanced/sharding.html
:tags: TorchRec,Recommender

.. Multimodality

.. customcarditem::
:header: Introduction to TorchMultimodal
:card_description: TorchMultimodal is a library that provides models, primitives and examples for training multimodal tasks
:image: _static/img/thumbnails/torchrec.png
:link: beginner/flava_finetuning_tutorial.html
:tags: TorchMultimodal


.. End of tutorial card section

Expand Down Expand Up @@ -919,3 +928,11 @@ Additional Resources

intermediate/torchrec_tutorial
advanced/sharding

.. toctree::
:maxdepth: 2
:includehidden:
:hidden:
:caption: Multimodality

beginner/flava_finetuning_tutorial
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ pytorch-lightning
torchx
ax-platform
nbformat>=4.2.0
datasets
transformers
torchmultimodal-nightly

# PyTorch Theme
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
Expand Down