@@ -260,7 +260,7 @@ def forward(self,
260
260
261
261
###############################################################################
262
262
# Utilities
263
- # =========
263
+ # ~~~~~~~~~
264
264
# In this section, we include a utility to generate semi-realistic data using
265
265
# Zipf distribution for sentence lengths. This is used to generate the nested
266
266
# query, key and value tensors. We also include a benchmark utility.
@@ -400,11 +400,13 @@ def benchmark(func, *args, **kwargs):
400
400
######################################################################################
401
401
# For reference some sample outputs on A100:
402
402
#
403
- # padded_time=0.03454, padded_peak_memory=4.14 GB
404
- # nested_time=0.00612, nested_peak_memory=0.76 GB
405
- # Difference between vanilla and nested result 0.0
406
- # Nested speedup: 5.65
407
- # Nested peak memory reduction 3.39 GB
403
+ # ..code::
404
+ #
405
+ # padded_time=0.03454, padded_peak_memory=4.14 GB
406
+ # nested_time=0.00612, nested_peak_memory=0.76 GB
407
+ # Difference between vanilla and nested result 0.0
408
+ # Nested speedup: 5.65
409
+ # Nested peak memory reduction 3.39 GB
408
410
#
409
411
# We can also see the same for backward pass
410
412
@@ -428,14 +430,16 @@ def benchmark(func, *args, **kwargs):
428
430
##################################################################################
429
431
# Sample outputs on A100:
430
432
#
431
- # ``padded_bw_time``=2.09337, ``padded_bw_peak_mem``=5.10 GB
432
- # ``nested_bw_time``=0.01452, ``nested_bw_peak_mem``=3.24 GB
433
- # Nested backward speedup: 144.13
434
- # Nested backward peak memory reduction 1.86 GB
435
- # Difference in ``out_proj.weight.grad`` 0.000244140625
436
- # Difference in ``packed_proj.weight.grad`` 0.001556396484375
437
- # Difference in ``out_proj.bias.grad`` 0.0
438
- # Difference in ``packed_proj.bias.grad`` 0.001953125
433
+ # ..code::
434
+ #
435
+ # ``padded_bw_time``=2.09337, ``padded_bw_peak_mem``=5.10 GB
436
+ # ``nested_bw_time``=0.01452, ``nested_bw_peak_mem``=3.24 GB
437
+ # Nested backward speedup: 144.13
438
+ # Nested backward peak memory reduction 1.86 GB
439
+ # Difference in ``out_proj.weight.grad`` 0.000244140625
440
+ # Difference in ``packed_proj.weight.grad`` 0.001556396484375
441
+ # Difference in ``out_proj.bias.grad`` 0.0
442
+ # Difference in ``packed_proj.bias.grad`` 0.001953125
439
443
#
440
444
441
445
##################################################################################
0 commit comments