Skip to content

use Nms from opencv #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 33 additions & 144 deletions MTM/NMS.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,74 +10,11 @@

@author: Laurent Thomas
"""
from __future__ import division, print_function # for compatibility with Py2
import pandas as pd

def Point_in_Rectangle(Point, Rectangle):
'''Return True if a point (x,y) is contained in a Rectangle(x, y, width, height)'''
# unpack variables
Px, Py = Point
Rx, Ry, w, h = Rectangle
import cv2

return (Rx <= Px) and (Px <= Rx + w -1) and (Ry <= Py) and (Py <= Ry + h -1) # simply test if x_Point is in the range of x for the rectangle


def computeIoU(BBox1,BBox2):
'''
Compute the IoU (Intersection over Union) between 2 rectangular bounding boxes defined by the top left (Xtop,Ytop) and bottom right (Xbot, Ybot) pixel coordinates
Code adapted from https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
'''
#print('BBox1 : ', BBox1)
#print('BBox2 : ', BBox2)

# Unpack input (python3 - tuple are no more supported as input in function definition - PEP3113 - Tuple can be used in as argument in a call but the function will not unpack it automatically)
Xleft1, Ytop1, Width1, Height1 = BBox1
Xleft2, Ytop2, Width2, Height2 = BBox2

# Compute bottom coordinates
Xright1 = Xleft1 + Width1 -1 # we remove -1 from the width since we start with 1 pixel already (the top one)
Ybot1 = Ytop1 + Height1 -1 # idem for the height

Xright2 = Xleft2 + Width2 -1
Ybot2 = Ytop2 + Height2 -1

# determine the (x, y)-coordinates of the top left and bottom right points of the intersection rectangle
Xleft = max(Xleft1, Xleft2)
Ytop = max(Ytop1, Ytop2)
Xright = min(Xright1, Xright2)
Ybot = min(Ybot1, Ybot2)

# Compute boolean for inclusion
BBox1_in_BBox2 = Point_in_Rectangle((Xleft1, Ytop1), BBox2) and Point_in_Rectangle((Xleft1, Ybot1), BBox2) and Point_in_Rectangle((Xright1, Ytop1), BBox2) and Point_in_Rectangle((Xright1, Ybot1), BBox2)
BBox2_in_BBox1 = Point_in_Rectangle((Xleft2, Ytop2), BBox1) and Point_in_Rectangle((Xleft2, Ybot2), BBox1) and Point_in_Rectangle((Xright2, Ytop2), BBox1) and Point_in_Rectangle((Xright2, Ybot2), BBox1)

# Check that for the intersection box, Xtop,Ytop is indeed on the top left of Xbot,Ybot
if BBox1_in_BBox2 or BBox2_in_BBox1:
#print('One BBox is included within the other')
IoU = 1

elif Xright<Xleft or Ybot<Ytop : # it means that there is no intersection (bbox is inverted)
#print('No overlap')
IoU = 0

else:
# Compute area of the intersecting box
Inter = (Xright - Xleft + 1) * (Ybot - Ytop + 1) # +1 since we are dealing with pixels. See a 1D example with 3 pixels for instance
#print('Intersection area : ', Inter)

# Compute area of the union as Sum of the 2 BBox area - Intersection
Union = Width1 * Height1 + Width2 * Height2 - Inter
#print('Union : ', Union)

# Compute Intersection over union
IoU = Inter/Union

#print('IoU : ',IoU)
return IoU



def NMS(tableHit, scoreThreshold=None, sortAscending=False, N_object=float("inf"), maxOverlap=0.5):
def NMS(tableHit, scoreThreshold=0.5, sortAscending=False, N_object=float("inf"), maxOverlap=0.5):
'''
Perform Non-Maxima supression : it compares the hits after maxima/minima detection, and removes the ones that are too close (too large overlap)
This function works both with an optionnal threshold on the score, and number of detected bbox
Expand All @@ -95,105 +32,57 @@ def NMS(tableHit, scoreThreshold=None, sortAscending=False, N_object=float("inf"
- tableHit : (Panda DataFrame) Each row is a hit, with columns "TemplateName"(String),"BBox"(x,y,width,height),"Score"(float)

- scoreThreshold : Float (or None), used to remove hit with too low prediction score.
If sortDescending=True (ie we use a correlation measure so we want to keep large scores) the scores above that threshold are kept
While if we use sortDescending=False (we use a difference measure ie we want to keep low score), the scores below that threshold are kept
If sortAscending=False (ie we use a correlation measure so we want to keep large scores) the scores above that threshold are kept
If True (we use a difference measure ie we want to keep low score), the scores below that threshold are kept

- N_object : number of best hit to return (by increasing score). Min=1, eventhough it does not really make sense to do NMS with only 1 hit
- N_object : maximum number of hit to return. Default=-1, ie return all hit passing NMS
- maxOverlap : float between 0 and 1, the maximal overlap authorised between 2 bounding boxes, above this value, the bounding box of lower score is deleted
- sortAscending : use True when low score means better prediction (Difference-based score), True otherwise (Correlation score)

OUTPUT
Panda DataFrame with best detection after NMS, it contains max N detection (but potentially less)
'''
listBoxes = tableHit["BBox"].to_list()
listScores = tableHit["Score"].to_list()

# Apply threshold on prediction score
if scoreThreshold==None :
threshTable = tableHit.copy() # copy to avoid modifying the input list in place

elif not sortAscending : # We keep rows above the threshold
threshTable = tableHit[ tableHit['Score']>=scoreThreshold ]

elif sortAscending : # We keep hit below the threshold
threshTable = tableHit[ tableHit['Score']<=scoreThreshold ]

# Sort score to have best predictions first (ie lower score if difference-based, higher score if correlation-based)
# important as we loop testing the best boxes against the other boxes)
threshTable.sort_values("Score", ascending=sortAscending, inplace=True) # Warning here is fine


# Split the inital pool into Final Hit that are kept and restTable that can be tested
# Initialisation : 1st keep is kept for sure, restTable is the rest of the list
#print("\nInitialise final hit list with first best hit")
outTable = threshTable.iloc[[0]].to_dict('records') # double square bracket to recover a DataFrame
restTable = threshTable.iloc[1:].to_dict('records')


# Loop to compute overlap
while len(outTable)<N_object and restTable: # second condition is restTable is not empty

# Report state of the loop
#print("\n\n\nNext while iteration")
if N_object==1:

#print("-> Final hit list")
#for hit in outTable: print(hit)
# Get row with highest or lower score
if sortAscending:
outTable = tableHit[tableHit.Score == tableHit.Score.min()]
else:
outTable = tableHit[tableHit.Score == tableHit.Score.max()]

#print("\n-> Remaining hit list")
#for hit in restTable: print(hit)
return outTable

# pick the next best peak in the rest of peak
testHit_dico = restTable[0] # dico
test_bbox = testHit_dico['BBox']
#print("\nTest BBox:{} for overlap against higher score bboxes".format(test_bbox))

# Loop over hit in outTable to compute successively overlap with testHit
for hit_dico in outTable:

# Recover Bbox from hit
bbox2 = hit_dico['BBox']

# Compute the Intersection over Union between test_peak and current peak
IoU = computeIoU(test_bbox, bbox2)

# Initialise the boolean value to true before test of overlap
ToAppend = True

if IoU>maxOverlap:
ToAppend = False
#print("IoU above threshold\n")
break # no need to test overlap with the other peaks

else:
#print("IoU below threshold\n")
# no overlap for this particular (test_peak,peak) pair, keep looping to test the other (test_peak,peak)
continue


# After testing against all peaks (for loop is over), append or not the peak to final
if ToAppend:
# Move the test_hit from restTable to outTable
#print("Append {} to list of final hits, remove it from Remaining hit list".format(test_hit))
outTable.append(testHit_dico)
restTable.remove(testHit_dico)

else:
# only remove the test_peak from restTable
#print("Remove {} from Remaining hit list".format(test_hit))
restTable.remove(testHit_dico)
# N object > 1 -> do NMS
if sortAscending: # invert score to have always high-score for bets prediction
listScores = [1-score for score in listScores] # NMS expect high-score for good predictions
scoreThreshold = 1-scoreThreshold

# Do NMS
indexes = cv2.dnn.NMSBoxes(listBoxes, listScores, scoreThreshold, maxOverlap)

# Get N best hit
if N_object == float("inf"):
indexes = [ index[0] for index in indexes ] # ordered by score
else:
indexes = [ index[0] for index in indexes[:N_object] ]

outTable = tableHit.iloc[indexes]

# Once function execution is done, return list of hit without overlap
#print("\nCollected N expected hit, or no hit left to test")
#print("NMS over\n")
return pd.DataFrame(outTable)
return outTable


if __name__ == "__main__":
ListHit =[
import pandas as pd
listHit =[
{'TemplateName':1,'BBox':(780, 350, 700, 480), 'Score':0.8},
{'TemplateName':1,'BBox':(806, 416, 716, 442), 'Score':0.6},
{'TemplateName':1,'BBox':(1074, 530, 680, 390), 'Score':0.4}
]

FinalHits = NMS( pd.DataFrame(ListHit), scoreThreshold=0.7, sortAscending=False, maxOverlap=0.5 )
finalHits = NMS( pd.DataFrame(listHit), scoreThreshold=0.61, sortAscending=False, maxOverlap=0.8, N_object=1 )

print(FinalHits)
print(finalHits)
12 changes: 4 additions & 8 deletions MTM/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
- method : int
one of OpenCV template matching method (0 to 5), default 5=0-mean cross-correlation
- N_object: int
expected number of objects in the image
expected number of objects in the image, -1 if unknown
- score_threshold: float in range [0,1]
if N>1, returns local minima/maxima respectively below/above the score_threshold
- searchBox : tuple (X, Y, Width, Height) in pixel unit
Expand All @@ -89,9 +89,6 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
if N_object!=float("inf") and type(N_object)!=int:
raise TypeError("N_object must be an integer")

elif N_object<1:
raise ValueError("At least one object should be expected in the image")

## Crop image to search region if provided
if searchBox != None:
xOffset, yOffset, searchWidth, searchHeight = searchBox
Expand Down Expand Up @@ -176,12 +173,11 @@ def matchTemplates(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=f

tableHit = findMatches(listTemplates, image, method, N_object, score_threshold, searchBox)

if method == 1: bestHits = NMS(tableHit, N_object=N_object, maxOverlap=maxOverlap, sortAscending=True)
if method == 1: sortAscending = True
elif method in (3,5): sortAscending = False

elif method in (3,5): bestHits = NMS(tableHit, N_object=N_object, maxOverlap=maxOverlap, sortAscending=False)
return NMS(tableHit, score_threshold, sortAscending, N_object, maxOverlap)

return bestHits


def drawBoxesOnRGB(image, tableHit, boxThickness=2, boxColor=(255, 255, 00), showLabel=False, labelColor=(255, 255, 0), labelScale=0.5 ):
'''
Expand Down
2 changes: 1 addition & 1 deletion MTM/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module
__version__ = '1.5.3post1'
__version__ = '1.6.0'
5 changes: 4 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@
import MTM, cv2
import numpy as np

print(MTM.__version__)

#%% Get image and templates by cropping
image = coins()
smallCoin = image[37:37+38, 80:80+41]
bigCoin = image[14:14+59,302:302+65]


#%% Perform matching
tableHit = MTM.matchTemplates([('small', smallCoin), ('big', bigCoin)], image, score_threshold=0.6, method=cv2.TM_CCOEFF_NORMED, maxOverlap=0) # Correlation-score
tableHit = MTM.matchTemplates([('small', smallCoin), ('big', bigCoin)], image, score_threshold=0.3, method=cv2.TM_CCOEFF_NORMED, maxOverlap=0) # Correlation-score
#tableHit = MTM.matchTemplates([('small', smallCoin), ('big', bigCoin)], image, score_threshold=0.4, method=cv2.TM_SQDIFF_NORMED, maxOverlap=0) # Difference-score

print("Found {} coins".format(len(tableHit)))
print(tableHit)


#%% Display matches
Overlay = MTM.drawBoxesOnRGB(image, tableHit, showLabel=True)
plt.figure()
Expand Down
Loading