diff --git a/connection.go b/connection.go index a62db429a..df07b955e 100644 --- a/connection.go +++ b/connection.go @@ -68,11 +68,19 @@ func (mc *mysqlConn) handleParams() (err error) { // time.Time parsing case "parseTime": - mc.parseTime = readBool(val) + var isBool bool + mc.parseTime, isBool = readBool(val) + if !isBool { + return errors.New("Invalid Bool value: " + val) + } // Strict mode case "strict": - mc.strict = readBool(val) + var isBool bool + mc.strict, isBool = readBool(val) + if !isBool { + return errors.New("Invalid Bool value: " + val) + } // Compression case "compress": diff --git a/driver_test.go b/driver_test.go index 246e83ee5..adf7af431 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1053,7 +1053,7 @@ func TestStmtMultiRows(t *testing.T) { } func TestConcurrent(t *testing.T) { - if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true { + if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled { t.Skip("CONCURRENT env var not set") } diff --git a/utils.go b/utils.go index e40fcb2f1..6658ef6f3 100644 --- a/utils.go +++ b/utils.go @@ -23,63 +23,25 @@ import ( "time" ) -// NullTime represents a time.Time that may be NULL. -// NullTime implements the Scanner interface so -// it can be used as a scan destination: -// -// var nt NullTime -// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) -// ... -// if nt.Valid { -// // use nt.Time -// } else { -// // NULL value -// } -// -// This NullTime implementation is not driver-specific -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} - -// Scan implements the Scanner interface. -// The value type must be time.Time or string / []byte (formatted time-string), -// otherwise Scan fails. -func (nt *NullTime) Scan(value interface{}) (err error) { - if value == nil { - nt.Time, nt.Valid = time.Time{}, false - return - } +var ( + errLog *log.Logger // Error Logger + dsnPattern *regexp.Regexp // Data Source Name Parser + tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs +) - switch v := value.(type) { - case time.Time: - nt.Time, nt.Valid = v, true - return - case []byte: - nt.Time, err = parseDateTime(string(v), time.UTC) - nt.Valid = (err == nil) - return - case string: - nt.Time, err = parseDateTime(v, time.UTC) - nt.Valid = (err == nil) - return - } +func init() { + errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) - nt.Valid = false - return fmt.Errorf("Can't convert %T to time.Time", value) -} + dsnPattern = regexp.MustCompile( + `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] + `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] + `\/(?P.*?)` + // /dbname + `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] -// Value implements the driver Valuer interface. -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil + tlsConfigRegister = make(map[string]*tls.Config) } -var tlsConfigMap map[string]*tls.Config - -// Registers a custom tls.Config to be used with sql.Open. +// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // // rootCertPool := x509.NewCertPool() @@ -102,39 +64,20 @@ var tlsConfigMap map[string]*tls.Config // }) // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // -func RegisterTLSConfig(key string, config *tls.Config) { - if tlsConfigMap == nil { - tlsConfigMap = make(map[string]*tls.Config) +func RegisterTLSConfig(key string, config *tls.Config) error { + if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { + return fmt.Errorf("Key '%s' is reserved", key) } - tlsConfigMap[key] = config -} -// Removes tls.Config associated with key. -func DeregisterTLSConfig(key string) { - if tlsConfigMap == nil { - return - } - delete(tlsConfigMap, key) + tlsConfigRegister[key] = config + return nil } -// Logger -var ( - errLog *log.Logger -) - -func init() { - errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) - - dsnPattern = regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] +// DeregisterTLSConfig removes the tls.Config associated with key. +func DeregisterTLSConfig(key string) { + delete(tlsConfigRegister, key) } -// Data Source Name Parser -var dsnPattern *regexp.Regexp - func parseDSN(dsn string) (cfg *config, err error) { cfg = new(config) cfg.params = make(map[string]string) @@ -166,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) { // Disable INFILE whitelist / enable all files case "allowAllFiles": - cfg.allowAllFiles = readBool(value) + var isBool bool + cfg.allowAllFiles, isBool = readBool(value) + if !isBool { + err = fmt.Errorf("Invalid Bool value: %s", value) + return + } // Switch "rowsAffected" mode case "clientFoundRows": - cfg.clientFoundRows = readBool(value) + var isBool bool + cfg.clientFoundRows, isBool = readBool(value) + if !isBool { + err = fmt.Errorf("Invalid Bool value: %s", value) + return + } // Time Location case "loc": @@ -188,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) { // TLS-Encryption case "tls": - if readBool(value) { - cfg.tls = &tls.Config{} - } else if strings.ToLower(value) == "skip-verify" { - cfg.tls = &tls.Config{InsecureSkipVerify: true} - // TODO: Check for Boolean false - } else if tlsConfig, ok := tlsConfigMap[value]; ok { - cfg.tls = tlsConfig + boolValue, isBool := readBool(value) + if isBool { + if boolValue { + cfg.tls = &tls.Config{} + } + } else { + if strings.ToLower(value) == "skip-verify" { + cfg.tls = &tls.Config{InsecureSkipVerify: true} + } else if tlsConfig, ok := tlsConfigRegister[value]; ok { + cfg.tls = tlsConfig + } else { + err = fmt.Errorf("Invalid value / unknown config name: %s", value) + return + } } default: @@ -253,6 +213,78 @@ func scramblePassword(scramble, password []byte) []byte { return scramble } +// Returns the bool value of the input. +// The 2nd return value indicates if the input was a valid bool value +func readBool(input string) (value bool, valid bool) { + switch input { + case "1", "true", "TRUE", "True": + return true, true + case "0", "false", "FALSE", "False": + return false, true + } + + // Not a valid bool value + return +} + +/****************************************************************************** +* Time related utils * +******************************************************************************/ + +// NullTime represents a time.Time that may be NULL. +// NullTime implements the Scanner interface so +// it can be used as a scan destination: +// +// var nt NullTime +// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) +// ... +// if nt.Valid { +// // use nt.Time +// } else { +// // NULL value +// } +// +// This NullTime implementation is not driver-specific +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +// The value type must be time.Time or string / []byte (formatted time-string), +// otherwise Scan fails. +func (nt *NullTime) Scan(value interface{}) (err error) { + if value == nil { + nt.Time, nt.Valid = time.Time{}, false + return + } + + switch v := value.(type) { + case time.Time: + nt.Time, nt.Valid = v, true + return + case []byte: + nt.Time, err = parseDateTime(string(v), time.UTC) + nt.Valid = (err == nil) + return + case string: + nt.Time, err = parseDateTime(v, time.UTC) + nt.Valid = (err == nil) + return + } + + nt.Valid = false + return fmt.Errorf("Can't convert %T to time.Time", value) +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { switch len(str) { case 10: // YYYY-MM-DD @@ -369,16 +401,6 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) { return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) } -func readBool(value string) bool { - switch strings.ToLower(value) { - case "true": - return true - case "1": - return true - } - return false -} - /****************************************************************************** * Convert from and to bytes * ******************************************************************************/