4
4
#include < torch/types.h>
5
5
#include < mutex>
6
6
7
- #include " src/torchcodec/_core/DeviceInterface .h"
7
+ #include " src/torchcodec/_core/CudaDevice .h"
8
8
#include " src/torchcodec/_core/FFMPEGCommon.h"
9
9
#include " src/torchcodec/_core/SingleStreamDecoder.h"
10
10
@@ -16,6 +16,10 @@ extern "C" {
16
16
namespace facebook ::torchcodec {
17
17
namespace {
18
18
19
+ bool g_cuda = registerDeviceInterface(
20
+ torch::kCUDA ,
21
+ [](const torch::Device& device) { return new CudaDevice (device); });
22
+
19
23
// We reuse cuda contexts across VideoDeoder instances. This is because
20
24
// creating a cuda context is expensive. The cache mechanism is as follows:
21
25
// 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for
@@ -49,7 +53,7 @@ torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) {
49
53
50
54
void addToCacheIfCacheHasCapacity (
51
55
const torch::Device& device,
52
- AVCodecContext* codecContext ) {
56
+ AVBufferRef* hwContext ) {
53
57
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex (device);
54
58
if (static_cast <int >(deviceIndex) >= MAX_CUDA_GPUS) {
55
59
return ;
@@ -60,8 +64,7 @@ void addToCacheIfCacheHasCapacity(
60
64
MAX_CONTEXTS_PER_GPU_IN_CACHE) {
61
65
return ;
62
66
}
63
- g_cached_hw_device_ctxs[deviceIndex].push_back (codecContext->hw_device_ctx );
64
- codecContext->hw_device_ctx = nullptr ;
67
+ g_cached_hw_device_ctxs[deviceIndex].push_back (av_buffer_ref (hwContext));
65
68
}
66
69
67
70
AVBufferRef* getFromCache (const torch::Device& device) {
@@ -158,39 +161,35 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
158
161
device, nonNegativeDeviceIndex, type);
159
162
#endif
160
163
}
164
+ } // namespace
161
165
162
- void throwErrorIfNonCudaDevice (const torch::Device& device) {
163
- TORCH_CHECK (
164
- device.type () != torch::kCPU ,
165
- " Device functions should only be called if the device is not CPU." )
166
- if (device.type () != torch::kCUDA ) {
167
- throw std::runtime_error (" Unsupported device: " + device.str ());
166
+ CudaDevice::CudaDevice (const torch::Device& device) : DeviceInterface(device) {
167
+ if (device_.type () != torch::kCUDA ) {
168
+ throw std::runtime_error (" Unsupported device: " + device_.str ());
168
169
}
169
170
}
170
- } // namespace
171
171
172
- void releaseContextOnCuda (
173
- const torch::Device& device,
174
- AVCodecContext* codecContext) {
175
- throwErrorIfNonCudaDevice (device );
176
- addToCacheIfCacheHasCapacity (device, codecContext);
172
+ CudaDevice::~CudaDevice () {
173
+ if (ctx_) {
174
+ addToCacheIfCacheHasCapacity (device_, ctx_);
175
+ av_buffer_unref (&ctx_ );
176
+ }
177
177
}
178
178
179
- void initializeContextOnCuda (
180
- const torch::Device& device,
181
- AVCodecContext* codecContext) {
182
- throwErrorIfNonCudaDevice (device);
179
+ void CudaDevice::initializeContext (AVCodecContext* codecContext) {
180
+ TORCH_CHECK (!ctx_, " FFmpeg HW device context already initialized" );
181
+
183
182
// It is important for pytorch itself to create the cuda context. If ffmpeg
184
183
// creates the context it may not be compatible with pytorch.
185
184
// This is a dummy tensor to initialize the cuda context.
186
185
torch::Tensor dummyTensorForCudaInitialization = torch::empty (
187
- {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device));
188
- codecContext->hw_device_ctx = getCudaContext (device);
186
+ {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
187
+ ctx_ = getCudaContext (device_);
188
+ codecContext->hw_device_ctx = av_buffer_ref (ctx_);
189
189
return ;
190
190
}
191
191
192
- void convertAVFrameToFrameOutputOnCuda (
193
- const torch::Device& device,
192
+ void CudaDevice::convertAVFrameToFrameOutput (
194
193
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
195
194
UniqueAVFrame& avFrame,
196
195
SingleStreamDecoder::FrameOutput& frameOutput,
@@ -217,11 +216,11 @@ void convertAVFrameToFrameOutputOnCuda(
217
216
" x3, got " ,
218
217
shape);
219
218
} else {
220
- dst = allocateEmptyHWCTensor (height, width, videoStreamOptions. device );
219
+ dst = allocateEmptyHWCTensor (height, width, device_ );
221
220
}
222
221
223
222
// Use the user-requested GPU for running the NPP kernel.
224
- c10::cuda::CUDAGuard deviceGuard (device );
223
+ c10::cuda::CUDAGuard deviceGuard (device_ );
225
224
226
225
NppiSize oSizeROI = {width, height};
227
226
Npp8u* input[2 ] = {avFrame->data [0 ], avFrame->data [1 ]};
@@ -249,7 +248,7 @@ void convertAVFrameToFrameOutputOnCuda(
249
248
// output.
250
249
at::cuda::CUDAEvent nppDoneEvent;
251
250
at::cuda::CUDAStream nppStreamWrapper =
252
- c10::cuda::getStreamFromExternal (nppGetStream (), device .index ());
251
+ c10::cuda::getStreamFromExternal (nppGetStream (), device_ .index ());
253
252
nppDoneEvent.record (nppStreamWrapper);
254
253
nppDoneEvent.block (at::cuda::getCurrentCUDAStream ());
255
254
@@ -264,11 +263,7 @@ void convertAVFrameToFrameOutputOnCuda(
264
263
// we have to do this because of an FFmpeg bug where hardware decoding is not
265
264
// appropriately set, so we just go off and find the matching codec for the CUDA
266
265
// device
267
- std::optional<const AVCodec*> findCudaCodec (
268
- const torch::Device& device,
269
- const AVCodecID& codecId) {
270
- throwErrorIfNonCudaDevice (device);
271
-
266
+ std::optional<const AVCodec*> CudaDevice::findCodec (const AVCodecID& codecId) {
272
267
void * i = nullptr ;
273
268
const AVCodec* codec = nullptr ;
274
269
while ((codec = av_codec_iterate (&i)) != nullptr ) {
0 commit comments