diff --git a/intermediate_source/mario_rl_tutorial.py b/intermediate_source/mario_rl_tutorial.py index 74a08a47b37..3f6380f9689 100755 --- a/intermediate_source/mario_rl_tutorial.py +++ b/intermediate_source/mario_rl_tutorial.py @@ -424,7 +424,23 @@ def __init__(self, input_dim, output_dim): if w != 84: raise ValueError(f"Expecting input width: 84, got: {w}") - self.online = nn.Sequential( + self.online = self._build_cnn(input_dim, output_dim) + self.target = self._build_cnn(input_dim, output_dim) + + # Q_target parameters are frozen. + for p in self.target.parameters(): + p.requires_grad = False + + def forward(self, input, model): + if model == "online": + return self.online(input) + elif model == "target": + return self.target(input) + + def _build_cnn(self, input_dim, output_dim): + c, _, _ = input_dim + + cnn = nn.Sequential( nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4), nn.ReLU(), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), @@ -437,17 +453,7 @@ def __init__(self, input_dim, output_dim): nn.Linear(512, output_dim), ) - self.target = copy.deepcopy(self.online) - - # Q_target parameters are frozen. - for p in self.target.parameters(): - p.requires_grad = False - - def forward(self, input, model): - if model == "online": - return self.online(input) - elif model == "target": - return self.target(input) + return cnn ######################################################################