diff --git a/generate.py b/generate.py index 19e1de26..e833e94f 100644 --- a/generate.py +++ b/generate.py @@ -310,7 +310,7 @@ def main( decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) # Uncomment to squeeze more perf out of prefill - if args.compile_prefill: + if compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True)