Skip to content

Commit bc65968

Browse files
committed
Fix nested style
Signed-off-by: Onur Berk Töre <onurberk_t@hotmail.com>
1 parent 2a34c3c commit bc65968

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

prototype_source/nestedtensor.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import torch.nn.functional as F
2626

2727
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28-
torch.set_default_device(device)
2928

3029
######################################################################
3130
# NestedTensor Initialization
@@ -36,7 +35,7 @@
3635
# From the Python frontend, a nestedtensor can be created from a list of tensors.
3736
# We denote nt[i] as the ith tensor component of a nestedtensor.
3837
nt = torch.nested.nested_tensor([torch.arange(12).reshape(
39-
2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float)
38+
2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device)
4039
print(f"{nt=}")
4140

4241
######################################################################
@@ -112,7 +111,7 @@
112111
# Applying the operation on a nestedtensor is equivalent to
113112
# applying the operation to the underlying tensor components,
114113
# with the result being a nestedtensor as well.
115-
nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))])
114+
nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)
116115
nt3 = torch.matmul(nt_transposed, nt_mm)
117116
print(f"Result of Matmul:\n {nt3}")
118117

@@ -319,7 +318,7 @@ def mha_padded(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nhea
319318

320319
# Have to manipulate masks in order to apply them to the attention weights
321320
key_padding_mask = attn_mask_q.view(N, 1, 1, L_t).expand(-1, nheads, -1, -1).reshape(N*nheads, 1, L_t).to(device=device)
322-
attn_mask = torch.zeros(key_padding_mask.shape, dtype=torch.float32)
321+
attn_mask = torch.zeros(key_padding_mask.shape, device=device, dtype=torch.float32)
323322
attn_mask = attn_mask.masked_fill_(key_padding_mask, float("-inf"))
324323

325324
# Zero out the attention weights where the mask is True by adding -inf prior to softmax
@@ -385,10 +384,10 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
385384
# create inputs
386385

387386
# create parameters
388-
W_q, b_q = torch.randn((E_total, E_q)), torch.randn(E_total)
389-
W_k, b_k = torch.randn((E_total, E_k)), torch.randn(E_total)
390-
W_v, b_v = torch.randn((E_total, E_v)), torch.randn(E_total)
391-
W_out, b_out = torch.randn((E_out, E_total)), torch.randn(E_out)
387+
W_q, b_q = torch.randn((E_total, E_q), device=device), torch.randn(E_total, device=device)
388+
W_k, b_k = torch.randn((E_total, E_k), device=device), torch.randn(E_total, device=device)
389+
W_v, b_v = torch.randn((E_total, E_v), device=device), torch.randn(E_total, device=device)
390+
W_out, b_out = torch.randn((E_out, E_total), device=device), torch.randn(E_out, device=device)
392391

393392
# create nested input
394393
queries = []
@@ -397,9 +396,9 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
397396
for i in range(N):
398397
l = sentence_lengths[i]
399398
s = l
400-
queries.append(torch.randn((l, E_q)))
401-
keys .append(torch.randn((s, E_k)))
402-
values .append(torch.randn((s, E_v)))
399+
queries.append(torch.randn((l, E_q), device=device))
400+
keys .append(torch.randn((s, E_k), device=device))
401+
values .append(torch.randn((s, E_v), device=device))
403402
query = torch.nested.nested_tensor(queries)
404403
key = torch.nested.nested_tensor(keys)
405404
value = torch.nested.nested_tensor(values)
@@ -454,7 +453,7 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
454453

455454
# embeddings are assumed to be the same
456455
E = E_total
457-
mha_lib = torch.nn.MultiheadAttention(E, nheads, batch_first=True)
456+
mha_lib = torch.nn.MultiheadAttention(E, nheads, batch_first=True, device=device)
458457
mha_lib.eval()
459458

460459
######################################################################

0 commit comments

Comments
 (0)