Skip to content

Commit 672c6aa

Browse files
jiayisunxEikanWang
authored andcommitted
modify nms
1 parent 403627d commit 672c6aa

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

torch_ipex/csrc/cpu/nms.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> batch_score_nms_kernel(const at::
100100
auto ndets = dets.size(0);
101101
auto nscore = scores.size(1);
102102

103-
std::vector<at::Tensor> scores_split = scores.split(1, 1);
103+
// Make the dimension of the score is always physically dense.
104+
std::vector<at::Tensor> scores_split;
105+
if (scores.stride(0) == 1) {
106+
scores_split = scores.split(1, 1);
107+
} else {
108+
scores_split = scores.t().contiguous().split(1, 0);
109+
}
104110

105111
std::vector<at::Tensor> bboxes_out(nscore);
106112
std::vector<at::Tensor> scores_out(nscore);
@@ -115,7 +121,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> batch_score_nms_kernel(const at::
115121
#endif
116122
// skip background (i = 0)
117123
for (int64_t i = 1; i < nscore; i++) {
118-
at::Tensor score = scores_split[i].squeeze(1);
124+
at::Tensor score = scores_split[i].squeeze();
119125

120126
at::Tensor mask_index = at::nonzero(score > 0.05).squeeze(1);
121127
at::Tensor bboxes = at::index_select(dets, /*dim*/0, mask_index);

0 commit comments

Comments
 (0)