diff --git a/Makefile b/Makefile index 8c21384967c..a01ea69bb50 100644 --- a/Makefile +++ b/Makefile @@ -102,6 +102,10 @@ download: wget -nv -N https://download.pytorch.org/models/resnet18-5c106cde.pth -P $(DATADIR) cp $(DATADIR)/resnet18-5c106cde.pth prototype_source/data/resnet18_pretrained_float.pth + # Download vocab for beginner_source/flava_finetuning_tutorial.py + wget -nv -N http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz -P $(DATADIR) + tar $(TAROPTS) -xzf $(DATADIR)/vocab.tar.gz -C ./beginner_source/data/ + docs: make download diff --git a/beginner_source/flava_finetuning_tutorial.py b/beginner_source/flava_finetuning_tutorial.py new file mode 100644 index 00000000000..92bc5031709 --- /dev/null +++ b/beginner_source/flava_finetuning_tutorial.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +""" +TorchMultimodal Tutorial: Finetuning FLAVA +============================================ +""" + +###################################################################### +# Multimodal AI has recently become very popular owing to its ubiquitous +# nature, from use cases like image captioning and visual search to more +# recent applications like image generation from text. **TorchMultimodal +# is a library powered by Pytorch consisting of building blocks and end to +# end examples, aiming to enable and accelerate research in +# multimodality**. +# +# In this tutorial, we will demonstrate how to use a **pretrained SoTA +# model called** `FLAVA `__ **from +# TorchMultimodal library to finetune on a multimodal task i.e. visual +# question answering** (VQA). The model consists of two unimodal transformer +# based encoders for text and image and a multimodal encoder to combine +# the two embeddings. It is pretrained using contrastive, image text matching and +# text, image and multimodal masking losses. + + +###################################################################### +# Installation +# ----------------- +# We will use TextVQA dataset and bert tokenizer from HuggingFace for this +# tutorial. So you need to install datasets and transformers in addition to TorchMultimodal. +# +# .. note:: +# +# When running this tutorial in Google Colab, install the required packages by +# creating a new cell and running the following commands: +# +# .. code-block:: +# +# !pip install torchmultimodal-nightly +# !pip install datasets +# !pip install transformers +# + +###################################################################### +# Steps +# ----- +# +# 1. Download the HuggingFace dataset to a directory on your computer by running the following command: +# +# .. code-block:: +# +# wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz +# tar xf vocab.tar.gz +# +# .. note:: +# If you are running this tutorial in Google Colab, run these commands +# in a new cell and prepend these commands with an exclamation mark (!) +# +# +# 2. For this tutorial, we treat VQA as a classification task where +# the inputs are images and question (text) and the output is an answer class. +# So we need to download the vocab file with answer classes and create the answer to +# label mapping. +# +# We also load the `textvqa +# dataset `__ containing 34602 training samples +# (images,questions and answers) from HuggingFace +# +# We see there are 3997 answer classes including a class representing +# unknown answers. +# + +with open("data/vocabs/answers_textvqa_more_than_1.txt") as f: + vocab = f.readlines() + +answer_to_idx = {} +for idx, entry in enumerate(vocab): + answer_to_idx[entry.strip("\n")] = idx +print(len(vocab)) +print(vocab[:5]) + +from datasets import load_dataset +dataset = load_dataset("textvqa") + +###################################################################### +# Lets display a sample entry from the dataset: +# + +import matplotlib.pyplot as plt +import numpy as np +idx = 5 +print("Question: ", dataset["train"][idx]["question"]) +print("Answers: " ,dataset["train"][idx]["answers"]) +im = np.asarray(dataset["train"][idx]["image"].resize((500,500))) +plt.imshow(im) +plt.show() + + +###################################################################### +# 3. Next, we write the transform function to convert the image and text into +# Tensors consumable by our model - For images, we use the transforms from +# torchvision to convert to Tensor and resize to uniform sizes - For text, +# we tokenize (and pad) them using the BertTokenizer from HuggingFace - +# For answers (i.e. labels), we take the most frequently occuring answer +# as the label to train with: +# + +import torch +from torchvision import transforms +from collections import defaultdict +from transformers import BertTokenizer +from functools import partial + +def transform(tokenizer, input): + batch = {} + image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])]) + image = image_transform(input["image"][0].convert("RGB")) + batch["image"] = [image] + + tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512) + batch.update(tokenized) + + + ans_to_count = defaultdict(int) + for ans in input["answers"][0]: + ans_to_count[ans] += 1 + max_value = max(ans_to_count, key=ans_to_count.get) + ans_idx = answer_to_idx.get(max_value,0) + batch["answers"] = torch.as_tensor([ans_idx]) + return batch + +tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512) +transform=partial(transform,tokenizer) +dataset.set_transform(transform) + + +###################################################################### +# 4. Finally, we import the flava_model_for_classification from +# torchmultimodal. It loads the pretrained flava checkpoint by default and +# includes a classification head. +# +# The model forward function passes the image through the visual encoder +# and the question through the text encoder. The image and question +# embeddings are then passed through the multimodal encoder. The final +# embedding corresponding to the CLS token is passed through a MLP head +# which finally gives the probability distribution over each possible +# answers. +# + +from torchmultimodal.models.flava.model import flava_model_for_classification +model = flava_model_for_classification(num_classes=len(vocab)) + + +###################################################################### +# 5. We put together the dataset and model in a toy training loop to +# demonstrate how to train the model for 3 iterations: +# + +from torch import nn +BATCH_SIZE = 2 +MAX_STEPS = 3 +from torch.utils.data import DataLoader + +train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE) +optimizer = torch.optim.AdamW(model.parameters()) + + +epochs = 1 +for _ in range(epochs): + for idx, batch in enumerate(train_dataloader): + optimizer.zero_grad() + out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"]) + loss = out.loss + loss.backward() + optimizer.step() + print(f"Loss at step {idx} = {loss}") + if idx > MAX_STEPS-1: + break + + +###################################################################### +# Conclusion +# ------------------- +# +# This tutorial introduced the basics around how to finetune on a +# multimodal task using FLAVA from TorchMultimodal. Please also check out +# other examples from the library like +# `MDETR `__ +# which is a multimodal model for object detection and +# `Omnivore `__ +# which is multitask model spanning image, video and 3d classification. +# diff --git a/index.rst b/index.rst index f57cdb27c34..1123ffbccf9 100644 --- a/index.rst +++ b/index.rst @@ -658,6 +658,15 @@ What's new in PyTorch tutorials? :link: advanced/sharding.html :tags: TorchRec,Recommender +.. Multimodality + +.. customcarditem:: + :header: Introduction to TorchMultimodal + :card_description: TorchMultimodal is a library that provides models, primitives and examples for training multimodal tasks + :image: _static/img/thumbnails/torchrec.png + :link: beginner/flava_finetuning_tutorial.html + :tags: TorchMultimodal + .. End of tutorial card section @@ -934,3 +943,11 @@ Additional Resources intermediate/torchrec_tutorial advanced/sharding + +.. toctree:: + :maxdepth: 2 + :includehidden: + :hidden: + :caption: Multimodality + + beginner/flava_finetuning_tutorial diff --git a/requirements.txt b/requirements.txt index 3ca757e2bc2..28a37d88ac2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,6 +27,9 @@ pytorch-lightning torchx ax-platform nbformat>=4.2.0 +datasets +transformers +torchmultimodal-nightly # needs to be updated to stable as soon as it's avaialable deep_phonemizer==0.0.17 # the following is necessary due to https://github.com/python/importlib_metadata/issues/411 @@ -50,4 +53,6 @@ wget gym==0.25.1 gym-super-mario-bros==7.4.0 timm +iopath pygame==2.1.2 +