Skip to content

Example for combining DDP + RPC #800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions distributed/rpc/ddp_rpc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Distributed DataParallel + Distributed RPC Framework Example

The example shows how to combine Distributed DataParallel with the Distributed
RPC Framework. There are two trainer nodes, 1 master node and 1 parameter
server in the example.

The master node creates an embedding table on the parameter server and drives
the training loop on the trainers. The model consists of a dense part
(nn.Linear) replicated on the trainers via Distributed DataParallel and a
sparse part (nn.EmbeddingBag) which resides on the parameter server. Each
trainer performs an embedding lookup on the parameter server (using the
Distributed RPC Framework) and then executes its local nn.Linear module.
During the backward pass, the gradients for the dense part are aggregated via
allreduce by DDP and the distributed backward pass updates the parameters for
the embedding table on the parameter server.


```
pip install -r requirements.txt
python main.py
```
180 changes: 180 additions & 0 deletions distributed/rpc/ddp_rpc/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from functools import wraps
import os
import random

import torch
import torch.distributed as dist
import torch.distributed.autograd as dist_autograd
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.rpc as rpc
from torch.distributed.rpc import RRef
from torch.distributed.rpc import ProcessGroupRpcBackendOptions
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim

NUM_EMBEDDINGS = 100
EMBEDDING_DIM = 16

class HybridModel(torch.nn.Module):
r"""
The model consists of a sparse part and a dense part. The dense part is an
nn.Linear module that is replicated across all trainers using
DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is
stored on the parameter server.

The model holds a Remote Reference to the embedding table on the parameter
server.
"""

def __init__(self, emb_rref, device):
super(HybridModel, self).__init__()
self.emb_rref = emb_rref
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
self.device = device

def forward(self, indices, offsets):
emb_lookup = self.emb_rref.rpc_sync().forward(indices, offsets)
return self.fc(emb_lookup.cuda(self.device))

def _retrieve_embedding_parameters(emb_rref):
return [RRef(p) for p in emb_rref.local_value().parameters()]


def _run_trainer(emb_rref, rank):
r"""
Each trainer runs a forward pass which involves an embedding lookup on the
parameter server and running nn.Linear locally. During the backward pass,
DDP is responsible for aggregating the gradients for the dense part
(nn.Linear) and distributed autograd ensures gradients updates are
propagated to the parameter server.
"""

# Setup the model.
model = HybridModel(emb_rref, rank)

# Retrieve all model parameters as rrefs for DistributedOptimizer.

# Retrieve parameters for embedding table.
model_parameter_rrefs = rpc.rpc_sync(
"ps", _retrieve_embedding_parameters, args=(emb_rref,))

# model.parameters() only includes local parameters.
for param in model.parameters():
model_parameter_rrefs.append(RRef(param))

# Setup distributed optimizer
opt = DistributedOptimizer(
optim.SGD,
model_parameter_rrefs,
lr=0.05,
)

criterion = torch.nn.CrossEntropyLoss()

def get_next_batch(rank):
for _ in range(10):
num_indices = random.randint(20, 50)
indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)

# Generate offsets.
offsets = []
start = 0
batch_size = 0
while start < num_indices:
offsets.append(start)
start += random.randint(1, 10)
batch_size += 1

offsets_tensor = torch.LongTensor(offsets)
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
yield indices, offsets_tensor, target

# Train for 100 epochs
for epoch in range(100):
# create distributed autograd context
for indices, offsets, target in get_next_batch(rank):
with dist_autograd.context() as context_id:
output = model(indices, offsets)
loss = criterion(output, target)

# Run distributed backward pass
dist_autograd.backward(context_id, [loss])

# Tun distributed optimizer
opt.step(context_id)

# Not necessary to zero grads as each iteration creates a different
# distributed autograd context which hosts different grads
print("Training done for epoch {}".format(epoch))


def run_worker(rank, world_size):
r"""
A wrapper function that initializes RPC, calls the function, and shuts down
RPC.
"""

# We need to use different port numbers in TCP init_method for init_rpc and
# init_process_group to avoid port conflicts.
rpc_backend_options = ProcessGroupRpcBackendOptions()
rpc_backend_options.init_method='tcp://localhost:29501'

# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
if rank == 2:
rpc.init_rpc(
"master",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options)

# Build the embedding table on the ps.
emb_rref = rpc.remote(
"ps",
torch.nn.EmbeddingBag,
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
kwargs={"mode": "sum"})

# Run the training loop on trainers.
futs = []
for trainer_rank in [0, 1]:
trainer_name = "trainer{}".format(trainer_rank)
fut = rpc.rpc_async(
trainer_name, _run_trainer, args=(emb_rref, rank))
futs.append(fut)

# Wait for all training to finish.
for fut in futs:
fut.wait()
elif rank <= 1:
# Initialize process group for Distributed DataParallel on trainers.
dist.init_process_group(
backend="gloo", rank=rank, world_size=2,
init_method='tcp://localhost:29500')

# Initialize RPC.
trainer_name = "trainer{}".format(rank)
rpc.init_rpc(
trainer_name,
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options)

# Trainer just waits for RPCs from master.
else:
rpc.init_rpc(
"ps",
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options)
# parameter server do nothing
pass

# block until all rpcs finish
rpc.shutdown()


if __name__=="__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add this to the run_python_examples.sh script?

Copy link
Contributor Author

@pritamdamania87 pritamdamania87 Jul 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like distributed is commented out in that script: https://github.com/pytorch/examples/blob/master/run_python_examples.sh#L178? I don't see any other distributed/rpc examples in that script either. I'm wondering if there was a reason to disable them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... was just commented out a week ago, could have been by mistake:

#794

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, actually they were added as commented out. I think we can just add a distributed function and uncomment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at #794 it seems like the goal was to have the entire script run within 5 minutes. If we add distributed to it, I don't think we can satisfy the goal without updating other examples in distributed. I'd prefer to make this change in a separate PR.

# 2 trainers, 1 parameter server, 1 master.
world_size = 4
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
1 change: 1 addition & 0 deletions distributed/rpc/ddp_rpc/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch>=1.6.0