Skip to content

Commit 7c48bff

Browse files
author
Svetlana Karslioglu
authored
Update intermediate_source/coding_ddpg.py
1 parent b575982 commit 7c48bff

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

intermediate_source/coding_ddpg.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -580,17 +580,34 @@ def make_transformed_env(
580580
#
581581

582582

583-
def env_constructor(
583+
def parallel_env_constructor(
584+
env_per_collector,
584585
transform_state_dict,
585586
):
586-
def make_t_env():
587-
env = make_transformed_env(make_env())
588-
env.transform[2].init_stats(3)
589-
env.transform[2].loc.copy_(transform_state_dict["loc"])
590-
env.transform[2].scale.copy_(transform_state_dict["scale"])
591-
return env
592-
env_creator = EnvCreator(make_t_env)
593-
return env_creator
587+
if env_per_collector == 1:
588+
589+
def make_t_env():
590+
env = make_transformed_env(make_env())
591+
env.transform[2].init_stats(3)
592+
env.transform[2].loc.copy_(transform_state_dict["loc"])
593+
env.transform[2].scale.copy_(transform_state_dict["scale"])
594+
return env
595+
596+
env_creator = EnvCreator(make_t_env)
597+
return env_creator
598+
599+
parallel_env = ParallelEnv(
600+
num_workers=env_per_collector,
601+
create_env_fn=EnvCreator(lambda: make_env()),
602+
create_env_kwargs=None,
603+
pin_memory=False,
604+
)
605+
env = make_transformed_env(parallel_env)
606+
# we call `init_stats` for a limited number of steps, just to instantiate
607+
# the lazy buffers.
608+
env.transform[2].init_stats(3, cat_dim=1, reduce_dim=[0, 1])
609+
env.transform[2].load_state_dict(transform_state_dict)
610+
return env
594611

595612

596613
# The backend can be ``gym`` or ``dm_control``

0 commit comments

Comments
 (0)