Skip to content

Commit 3fec41e

Browse files
authored
Add customer ops ROIAlign and nms into Autocast (#9)
1 parent 1ba2f71 commit 3fec41e

File tree

6 files changed

+141
-42
lines changed

6 files changed

+141
-42
lines changed

intel_pytorch_extension_py/ops/nms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import _torch_ipex as core
1+
import torch
22

3-
nms = core.nms
4-
batch_score_nms = core.batch_score_nms
3+
nms = torch.ops.torch_ipex.nms
4+
batch_score_nms = torch.ops.torch_ipex.batch_score_nms

intel_pytorch_extension_py/ops/roi_align.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from torch.autograd.function import once_differentiable
66
from torch.nn.modules.utils import _pair
77

8-
import _torch_ipex as core
9-
108

119
class _ROIAlign(Function):
1210
@staticmethod
@@ -16,7 +14,7 @@ def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
1614
ctx.spatial_scale = spatial_scale
1715
ctx.sampling_ratio = sampling_ratio
1816
ctx.input_shape = input.size()
19-
output = core.roi_align_forward(
17+
output = torch.ops.torch_ipex.ROIAlign_forward(
2018
input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio
2119
)
2220
return output
@@ -29,7 +27,7 @@ def backward(ctx, grad_output):
2927
spatial_scale = ctx.spatial_scale
3028
sampling_ratio = ctx.sampling_ratio
3129
bs, ch, h, w = ctx.input_shape
32-
grad_input = core.roi_align_backward(
30+
grad_input = torch.ops.torch_ipex.ROIAlign_backward(
3331
grad_output,
3432
rois,
3533
spatial_scale,

torch_ipex/csrc/cpu/ExtendOPs.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,29 @@ class AtenIpexTypeExt {
1111
public:
1212
static at::Tensor ROIAlign_forward(const at::Tensor& input,
1313
const at::Tensor& rois,
14-
const float spatial_scale,
15-
const int pooled_height,
16-
const int pooled_width,
17-
const int sampling_ratio);
14+
const double spatial_scale,
15+
const int64_t pooled_height,
16+
const int64_t pooled_width,
17+
const int64_t sampling_ratio);
1818

1919
static at::Tensor ROIAlign_backward(const at::Tensor& grad,
2020
const at::Tensor& rois,
21-
const float spatial_scale,
22-
const int pooled_height,
23-
const int pooled_width,
24-
const int batch_size,
25-
const int channels,
26-
const int height,
27-
const int width,
28-
const int sampling_ratio);
21+
const double spatial_scale,
22+
const int64_t pooled_height,
23+
const int64_t pooled_width,
24+
const int64_t batch_size,
25+
const int64_t channels,
26+
const int64_t height,
27+
const int64_t width,
28+
const int64_t sampling_ratio);
2929

3030
static at::Tensor nms(const at::Tensor& dets,
3131
const at::Tensor& scores,
32-
const float threshold);
32+
const double threshold);
3333

3434
static std::tuple<at::Tensor, at::Tensor, at::Tensor> batch_score_nms(const at::Tensor& dets,
3535
const at::Tensor& scores,
36-
const float threshold);
36+
const double threshold);
3737

3838
static at::Tensor interaction_forward(const std::vector<at::Tensor> & input);
3939
static std::vector<at::Tensor> interaction_backward(const at::Tensor & grad_out,

torch_ipex/csrc/cpu/ROIAlign.cpp

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
22
#include "ExtendOPs.h"
3+
#include "torch_ipex/csrc/autocast_mode.h"
4+
#include "torch_ipex/csrc/autocast_verbose.h"
35
namespace torch_ipex {
46

57
// implementation taken from Caffe2
@@ -492,37 +494,93 @@ at::Tensor ROIAlign_backward_cpu(const at::Tensor& grad,
492494

493495
at::Tensor AtenIpexTypeExt::ROIAlign_forward(const at::Tensor& input,
494496
const at::Tensor& rois,
495-
const float spatial_scale,
496-
const int pooled_height,
497-
const int pooled_width,
498-
const int sampling_ratio) {
497+
const double spatial_scale,
498+
const int64_t pooled_height,
499+
const int64_t pooled_width,
500+
const int64_t sampling_ratio) {
499501
#if defined(IPEX_DISP_OP)
500502
printf("AtenIpexTypeExt::ROIAlign_forward\n");
501503
#endif
502504
#if defined(IPEX_PROFILE_OP)
503505
RECORD_FUNCTION("AtenIpexTypeExt::ROIAlign_forward", std::vector<c10::IValue>({}));
504506
#endif
505-
return ROIAlign_forward_cpu(input.contiguous().to(torch::kFloat), rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
507+
// input needs to be converted to contiguous temporarily, because ROIAlign dose not support channel-last format yet.
508+
return ROIAlign_forward_cpu(input.contiguous(), rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
506509
}
507510

508511
at::Tensor AtenIpexTypeExt::ROIAlign_backward(const at::Tensor& grad,
509512
const at::Tensor& rois,
510-
const float spatial_scale,
511-
const int pooled_height,
512-
const int pooled_width,
513-
const int batch_size,
514-
const int channels,
515-
const int height,
516-
const int width,
517-
const int sampling_ratio) {
513+
const double spatial_scale,
514+
const int64_t pooled_height,
515+
const int64_t pooled_width,
516+
const int64_t batch_size,
517+
const int64_t channels,
518+
const int64_t height,
519+
const int64_t width,
520+
const int64_t sampling_ratio) {
518521
#if defined(IPEX_DISP_OP)
519522
printf("AtenIpexTypeExt::ROIAlign_backward\n");
520523
#endif
521524
#if defined(IPEX_PROFILE_OP)
522525
RECORD_FUNCTION("AtenIpexTypeExt::ROIAlign_backward", std::vector<c10::IValue>({}));
523526
#endif
524-
return ROIAlign_backward_cpu(grad.contiguous().to(torch::kFloat), rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
527+
// grad needs to be converted to contiguous temporarily, because ROIAlign dose not support channel-last format yet.
528+
return ROIAlign_backward_cpu(grad.contiguous(), rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
525529
}
526530

531+
} // namespace torch_ipex
532+
533+
namespace {
534+
static auto dispatch =
535+
torch::RegisterOperators()
536+
.op("torch_ipex::ROIAlign_forward", &torch_ipex::AtenIpexTypeExt::ROIAlign_forward)
537+
.op("torch_ipex::ROIAlign_backward", &torch_ipex::AtenIpexTypeExt::ROIAlign_backward);
538+
}
539+
540+
namespace torch_ipex {
541+
namespace autocast {
542+
543+
at::Tensor ROIAlign_forward(const at::Tensor& input,
544+
const at::Tensor& rois,
545+
const double spatial_scale,
546+
const int64_t pooled_height,
547+
const int64_t pooled_width,
548+
const int64_t sampling_ratio) {
549+
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
550+
static auto op = torch::Dispatcher::singleton()
551+
.findSchemaOrThrow("torch_ipex::ROIAlign_forward", "")
552+
.typed<decltype(ROIAlign_forward)>();
553+
#if defined(ENABLE_AUTOCAST_VERBOSE)
554+
verbose::OpNameGuard op_name("ROIAlign_forward");
555+
#endif
556+
return op.call(cpu_cached_cast(at::kFloat, input), rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
527557
}
528558

559+
at::Tensor ROIAlign_backward(const at::Tensor& grad,
560+
const at::Tensor& rois,
561+
const double spatial_scale,
562+
const int64_t pooled_height,
563+
const int64_t pooled_width,
564+
const int64_t batch_size,
565+
const int64_t channels,
566+
const int64_t height,
567+
const int64_t width,
568+
const int64_t sampling_ratio) {
569+
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
570+
static auto op = torch::Dispatcher::singleton()
571+
.findSchemaOrThrow("torch_ipex::ROIAlign_backward", "")
572+
.typed<decltype(ROIAlign_backward)>();
573+
#if defined(ENABLE_AUTOCAST_VERBOSE)
574+
verbose::OpNameGuard op_name("ROIAlign_backward");
575+
#endif
576+
return op.call(cpu_cached_cast(at::kFloat, grad), rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
577+
}
578+
579+
TORCH_LIBRARY_IMPL(torch_ipex, AutocastCPU, m){
580+
m.impl("ROIAlign_forward", torch_ipex::autocast::ROIAlign_forward);
581+
m.impl("ROIAlign_backward", torch_ipex::autocast::ROIAlign_backward);
582+
}
583+
584+
} // namespace autocast
585+
} // namespace torch_ipex
586+

torch_ipex/csrc/cpu/nms.cpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <algorithm>
55
#include <c10/util/Exception.h>
66
#include <torch/csrc/autograd/function.h>
7+
#include "torch_ipex/csrc/autocast_mode.h"
8+
#include "torch_ipex/csrc/autocast_verbose.h"
79
namespace torch_ipex {
810

911
/*
@@ -176,7 +178,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> batch_score_nms_cpu(const at::Ten
176178

177179
at::Tensor AtenIpexTypeExt::nms(const at::Tensor& dets,
178180
const at::Tensor& scores,
179-
const float threshold) {
181+
const double threshold) {
180182
#if defined(IPEX_DISP_OP)
181183
printf("IpexExternal::nms\n");
182184
#endif
@@ -195,7 +197,7 @@ at::Tensor AtenIpexTypeExt::nms(const at::Tensor& dets,
195197

196198
std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexTypeExt::batch_score_nms(const at::Tensor& dets,
197199
const at::Tensor& scores,
198-
const float threshold) {
200+
const double threshold) {
199201
#if defined(IPEX_DISP_OP)
200202
printf("IpexExternal::batch_score_nms\n");
201203
#endif
@@ -211,4 +213,50 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexTypeExt::batch_score_nms(
211213
//return std::tuple<at::Tensor,at::Tensor,at::Tensor>(bridge::shallowUpgradeToDPCPPTensor(std::get<0>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<1>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<2>(_ipex_result)));
212214
return std::tuple<at::Tensor,at::Tensor,at::Tensor>(std::get<0>(result), std::get<1>(result), std::get<2>(result));
213215
}
216+
} // namespace torch_ipex
217+
218+
219+
namespace {
220+
static auto dispatch =
221+
torch::RegisterOperators()
222+
.op("torch_ipex::nms", &torch_ipex::AtenIpexTypeExt::nms)
223+
.op("torch_ipex::batch_score_nms", &torch_ipex::AtenIpexTypeExt::batch_score_nms);
214224
}
225+
226+
namespace torch_ipex {
227+
namespace autocast {
228+
229+
at::Tensor nms(const at::Tensor& dets,
230+
const at::Tensor& scores,
231+
const double threshold) {
232+
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
233+
static auto op = torch::Dispatcher::singleton()
234+
.findSchemaOrThrow("torch_ipex::nms", "")
235+
.typed<decltype(nms)>();
236+
#if defined(ENABLE_AUTOCAST_VERBOSE)
237+
verbose::OpNameGuard op_name("nms");
238+
#endif
239+
return op.call(dets, cpu_cached_cast(at::kFloat, scores), threshold);
240+
}
241+
242+
std::tuple<at::Tensor, at::Tensor, at::Tensor> batch_score_nms(const at::Tensor& dets,
243+
const at::Tensor& scores,
244+
const double threshold) {
245+
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
246+
static auto op = torch::Dispatcher::singleton()
247+
.findSchemaOrThrow("torch_ipex::batch_score_nms", "")
248+
.typed<decltype(batch_score_nms)>();
249+
#if defined(ENABLE_AUTOCAST_VERBOSE)
250+
verbose::OpNameGuard op_name("batch_score_nms");
251+
#endif
252+
return op.call(dets, cpu_cached_cast(at::kFloat, scores), threshold);
253+
}
254+
255+
TORCH_LIBRARY_IMPL(torch_ipex, AutocastCPU, m){
256+
m.impl("nms", torch_ipex::autocast::nms);
257+
m.impl("batch_score_nms", torch_ipex::autocast::batch_score_nms);
258+
}
259+
260+
} // namespace autocast
261+
} // namespace torch_ipex
262+

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,6 @@ void InitIpexModuleBindings(py::module m) {
176176
});
177177

178178
// extend OPs
179-
m.def("roi_align_forward", &AtenIpexTypeExt::ROIAlign_forward);
180-
m.def("roi_align_backward", &AtenIpexTypeExt::ROIAlign_backward);
181-
182-
m.def("nms", &AtenIpexTypeExt::nms);
183-
m.def("batch_score_nms", &AtenIpexTypeExt::batch_score_nms);
184179
m.def("embedding_bag_fast_path_sum", &AtenIpexTypeExt::embedding_bag_fast_path_sum);
185180
}
186181
} // namespace

0 commit comments

Comments
 (0)