4
4
import intel_pytorch_extension as ipex
5
5
from common_utils import TestCase
6
6
import time , sys
7
- from intel_pytorch_extension import batch_score_nms , parallel_scale_back_batch
7
+ from intel_pytorch_extension import batch_score_nms , parallel_scale_back_batch , nms
8
8
import torch .nn .functional as F
9
9
import os
10
10
@@ -118,17 +118,16 @@ def decode_single(self, bboxes_in, scores_in, criteria, max_output, max_num=200)
118
118
max_ids = max_ids [- max_output :]
119
119
return bboxes_out [max_ids , :], labels_out [max_ids ], scores_out [max_ids ]
120
120
121
- def test_nms_result (self ):
121
+ def test_batch_nms_result (self ):
122
122
batch_size = 1
123
123
number_boxes = 15130
124
124
scale_xy = 0.1
125
125
scale_wh = 0.2
126
126
criteria = 0.50
127
127
max_output = 200
128
- predicted_loc = torch .randn ((batch_size , number_boxes , 4 )).contiguous ().to (torch .float32 )
129
- predicted_score = torch .randn ((batch_size , number_boxes , 81 )).contiguous ().to (torch .float32 )
130
- dboxes_xywh = torch .randn ((1 , number_boxes , 4 )).contiguous ().to (torch .float64 )
131
- dboxes_xywh = torch .load (os .path .dirname (__file__ ) + "/data/nms_dboxes_xywh.pt" )
128
+ predicted_loc = torch .load (os .path .join (os .path .dirname (__file__ ), "data/nms_ploc.pt" )) # sizes: [1, 15130, 4]
129
+ predicted_score = torch .load (os .path .join (os .path .dirname (__file__ ), "data/nms_plabel.pt" )) # sizes: [1, 15130, 81]
130
+ dboxes_xywh = torch .load (os .path .join (os .path .dirname (__file__ ), "data/nms_dboxes_xywh.pt" ))
132
131
bboxes , probs = parallel_scale_back_batch (predicted_loc , predicted_score , dboxes_xywh , scale_xy , scale_wh )
133
132
bboxes_clone = bboxes .clone ()
134
133
probs_clone = probs .clone ()
@@ -147,5 +146,46 @@ def test_nms_result(self):
147
146
self .assertEqual (label , label2 )
148
147
self .assertTrue (torch .allclose (prob , prob2 , rtol = 1e-4 , atol = 1e-4 ))
149
148
149
+ def test_nms_kernel_result (self ):
150
+ batch_size = 1
151
+ class_number = 81
152
+ scale_xy = 0.1
153
+ scale_wh = 0.2
154
+ criteria = 0.50
155
+ max_output = 200
156
+ predicted_loc = torch .load (os .path .join (os .path .dirname (__file__ ), "data/nms_ploc.pt" )) # sizes: [1, 15130, 4]
157
+ predicted_score = torch .load (os .path .join (os .path .dirname (__file__ ), "data/nms_plabel.pt" )) # sizes: [1, 15130, 81]
158
+ dboxes_xywh = torch .load (os .path .join (os .path .dirname (__file__ ), "data/nms_dboxes_xywh.pt" ))
159
+ bboxes , probs = parallel_scale_back_batch (predicted_loc , predicted_score , dboxes_xywh , scale_xy , scale_wh )
160
+
161
+ for bs in range (batch_size ):
162
+ loc = bboxes [bs ].squeeze (0 )
163
+ for class_id in range (class_number ):
164
+ if class_id == 0 :
165
+ # Skip the background
166
+ continue
167
+ score = probs [bs , :, class_id ]
168
+
169
+ score_sorted , indices = torch .sort (score , descending = True )
170
+ loc_sorted = torch .index_select (loc , 0 , indices )
171
+
172
+ result = nms (loc_sorted .clone (), score_sorted .clone (), criteria , True )
173
+ result_ref = nms (loc .clone (), score .clone (), criteria , False )
174
+ result_ref2 = nms (loc_sorted .clone ().to (dtype = torch .float64 ), score_sorted .clone ().to (dtype = torch .float64 ), criteria , True )
175
+
176
+ bbox_keep , _ = torch .sort (torch .index_select (loc_sorted , 0 , result ).squeeze (0 ), 0 )
177
+ bbox_keep_ref , _ = torch .sort (torch .index_select (loc , 0 , result_ref ).squeeze (0 ), 0 )
178
+ bbox_keep_ref2 , _ = torch .sort (torch .index_select (loc_sorted , 0 , result_ref2 ).squeeze (0 ), 0 )
179
+
180
+ score_keep , _ = torch .sort (torch .index_select (score_sorted , 0 , result ).squeeze (0 ), 0 )
181
+ score_keep_ref , _ = torch .sort (torch .index_select (score , 0 , result_ref ).squeeze (0 ), 0 )
182
+ score_keep_ref2 , _ = torch .sort (torch .index_select (score_sorted , 0 , result_ref2 ).squeeze (0 ), 0 )
183
+
184
+ self .assertEqual (result .size (0 ), result_ref .size (0 ))
185
+ self .assertTrue (torch .allclose (bbox_keep , bbox_keep_ref , rtol = 1e-4 , atol = 1e-4 ))
186
+ self .assertTrue (torch .allclose (score_keep , score_keep_ref , rtol = 1e-4 , atol = 1e-4 ))
187
+ self .assertTrue (torch .allclose (bbox_keep , bbox_keep_ref2 , rtol = 1e-4 , atol = 1e-4 ))
188
+ self .assertTrue (torch .allclose (score_keep , score_keep_ref2 , rtol = 1e-4 , atol = 1e-4 ))
189
+
150
190
if __name__ == '__main__' :
151
191
test = unittest .main ()
0 commit comments