|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Model ensembling |
| 4 | +================ |
| 5 | +
|
| 6 | +This tutorial illustrates how to vectorize model ensembling using ``torch.vmap``. |
| 7 | +
|
| 8 | +What is model ensembling? |
| 9 | +------------------------- |
| 10 | +Model ensembling combines the predictions from multiple models together. |
| 11 | +Traditionally this is done by running each model on some inputs separately |
| 12 | +and then combining the predictions. However, if you're running models with |
| 13 | +the same architecture, then it may be possible to combine them together |
| 14 | +using ``torch.vmap``. ``vmap`` is a function transform that maps functions across |
| 15 | +dimensions of the input tensors. One of its use cases is eliminating |
| 16 | +for-loops and speeding them up through vectorization. |
| 17 | +
|
| 18 | +Let's demonstrate how to do this using an ensemble of simple MLPs. |
| 19 | +""" |
| 20 | + |
| 21 | +import torch |
| 22 | +import torch.nn as nn |
| 23 | +import torch.nn.functional as F |
| 24 | +torch.manual_seed(0) |
| 25 | + |
| 26 | +# Here's a simple MLP |
| 27 | +class SimpleMLP(nn.Module): |
| 28 | + def __init__(self): |
| 29 | + super(SimpleMLP, self).__init__() |
| 30 | + self.fc1 = nn.Linear(784, 128) |
| 31 | + self.fc2 = nn.Linear(128, 128) |
| 32 | + self.fc3 = nn.Linear(128, 10) |
| 33 | + |
| 34 | + def forward(self, x): |
| 35 | + x = x.flatten(1) |
| 36 | + x = self.fc1(x) |
| 37 | + x = F.relu(x) |
| 38 | + x = self.fc2(x) |
| 39 | + x = F.relu(x) |
| 40 | + x = self.fc3(x) |
| 41 | + return x |
| 42 | + |
| 43 | +###################################################################### |
| 44 | +# Let’s generate a batch of dummy data and pretend that we’re working with |
| 45 | +# an MNIST dataset. Thus, the dummy images are 28 by 28, and we have a |
| 46 | +# minibatch of size 64. Furthermore, lets say we want to combine the predictions |
| 47 | +# from 10 different models. |
| 48 | + |
| 49 | +device = 'cuda' |
| 50 | +num_models = 10 |
| 51 | + |
| 52 | +data = torch.randn(100, 64, 1, 28, 28, device=device) |
| 53 | +targets = torch.randint(10, (6400,), device=device) |
| 54 | + |
| 55 | +models = [SimpleMLP().to(device) for _ in range(num_models)] |
| 56 | + |
| 57 | +###################################################################### |
| 58 | +# We have a couple of options for generating predictions. Maybe we want to |
| 59 | +# give each model a different randomized minibatch of data. Alternatively, |
| 60 | +# maybe we want to run the same minibatch of data through each model (e.g. |
| 61 | +# if we were testing the effect of different model initializations). |
| 62 | + |
| 63 | +###################################################################### |
| 64 | +# Option 1: different minibatch for each model |
| 65 | + |
| 66 | +minibatches = data[:num_models] |
| 67 | +predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)] |
| 68 | + |
| 69 | +###################################################################### |
| 70 | +# Option 2: Same minibatch |
| 71 | + |
| 72 | +minibatch = data[0] |
| 73 | +predictions2 = [model(minibatch) for model in models] |
| 74 | + |
| 75 | +###################################################################### |
| 76 | +# Using vmap to vectorize the ensemble |
| 77 | +# ------------------------------------ |
| 78 | +# |
| 79 | +# Let's use vmap to speed up the for-loop. We must first prepare the models |
| 80 | +# for use with vmap. |
| 81 | +# |
| 82 | +# First, let’s combine the states of the model together by stacking each |
| 83 | +# parameter. For example, ``model[i].fc1.weight`` has shape ``[784, 128]``; we are |
| 84 | +# going to stack the .fc1.weight of each of the 10 models to produce a big |
| 85 | +# weight of shape ``[10, 784, 128]``. |
| 86 | +# |
| 87 | +# PyTorch offers the ``torch.func.stack_module_state`` convenience function to do |
| 88 | +# this. |
| 89 | +from torch.func import stack_module_state |
| 90 | + |
| 91 | +params, buffers = stack_module_state(models) |
| 92 | + |
| 93 | +###################################################################### |
| 94 | +# Next, we need to define a function to vmap over. The function should, |
| 95 | +# given parameters and buffers and inputs, run the model using those |
| 96 | +# parameters, buffers, and inputs. We'll use ``torch.func.functional_call`` |
| 97 | +# to help out: |
| 98 | + |
| 99 | +from torch.func import functional_call |
| 100 | +import copy |
| 101 | + |
| 102 | +# Construct a "stateless" version of one of the models. It is "stateless" in |
| 103 | +# the sense that the parameters are meta Tensors and do not have storage. |
| 104 | +base_model = copy.deepcopy(models[0]) |
| 105 | +base_model = base_model.to('meta') |
| 106 | + |
| 107 | +def fmodel(params, buffers, x): |
| 108 | + return functional_call(base_model, (params, buffers), (x,)) |
| 109 | + |
| 110 | +###################################################################### |
| 111 | +# Option 1: get predictions using a different minibatch for each model. |
| 112 | +# |
| 113 | +# By default, vmap maps a function across the first dimension of all inputs to |
| 114 | +# the passed-in function. After using ``stack_module_state``, each of |
| 115 | +# the params and buffers have an additional dimension of size 'num_models' at |
| 116 | +# the front, and minibatches has a dimension of size 'num_models'. |
| 117 | + |
| 118 | +print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension |
| 119 | + |
| 120 | +assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models' |
| 121 | + |
| 122 | +from torch import vmap |
| 123 | + |
| 124 | +predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) |
| 125 | + |
| 126 | +# verify the vmap predictions match the |
| 127 | +assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5) |
| 128 | + |
| 129 | +###################################################################### |
| 130 | +# Option 2: get predictions using the same minibatch of data. |
| 131 | +# |
| 132 | +# vmap has an in_dims arg that specifies which dimensions to map over. |
| 133 | +# By using ``None``, we tell vmap we want the same minibatch to apply for all of |
| 134 | +# the 10 models. |
| 135 | + |
| 136 | +predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) |
| 137 | + |
| 138 | +assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5) |
| 139 | + |
| 140 | +###################################################################### |
| 141 | +# A quick note: there are limitations around what types of functions can be |
| 142 | +# transformed by vmap. The best functions to transform are ones that are pure |
| 143 | +# functions: a function where the outputs are only determined by the inputs |
| 144 | +# that have no side effects (e.g. mutation). vmap is unable to handle mutation |
| 145 | +# of arbitrary Python data structures, but it is able to handle many in-place |
| 146 | +# PyTorch operations. |
| 147 | + |
| 148 | +###################################################################### |
| 149 | +# Performance |
| 150 | +# ----------- |
| 151 | +# Curious about performance numbers? Here's how the numbers look. |
| 152 | + |
| 153 | +from torch.utils.benchmark import Timer |
| 154 | +without_vmap = Timer( |
| 155 | + stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]", |
| 156 | + globals=globals()) |
| 157 | +with_vmap = Timer( |
| 158 | + stmt="vmap(fmodel)(params, buffers, minibatches)", |
| 159 | + globals=globals()) |
| 160 | +print(f'Predictions without vmap {without_vmap.timeit(100)}') |
| 161 | +print(f'Predictions with vmap {with_vmap.timeit(100)}') |
| 162 | + |
| 163 | +###################################################################### |
| 164 | +# There's a large speedup using vmap! |
| 165 | +# |
| 166 | +# In general, vectorization with vmap should be faster than running a function |
| 167 | +# in a for-loop and competitive with manual batching. There are some exceptions |
| 168 | +# though, like if we haven’t implemented the vmap rule for a particular |
| 169 | +# operation or if the underlying kernels weren’t optimized for older hardware |
| 170 | +# (GPUs). If you see any of these cases, please let us know by opening an issue |
| 171 | +# on GitHub. |
0 commit comments