Skip to content

Commit 19097f0

Browse files
committed
Create DeviceInterface in addStream
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parent f0e6421 commit 19097f0

File tree

7 files changed

+61
-34
lines changed

7 files changed

+61
-34
lines changed

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ extern "C" {
1616
namespace facebook::torchcodec {
1717
namespace {
1818

19-
bool g_cuda = registerDeviceInterface("cuda", [](const std::string& device) {
20-
return new CudaDevice(device);
21-
});
19+
bool g_cuda = registerDeviceInterface(
20+
torch::kCUDA,
21+
[](const torch::Device& device) { return new CudaDevice(device); });
2222

2323
// We reuse cuda contexts across VideoDeoder instances. This is because
2424
// creating a cuda context is expensive. The cache mechanism is as follows:
@@ -162,7 +162,7 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
162162
}
163163
} // namespace
164164

165-
CudaDevice::CudaDevice(const std::string& device) : DeviceInterface(device) {
165+
CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) {
166166
if (device_.type() != torch::kCUDA) {
167167
throw std::runtime_error("Unsupported device: " + device_.str());
168168
}

src/torchcodec/decoders/_core/CudaDevice.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace facebook::torchcodec {
1212

1313
class CudaDevice : public DeviceInterface {
1414
public:
15-
CudaDevice(const std::string& device);
15+
CudaDevice(const torch::Device& device);
1616

1717
virtual ~CudaDevice(){};
1818

src/torchcodec/decoders/_core/DeviceInterface.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace facebook::torchcodec {
1212

1313
namespace {
1414
std::mutex g_interface_mutex;
15-
std::map<std::string, CreateDeviceInterfaceFn> g_interface_map;
15+
std::map<torch::DeviceType, CreateDeviceInterfaceFn> g_interface_map;
1616

1717
std::string getDeviceType(const std::string& device) {
1818
size_t pos = device.find(':');
@@ -25,7 +25,7 @@ std::string getDeviceType(const std::string& device) {
2525
} // namespace
2626

2727
bool registerDeviceInterface(
28-
const std::string deviceType,
28+
torch::DeviceType deviceType,
2929
CreateDeviceInterfaceFn createInterface) {
3030
std::scoped_lock lock(g_interface_mutex);
3131
TORCH_CHECK(
@@ -36,15 +36,39 @@ bool registerDeviceInterface(
3636
return true;
3737
}
3838

39-
std::unique_ptr<DeviceInterface> createDeviceInterface(
40-
const std::string device) {
39+
torch::Device createTorchDevice(const std::string device) {
4140
// TODO: remove once DeviceInterface for CPU is implemented
4241
if (device == "cpu") {
43-
return nullptr;
42+
return torch::kCPU;
4443
}
4544

4645
std::scoped_lock lock(g_interface_mutex);
4746
std::string deviceType = getDeviceType(device);
47+
TORCH_CHECK(
48+
std::find_if(
49+
g_interface_map.begin(),
50+
g_interface_map.end(),
51+
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>&
52+
arg) {
53+
return device.rfind(
54+
torch::DeviceTypeName(arg.first, /*lower_case*/ true),
55+
0) == 0;
56+
}) != g_interface_map.end(),
57+
"Unsupported device: ",
58+
device);
59+
60+
return torch::Device(device);
61+
}
62+
63+
std::unique_ptr<DeviceInterface> createDeviceInterface(
64+
const torch::Device& device) {
65+
auto deviceType = device.type();
66+
// TODO: remove once DeviceInterface for CPU is implemented
67+
if (deviceType == torch::kCPU) {
68+
return nullptr;
69+
}
70+
71+
std::scoped_lock lock(g_interface_mutex);
4872
TORCH_CHECK(
4973
g_interface_map.find(deviceType) != g_interface_map.end(),
5074
"Unsupported device: ",

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace facebook::torchcodec {
2626

2727
class DeviceInterface {
2828
public:
29-
DeviceInterface(const std::string& device) : device_(device) {}
29+
DeviceInterface(const torch::Device& device) : device_(device) {}
3030

3131
virtual ~DeviceInterface(){};
3232

@@ -53,13 +53,15 @@ class DeviceInterface {
5353
};
5454

5555
using CreateDeviceInterfaceFn =
56-
std::function<DeviceInterface*(const std::string& device)>;
56+
std::function<DeviceInterface*(const torch::Device& device)>;
5757

5858
bool registerDeviceInterface(
59-
const std::string deviceType,
59+
torch::DeviceType deviceType,
6060
const CreateDeviceInterfaceFn createInterface);
6161

62+
torch::Device createTorchDevice(const std::string device);
63+
6264
std::unique_ptr<DeviceInterface> createDeviceInterface(
63-
const std::string device);
65+
const torch::Device& device);
6466

6567
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ VideoDecoder::VideoDecoder(
9797

9898
VideoDecoder::~VideoDecoder() {
9999
for (auto& [streamIndex, streamInfo] : streamInfos_) {
100-
auto& device = streamInfo.videoStreamOptions.device;
101-
if (device) {
102-
device->releaseContext(streamInfo.codecContext.get());
100+
auto& deviceInterface = streamInfo.deviceInterface;
101+
if (deviceInterface) {
102+
deviceInterface->releaseContext(streamInfo.codecContext.get());
103103
}
104104
}
105105
}
@@ -388,7 +388,7 @@ torch::Tensor VideoDecoder::getKeyFrameIndices() {
388388
void VideoDecoder::addStream(
389389
int streamIndex,
390390
AVMediaType mediaType,
391-
DeviceInterface* device,
391+
const torch::Device& device,
392392
std::optional<int> ffmpegThreadCount) {
393393
TORCH_CHECK(
394394
activeStreamIndex_ == NO_ACTIVE_STREAM,
@@ -416,6 +416,7 @@ void VideoDecoder::addStream(
416416
streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base;
417417
streamInfo.stream = formatContext_->streams[activeStreamIndex_];
418418
streamInfo.avMediaType = mediaType;
419+
streamInfo.deviceInterface = createDeviceInterface(device);
419420

420421
// This should never happen, checking just to be safe.
421422
TORCH_CHECK(
@@ -427,9 +428,10 @@ void VideoDecoder::addStream(
427428
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
428429
// addStream() which is supposed to be generic
429430
if (mediaType == AVMEDIA_TYPE_VIDEO) {
430-
if (device) {
431+
if (streamInfo.deviceInterface) {
431432
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
432-
device->findCodec(streamInfo.stream->codecpar->codec_id)
433+
streamInfo.deviceInterface
434+
->findCodec(streamInfo.stream->codecpar->codec_id)
433435
.value_or(avCodec));
434436
}
435437
}
@@ -447,8 +449,8 @@ void VideoDecoder::addStream(
447449

448450
// TODO_CODE_QUALITY same as above.
449451
if (mediaType == AVMEDIA_TYPE_VIDEO) {
450-
if (device) {
451-
device->initializeContext(codecContext);
452+
if (streamInfo.deviceInterface) {
453+
streamInfo.deviceInterface->initializeContext(codecContext);
452454
}
453455
}
454456

@@ -478,7 +480,7 @@ void VideoDecoder::addVideoStream(
478480
addStream(
479481
streamIndex,
480482
AVMEDIA_TYPE_VIDEO,
481-
videoStreamOptions.device.get(),
483+
videoStreamOptions.device,
482484
videoStreamOptions.ffmpegThreadCount);
483485

484486
auto& streamMetadata =
@@ -1212,11 +1214,11 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
12121214
formatContext_->streams[activeStreamIndex_]->time_base);
12131215
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
12141216
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
1215-
} else if (!streamInfo.videoStreamOptions.device) {
1217+
} else if (!streamInfo.deviceInterface) {
12161218
convertAVFrameToFrameOutputOnCPU(
12171219
avFrame, frameOutput, preAllocatedOutputTensor);
1218-
} else if (streamInfo.videoStreamOptions.device) {
1219-
streamInfo.videoStreamOptions.device->convertAVFrameToFrameOutput(
1220+
} else if (streamInfo.deviceInterface) {
1221+
streamInfo.deviceInterface->convertAVFrameToFrameOutput(
12201222
streamInfo.videoStreamOptions,
12211223
avFrame,
12221224
frameOutput,
@@ -1559,10 +1561,8 @@ VideoDecoder::FrameBatchOutput::FrameBatchOutput(
15591561
videoStreamOptions, streamMetadata);
15601562
int height = frameDims.height;
15611563
int width = frameDims.width;
1562-
torch::Device device = (videoStreamOptions.device)
1563-
? videoStreamOptions.device->device()
1564-
: torch::kCPU;
1565-
data = allocateEmptyHWCTensor(height, width, device, numFrames);
1564+
data = allocateEmptyHWCTensor(
1565+
height, width, videoStreamOptions.device, numFrames);
15661566
}
15671567

15681568
torch::Tensor allocateEmptyHWCTensor(

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class VideoDecoder {
140140
std::optional<int> height;
141141
std::optional<ColorConversionLibrary> colorConversionLibrary;
142142
// By default we use CPU for decoding for both C++ and python users.
143-
std::shared_ptr<DeviceInterface> device;
143+
torch::Device device = torch::kCPU;
144144
};
145145

146146
struct AudioStreamOptions {
@@ -359,6 +359,8 @@ class VideoDecoder {
359359
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
360360
// be created before decoding a new frame.
361361
DecodedFrameContext prevFrameContext;
362+
363+
std::unique_ptr<DeviceInterface> deviceInterface;
362364
};
363365

364366
// --------------------------------------------------------------------------
@@ -461,7 +463,7 @@ class VideoDecoder {
461463
void addStream(
462464
int streamIndex,
463465
AVMediaType mediaType,
464-
DeviceInterface* device = nullptr,
466+
const torch::Device& device = torch::kCPU,
465467
std::optional<int> ffmpegThreadCount = std::nullopt);
466468

467469
// Returns the "best" stream index for a given media type. The "best" is

src/torchcodec/decoders/_core/custom_ops.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,7 @@ void _add_video_stream(
238238
}
239239
}
240240
if (device.has_value()) {
241-
videoStreamOptions.device =
242-
createDeviceInterface(std::string(device.value()));
241+
videoStreamOptions.device = createTorchDevice(std::string(device.value()));
243242
}
244243

245244
auto videoDecoder = unwrapTensorToGetDecoder(decoder);

0 commit comments

Comments
 (0)