diff --git a/driver.go b/driver.go index d42ce7a3d..27cf5ad4e 100644 --- a/driver.go +++ b/driver.go @@ -20,6 +20,7 @@ import ( "database/sql" "database/sql/driver" "net" + "sync" ) // watcher interface is used for context support (From Go 1.8) @@ -35,12 +36,17 @@ type MySQLDriver struct{} // Custom dial functions must be registered with RegisterDial type DialFunc func(addr string) (net.Conn, error) -var dials map[string]DialFunc +var ( + dialsLock sync.RWMutex + dials map[string]DialFunc +) // RegisterDial registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. // addr is passed as a parameter to the dial function. func RegisterDial(net string, dial DialFunc) { + dialsLock.Lock() + defer dialsLock.Unlock() if dials == nil { dials = make(map[string]DialFunc) } @@ -66,7 +72,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.parseTime = mc.cfg.ParseTime // Connect to Server - if dial, ok := dials[mc.cfg.Net]; ok { + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { mc.netConn, err = dial(mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout}