diff --git a/intermediate_source/reinforcement_q_learning.py b/intermediate_source/reinforcement_q_learning.py index 5540405ee3a..d196de61fda 100644 --- a/intermediate_source/reinforcement_q_learning.py +++ b/intermediate_source/reinforcement_q_learning.py @@ -408,7 +408,7 @@ def optimize_model(): # Compute a mask of non-final states and concatenate the batch elements # (a final state would've been the one after which simulation ended) non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, - batch.next_state)), device=device, dtype=torch.uint8) + batch.next_state)), device=device, dtype=torch.bool) non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]) state_batch = torch.cat(batch.state)