Skip to content

Commit c58044f

Browse files
committed
[DEBUG] ppo compile
1 parent b538c66 commit c58044f

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

sota-implementations/ppo/ppo_atari.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,18 @@ def update(batch, num_network_updates):
168168
num_network_updates = num_network_updates + 1
169169
# Get a data batch
170170
batch = batch.to(device, non_blocking=True)
171+
def forward(batch, num_network_updates):
172+
173+
# Forward pass PPO loss
174+
loss = loss_module(batch)
175+
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
176+
return loss, loss_sum
177+
178+
loss, loss_sum = torch.compile(forward, backend="inductor", mode="reduce-overhead")(batch, num_network_updates)
171179

172-
# Forward pass PPO loss
173-
loss = loss_module(batch)
174-
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
175180
# Backward pass
176181
loss_sum.backward()
182+
177183
torch.nn.utils.clip_grad_norm_(
178184
loss_module.parameters(), max_norm=cfg_optim_max_grad_norm
179185
)

torchrl/_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -985,9 +985,10 @@ def count_and_compile(*model_args, **model_kwargs):
985985
nonlocal count
986986
nonlocal compiled_model
987987
count += 1
988-
if count == warmup:
989-
compiled_model = torch.compile(model, *args, **kwargs)
990-
return compiled_model(*model_args, **model_kwargs)
988+
#if count == warmup:
989+
# compiled_model = torch.compile(model, fullgraph=True, backend="inductor")
990+
out = compiled_model(*model_args, **model_kwargs)
991+
return out
991992

992993
return count_and_compile
993994

0 commit comments

Comments
 (0)