Skip to content

Commit 8e581e2

Browse files
committed
wrap in fused kernel scope
1 parent 6c67d6d commit 8e581e2

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

intermediate_source/scaled_dot_product_attention_tutorial.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,18 @@ def generate_rand_batch(
228228
seq_len_list,
229229
)
230230

231-
# Currently the fastpaths don't support NestedTensor for training
232231
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
233232
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
233+
234+
# Currently the fused implementations don't support NestedTensor for training
234235
model.eval()
235-
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
236-
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
236+
237+
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
238+
try:
239+
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
240+
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
241+
except RuntimeError:
242+
print("FlashAttention is not supported. See warnings for reasons.")
237243

238244

239245
######################################################################

0 commit comments

Comments
 (0)