25
25
import torch .nn .functional as F
26
26
27
27
device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
28
- torch .set_default_device (device )
29
28
30
29
######################################################################
31
30
# NestedTensor Initialization
36
35
# From the Python frontend, a nestedtensor can be created from a list of tensors.
37
36
# We denote nt[i] as the ith tensor component of a nestedtensor.
38
37
nt = torch .nested .nested_tensor ([torch .arange (12 ).reshape (
39
- 2 , 6 ), torch .arange (18 ).reshape (3 , 6 )], dtype = torch .float )
38
+ 2 , 6 ), torch .arange (18 ).reshape (3 , 6 )], dtype = torch .float , device = device )
40
39
print (f"{ nt = } " )
41
40
42
41
######################################################################
112
111
# Applying the operation on a nestedtensor is equivalent to
113
112
# applying the operation to the underlying tensor components,
114
113
# with the result being a nestedtensor as well.
115
- nt_mm = torch .nested .nested_tensor ([torch .randn ((2 , 3 , 4 )), torch .randn ((2 , 3 , 5 ))])
114
+ nt_mm = torch .nested .nested_tensor ([torch .randn ((2 , 3 , 4 )), torch .randn ((2 , 3 , 5 ))], device = device )
116
115
nt3 = torch .matmul (nt_transposed , nt_mm )
117
116
print (f"Result of Matmul:\n { nt3 } " )
118
117
@@ -319,7 +318,7 @@ def mha_padded(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nhea
319
318
320
319
# Have to manipulate masks in order to apply them to the attention weights
321
320
key_padding_mask = attn_mask_q .view (N , 1 , 1 , L_t ).expand (- 1 , nheads , - 1 , - 1 ).reshape (N * nheads , 1 , L_t ).to (device = device )
322
- attn_mask = torch .zeros (key_padding_mask .shape , dtype = torch .float32 )
321
+ attn_mask = torch .zeros (key_padding_mask .shape , device = device , dtype = torch .float32 )
323
322
attn_mask = attn_mask .masked_fill_ (key_padding_mask , float ("-inf" ))
324
323
325
324
# Zero out the attention weights where the mask is True by adding -inf prior to softmax
@@ -385,10 +384,10 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
385
384
# create inputs
386
385
387
386
# create parameters
388
- W_q , b_q = torch .randn ((E_total , E_q )), torch .randn (E_total )
389
- W_k , b_k = torch .randn ((E_total , E_k )), torch .randn (E_total )
390
- W_v , b_v = torch .randn ((E_total , E_v )), torch .randn (E_total )
391
- W_out , b_out = torch .randn ((E_out , E_total )), torch .randn (E_out )
387
+ W_q , b_q = torch .randn ((E_total , E_q ), device = device ), torch .randn (E_total , device = device )
388
+ W_k , b_k = torch .randn ((E_total , E_k ), device = device ), torch .randn (E_total , device = device )
389
+ W_v , b_v = torch .randn ((E_total , E_v ), device = device ), torch .randn (E_total , device = device )
390
+ W_out , b_out = torch .randn ((E_out , E_total ), device = device ), torch .randn (E_out , device = device )
392
391
393
392
# create nested input
394
393
queries = []
@@ -397,9 +396,9 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
397
396
for i in range (N ):
398
397
l = sentence_lengths [i ]
399
398
s = l
400
- queries .append (torch .randn ((l , E_q )))
401
- keys .append (torch .randn ((s , E_k )))
402
- values .append (torch .randn ((s , E_v )))
399
+ queries .append (torch .randn ((l , E_q ), device = device ))
400
+ keys .append (torch .randn ((s , E_k ), device = device ))
401
+ values .append (torch .randn ((s , E_v ), device = device ))
403
402
query = torch .nested .nested_tensor (queries )
404
403
key = torch .nested .nested_tensor (keys )
405
404
value = torch .nested .nested_tensor (values )
@@ -454,7 +453,7 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
454
453
455
454
# embeddings are assumed to be the same
456
455
E = E_total
457
- mha_lib = torch .nn .MultiheadAttention (E , nheads , batch_first = True )
456
+ mha_lib = torch .nn .MultiheadAttention (E , nheads , batch_first = True , device = device )
458
457
mha_lib .eval ()
459
458
460
459
######################################################################
0 commit comments