@@ -580,17 +580,34 @@ def make_transformed_env(
580
580
#
581
581
582
582
583
- def env_constructor (
583
+ def parallel_env_constructor (
584
+ env_per_collector ,
584
585
transform_state_dict ,
585
586
):
586
- def make_t_env ():
587
- env = make_transformed_env (make_env ())
588
- env .transform [2 ].init_stats (3 )
589
- env .transform [2 ].loc .copy_ (transform_state_dict ["loc" ])
590
- env .transform [2 ].scale .copy_ (transform_state_dict ["scale" ])
591
- return env
592
- env_creator = EnvCreator (make_t_env )
593
- return env_creator
587
+ if env_per_collector == 1 :
588
+
589
+ def make_t_env ():
590
+ env = make_transformed_env (make_env ())
591
+ env .transform [2 ].init_stats (3 )
592
+ env .transform [2 ].loc .copy_ (transform_state_dict ["loc" ])
593
+ env .transform [2 ].scale .copy_ (transform_state_dict ["scale" ])
594
+ return env
595
+
596
+ env_creator = EnvCreator (make_t_env )
597
+ return env_creator
598
+
599
+ parallel_env = ParallelEnv (
600
+ num_workers = env_per_collector ,
601
+ create_env_fn = EnvCreator (lambda : make_env ()),
602
+ create_env_kwargs = None ,
603
+ pin_memory = False ,
604
+ )
605
+ env = make_transformed_env (parallel_env )
606
+ # we call `init_stats` for a limited number of steps, just to instantiate
607
+ # the lazy buffers.
608
+ env .transform [2 ].init_stats (3 , cat_dim = 1 , reduce_dim = [0 , 1 ])
609
+ env .transform [2 ].load_state_dict (transform_state_dict )
610
+ return env
594
611
595
612
596
613
# The backend can be ``gym`` or ``dm_control``
0 commit comments