Skip to content

Commit 1015af6

Browse files
committed
more version handling for v.25 and v.26
1 parent 1a878bf commit 1015af6

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,10 @@ def forward(self, x):
261261
# Get number of actions from gym action space
262262
n_actions = env.action_space.n
263263
# Get the number of state observations
264-
state, _ = env.reset()
264+
if gym.__version__[:4] == '0.26':
265+
state, _ = env.reset()
266+
elif gym.__version__[:4] == '0.25':
267+
state, _ = env.reset(return_info=True)
265268
n_observations = len(state)
266269

267270
policy_net = DQN(n_observations, n_actions).to(device)
@@ -401,7 +404,10 @@ def optimize_model():
401404

402405
for i_episode in range(num_episodes):
403406
# Initialize the environment and get it's state
404-
state, _ = env.reset()
407+
if gym.__version__[:4] == '0.26':
408+
state, _ = env.reset()
409+
elif gym.__version__[:4] == '0.25':
410+
state, _ = env.reset(return_info=True)
405411
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
406412
for t in count():
407413
action = select_action(state)

0 commit comments

Comments
 (0)