Skip to content

Commit 6f83877

Browse files
authored
Improve Neural Tangent Kernels tutorial (#2476)
- Allow it to work on CPU - Disable TensorFloat on Ampere+ GPU to meet accuracy expectations - Add TensorFloat to dictionary of known words
1 parent 3eef691 commit 6f83877

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

en-wordlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ TPU
154154
TensorBoard
155155
TensorBoards
156156
TensorDict
157+
TensorFloat
157158
TextVQA
158159
Tokenization
159160
TorchDynamo

intermediate_source/neural_tangent_kernels.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import torch
2525
import torch.nn as nn
2626
from torch.func import functional_call, vmap, vjp, jvp, jacrev
27-
device = 'cuda'
27+
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
2828

2929
class CNN(nn.Module):
3030
def __init__(self):
@@ -224,8 +224,11 @@ def get_ntk_slice(vec):
224224
if compute == 'diagonal':
225225
return torch.einsum('NMKK->NMK', result)
226226

227-
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
228-
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
227+
# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy
228+
with torch.backends.cudnn.flags(allow_tf32=False):
229+
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
230+
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
231+
229232
assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
230233

231234
######################################################################

0 commit comments

Comments
 (0)