Skip to content

Commit d756188

Browse files
Svetlana Karsliogluvmoens
Svetlana Karslioglu
andauthored
Apply suggestions from code review
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
1 parent 7ba9abd commit d756188

File tree

1 file changed

+13
-30
lines changed

1 file changed

+13
-30
lines changed

intermediate_source/coding_ddpg.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
device = (
7979
torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda:0")
8080
)
81+
collector_device = torch.device("cpu")
8182

8283
###############################################################################
8384
# TorchRL :class:`~torchrl.objectives.LossModule`
@@ -449,7 +450,6 @@ def make_env(from_pixels=False):
449450
raise NotImplementedError
450451

451452
env_kwargs = {
452-
"device": device,
453453
"from_pixels": from_pixels,
454454
"pixels_only": from_pixels,
455455
"frame_skip": 2,
@@ -546,7 +546,7 @@ def make_transformed_env(
546546

547547
env.append_transform(StepCounter(max_frames_per_traj))
548548

549-
# We need a marker for the start of trajectories for our Ornstein-Uhlenbeck
549+
# We need a marker for the start of trajectories for our Ornstein-Uhlenbeck (OU)
550550
# exploration:
551551
env.append_transform(InitTracker())
552552

@@ -580,34 +580,17 @@ def make_transformed_env(
580580
#
581581

582582

583-
def parallel_env_constructor(
584-
env_per_collector,
583+
def env_constructor(
585584
transform_state_dict,
586585
):
587-
if env_per_collector == 1:
588-
589-
def make_t_env():
590-
env = make_transformed_env(make_env())
591-
env.transform[2].init_stats(3)
592-
env.transform[2].loc.copy_(transform_state_dict["loc"])
593-
env.transform[2].scale.copy_(transform_state_dict["scale"])
594-
return env
595-
596-
env_creator = EnvCreator(make_t_env)
597-
return env_creator
598-
599-
parallel_env = ParallelEnv(
600-
num_workers=env_per_collector,
601-
create_env_fn=EnvCreator(lambda: make_env()),
602-
create_env_kwargs=None,
603-
pin_memory=False,
604-
)
605-
env = make_transformed_env(parallel_env)
606-
# we call `init_stats` for a limited number of steps, just to instantiate
607-
# the lazy buffers.
608-
env.transform[2].init_stats(3, cat_dim=1, reduce_dim=[0, 1])
609-
env.transform[2].load_state_dict(transform_state_dict)
610-
return env
586+
def make_t_env():
587+
env = make_transformed_env(make_env())
588+
env.transform[2].init_stats(3)
589+
env.transform[2].loc.copy_(transform_state_dict["loc"])
590+
env.transform[2].scale.copy_(transform_state_dict["scale"])
591+
return env
592+
env_creator = EnvCreator(make_t_env)
593+
return env_creator
611594

612595

613596
# The backend can be ``gym`` or ``dm_control``
@@ -868,9 +851,9 @@ def make_ddpg_actor(
868851
init_random_frames=init_random_frames,
869852
reset_at_each_iter=False,
870853
split_trajs=False,
871-
device=device,
854+
device=collector_device,
872855
# device for execution
873-
storing_device=device,
856+
storing_device=collector_device,
874857
# device where data will be stored and passed
875858
update_at_each_batch=False,
876859
exploration_type=ExplorationType.RANDOM,

0 commit comments

Comments
 (0)