From 0b2f41a68c0d3ddd37630455cf10816967cef462 Mon Sep 17 00:00:00 2001 From: Ross Light Date: Tue, 27 Mar 2018 15:05:04 -0700 Subject: [PATCH] Make RegisterDial safe to call from multiple goroutines. Fixes #772 --- driver.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/driver.go b/driver.go index d42ce7a3d..27cf5ad4e 100644 --- a/driver.go +++ b/driver.go @@ -20,6 +20,7 @@ import ( "database/sql" "database/sql/driver" "net" + "sync" ) // watcher interface is used for context support (From Go 1.8) @@ -35,12 +36,17 @@ type MySQLDriver struct{} // Custom dial functions must be registered with RegisterDial type DialFunc func(addr string) (net.Conn, error) -var dials map[string]DialFunc +var ( + dialsLock sync.RWMutex + dials map[string]DialFunc +) // RegisterDial registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. // addr is passed as a parameter to the dial function. func RegisterDial(net string, dial DialFunc) { + dialsLock.Lock() + defer dialsLock.Unlock() if dials == nil { dials = make(map[string]DialFunc) } @@ -66,7 +72,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.parseTime = mc.cfg.ParseTime // Connect to Server - if dial, ok := dials[mc.cfg.Net]; ok { + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { mc.netConn, err = dial(mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout}