Skip to content

Commit 72b6df1

Browse files
change the get_screen method so that it does not fill the gpu memory (#415)
Co-authored-by: holly1238 <77758406+holly1238@users.noreply.github.com>
1 parent e18d233 commit 72b6df1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def conv2d_size_out(size, kernel_size = 5, stride = 2):
224224
# Called with either one element to determine next action, or a batch
225225
# during optimization. Returns tensor([[left0exp,right0exp]...]).
226226
def forward(self, x):
227+
x = x.to(device)
227228
x = F.relu(self.bn1(self.conv1(x)))
228229
x = F.relu(self.bn2(self.conv2(x)))
229230
x = F.relu(self.bn3(self.conv3(x)))
@@ -273,7 +274,7 @@ def get_screen():
273274
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
274275
screen = torch.from_numpy(screen)
275276
# Resize, and add a batch dimension (BCHW)
276-
return resize(screen).unsqueeze(0).to(device)
277+
return resize(screen).unsqueeze(0)
277278

278279

279280
env.reset()

0 commit comments

Comments
 (0)