Skip to content

Commit d9a234a

Browse files
authored
Merge pull request #6 from multi-template-matching/nms-opencv
use Nms from opencv
2 parents 55911d0 + df284f6 commit d9a234a

File tree

6 files changed

+544
-156
lines changed

6 files changed

+544
-156
lines changed

MTM/NMS.py

Lines changed: 33 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -10,74 +10,11 @@
1010
1111
@author: Laurent Thomas
1212
"""
13-
from __future__ import division, print_function # for compatibility with Py2
14-
import pandas as pd
1513

16-
def Point_in_Rectangle(Point, Rectangle):
17-
'''Return True if a point (x,y) is contained in a Rectangle(x, y, width, height)'''
18-
# unpack variables
19-
Px, Py = Point
20-
Rx, Ry, w, h = Rectangle
14+
import cv2
2115

22-
return (Rx <= Px) and (Px <= Rx + w -1) and (Ry <= Py) and (Py <= Ry + h -1) # simply test if x_Point is in the range of x for the rectangle
2316

24-
25-
def computeIoU(BBox1,BBox2):
26-
'''
27-
Compute the IoU (Intersection over Union) between 2 rectangular bounding boxes defined by the top left (Xtop,Ytop) and bottom right (Xbot, Ybot) pixel coordinates
28-
Code adapted from https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
29-
'''
30-
#print('BBox1 : ', BBox1)
31-
#print('BBox2 : ', BBox2)
32-
33-
# Unpack input (python3 - tuple are no more supported as input in function definition - PEP3113 - Tuple can be used in as argument in a call but the function will not unpack it automatically)
34-
Xleft1, Ytop1, Width1, Height1 = BBox1
35-
Xleft2, Ytop2, Width2, Height2 = BBox2
36-
37-
# Compute bottom coordinates
38-
Xright1 = Xleft1 + Width1 -1 # we remove -1 from the width since we start with 1 pixel already (the top one)
39-
Ybot1 = Ytop1 + Height1 -1 # idem for the height
40-
41-
Xright2 = Xleft2 + Width2 -1
42-
Ybot2 = Ytop2 + Height2 -1
43-
44-
# determine the (x, y)-coordinates of the top left and bottom right points of the intersection rectangle
45-
Xleft = max(Xleft1, Xleft2)
46-
Ytop = max(Ytop1, Ytop2)
47-
Xright = min(Xright1, Xright2)
48-
Ybot = min(Ybot1, Ybot2)
49-
50-
# Compute boolean for inclusion
51-
BBox1_in_BBox2 = Point_in_Rectangle((Xleft1, Ytop1), BBox2) and Point_in_Rectangle((Xleft1, Ybot1), BBox2) and Point_in_Rectangle((Xright1, Ytop1), BBox2) and Point_in_Rectangle((Xright1, Ybot1), BBox2)
52-
BBox2_in_BBox1 = Point_in_Rectangle((Xleft2, Ytop2), BBox1) and Point_in_Rectangle((Xleft2, Ybot2), BBox1) and Point_in_Rectangle((Xright2, Ytop2), BBox1) and Point_in_Rectangle((Xright2, Ybot2), BBox1)
53-
54-
# Check that for the intersection box, Xtop,Ytop is indeed on the top left of Xbot,Ybot
55-
if BBox1_in_BBox2 or BBox2_in_BBox1:
56-
#print('One BBox is included within the other')
57-
IoU = 1
58-
59-
elif Xright<Xleft or Ybot<Ytop : # it means that there is no intersection (bbox is inverted)
60-
#print('No overlap')
61-
IoU = 0
62-
63-
else:
64-
# Compute area of the intersecting box
65-
Inter = (Xright - Xleft + 1) * (Ybot - Ytop + 1) # +1 since we are dealing with pixels. See a 1D example with 3 pixels for instance
66-
#print('Intersection area : ', Inter)
67-
68-
# Compute area of the union as Sum of the 2 BBox area - Intersection
69-
Union = Width1 * Height1 + Width2 * Height2 - Inter
70-
#print('Union : ', Union)
71-
72-
# Compute Intersection over union
73-
IoU = Inter/Union
74-
75-
#print('IoU : ',IoU)
76-
return IoU
77-
78-
79-
80-
def NMS(tableHit, scoreThreshold=None, sortAscending=False, N_object=float("inf"), maxOverlap=0.5):
17+
def NMS(tableHit, scoreThreshold=0.5, sortAscending=False, N_object=float("inf"), maxOverlap=0.5):
8118
'''
8219
Perform Non-Maxima supression : it compares the hits after maxima/minima detection, and removes the ones that are too close (too large overlap)
8320
This function works both with an optionnal threshold on the score, and number of detected bbox
@@ -95,105 +32,57 @@ def NMS(tableHit, scoreThreshold=None, sortAscending=False, N_object=float("inf"
9532
- tableHit : (Panda DataFrame) Each row is a hit, with columns "TemplateName"(String),"BBox"(x,y,width,height),"Score"(float)
9633
9734
- scoreThreshold : Float (or None), used to remove hit with too low prediction score.
98-
If sortDescending=True (ie we use a correlation measure so we want to keep large scores) the scores above that threshold are kept
99-
While if we use sortDescending=False (we use a difference measure ie we want to keep low score), the scores below that threshold are kept
35+
If sortAscending=False (ie we use a correlation measure so we want to keep large scores) the scores above that threshold are kept
36+
If True (we use a difference measure ie we want to keep low score), the scores below that threshold are kept
10037
101-
- N_object : number of best hit to return (by increasing score). Min=1, eventhough it does not really make sense to do NMS with only 1 hit
38+
- N_object : maximum number of hit to return. Default=-1, ie return all hit passing NMS
10239
- maxOverlap : float between 0 and 1, the maximal overlap authorised between 2 bounding boxes, above this value, the bounding box of lower score is deleted
10340
- sortAscending : use True when low score means better prediction (Difference-based score), True otherwise (Correlation score)
10441
10542
OUTPUT
10643
Panda DataFrame with best detection after NMS, it contains max N detection (but potentially less)
10744
'''
45+
listBoxes = tableHit["BBox"].to_list()
46+
listScores = tableHit["Score"].to_list()
10847

109-
# Apply threshold on prediction score
110-
if scoreThreshold==None :
111-
threshTable = tableHit.copy() # copy to avoid modifying the input list in place
112-
113-
elif not sortAscending : # We keep rows above the threshold
114-
threshTable = tableHit[ tableHit['Score']>=scoreThreshold ]
115-
116-
elif sortAscending : # We keep hit below the threshold
117-
threshTable = tableHit[ tableHit['Score']<=scoreThreshold ]
118-
119-
# Sort score to have best predictions first (ie lower score if difference-based, higher score if correlation-based)
120-
# important as we loop testing the best boxes against the other boxes)
121-
threshTable.sort_values("Score", ascending=sortAscending, inplace=True) # Warning here is fine
122-
123-
124-
# Split the inital pool into Final Hit that are kept and restTable that can be tested
125-
# Initialisation : 1st keep is kept for sure, restTable is the rest of the list
126-
#print("\nInitialise final hit list with first best hit")
127-
outTable = threshTable.iloc[[0]].to_dict('records') # double square bracket to recover a DataFrame
128-
restTable = threshTable.iloc[1:].to_dict('records')
129-
130-
131-
# Loop to compute overlap
132-
while len(outTable)<N_object and restTable: # second condition is restTable is not empty
133-
134-
# Report state of the loop
135-
#print("\n\n\nNext while iteration")
48+
if N_object==1:
13649

137-
#print("-> Final hit list")
138-
#for hit in outTable: print(hit)
50+
# Get row with highest or lower score
51+
if sortAscending:
52+
outTable = tableHit[tableHit.Score == tableHit.Score.min()]
53+
else:
54+
outTable = tableHit[tableHit.Score == tableHit.Score.max()]
13955

140-
#print("\n-> Remaining hit list")
141-
#for hit in restTable: print(hit)
56+
return outTable
14257

143-
# pick the next best peak in the rest of peak
144-
testHit_dico = restTable[0] # dico
145-
test_bbox = testHit_dico['BBox']
146-
#print("\nTest BBox:{} for overlap against higher score bboxes".format(test_bbox))
147-
148-
# Loop over hit in outTable to compute successively overlap with testHit
149-
for hit_dico in outTable:
150-
151-
# Recover Bbox from hit
152-
bbox2 = hit_dico['BBox']
153-
154-
# Compute the Intersection over Union between test_peak and current peak
155-
IoU = computeIoU(test_bbox, bbox2)
156-
157-
# Initialise the boolean value to true before test of overlap
158-
ToAppend = True
15958

160-
if IoU>maxOverlap:
161-
ToAppend = False
162-
#print("IoU above threshold\n")
163-
break # no need to test overlap with the other peaks
164-
165-
else:
166-
#print("IoU below threshold\n")
167-
# no overlap for this particular (test_peak,peak) pair, keep looping to test the other (test_peak,peak)
168-
continue
169-
170-
171-
# After testing against all peaks (for loop is over), append or not the peak to final
172-
if ToAppend:
173-
# Move the test_hit from restTable to outTable
174-
#print("Append {} to list of final hits, remove it from Remaining hit list".format(test_hit))
175-
outTable.append(testHit_dico)
176-
restTable.remove(testHit_dico)
177-
178-
else:
179-
# only remove the test_peak from restTable
180-
#print("Remove {} from Remaining hit list".format(test_hit))
181-
restTable.remove(testHit_dico)
59+
# N object > 1 -> do NMS
60+
if sortAscending: # invert score to have always high-score for bets prediction
61+
listScores = [1-score for score in listScores] # NMS expect high-score for good predictions
62+
scoreThreshold = 1-scoreThreshold
63+
64+
# Do NMS
65+
indexes = cv2.dnn.NMSBoxes(listBoxes, listScores, scoreThreshold, maxOverlap)
18266

67+
# Get N best hit
68+
if N_object == float("inf"):
69+
indexes = [ index[0] for index in indexes ] # ordered by score
70+
else:
71+
indexes = [ index[0] for index in indexes[:N_object] ]
72+
73+
outTable = tableHit.iloc[indexes]
18374

184-
# Once function execution is done, return list of hit without overlap
185-
#print("\nCollected N expected hit, or no hit left to test")
186-
#print("NMS over\n")
187-
return pd.DataFrame(outTable)
75+
return outTable
18876

18977

19078
if __name__ == "__main__":
191-
ListHit =[
79+
import pandas as pd
80+
listHit =[
19281
{'TemplateName':1,'BBox':(780, 350, 700, 480), 'Score':0.8},
19382
{'TemplateName':1,'BBox':(806, 416, 716, 442), 'Score':0.6},
19483
{'TemplateName':1,'BBox':(1074, 530, 680, 390), 'Score':0.4}
19584
]
19685

197-
FinalHits = NMS( pd.DataFrame(ListHit), scoreThreshold=0.7, sortAscending=False, maxOverlap=0.5 )
86+
finalHits = NMS( pd.DataFrame(listHit), scoreThreshold=0.61, sortAscending=False, maxOverlap=0.8, N_object=1 )
19887

199-
print(FinalHits)
88+
print(finalHits)

MTM/__init__.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
7676
- method : int
7777
one of OpenCV template matching method (0 to 5), default 5=0-mean cross-correlation
7878
- N_object: int
79-
expected number of objects in the image
79+
expected number of objects in the image, -1 if unknown
8080
- score_threshold: float in range [0,1]
8181
if N>1, returns local minima/maxima respectively below/above the score_threshold
8282
- searchBox : tuple (X, Y, Width, Height) in pixel unit
@@ -89,9 +89,6 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
8989
if N_object!=float("inf") and type(N_object)!=int:
9090
raise TypeError("N_object must be an integer")
9191

92-
elif N_object<1:
93-
raise ValueError("At least one object should be expected in the image")
94-
9592
## Crop image to search region if provided
9693
if searchBox != None:
9794
xOffset, yOffset, searchWidth, searchHeight = searchBox
@@ -176,12 +173,11 @@ def matchTemplates(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=f
176173

177174
tableHit = findMatches(listTemplates, image, method, N_object, score_threshold, searchBox)
178175

179-
if method == 1: bestHits = NMS(tableHit, N_object=N_object, maxOverlap=maxOverlap, sortAscending=True)
176+
if method == 1: sortAscending = True
177+
elif method in (3,5): sortAscending = False
180178

181-
elif method in (3,5): bestHits = NMS(tableHit, N_object=N_object, maxOverlap=maxOverlap, sortAscending=False)
179+
return NMS(tableHit, score_threshold, sortAscending, N_object, maxOverlap)
182180

183-
return bestHits
184-
185181

186182
def drawBoxesOnRGB(image, tableHit, boxThickness=2, boxColor=(255, 255, 00), showLabel=False, labelColor=(255, 255, 0), labelScale=0.5 ):
187183
'''

MTM/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# 1) we don't load dependencies by storing it in __init__.py
33
# 2) we can import it in setup.py for the same reason
44
# 3) we can import it into your module module
5-
__version__ = '1.5.3post1'
5+
__version__ = '1.6.0'

test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,22 @@
66
import MTM, cv2
77
import numpy as np
88

9+
print(MTM.__version__)
10+
911
#%% Get image and templates by cropping
1012
image = coins()
1113
smallCoin = image[37:37+38, 80:80+41]
1214
bigCoin = image[14:14+59,302:302+65]
1315

1416

1517
#%% Perform matching
16-
tableHit = MTM.matchTemplates([('small', smallCoin), ('big', bigCoin)], image, score_threshold=0.6, method=cv2.TM_CCOEFF_NORMED, maxOverlap=0) # Correlation-score
18+
tableHit = MTM.matchTemplates([('small', smallCoin), ('big', bigCoin)], image, score_threshold=0.3, method=cv2.TM_CCOEFF_NORMED, maxOverlap=0) # Correlation-score
1719
#tableHit = MTM.matchTemplates([('small', smallCoin), ('big', bigCoin)], image, score_threshold=0.4, method=cv2.TM_SQDIFF_NORMED, maxOverlap=0) # Difference-score
1820

1921
print("Found {} coins".format(len(tableHit)))
2022
print(tableHit)
2123

24+
2225
#%% Display matches
2326
Overlay = MTM.drawBoxesOnRGB(image, tableHit, showLabel=True)
2427
plt.figure()

0 commit comments

Comments
 (0)