From e2495e56027ec41017ab23943f020a12102b2f42 Mon Sep 17 00:00:00 2001 From: aurada Date: Sat, 18 Nov 2023 00:41:42 -0500 Subject: [PATCH] fix unpack error and move data to gpu --- advanced_source/cpp_extension.rst | 45 ++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/advanced_source/cpp_extension.rst b/advanced_source/cpp_extension.rst index cb0e990797e..bd0c463ceb0 100644 --- a/advanced_source/cpp_extension.rst +++ b/advanced_source/cpp_extension.rst @@ -1143,7 +1143,7 @@ on it: auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); auto d_input = d_X.slice(/*dim=*/1, state_size); - return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates}; + return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; } @@ -1182,8 +1182,47 @@ Performance Comparison Our hope was that parallelizing and fusing the pointwise operations of our code with CUDA would improve the performance of our LLTM. Let's see if that holds -true. We can run the code I listed earlier to run a benchmark. Our fastest -version earlier was the CUDA-based C++ code:: +true. We can modify the earlier code to transfer the module and input data to GPU for benchmarking. + + import time + + import torch + + batch_size = 16 + input_features = 32 + state_size = 128 + + # Check if CUDA (GPU) is available + if torch.cuda.is_available(): + # Set the device to CUDA + device = torch.device("cuda") + print("CUDA is available. Using GPU.") + else: + # If CUDA is not available, fall back to CPU + device = torch.device("cpu") + print("CUDA is not available. Using CPU.") + + X = torch.randn(batch_size, input_features, device=device) + h = torch.randn(batch_size, state_size, device=device) + C = torch.randn(batch_size, state_size, device=device) + + rnn = LLTM(input_features, state_size).to(device) + + forward = 0 + backward = 0 + for _ in range(100000): + start = time.time() + new_h, new_C = rnn(X, (h, C)) + forward += time.time() - start + + start = time.time() + (new_h.sum() + new_C.sum()).backward() + backward += time.time() - start + + print('Forward: {:.3f} s | Backward {:.3f} s'.format(forward, backward)) + + +Our fastest version earlier was the CUDA-based C++ code:: Forward: 149.802 us | Backward 393.458 us