Skip to content

Update rpc_tutorial "Getting Started with Distributed RPC Framework" #1474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 20, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 24 additions & 45 deletions intermediate_source/rpc_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,39 +76,13 @@ usages.
self.dropout = nn.Dropout(p=0.6)
self.affine2 = nn.Linear(128, 2)

self.saved_log_probs = []
self.rewards = []

def forward(self, x):
x = self.affine1(x)
x = self.dropout(x)
x = F.relu(x)
action_scores = self.affine2(x)
return F.softmax(action_scores, dim=1)

Let's first prepare a helper to run functions remotely on the owner worker of an
``RRef``. You will find this function being used in several places this
tutorial's examples. Ideally, the `torch.distributed.rpc` package should provide
these helper functions out of box. For example, it will be easier if
applications can directly call ``RRef.some_func(*arg)`` which will then
translate to RPC to the ``RRef`` owner. The progress on this API is tracked in
`pytorch/pytorch#31743 <https://github.com/pytorch/pytorch/issues/31743>`__.

.. code:: python

from torch.distributed.rpc import rpc_sync

def _call_method(method, rref, *args, **kwargs):
return method(rref.local_value(), *args, **kwargs)


def _remote_method(method, rref, *args, **kwargs):
args = [method, rref] + list(args)
return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)

# to call a function on an rref, we could do the following
# _remote_method(some_func, rref, *args)


We are ready to present the observer. In this example, each observer creates its
own environment, and waits for the agent's command to run an episode. In each
Expand All @@ -134,10 +108,14 @@ simple and the two steps explicit in this example.
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument('--world_size', default=2, help='Number of workers')
parser.add_argument('--log_interval', default=1, help='Log every log_interval episodes')
parser.add_argument('--gamma', default=0.1, help='how much to value future rewards')
parser.add_argument('--seed', default=1, help='random seed for reproducibility')
parser.add_argument('--world_size', default=2, type=int, metavar='W',
help='number of workers')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
args = parser.parse_args()

class Observer:
Expand All @@ -147,18 +125,19 @@ simple and the two steps explicit in this example.
self.env = gym.make('CartPole-v1')
self.env.seed(args.seed)

def run_episode(self, agent_rref, n_steps):
def run_episode(self, agent_rref):
state, ep_reward = self.env.reset(), 0
for step in range(n_steps):
for _ in range(10000):
# send the state to the agent to get an action
action = _remote_method(Agent.select_action, agent_rref, self.id, state)
action = agent_rref.rpc_sync().select_action(self.id, state)

# apply the action to the environment, and get the reward
state, reward, done, _ = self.env.step(action)

# report the reward to the agent for training purpose
_remote_method(Agent.report_reward, agent_rref, self.id, reward)
agent_rref.rpc_sync().report_reward(self.id, reward)

# finishes after the number of self.env._max_episode_steps
if done:
break

Expand Down Expand Up @@ -242,15 +221,15 @@ contain the recorded action probs and rewards.

class Agent:
...
def run_episode(self, n_steps=0):
def run_episode(self):
futs = []
for ob_rref in self.ob_rrefs:
# make async RPC to kick off an episode on all observers
futs.append(
rpc_async(
ob_rref.owner(),
_call_method,
args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps)
ob_rref.rpc_sync().run_episode,
args=(self.agent_rref,)
)
)

Expand Down Expand Up @@ -324,8 +303,7 @@ available in the `API page <https://pytorch.org/docs/master/rpc.html>`__.
import torch.multiprocessing as mp

AGENT_NAME = "agent"
OBSERVER_NAME="obs"
TOTAL_EPISODE_STEP = 100
OBSERVER_NAME="obs{}"

def run_worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
Expand All @@ -335,17 +313,17 @@ available in the `API page <https://pytorch.org/docs/master/rpc.html>`__.
rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size)

agent = Agent(world_size)
print(f"This will run until reward threshold of {agent.reward_threshold}"
" is reached. Ctrl+C to exit.")
for i_episode in count(1):
n_steps = int(TOTAL_EPISODE_STEP / (args.world_size - 1))
agent.run_episode(n_steps=n_steps)
Comment on lines -339 to -340

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel it is useful to understand how n_steps are divided/parallelized when I read this tutorial

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, it is also mentioned in the descriptions so I will add it back in. The reason I changed it was because the default value was too small so the trainer wasn't making any progress; I will just increase the default val.

agent.run_episode()
last_reward = agent.finish_episode()

if i_episode % args.log_interval == 0:
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
i_episode, last_reward, agent.running_reward))

print(f"Episode {i_episode}\tLast reward: {last_reward:.2f}\tAverage reward: "
f"{agent.running_reward:.2f}")
if agent.running_reward > agent.reward_threshold:
print("Solved! Running reward is now {}!".format(agent.running_reward))
print(f"Solved! Running reward is now {agent.running_reward}!")
break
else:
# other ranks are the observer
Expand All @@ -367,6 +345,7 @@ Below are some sample outputs when training with `world_size=2`.

::

This will run until reward threshold of 475.0 is reached. Ctrl+C to exit.
Episode 10 Last reward: 26.00 Average reward: 10.01
Episode 20 Last reward: 16.00 Average reward: 11.27
Episode 30 Last reward: 49.00 Average reward: 18.62
Expand Down