Skip to content

Commit 967d22b

Browse files
vaibhawvipulsoumith
andcommitted
torch.uint8 is now deprecated, moving to torch.bool (#812)
Co-authored-by: Soumith Chintala <soumith@gmail.com>
1 parent 8244bff commit 967d22b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

intermediate_source/reinforcement_q_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def optimize_model():
408408
# Compute a mask of non-final states and concatenate the batch elements
409409
# (a final state would've been the one after which simulation ended)
410410
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
411-
batch.next_state)), device=device, dtype=torch.uint8)
411+
batch.next_state)), device=device, dtype=torch.bool)
412412
non_final_next_states = torch.cat([s for s in batch.next_state
413413
if s is not None])
414414
state_batch = torch.cat(batch.state)

0 commit comments

Comments
 (0)