Skip to content

Commit e2495e5

Browse files
committed
fix unpack error and move data to gpu
1 parent 51a3f60 commit e2495e5

File tree

1 file changed

+42
-3
lines changed

1 file changed

+42
-3
lines changed

advanced_source/cpp_extension.rst

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ on it:
11431143
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
11441144
auto d_input = d_X.slice(/*dim=*/1, state_size);
11451145
1146-
return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates};
1146+
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
11471147
}
11481148
11491149
@@ -1182,8 +1182,47 @@ Performance Comparison
11821182

11831183
Our hope was that parallelizing and fusing the pointwise operations of our code
11841184
with CUDA would improve the performance of our LLTM. Let's see if that holds
1185-
true. We can run the code I listed earlier to run a benchmark. Our fastest
1186-
version earlier was the CUDA-based C++ code::
1185+
true. We can modify the earlier code to transfer the module and input data to GPU for benchmarking.
1186+
1187+
import time
1188+
1189+
import torch
1190+
1191+
batch_size = 16
1192+
input_features = 32
1193+
state_size = 128
1194+
1195+
# Check if CUDA (GPU) is available
1196+
if torch.cuda.is_available():
1197+
# Set the device to CUDA
1198+
device = torch.device("cuda")
1199+
print("CUDA is available. Using GPU.")
1200+
else:
1201+
# If CUDA is not available, fall back to CPU
1202+
device = torch.device("cpu")
1203+
print("CUDA is not available. Using CPU.")
1204+
1205+
X = torch.randn(batch_size, input_features, device=device)
1206+
h = torch.randn(batch_size, state_size, device=device)
1207+
C = torch.randn(batch_size, state_size, device=device)
1208+
1209+
rnn = LLTM(input_features, state_size).to(device)
1210+
1211+
forward = 0
1212+
backward = 0
1213+
for _ in range(100000):
1214+
start = time.time()
1215+
new_h, new_C = rnn(X, (h, C))
1216+
forward += time.time() - start
1217+
1218+
start = time.time()
1219+
(new_h.sum() + new_C.sum()).backward()
1220+
backward += time.time() - start
1221+
1222+
print('Forward: {:.3f} s | Backward {:.3f} s'.format(forward, backward))
1223+
1224+
1225+
Our fastest version earlier was the CUDA-based C++ code::
11871226

11881227
Forward: 149.802 us | Backward 393.458 us
11891228

0 commit comments

Comments
 (0)