diff --git a/AUTHORS b/AUTHORS index 14e8398fd..04c586d99 100644 --- a/AUTHORS +++ b/AUTHORS @@ -15,6 +15,7 @@ Aaron Hopkins Achille Roussel Alexey Palazhchenko Andrew Reid +Anthony Gruetzmacher Arne Hormann Asta Xie Bulat Gaifullin diff --git a/driver.go b/driver.go index d42ce7a3d..18f77adb0 100644 --- a/driver.go +++ b/driver.go @@ -17,11 +17,20 @@ package mysql import ( + "context" "database/sql" "database/sql/driver" + "errors" "net" ) +var ( + errInvalidUser = errors.New("invalid Connection: User is not set or longer than 32 chars") + errInvalidAddr = errors.New("invalid Connection: Addr config is missing") + errInvalidNet = errors.New("invalid Connection: Only tcp is valid for Net") + errInvalidDBName = errors.New("invalid Connection: DBName config is missing") +) + // watcher interface is used for context support (From Go 1.8) type watcher interface { startWatcher() @@ -29,7 +38,12 @@ type watcher interface { // MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. -type MySQLDriver struct{} +type MySQLDriver struct { +} + +type MySQLConnector struct { + Cfg *Config +} // DialFunc is a function which can be used to establish the network connection. // Custom dial functions must be registered with RegisterDial @@ -47,33 +61,22 @@ func RegisterDial(net string, dial DialFunc) { dials[net] = dial } -// Open new Connection. -// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how -// the DSN string is formated -func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { +//Open a new Connection +func connectServer(cxt context.Context, mc *mysqlConn) error { var err error - - // New mysqlConn - mc := &mysqlConn{ - maxAllowedPacket: maxPacketSize, - maxWriteSize: maxPacketSize - 1, - closech: make(chan struct{}), - } - mc.cfg, err = ParseDSN(dsn) - if err != nil { - return nil, err - } - mc.parseTime = mc.cfg.ParseTime - // Connect to Server if dial, ok := dials[mc.cfg.Net]; ok { mc.netConn, err = dial(mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) + if cxt == nil { + mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) + } else { + mc.netConn, err = nd.DialContext(cxt, mc.cfg.Net, mc.cfg.Addr) + } } if err != nil { - return nil, err + return err } // Enable TCP Keepalives on TCP connections @@ -82,7 +85,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { // Don't send COM_QUIT before handshake. mc.netConn.Close() mc.netConn = nil - return nil, err + return err } } @@ -101,13 +104,13 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { cipher, err := mc.readInitPacket() if err != nil { mc.cleanup() - return nil, err + return err } // Send Client Authentication Packet if err = mc.writeAuthPacket(cipher); err != nil { mc.cleanup() - return nil, err + return err } // Handle response to auth packet, switch methods if possible @@ -116,7 +119,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. mc.cleanup() - return nil, err + return err } if mc.cfg.MaxAllowedPacket > 0 { @@ -126,7 +129,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { maxap, err := mc.getSystemVar("max_allowed_packet") if err != nil { mc.Close() - return nil, err + return err } mc.maxAllowedPacket = stringToInt(maxap) - 1 } @@ -134,6 +137,91 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.maxWriteSize = mc.maxAllowedPacket } + return err +} + +//Connect opens a new connection without using a DSN +func (c MySQLConnector) Connect(cxt context.Context) (driver.Conn, error) { + var err error + + //Validate the connection parameters + //the following are required User,Pass,Net,Addr,DBName + //Pass may be blank + //The other optional parameters are not checks + //as GO will automatically enforce proper bool types on the options + if len(c.Cfg.User) > 32 || len(c.Cfg.User) <= 0 { + return nil, errInvalidUser + } + + if len(c.Cfg.Addr) <= 0 { + return nil, errInvalidAddr + } + + if len(c.Cfg.DBName) <= 0 { + return nil, errInvalidDBName + } + + if c.Cfg.Net != "tcp" { + return nil, errInvalidNet + } + + //New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + cfg: c.Cfg, + parseTime: c.Cfg.ParseTime, + } + + //Check if the there is a canelation before creating the connection + select { + case <-cxt.Done(): + return nil, cxt.Err() + default: + //Connect to the server and setting the connection settings + err = connectServer(cxt, mc) + if err != nil { + return nil, err + } + } + return mc, nil +} + +//Driver returns a driver interface +func (d MySQLDriver) Driver() driver.Driver { + return MySQLDriver{} +} + +//Driver returns a driver interface +func (c MySQLConnector) Driver() driver.Driver { + return MySQLDriver{} +} + +// Open new Connection using a DSN. +// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how +// the DSN string is formated +func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + } + mc.cfg, err = ParseDSN(dsn) + if err != nil { + return nil, err + } + mc.parseTime = mc.cfg.ParseTime + + err = connectServer(nil, mc) + // Connect to Server + if err != nil { + return nil, err + } + // Handle DSN Params err = mc.handleParams() if err != nil { diff --git a/driver_go110_test.go b/driver_go110_test.go new file mode 100644 index 000000000..2af9b8769 --- /dev/null +++ b/driver_go110_test.go @@ -0,0 +1,82 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10 + +package mysql + +import ( + "context" + "database/sql" + "database/sql/driver" + "sync" + "testing" +) + +type Connector struct { + m sync.Mutex + mysql *MySQLConnector +} + +func (c *Connector) Connect(cxt context.Context) (driver.Conn, error) { + var err error + + if c.mysql == nil { + c.mysql = c.init() + } + + //Just use the global DSN because we just want to test the connector + //interface and we do not care about any custom functionality in the Connector + c.m.Lock() + c.mysql.Cfg, err = ParseDSN(dsn) + c.m.Unlock() + if err != nil { + println(err) + return nil, err + } + + return c.mysql.Connect(cxt) +} + +func (c *Connector) Driver() driver.Driver { + return c.mysql.Driver() +} + +func (c *Connector) init() *MySQLConnector { + return &MySQLConnector{} +} + +func runtestsWithConnector(t *testing.T, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + connector := &Connector{} + + db := sql.OpenDB(connector) + if err := db.Ping(); err != nil { + db.Close() + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + dbt := &DBTest{t, db} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + } + +} + +func TestPingWithConnector(t *testing.T) { + runtestsWithConnector(t, func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.fail("Ping With Connector", "Ping With Connector", err) + } + }) +}