|
78 | 78 | device = (
|
79 | 79 | torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda:0")
|
80 | 80 | )
|
| 81 | +collector_device = torch.device("cpu") |
81 | 82 |
|
82 | 83 | ###############################################################################
|
83 | 84 | # TorchRL :class:`~torchrl.objectives.LossModule`
|
@@ -449,7 +450,6 @@ def make_env(from_pixels=False):
|
449 | 450 | raise NotImplementedError
|
450 | 451 |
|
451 | 452 | env_kwargs = {
|
452 |
| - "device": device, |
453 | 453 | "from_pixels": from_pixels,
|
454 | 454 | "pixels_only": from_pixels,
|
455 | 455 | "frame_skip": 2,
|
@@ -546,7 +546,7 @@ def make_transformed_env(
|
546 | 546 |
|
547 | 547 | env.append_transform(StepCounter(max_frames_per_traj))
|
548 | 548 |
|
549 |
| - # We need a marker for the start of trajectories for our Ornstein-Uhlenbeck |
| 549 | + # We need a marker for the start of trajectories for our Ornstein-Uhlenbeck (OU) |
550 | 550 | # exploration:
|
551 | 551 | env.append_transform(InitTracker())
|
552 | 552 |
|
@@ -580,34 +580,17 @@ def make_transformed_env(
|
580 | 580 | #
|
581 | 581 |
|
582 | 582 |
|
583 |
| -def parallel_env_constructor( |
584 |
| - env_per_collector, |
| 583 | +def env_constructor( |
585 | 584 | transform_state_dict,
|
586 | 585 | ):
|
587 |
| - if env_per_collector == 1: |
588 |
| - |
589 |
| - def make_t_env(): |
590 |
| - env = make_transformed_env(make_env()) |
591 |
| - env.transform[2].init_stats(3) |
592 |
| - env.transform[2].loc.copy_(transform_state_dict["loc"]) |
593 |
| - env.transform[2].scale.copy_(transform_state_dict["scale"]) |
594 |
| - return env |
595 |
| - |
596 |
| - env_creator = EnvCreator(make_t_env) |
597 |
| - return env_creator |
598 |
| - |
599 |
| - parallel_env = ParallelEnv( |
600 |
| - num_workers=env_per_collector, |
601 |
| - create_env_fn=EnvCreator(lambda: make_env()), |
602 |
| - create_env_kwargs=None, |
603 |
| - pin_memory=False, |
604 |
| - ) |
605 |
| - env = make_transformed_env(parallel_env) |
606 |
| - # we call `init_stats` for a limited number of steps, just to instantiate |
607 |
| - # the lazy buffers. |
608 |
| - env.transform[2].init_stats(3, cat_dim=1, reduce_dim=[0, 1]) |
609 |
| - env.transform[2].load_state_dict(transform_state_dict) |
610 |
| - return env |
| 586 | + def make_t_env(): |
| 587 | + env = make_transformed_env(make_env()) |
| 588 | + env.transform[2].init_stats(3) |
| 589 | + env.transform[2].loc.copy_(transform_state_dict["loc"]) |
| 590 | + env.transform[2].scale.copy_(transform_state_dict["scale"]) |
| 591 | + return env |
| 592 | + env_creator = EnvCreator(make_t_env) |
| 593 | + return env_creator |
611 | 594 |
|
612 | 595 |
|
613 | 596 | # The backend can be ``gym`` or ``dm_control``
|
@@ -868,9 +851,9 @@ def make_ddpg_actor(
|
868 | 851 | init_random_frames=init_random_frames,
|
869 | 852 | reset_at_each_iter=False,
|
870 | 853 | split_trajs=False,
|
871 |
| - device=device, |
| 854 | + device=collector_device, |
872 | 855 | # device for execution
|
873 |
| - storing_device=device, |
| 856 | + storing_device=collector_device, |
874 | 857 | # device where data will be stored and passed
|
875 | 858 | update_at_each_batch=False,
|
876 | 859 | exploration_type=ExplorationType.RANDOM,
|
|
0 commit comments