Skip to content

Fix the unpack error and add code for transferring data to the GPU #2688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions advanced_source/cpp_extension.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ on it:
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);

return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates};
return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}


Expand Down Expand Up @@ -1182,8 +1182,47 @@ Performance Comparison

Our hope was that parallelizing and fusing the pointwise operations of our code
with CUDA would improve the performance of our LLTM. Let's see if that holds
true. We can run the code I listed earlier to run a benchmark. Our fastest
version earlier was the CUDA-based C++ code::
true. We can modify the earlier code to transfer the module and input data to GPU for benchmarking.

import time

import torch

batch_size = 16
input_features = 32
state_size = 128

# Check if CUDA (GPU) is available
if torch.cuda.is_available():
# Set the device to CUDA
device = torch.device("cuda")
print("CUDA is available. Using GPU.")
else:
# If CUDA is not available, fall back to CPU
device = torch.device("cpu")
print("CUDA is not available. Using CPU.")

X = torch.randn(batch_size, input_features, device=device)
h = torch.randn(batch_size, state_size, device=device)
C = torch.randn(batch_size, state_size, device=device)

rnn = LLTM(input_features, state_size).to(device)

forward = 0
backward = 0
for _ in range(100000):
start = time.time()
new_h, new_C = rnn(X, (h, C))
forward += time.time() - start

start = time.time()
(new_h.sum() + new_C.sum()).backward()
backward += time.time() - start

print('Forward: {:.3f} s | Backward {:.3f} s'.format(forward, backward))


Our fastest version earlier was the CUDA-based C++ code::

Forward: 149.802 us | Backward 393.458 us

Expand Down