Skip to content

Commit 39d3e6e

Browse files
authored
Fix torchrl scripts for PT 2.6 TorchRL>=0.6 (#3199)
Fixes #3195 Fixing TorchRL scripts for Pytorch 2.6 release
1 parent 1150bb5 commit 39d3e6e

File tree

5 files changed

+6
-8
lines changed

5 files changed

+6
-8
lines changed

.jenkins/validate_tutorials_built.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@
5353
"intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release.
5454
"intermediate_source/torch_export_tutorial", # reenable after 2940 is fixed.
5555
"prototype_source/gpu_quantization_torchao_tutorial", # enable when 3194
56-
"advanced_source/pendulum", # enable when 3195 is fixed
57-
"intermediate_source/reinforcement_ppo" # enable when 3195 is fixed
5856
]
5957

6058
def tutorial_source_dirs() -> List[Path]:

advanced_source/coding_ddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval):
893893
record_frames=1000,
894894
policy_exploration=actor_model_explore,
895895
environment=environment,
896-
exploration_type=ExplorationType.MEAN,
896+
exploration_type=ExplorationType.DETERMINISTIC,
897897
record_interval=record_interval,
898898
)
899899
return recorder_obj

advanced_source/pendulum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def __init__(self, td_params=None, seed=None, device="cpu"):
604604
env,
605605
# ``Unsqueeze`` the observations that we will concatenate
606606
UnsqueezeTransform(
607-
unsqueeze_dim=-1,
607+
dim=-1,
608608
in_keys=["th", "thdot"],
609609
in_keys_inv=["th", "thdot"],
610610
),

intermediate_source/dqn_with_rnn_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@
433433
exploration_module.step(data.numel())
434434
updater.step()
435435

436-
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
436+
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
437437
rollout = env.rollout(10000, stoch_policy)
438438
traj_lens.append(rollout.get(("next", "step_count")).max().item())
439439

intermediate_source/reinforcement_ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,8 @@
419419
in_keys=["loc", "scale"],
420420
distribution_class=TanhNormal,
421421
distribution_kwargs={
422-
"min": env.action_spec.space.low,
423-
"max": env.action_spec.space.high,
422+
"low": env.action_spec.space.low,
423+
"high": env.action_spec.space.high,
424424
},
425425
return_log_prob=True,
426426
# we'll need the log-prob for the numerator of the importance weights
@@ -639,7 +639,7 @@
639639
# number of steps (1000, which is our ``env`` horizon).
640640
# The ``rollout`` method of the ``env`` can take a policy as argument:
641641
# it will then execute this policy at each step.
642-
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
642+
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
643643
# execute a rollout with the trained policy
644644
eval_rollout = env.rollout(1000, policy_module)
645645
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())

0 commit comments

Comments
 (0)