Skip to content

Commit d922430

Browse files
committed
Create DeviceInterface in addStream
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parent a80acf2 commit d922430

File tree

7 files changed

+61
-34
lines changed

7 files changed

+61
-34
lines changed

src/torchcodec/_core/CudaDevice.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ extern "C" {
1616
namespace facebook::torchcodec {
1717
namespace {
1818

19-
bool g_cuda = registerDeviceInterface("cuda", [](const std::string& device) {
20-
return new CudaDevice(device);
21-
});
19+
bool g_cuda = registerDeviceInterface(
20+
torch::kCUDA,
21+
[](const torch::Device& device) { return new CudaDevice(device); });
2222

2323
// We reuse cuda contexts across VideoDeoder instances. This is because
2424
// creating a cuda context is expensive. The cache mechanism is as follows:
@@ -164,7 +164,7 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
164164
}
165165
} // namespace
166166

167-
CudaDevice::CudaDevice(const std::string& device) : DeviceInterface(device) {
167+
CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) {
168168
if (device_.type() != torch::kCUDA) {
169169
throw std::runtime_error("Unsupported device: " + device_.str());
170170
}

src/torchcodec/_core/CudaDevice.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace facebook::torchcodec {
1212

1313
class CudaDevice : public DeviceInterface {
1414
public:
15-
CudaDevice(const std::string& device);
15+
CudaDevice(const torch::Device& device);
1616

1717
virtual ~CudaDevice(){};
1818

src/torchcodec/_core/DeviceInterface.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace facebook::torchcodec {
1212

1313
namespace {
1414
std::mutex g_interface_mutex;
15-
std::map<std::string, CreateDeviceInterfaceFn> g_interface_map;
15+
std::map<torch::DeviceType, CreateDeviceInterfaceFn> g_interface_map;
1616

1717
std::string getDeviceType(const std::string& device) {
1818
size_t pos = device.find(':');
@@ -25,7 +25,7 @@ std::string getDeviceType(const std::string& device) {
2525
} // namespace
2626

2727
bool registerDeviceInterface(
28-
const std::string deviceType,
28+
torch::DeviceType deviceType,
2929
CreateDeviceInterfaceFn createInterface) {
3030
std::scoped_lock lock(g_interface_mutex);
3131
TORCH_CHECK(
@@ -36,15 +36,39 @@ bool registerDeviceInterface(
3636
return true;
3737
}
3838

39-
std::unique_ptr<DeviceInterface> createDeviceInterface(
40-
const std::string device) {
39+
torch::Device createTorchDevice(const std::string device) {
4140
// TODO: remove once DeviceInterface for CPU is implemented
4241
if (device == "cpu") {
43-
return nullptr;
42+
return torch::kCPU;
4443
}
4544

4645
std::scoped_lock lock(g_interface_mutex);
4746
std::string deviceType = getDeviceType(device);
47+
TORCH_CHECK(
48+
std::find_if(
49+
g_interface_map.begin(),
50+
g_interface_map.end(),
51+
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>&
52+
arg) {
53+
return device.rfind(
54+
torch::DeviceTypeName(arg.first, /*lower_case*/ true),
55+
0) == 0;
56+
}) != g_interface_map.end(),
57+
"Unsupported device: ",
58+
device);
59+
60+
return torch::Device(device);
61+
}
62+
63+
std::unique_ptr<DeviceInterface> createDeviceInterface(
64+
const torch::Device& device) {
65+
auto deviceType = device.type();
66+
// TODO: remove once DeviceInterface for CPU is implemented
67+
if (deviceType == torch::kCPU) {
68+
return nullptr;
69+
}
70+
71+
std::scoped_lock lock(g_interface_mutex);
4872
TORCH_CHECK(
4973
g_interface_map.find(deviceType) != g_interface_map.end(),
5074
"Unsupported device: ",

src/torchcodec/_core/DeviceInterface.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace facebook::torchcodec {
2626

2727
class DeviceInterface {
2828
public:
29-
DeviceInterface(const std::string& device) : device_(device) {}
29+
DeviceInterface(const torch::Device& device) : device_(device) {}
3030

3131
virtual ~DeviceInterface(){};
3232

@@ -53,13 +53,15 @@ class DeviceInterface {
5353
};
5454

5555
using CreateDeviceInterfaceFn =
56-
std::function<DeviceInterface*(const std::string& device)>;
56+
std::function<DeviceInterface*(const torch::Device& device)>;
5757

5858
bool registerDeviceInterface(
59-
const std::string deviceType,
59+
torch::DeviceType deviceType,
6060
const CreateDeviceInterfaceFn createInterface);
6161

62+
torch::Device createTorchDevice(const std::string device);
63+
6264
std::unique_ptr<DeviceInterface> createDeviceInterface(
63-
const std::string device);
65+
const torch::Device& device);
6466

6567
} // namespace facebook::torchcodec

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ SingleStreamDecoder::SingleStreamDecoder(
9494

9595
SingleStreamDecoder::~SingleStreamDecoder() {
9696
for (auto& [streamIndex, streamInfo] : streamInfos_) {
97-
auto& device = streamInfo.videoStreamOptions.device;
98-
if (device) {
99-
device->releaseContext(streamInfo.codecContext.get());
97+
auto& deviceInterface = streamInfo.deviceInterface;
98+
if (deviceInterface) {
99+
deviceInterface->releaseContext(streamInfo.codecContext.get());
100100
}
101101
}
102102
}
@@ -386,7 +386,7 @@ torch::Tensor SingleStreamDecoder::getKeyFrameIndices() {
386386
void SingleStreamDecoder::addStream(
387387
int streamIndex,
388388
AVMediaType mediaType,
389-
DeviceInterface* device,
389+
const torch::Device& device,
390390
std::optional<int> ffmpegThreadCount) {
391391
TORCH_CHECK(
392392
activeStreamIndex_ == NO_ACTIVE_STREAM,
@@ -414,6 +414,7 @@ void SingleStreamDecoder::addStream(
414414
streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base;
415415
streamInfo.stream = formatContext_->streams[activeStreamIndex_];
416416
streamInfo.avMediaType = mediaType;
417+
streamInfo.deviceInterface = createDeviceInterface(device);
417418

418419
// This should never happen, checking just to be safe.
419420
TORCH_CHECK(
@@ -425,9 +426,10 @@ void SingleStreamDecoder::addStream(
425426
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
426427
// addStream() which is supposed to be generic
427428
if (mediaType == AVMEDIA_TYPE_VIDEO) {
428-
if (device) {
429+
if (streamInfo.deviceInterface) {
429430
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
430-
device->findCodec(streamInfo.stream->codecpar->codec_id)
431+
streamInfo.deviceInterface
432+
->findCodec(streamInfo.stream->codecpar->codec_id)
431433
.value_or(avCodec));
432434
}
433435
}
@@ -445,8 +447,8 @@ void SingleStreamDecoder::addStream(
445447

446448
// TODO_CODE_QUALITY same as above.
447449
if (mediaType == AVMEDIA_TYPE_VIDEO) {
448-
if (device) {
449-
device->initializeContext(codecContext);
450+
if (streamInfo.deviceInterface) {
451+
streamInfo.deviceInterface->initializeContext(codecContext);
450452
}
451453
}
452454

@@ -476,7 +478,7 @@ void SingleStreamDecoder::addVideoStream(
476478
addStream(
477479
streamIndex,
478480
AVMEDIA_TYPE_VIDEO,
479-
videoStreamOptions.device.get(),
481+
videoStreamOptions.device,
480482
videoStreamOptions.ffmpegThreadCount);
481483

482484
auto& streamMetadata =
@@ -1217,11 +1219,11 @@ SingleStreamDecoder::convertAVFrameToFrameOutput(
12171219
formatContext_->streams[activeStreamIndex_]->time_base);
12181220
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
12191221
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
1220-
} else if (!streamInfo.videoStreamOptions.device) {
1222+
} else if (!streamInfo.deviceInterface) {
12211223
convertAVFrameToFrameOutputOnCPU(
12221224
avFrame, frameOutput, preAllocatedOutputTensor);
1223-
} else if (streamInfo.videoStreamOptions.device) {
1224-
streamInfo.videoStreamOptions.device->convertAVFrameToFrameOutput(
1225+
} else if (streamInfo.deviceInterface) {
1226+
streamInfo.deviceInterface->convertAVFrameToFrameOutput(
12251227
streamInfo.videoStreamOptions,
12261228
avFrame,
12271229
frameOutput,
@@ -1564,10 +1566,8 @@ SingleStreamDecoder::FrameBatchOutput::FrameBatchOutput(
15641566
videoStreamOptions, streamMetadata);
15651567
int height = frameDims.height;
15661568
int width = frameDims.width;
1567-
torch::Device device = (videoStreamOptions.device)
1568-
? videoStreamOptions.device->device()
1569-
: torch::kCPU;
1570-
data = allocateEmptyHWCTensor(height, width, device, numFrames);
1569+
data = allocateEmptyHWCTensor(
1570+
height, width, videoStreamOptions.device, numFrames);
15711571
}
15721572

15731573
torch::Tensor allocateEmptyHWCTensor(

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class SingleStreamDecoder {
139139
std::optional<int> height;
140140
std::optional<ColorConversionLibrary> colorConversionLibrary;
141141
// By default we use CPU for decoding for both C++ and python users.
142-
std::shared_ptr<DeviceInterface> device;
142+
torch::Device device = torch::kCPU;
143143
};
144144

145145
struct AudioStreamOptions {
@@ -358,6 +358,8 @@ class SingleStreamDecoder {
358358
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
359359
// be created before decoding a new frame.
360360
DecodedFrameContext prevFrameContext;
361+
362+
std::unique_ptr<DeviceInterface> deviceInterface;
361363
};
362364

363365
// --------------------------------------------------------------------------
@@ -460,7 +462,7 @@ class SingleStreamDecoder {
460462
void addStream(
461463
int streamIndex,
462464
AVMediaType mediaType,
463-
DeviceInterface* device = nullptr,
465+
const torch::Device& device = torch::kCPU,
464466
std::optional<int> ffmpegThreadCount = std::nullopt);
465467

466468
// Returns the "best" stream index for a given media type. The "best" is

src/torchcodec/_core/custom_ops.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,7 @@ void _add_video_stream(
243243
}
244244
}
245245
if (device.has_value()) {
246-
videoStreamOptions.device =
247-
createDeviceInterface(std::string(device.value()));
246+
videoStreamOptions.device = createTorchDevice(std::string(device.value()));
248247
}
249248

250249
auto videoDecoder = unwrapTensorToGetDecoder(decoder);

0 commit comments

Comments
 (0)