@@ -1143,7 +1143,7 @@ on it:
1143
1143
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
1144
1144
auto d_input = d_X.slice(/*dim=*/1, state_size);
1145
1145
1146
- return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates };
1146
+ return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
1147
1147
}
1148
1148
1149
1149
@@ -1182,8 +1182,47 @@ Performance Comparison
1182
1182
1183
1183
Our hope was that parallelizing and fusing the pointwise operations of our code
1184
1184
with CUDA would improve the performance of our LLTM. Let's see if that holds
1185
- true. We can run the code I listed earlier to run a benchmark. Our fastest
1186
- version earlier was the CUDA-based C++ code::
1185
+ true. We can modify the earlier code to transfer the module and input data to GPU for benchmarking.
1186
+
1187
+ import time
1188
+
1189
+ import torch
1190
+
1191
+ batch_size = 16
1192
+ input_features = 32
1193
+ state_size = 128
1194
+
1195
+ # Check if CUDA (GPU) is available
1196
+ if torch.cuda.is_available():
1197
+ # Set the device to CUDA
1198
+ device = torch.device("cuda")
1199
+ print("CUDA is available. Using GPU.")
1200
+ else:
1201
+ # If CUDA is not available, fall back to CPU
1202
+ device = torch.device("cpu")
1203
+ print("CUDA is not available. Using CPU.")
1204
+
1205
+ X = torch.randn(batch_size, input_features, device=device)
1206
+ h = torch.randn(batch_size, state_size, device=device)
1207
+ C = torch.randn(batch_size, state_size, device=device)
1208
+
1209
+ rnn = LLTM(input_features, state_size).to(device)
1210
+
1211
+ forward = 0
1212
+ backward = 0
1213
+ for _ in range(100000):
1214
+ start = time.time()
1215
+ new_h, new_C = rnn(X, (h, C))
1216
+ forward += time.time() - start
1217
+
1218
+ start = time.time()
1219
+ (new_h.sum() + new_C.sum()).backward()
1220
+ backward += time.time() - start
1221
+
1222
+ print('Forward: {:.3f} s | Backward {:.3f} s'.format(forward, backward))
1223
+
1224
+
1225
+ Our fastest version earlier was the CUDA-based C++ code::
1187
1226
1188
1227
Forward: 149.802 us | Backward 393.458 us
1189
1228
0 commit comments