Skip to content

Commit 8f146ca

Browse files
enable op: parallel_scale_back and Combined_NMS (#10)
1 parent 3fec41e commit 8f146ca

File tree

5 files changed

+379
-58
lines changed

5 files changed

+379
-58
lines changed

intel_pytorch_extension_py/ops/nms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
nms = torch.ops.torch_ipex.nms
44
batch_score_nms = torch.ops.torch_ipex.batch_score_nms
5+
parallel_scale_back_batch = torch.ops.torch_ipex.parallel_scale_back_batch

tests/cpu/data/nms_dboxes_xywh.pt

474 KB
Binary file not shown.

tests/cpu/test_nms.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import unittest, copy
2+
import torch
3+
import torch.nn as nn
4+
import intel_pytorch_extension as ipex
5+
from common_utils import TestCase
6+
import time, sys
7+
from intel_pytorch_extension import batch_score_nms, parallel_scale_back_batch
8+
import torch.nn.functional as F
9+
import os
10+
11+
def get_rand_seed():
12+
return int(time.time() * 1000000000)
13+
14+
# This function is from https://github.com/kuangliu/pytorch-ssd.
15+
def calc_iou_tensor(box1, box2):
16+
""" Calculation of IoU based on two boxes tensor,
17+
Reference to https://github.com/kuangliu/pytorch-ssd
18+
input:
19+
box1 (N, 4)
20+
box2 (M, 4)
21+
output:
22+
IoU (N, M)
23+
"""
24+
N = box1.size(0)
25+
M = box2.size(0)
26+
be1 = box1.unsqueeze(1).expand(-1, M, -1)
27+
be2 = box2.unsqueeze(0).expand(N, -1, -1)
28+
# Left Top & Right Bottom
29+
lt = torch.max(be1[:,:,:2], be2[:,:,:2])
30+
#mask1 = (be1[:,:, 0] < be2[:,:, 0]) ^ (be1[:,:, 1] < be2[:,:, 1])
31+
#mask1 = ~mask1
32+
rb = torch.min(be1[:,:,2:], be2[:,:,2:])
33+
#mask2 = (be1[:,:, 2] < be2[:,:, 2]) ^ (be1[:,:, 3] < be2[:,:, 3])
34+
#mask2 = ~mask2
35+
delta = rb - lt
36+
delta[delta < 0] = 0
37+
intersect = delta[:,:,0]*delta[:,:,1]
38+
#*mask1.float()*mask2.float()
39+
delta1 = be1[:,:,2:] - be1[:,:,:2]
40+
area1 = delta1[:,:,0]*delta1[:,:,1]
41+
delta2 = be2[:,:,2:] - be2[:,:,:2]
42+
area2 = delta2[:,:,0]*delta2[:,:,1]
43+
iou = intersect/(area1 + area2 - intersect)
44+
return iou
45+
46+
class TestScaleBackBatch(TestCase):
47+
def scale_back_batch(self, bboxes_in, scores_in, dboxes_xywh, scale_xy, scale_wh):
48+
"""
49+
Python implementation of Encoder::scale_back_batch, refer to https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py
50+
"""
51+
bboxes_in[:, :, :2] = scale_xy*bboxes_in[:, :, :2]
52+
bboxes_in[:, :, 2:] = scale_wh*bboxes_in[:, :, 2:]
53+
bboxes_in[:, :, :2] = bboxes_in[:, :, :2]*dboxes_xywh[:, :, 2:] + dboxes_xywh[:, :, :2]
54+
bboxes_in[:, :, 2:] = bboxes_in[:, :, 2:].exp()*dboxes_xywh[:, :, 2:]
55+
# Transform format to ltrb
56+
l, t, r, b = bboxes_in[:, :, 0] - 0.5*bboxes_in[:, :, 2],\
57+
bboxes_in[:, :, 1] - 0.5*bboxes_in[:, :, 3],\
58+
bboxes_in[:, :, 0] + 0.5*bboxes_in[:, :, 2],\
59+
bboxes_in[:, :, 1] + 0.5*bboxes_in[:, :, 3]
60+
bboxes_in[:, :, 0] = l
61+
bboxes_in[:, :, 1] = t
62+
bboxes_in[:, :, 2] = r
63+
bboxes_in[:, :, 3] = b
64+
return bboxes_in, F.softmax(scores_in, dim=-1)
65+
66+
def test_scale_back_batch_result(self):
67+
batch_size = 16
68+
number_boxes = 1024
69+
scale_xy = 0.1
70+
scale_wh = 0.2
71+
predicted_loc = torch.randn((batch_size, number_boxes, 4)).contiguous().to(torch.float32)
72+
predicted_score = torch.randn((batch_size, number_boxes, 81)).contiguous().to(torch.float32)
73+
dboxes_xywh = torch.randn((1, number_boxes, 4)).contiguous().to(torch.float64)
74+
bbox_res1, score_res1 = self.scale_back_batch(predicted_loc.clone(), predicted_score.clone(), dboxes_xywh.clone(), scale_xy, scale_wh)
75+
bbox_res2, score_res2 = parallel_scale_back_batch(predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh)
76+
self.assertTrue(torch.allclose(bbox_res1, bbox_res2, rtol=1e-4, atol=1e-4))
77+
self.assertTrue(torch.allclose(score_res1, score_res2, rtol=1e-4, atol=1e-4))
78+
79+
class TestNMS(TestCase):
80+
def decode_single(self, bboxes_in, scores_in, criteria, max_output, max_num=200):
81+
"""
82+
Python implementation of Encoder::decode_single, refer to https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py
83+
"""
84+
# perform non-maximum suppression
85+
# Reference to https://github.com/amdegroot/ssd.pytorch
86+
87+
bboxes_out = []
88+
scores_out = []
89+
labels_out = []
90+
for i, score in enumerate(scores_in.split(1, 1)):
91+
# skip background
92+
# print(score[score>0.90])
93+
if i == 0: continue
94+
score = score.squeeze(1)
95+
mask = score > 0.05
96+
bboxes, score = bboxes_in[mask, :], score[mask]
97+
if score.size(0) == 0: continue
98+
score_sorted, score_idx_sorted = score.sort(dim=0)
99+
# select max_output indices
100+
score_idx_sorted = score_idx_sorted[-max_num:]
101+
candidates = []
102+
while score_idx_sorted.numel() > 0:
103+
idx = score_idx_sorted[-1].item()
104+
bboxes_sorted = bboxes[score_idx_sorted, :]
105+
bboxes_idx = bboxes[idx, :].unsqueeze(dim=0)
106+
iou_sorted = calc_iou_tensor(bboxes_sorted, bboxes_idx).squeeze()
107+
# we only need iou < criteria
108+
score_idx_sorted = score_idx_sorted[iou_sorted < criteria]
109+
candidates.append(idx)
110+
111+
bboxes_out.append(bboxes[candidates, :])
112+
scores_out.append(score[candidates])
113+
labels_out.extend([i]*len(candidates))
114+
bboxes_out, labels_out, scores_out = torch.cat(bboxes_out, dim=0), \
115+
torch.tensor(labels_out, dtype=torch.long), \
116+
torch.cat(scores_out, dim=0)
117+
_, max_ids = scores_out.sort(dim=0)
118+
max_ids = max_ids[-max_output:]
119+
return bboxes_out[max_ids, :], labels_out[max_ids], scores_out[max_ids]
120+
121+
def test_nms_result(self):
122+
batch_size = 1
123+
number_boxes = 15130
124+
scale_xy = 0.1
125+
scale_wh = 0.2
126+
criteria = 0.50
127+
max_output = 200
128+
predicted_loc = torch.randn((batch_size, number_boxes, 4)).contiguous().to(torch.float32)
129+
predicted_score = torch.randn((batch_size, number_boxes, 81)).contiguous().to(torch.float32)
130+
dboxes_xywh = torch.randn((1, number_boxes, 4)).contiguous().to(torch.float64)
131+
dboxes_xywh = torch.load(os.path.dirname(__file__) + "/data/nms_dboxes_xywh.pt")
132+
bboxes, probs = parallel_scale_back_batch(predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh)
133+
bboxes_clone = bboxes.clone()
134+
probs_clone = probs.clone()
135+
136+
output = []
137+
for bbox, prob in zip(bboxes.split(1, 0), probs.split(1, 0)):
138+
bbox = bbox.squeeze(0)
139+
prob = prob.squeeze(0)
140+
output.append(self.decode_single(bbox, prob, criteria, max_output))
141+
output2 = batch_score_nms(bboxes_clone, probs_clone, criteria, max_output)
142+
143+
for i in range(batch_size):
144+
loc, label, prob = [r for r in output[i]]
145+
loc2, label2, prob2 = [r for r in output2[i]]
146+
self.assertTrue(torch.allclose(loc, loc2, rtol=1e-4, atol=1e-4))
147+
self.assertEqual(label, label2)
148+
self.assertTrue(torch.allclose(prob, prob2, rtol=1e-4, atol=1e-4))
149+
150+
if __name__ == '__main__':
151+
test = unittest.main()

torch_ipex/csrc/cpu/ExtendOPs.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,24 @@ class AtenIpexTypeExt {
3131
const at::Tensor& scores,
3232
const double threshold);
3333

34-
static std::tuple<at::Tensor, at::Tensor, at::Tensor> batch_score_nms(const at::Tensor& dets,
34+
/// \brief Perform non-maximum suppression.
35+
///
36+
/// C++ version of Encoder::decode_single.
37+
/// Refer to https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py.
38+
///
39+
/// \param dets: predicted loc in ltrb format, size [BS, number_boxes, 4], for example: [1, 15130, 4].
40+
/// \param scores: predicted score, size [BS, number_boxes, class_number], for example: [1, 15130, 81].
41+
/// \param threshold: IOU threshold(scalar) to suppress bboxs which has the IOU val larger than the threshold.
42+
/// \param max_output: the max number of output bbox.
43+
///
44+
/// \return result is a list of tuble. In each tuble, there are 3 tensors:
45+
/// bboxes_out_: the selected out bboxes coordinate, size [max_output, 4].
46+
/// labels_out_: the label of each selected out bboxes, size [max_output].
47+
/// scores_out_: the score of each selected out bboxes, size [max_output].
48+
static std::vector<std::tuple<at::Tensor, at::Tensor, at::Tensor>> batch_score_nms(const at::Tensor& dets,
3549
const at::Tensor& scores,
36-
const double threshold);
50+
const double threshold,
51+
const int64_t max_output);
3752

3853
static at::Tensor interaction_forward(const std::vector<at::Tensor> & input);
3954
static std::vector<at::Tensor> interaction_backward(const at::Tensor & grad_out,
@@ -52,6 +67,25 @@ class AtenIpexTypeExt {
5267
int64_t mode,
5368
const c10::optional<int64_t> padding_idx);
5469

70+
/// \brief Do scale and transform from xywh to ltrb for predicted loc and do Softmax along the last dim for predicted score.
71+
///
72+
/// C++ version of Encoder::scale_back_batch.
73+
/// Refer to https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py.
74+
///
75+
/// \param bboxes_in: predicted loc in xywh format, size [BS, number_boxes, 4], for example: [1, 15130, 4].
76+
/// \param scores_in: predicted score, size [BS, number_boxes, class_number], for example: [1, 15130, 81].
77+
/// \param dboxes_xywh: scale factor for each bbox from predicted loc to true loc, size [1, number_boxes, 4].
78+
/// \param scale_xy: scale factor(scalar) of xy dimention for bboxes_in.
79+
/// \param scale_wh: scale factor(scalar) of wh dimention for bboxes_in.
80+
///
81+
/// \return tuple<bbox_result, bbox_result>,
82+
/// bbox_result: True loc in lrtb format, size [BS, number_boxes, 4], for example: [1, 15130, 4].
83+
/// scores_result: Normalized score, size [BS, number_boxes, class_number], for example: [1, 15130, 81].
84+
static std::tuple<at::Tensor, at::Tensor> parallel_scale_back_batch(const at::Tensor& bboxes_in,
85+
const at::Tensor& scores_in,
86+
const at::Tensor& dboxes_xywh,
87+
const double scale_xy,
88+
const double scale_wh);
5589
};
5690

5791
} // namespace torch_ipex

0 commit comments

Comments
 (0)