From edcc7e9e08cf27e64e0471ea7a7eeaa8df9cbf97 Mon Sep 17 00:00:00 2001 From: saurabhkthakur Date: Sat, 4 Nov 2023 22:49:15 +0530 Subject: [PATCH] updated the code to avoid deepcopy() --- intermediate_source/mario_rl_tutorial.py | 30 ++++++++++++++---------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/intermediate_source/mario_rl_tutorial.py b/intermediate_source/mario_rl_tutorial.py index 74a08a47b37..3f6380f9689 100755 --- a/intermediate_source/mario_rl_tutorial.py +++ b/intermediate_source/mario_rl_tutorial.py @@ -424,7 +424,23 @@ def __init__(self, input_dim, output_dim): if w != 84: raise ValueError(f"Expecting input width: 84, got: {w}") - self.online = nn.Sequential( + self.online = self._build_cnn(input_dim, output_dim) + self.target = self._build_cnn(input_dim, output_dim) + + # Q_target parameters are frozen. + for p in self.target.parameters(): + p.requires_grad = False + + def forward(self, input, model): + if model == "online": + return self.online(input) + elif model == "target": + return self.target(input) + + def _build_cnn(self, input_dim, output_dim): + c, _, _ = input_dim + + cnn = nn.Sequential( nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4), nn.ReLU(), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), @@ -437,17 +453,7 @@ def __init__(self, input_dim, output_dim): nn.Linear(512, output_dim), ) - self.target = copy.deepcopy(self.online) - - # Q_target parameters are frozen. - for p in self.target.parameters(): - p.requires_grad = False - - def forward(self, input, model): - if model == "online": - return self.online(input) - elif model == "target": - return self.target(input) + return cnn ######################################################################