Skip to content

Commit c535b15

Browse files
committed
change the get_screen method so that it does not fill the gpu memory
1 parent b38343e commit c535b15

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def conv2d_size_out(size, kernel_size = 5, stride = 2):
229229
# Called with either one element to determine next action, or a batch
230230
# during optimization. Returns tensor([[left0exp,right0exp]...]).
231231
def forward(self, x):
232+
x = x.to(device)
232233
x = F.relu(self.bn1(self.conv1(x)))
233234
x = F.relu(self.bn2(self.conv2(x)))
234235
x = F.relu(self.bn3(self.conv3(x)))
@@ -278,7 +279,7 @@ def get_screen():
278279
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
279280
screen = torch.from_numpy(screen)
280281
# Resize, and add a batch dimension (BCHW)
281-
return resize(screen).unsqueeze(0).to(device)
282+
return resize(screen).unsqueeze(0)
282283

283284

284285
env.reset()

0 commit comments

Comments
 (0)