Skip to content

Update torchrl==0.3.0 tutos #2759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 20 additions & 35 deletions advanced_source/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,33 @@

# sphinx_gallery_start_ignore
import warnings

warnings.filterwarnings("ignore")
import multiprocessing
from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
assert multiprocessing.get_start_method() == "fork"
pass

# sphinx_gallery_end_ignore


import torchrl
import torch
import tqdm
from typing import Tuple


###############################################################################
# We will execute the policy on CUDA if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)
collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA

###############################################################################
Expand Down Expand Up @@ -244,23 +251,18 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams):
hp.update(hyperparams)
value_key = "state_action_value"
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(
value_network=self.actor_critic, value_key=value_key, **hp
)
self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)
elif value_type == ValueEstimators.TD0:
self._value_estimator = TD0Estimator(
value_network=self.actor_critic, value_key=value_key, **hp
)
self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)
elif value_type == ValueEstimators.GAE:
raise NotImplementedError(
f"Value type {value_type} it not implemented for loss {type(self)}."
)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(
value_network=self.actor_critic, value_key=value_key, **hp
)
self._value_estimator = TDLambdaEstimator(value_network=self.actor_critic, **hp)
else:
raise NotImplementedError(f"Unknown value type {value_type}")
self._value_estimator.set_keys(value=value_key)


###############################################################################
Expand Down Expand Up @@ -311,7 +313,7 @@ def _loss_actor(
def _loss_value(
self,
tensordict,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
):
td_copy = tensordict.clone()

# V(s, a)
Expand Down Expand Up @@ -349,7 +351,7 @@ def _loss_value(
# value and actor loss, collect the cost values and write them in a ``TensorDict``
# delivered to the user.

from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict import TensorDict, TensorDictBase


def _forward(self, input_tensordict: TensorDictBase) -> TensorDict:
Expand Down Expand Up @@ -457,6 +459,7 @@ def make_env(from_pixels=False):
raise NotImplementedError

env_kwargs = {
"device": device,
"from_pixels": from_pixels,
"pixels_only": from_pixels,
"frame_skip": 2,
Expand Down Expand Up @@ -519,16 +522,6 @@ def make_transformed_env(
# syntax.
env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling))

double_to_float_list = []
double_to_float_inv_list = []
if env_library is DMControlEnv:
# ``DMControl`` requires double-precision
double_to_float_list += [
"reward",
"action",
]
double_to_float_inv_list += ["action"]

# We concatenate all states into a single "observation_vector"
# even if there is a single tensor, it'll be renamed in "observation_vector".
# This facilitates the downstream operations as we know the name of the
Expand All @@ -544,12 +537,7 @@ def make_transformed_env(
# version of the transform
env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True))

double_to_float_list.append(out_key)
env.append_transform(
DoubleToFloat(
in_keys=double_to_float_list, in_keys_inv=double_to_float_inv_list
)
)
env.append_transform(DoubleToFloat())

env.append_transform(StepCounter(max_frames_per_traj))

Expand Down Expand Up @@ -874,9 +862,6 @@ def make_ddpg_actor(
reset_at_each_iter=False,
split_trajs=False,
device=collector_device,
# device for execution
storing_device=collector_device,
# device where data will be stored and passed
exploration_type=ExplorationType.RANDOM,
)

Expand Down
32 changes: 25 additions & 7 deletions advanced_source/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
is an integrative part of reinforcement learning and control engineering.

TorchRL provides a set of tools to do this in multiple contexts.
This tutorial demonstrates how to use PyTorch and TorchRL code a pendulum
This tutorial demonstrates how to use PyTorch and TorchRL code a pendulum
simulator from the ground up.
It is freely inspired by the Pendulum-v1 implementation from `OpenAI-Gym/Farama-Gymnasium
control library <https://github.com/Farama-Foundation/Gymnasium>`__.
Expand Down Expand Up @@ -49,9 +49,9 @@
# cover a broader range of features of the environment API in TorchRL.
#
# Modeling stateless environments gives users full control over the input and
# outputs of the simulator: one can reset an experiment at any stage or actively
# modify the dynamics from the outside. However, it assumes that we have some control
# over a task, which may not always be the case: solving a problem where we cannot
# outputs of the simulator: one can reset an experiment at any stage or actively
# modify the dynamics from the outside. However, it assumes that we have some control
# over a task, which may not always be the case: solving a problem where we cannot
# control the current state is more challenging but has a much wider set of applications.
#
# Another advantage of stateless environments is that they can enable
Expand All @@ -73,14 +73,31 @@
# simulation graph.
# * Finally, we will train a simple policy to solve the system we implemented.
#

# sphinx_gallery_start_ignore
import warnings

warnings.filterwarnings("ignore")
from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
pass

# sphinx_gallery_end_ignore

from collections import defaultdict
from typing import Optional

import numpy as np
import torch
import tqdm
from tensordict import TensorDict, TensorDictBase
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from torch import nn

from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
Expand Down Expand Up @@ -167,7 +184,7 @@
# of :meth:`~torchrl.envs.EnvBase.step` in the input ``tensordict`` to enforce
# input/output consistency.
#
# Typically, for stateful environments, this will look like this:
# Typically, for stateful environments, this will look like this:
#
# .. code-block::
#
Expand Down Expand Up @@ -221,6 +238,7 @@
# needed as the state needs to be read from the environment.
#


def _step(tensordict):
th, thdot = tensordict["th"], tensordict["thdot"] # th := theta

Expand Down Expand Up @@ -896,7 +914,7 @@ def plot():
######################################################################
# Conclusion
# ----------
#
#
# In this tutorial, we have learned how to code a stateless environment from
# scratch. We touched the subjects of:
#
Expand Down
39 changes: 32 additions & 7 deletions intermediate_source/dqn_with_rnn_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@
# -----
#

# sphinx_gallery_start_ignore
import warnings

warnings.filterwarnings("ignore")
from torch import multiprocessing

# TorchRL prefers spawn method, that restricts creation of ``~torchrl.envs.ParallelEnv`` inside
# `__main__` method call, but for the easy of reading the code switch to fork
# which is also a default spawn method in Google's Colaboratory
try:
multiprocessing.set_start_method("fork")
except RuntimeError:
pass

# sphinx_gallery_end_ignore

import torch
import tqdm
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
Expand All @@ -88,10 +104,15 @@
TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ConvNet, EGreedyWrapper, LSTMModule, MLP, QValueModule
from torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule
from torchrl.objectives import DQNLoss, SoftUpdate

device = torch.device(0) if torch.cuda.device_count() else torch.device("cpu")
is_fork = multiprocessing.get_start_method() == "fork"
device = (
torch.device(0)
if torch.cuda.is_available() and not is_fork
else torch.device("cpu")
)

######################################################################
# Environment
Expand Down Expand Up @@ -293,11 +314,15 @@
# DQN being a deterministic algorithm, exploration is a crucial part of it.
# We'll be using an :math:`\epsilon`-greedy policy with an epsilon of 0.2 decaying
# progressively to 0.
# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyWrapper.step`
# This decay is achieved via a call to :meth:`~torchrl.modules.EGreedyModule.step`
# (see training loop below).
#
stoch_policy = EGreedyWrapper(
stoch_policy, annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2
exploration_module = EGreedyModule(
annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2
)
stoch_policy = Seq(
stoch_policy,
exploration_module,
)

######################################################################
Expand Down Expand Up @@ -362,7 +387,7 @@
# For the sake of efficiency, we're only running a few thousands iterations
# here. In a real setting, the total number of frames should be set to 1M.
#
collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200)
collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device)
rb = TensorDictReplayBuffer(
storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10
)
Expand Down Expand Up @@ -403,7 +428,7 @@
pbar.set_description(
f"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}"
)
stoch_policy.step(data.numel())
exploration_module.step(data.numel())
updater.step()

with set_exploration_type(ExplorationType.MODE), torch.no_grad():
Expand Down
4 changes: 2 additions & 2 deletions intermediate_source/mario_rl_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
#
# %%bash
# pip install gym-super-mario-bros==7.4.0
# pip install tensordict==0.2.0
# pip install torchrl==0.2.0
# pip install tensordict==0.3.0
# pip install torchrl==0.3.0
#

import torch
Expand Down
Loading