diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index bcad775496b..d7dbd74958f 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -76,9 +76,6 @@ usages. self.dropout = nn.Dropout(p=0.6) self.affine2 = nn.Linear(128, 2) - self.saved_log_probs = [] - self.rewards = [] - def forward(self, x): x = self.affine1(x) x = self.dropout(x) @@ -86,29 +83,6 @@ usages. action_scores = self.affine2(x) return F.softmax(action_scores, dim=1) -Let's first prepare a helper to run functions remotely on the owner worker of an -``RRef``. You will find this function being used in several places this -tutorial's examples. Ideally, the `torch.distributed.rpc` package should provide -these helper functions out of box. For example, it will be easier if -applications can directly call ``RRef.some_func(*arg)`` which will then -translate to RPC to the ``RRef`` owner. The progress on this API is tracked in -`pytorch/pytorch#31743 `__. - -.. code:: python - - from torch.distributed.rpc import rpc_sync - - def _call_method(method, rref, *args, **kwargs): - return method(rref.local_value(), *args, **kwargs) - - - def _remote_method(method, rref, *args, **kwargs): - args = [method, rref] + list(args) - return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs) - - # to call a function on an rref, we could do the following - # _remote_method(some_func, rref, *args) - We are ready to present the observer. In this example, each observer creates its own environment, and waits for the agent's command to run an episode. In each @@ -134,10 +108,14 @@ simple and the two steps explicit in this example. formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument('--world_size', default=2, help='Number of workers') - parser.add_argument('--log_interval', default=1, help='Log every log_interval episodes') - parser.add_argument('--gamma', default=0.1, help='how much to value future rewards') - parser.add_argument('--seed', default=1, help='random seed for reproducibility') + parser.add_argument('--world_size', default=2, type=int, metavar='W', + help='number of workers') + parser.add_argument('--log_interval', type=int, default=10, metavar='N', + help='interval between training status logs') + parser.add_argument('--gamma', type=float, default=0.99, metavar='G', + help='how much to value future rewards') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed for reproducibility') args = parser.parse_args() class Observer: @@ -147,18 +125,19 @@ simple and the two steps explicit in this example. self.env = gym.make('CartPole-v1') self.env.seed(args.seed) - def run_episode(self, agent_rref, n_steps): + def run_episode(self, agent_rref): state, ep_reward = self.env.reset(), 0 - for step in range(n_steps): + for _ in range(10000): # send the state to the agent to get an action - action = _remote_method(Agent.select_action, agent_rref, self.id, state) + action = agent_rref.rpc_sync().select_action(self.id, state) # apply the action to the environment, and get the reward state, reward, done, _ = self.env.step(action) # report the reward to the agent for training purpose - _remote_method(Agent.report_reward, agent_rref, self.id, reward) + agent_rref.rpc_sync().report_reward(self.id, reward) + # finishes after the number of self.env._max_episode_steps if done: break @@ -242,15 +221,15 @@ contain the recorded action probs and rewards. class Agent: ... - def run_episode(self, n_steps=0): + def run_episode(self): futs = [] for ob_rref in self.ob_rrefs: # make async RPC to kick off an episode on all observers futs.append( rpc_async( ob_rref.owner(), - _call_method, - args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps) + ob_rref.rpc_sync().run_episode, + args=(self.agent_rref,) ) ) @@ -324,8 +303,7 @@ available in the `API page `__. import torch.multiprocessing as mp AGENT_NAME = "agent" - OBSERVER_NAME="obs" - TOTAL_EPISODE_STEP = 100 + OBSERVER_NAME="obs{}" def run_worker(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' @@ -335,17 +313,17 @@ available in the `API page `__. rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size) agent = Agent(world_size) + print(f"This will run until reward threshold of {agent.reward_threshold}" + " is reached. Ctrl+C to exit.") for i_episode in count(1): - n_steps = int(TOTAL_EPISODE_STEP / (args.world_size - 1)) - agent.run_episode(n_steps=n_steps) + agent.run_episode() last_reward = agent.finish_episode() if i_episode % args.log_interval == 0: - print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format( - i_episode, last_reward, agent.running_reward)) - + print(f"Episode {i_episode}\tLast reward: {last_reward:.2f}\tAverage reward: " + f"{agent.running_reward:.2f}") if agent.running_reward > agent.reward_threshold: - print("Solved! Running reward is now {}!".format(agent.running_reward)) + print(f"Solved! Running reward is now {agent.running_reward}!") break else: # other ranks are the observer @@ -367,6 +345,7 @@ Below are some sample outputs when training with `world_size=2`. :: + This will run until reward threshold of 475.0 is reached. Ctrl+C to exit. Episode 10 Last reward: 26.00 Average reward: 10.01 Episode 20 Last reward: 16.00 Average reward: 11.27 Episode 30 Last reward: 49.00 Average reward: 18.62