Skip to content

Add torchmultimodal tutorial for flava finetuning #2054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
242f420
[WIP] Add torchmultimodal tutorial for flava finetuning
Sep 24, 2022
6fbb8ce
Merge branch 'master' into tmm
Sep 26, 2022
d991b9b
[WIP] Add torchmultimodal tutorial for flava finetuning
Sep 26, 2022
3188b66
Update
Sep 26, 2022
0eca85b
Merge branch 'tmm' of https://github.com/pytorch/tutorials into tmm
Sep 27, 2022
7002ed5
Fix imports
Sep 27, 2022
d12df19
Merge branch 'master' into tmm
ankitade Sep 28, 2022
c33c3aa
Address comments
Oct 3, 2022
31d1ca4
Fix syntaxerror
Oct 3, 2022
6b6563a
Fix syntax
Oct 3, 2022
0fe598c
Fix formatting
Oct 3, 2022
720d370
[DO NOT MERGE] 1.13 RC Test
Oct 10, 2022
e67331d
Update .jenkins/build.sh
Oct 11, 2022
38939c4
Update build.sh
Oct 13, 2022
59048da
Merge branch 'master' into 1.13-RC-TEST
Oct 14, 2022
c0d5fed
Update build.sh
Oct 14, 2022
dbab110
Merge branch 'master' into 1.13-RC-TEST
Oct 17, 2022
f76b30d
Merge branch 'master' into 1.13-RC-TEST
Oct 17, 2022
f509d8e
Update build.sh
Oct 17, 2022
d6e72e0
Update build.sh
Oct 17, 2022
5fbf500
Update build.sh
Oct 17, 2022
3c7694f
Update build.sh
Oct 17, 2022
3559c44
Remove functorch
Oct 17, 2022
71e2e2c
Merge branch 'master' into 1.13-RC-TEST
Oct 18, 2022
f665ee3
Merge branch 'master' into 1.13-RC-TEST
Oct 18, 2022
06b9874
Temporarily disabling fx_numeric_suite_tutorial
Oct 19, 2022
3c0fc31
Update build.sh
Oct 19, 2022
a449a55
Disable in the validate list
Oct 20, 2022
047a956
Disable ax tutorial
Oct 20, 2022
3e6a2ce
Merge branch '1.13-RC-TEST' into tmm
Oct 20, 2022
0a10b61
Merge branch 'master' into tmm
Oct 26, 2022
cff152e
rebase
Oct 26, 2022
1c11444
Small fix
Oct 26, 2022
3549f56
Merge branch 'master' into tmm
Oct 26, 2022
158e289
Merge branch 'master' into tmm
Oct 26, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ download:
wget -nv -N https://download.pytorch.org/models/resnet18-5c106cde.pth -P $(DATADIR)
cp $(DATADIR)/resnet18-5c106cde.pth prototype_source/data/resnet18_pretrained_float.pth

# Download vocab for beginner_source/flava_finetuning_tutorial.py
wget -nv -N http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz -P $(DATADIR)
tar $(TAROPTS) -xzf $(DATADIR)/vocab.tar.gz -C ./beginner_source/data/


docs:
make download
Expand Down
190 changes: 190 additions & 0 deletions beginner_source/flava_finetuning_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
"""
TorchMultimodal Tutorial: Finetuning FLAVA
============================================
"""

######################################################################
# Multimodal AI has recently become very popular owing to its ubiquitous
# nature, from use cases like image captioning and visual search to more
# recent applications like image generation from text. **TorchMultimodal
# is a library powered by Pytorch consisting of building blocks and end to
# end examples, aiming to enable and accelerate research in
# multimodality**.
#
# In this tutorial, we will demonstrate how to use a **pretrained SoTA

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we just say state-of-the-art here?

# model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from
# TorchMultimodal library to finetune on a multimodal task i.e. visual
# question answering** (VQA). The model consists of two unimodal transformer
# based encoders for text and image and a multimodal encoder to combine
# the two embeddings. It is pretrained using contrastive, image text matching and

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the losses be enumerated in a different way here? I feel the comma placement makes this kinda confusing

# text, image and multimodal masking losses.


######################################################################
# Installation
# -----------------
# We will use TextVQA dataset and bert tokenizer from HuggingFace for this
# tutorial. So you need to install datasets and transformers in addition to TorchMultimodal.
#
# .. note::
#
# When running this tutorial in Google Colab, install the required packages by
# creating a new cell and running the following commands:
#
# .. code-block::
#
# !pip install torchmultimodal-nightly
# !pip install datasets
# !pip install transformers
#

######################################################################
# Steps
# -----
#
# 1. Download the HuggingFace dataset to a directory on your computer by running the following command:
#
# .. code-block::
#
# wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz
# tar xf vocab.tar.gz
#
# .. note::
# If you are running this tutorial in Google Colab, run these commands
# in a new cell and prepend these commands with an exclamation mark (!)
#
#
# 2. For this tutorial, we treat VQA as a classification task where
# the inputs are images and question (text) and the output is an answer class.
# So we need to download the vocab file with answer classes and create the answer to
# label mapping.
#
# We also load the `textvqa
# dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ containing 34602 training samples
# (images,questions and answers) from HuggingFace
#
# We see there are 3997 answer classes including a class representing
# unknown answers.
#

with open("data/vocabs/answers_textvqa_more_than_1.txt") as f:
vocab = f.readlines()

answer_to_idx = {}
for idx, entry in enumerate(vocab):
answer_to_idx[entry.strip("\n")] = idx
print(len(vocab))
print(vocab[:5])

from datasets import load_dataset
dataset = load_dataset("textvqa")

######################################################################
# Lets display a sample entry from the dataset:
#

import matplotlib.pyplot as plt
import numpy as np
idx = 5
print("Question: ", dataset["train"][idx]["question"])
print("Answers: " ,dataset["train"][idx]["answers"])
im = np.asarray(dataset["train"][idx]["image"].resize((500,500)))
plt.imshow(im)
plt.show()


######################################################################
# 3. Next, we write the transform function to convert the image and text into
# Tensors consumable by our model - For images, we use the transforms from
# torchvision to convert to Tensor and resize to uniform sizes - For text,
# we tokenize (and pad) them using the BertTokenizer from HuggingFace -
# For answers (i.e. labels), we take the most frequently occuring answer
# as the label to train with:
#

import torch
from torchvision import transforms
from collections import defaultdict
from transformers import BertTokenizer
from functools import partial

def transform(tokenizer, input):
batch = {}
image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])
image = image_transform(input["image"][0].convert("RGB"))
batch["image"] = [image]

tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512)
batch.update(tokenized)


ans_to_count = defaultdict(int)
for ans in input["answers"][0]:
ans_to_count[ans] += 1
max_value = max(ans_to_count, key=ans_to_count.get)
ans_idx = answer_to_idx.get(max_value,0)
batch["answers"] = torch.as_tensor([ans_idx])
return batch

tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512)
transform=partial(transform,tokenizer)
dataset.set_transform(transform)


######################################################################
# 4. Finally, we import the flava_model_for_classification from
# torchmultimodal. It loads the pretrained flava checkpoint by default and
# includes a classification head.
#
# The model forward function passes the image through the visual encoder
# and the question through the text encoder. The image and question
# embeddings are then passed through the multimodal encoder. The final
# embedding corresponding to the CLS token is passed through a MLP head
# which finally gives the probability distribution over each possible
# answers.
#

from torchmultimodal.models.flava.model import flava_model_for_classification
model = flava_model_for_classification(num_classes=len(vocab))


######################################################################
# 5. We put together the dataset and model in a toy training loop to
# demonstrate how to train the model for 3 iterations:
#

from torch import nn
BATCH_SIZE = 2
MAX_STEPS = 3
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE)
optimizer = torch.optim.AdamW(model.parameters())


epochs = 1
for _ in range(epochs):
for idx, batch in enumerate(train_dataloader):
optimizer.zero_grad()
out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"])
loss = out.loss
loss.backward()
optimizer.step()
print(f"Loss at step {idx} = {loss}")
if idx > MAX_STEPS-1:
break


######################################################################
# Conclusion
# -------------------
#
# This tutorial introduced the basics around how to finetune on a
# multimodal task using FLAVA from TorchMultimodal. Please also check out
# other examples from the library like
# `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__
# which is a multimodal model for object detection and
# `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
# which is multitask model spanning image, video and 3d classification.
#

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can add it in follow up PR

17 changes: 17 additions & 0 deletions index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,15 @@ What's new in PyTorch tutorials?
:link: advanced/sharding.html
:tags: TorchRec,Recommender

.. Multimodality

.. customcarditem::
:header: Introduction to TorchMultimodal
:card_description: TorchMultimodal is a library that provides models, primitives and examples for training multimodal tasks
:image: _static/img/thumbnails/torchrec.png
:link: beginner/flava_finetuning_tutorial.html
:tags: TorchMultimodal


.. End of tutorial card section

Expand Down Expand Up @@ -934,3 +943,11 @@ Additional Resources

intermediate/torchrec_tutorial
advanced/sharding

.. toctree::
:maxdepth: 2
:includehidden:
:hidden:
:caption: Multimodality

beginner/flava_finetuning_tutorial
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ pytorch-lightning
torchx
ax-platform
nbformat>=4.2.0
datasets
transformers
torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable
deep_phonemizer==0.0.17

# the following is necessary due to https://github.com/python/importlib_metadata/issues/411
Expand All @@ -50,4 +53,6 @@ wget
gym==0.25.1
gym-super-mario-bros==7.4.0
timm
iopath
pygame==2.1.2