diff --git a/intermediate_source/reinforcement_ppo.py b/intermediate_source/reinforcement_ppo.py index 6b0e8522f8d..8dee73969db 100644 --- a/intermediate_source/reinforcement_ppo.py +++ b/intermediate_source/reinforcement_ppo.py @@ -604,7 +604,7 @@ data_view = tensordict_data.reshape(-1) replay_buffer.extend(data_view.cpu()) for _ in range(frames_per_batch // sub_batch_size): - subdata, *_ = replay_buffer.sample(sub_batch_size) + subdata = replay_buffer.sample(sub_batch_size) loss_vals = loss_module(subdata.to(device)) loss_value = ( loss_vals["loss_objective"]