Skip to content

Commit c63239b

Browse files
committed
by pass NMS for N=1
1 parent 19be6e9 commit c63239b

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

MTM/NMS.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,26 @@ def NMS(tableHit, scoreThreshold=0, sortAscending=False, N_object=float("inf"),
4545
listBoxes = tableHit["BBox"].to_list()
4646
listScores = tableHit["Score"].to_list()
4747

48-
if sortAscending:
48+
if N_object==1:
49+
50+
# Get row with highest or lower score
51+
if sortAscending:
52+
outTable = tableHit[tableHit.Score == tableHit.Score.min()]
53+
else:
54+
outTable = tableHit[tableHit.Score == tableHit.Score.max()]
55+
56+
return outTable
57+
58+
59+
# N object > 1 -> do NMS
60+
if sortAscending: # invert score to have always high-score for bets prediction
4961
listScores = [1-score for score in listScores] # NMS expect high-score for good predictions
5062
scoreThreshold = 1-scoreThreshold
51-
63+
64+
# Do NMS
5265
indexes = cv2.dnn.NMSBoxes(listBoxes, listScores, scoreThreshold, maxOverlap)
5366

67+
# Get N best hit
5468
if N_object == float("inf"):
5569
indexes = [ index[0] for index in indexes ] # ordered by score
5670
else:
@@ -69,6 +83,6 @@ def NMS(tableHit, scoreThreshold=0, sortAscending=False, N_object=float("inf"),
6983
{'TemplateName':1,'BBox':(1074, 530, 680, 390), 'Score':0.4}
7084
]
7185

72-
FinalHits = NMS( pd.DataFrame(ListHit), scoreThreshold=0.61, sortAscending=True, maxOverlap=0.8, N_object=2 )
86+
FinalHits = NMS( pd.DataFrame(ListHit), scoreThreshold=0.61, sortAscending=False, maxOverlap=0.8, N_object=1 )
7387

7488
print(FinalHits)

0 commit comments

Comments
 (0)