From abd1799c82ed6da95477d57a6d8682f751b4d475 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Tue, 2 Jul 2013 00:13:55 +0200 Subject: [PATCH 1/2] Remove possible data race in tls.Config map init also resorted code --- utils.go | 174 +++++++++++++++++++++++++++---------------------------- 1 file changed, 85 insertions(+), 89 deletions(-) diff --git a/utils.go b/utils.go index e40fcb2f1..b8bb4b867 100644 --- a/utils.go +++ b/utils.go @@ -23,62 +23,24 @@ import ( "time" ) -// NullTime represents a time.Time that may be NULL. -// NullTime implements the Scanner interface so -// it can be used as a scan destination: -// -// var nt NullTime -// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) -// ... -// if nt.Valid { -// // use nt.Time -// } else { -// // NULL value -// } -// -// This NullTime implementation is not driver-specific -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} - -// Scan implements the Scanner interface. -// The value type must be time.Time or string / []byte (formatted time-string), -// otherwise Scan fails. -func (nt *NullTime) Scan(value interface{}) (err error) { - if value == nil { - nt.Time, nt.Valid = time.Time{}, false - return - } +var ( + errLog *log.Logger // Error Logger + dsnPattern *regexp.Regexp // Data Source Name Parser + tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs +) - switch v := value.(type) { - case time.Time: - nt.Time, nt.Valid = v, true - return - case []byte: - nt.Time, err = parseDateTime(string(v), time.UTC) - nt.Valid = (err == nil) - return - case string: - nt.Time, err = parseDateTime(v, time.UTC) - nt.Valid = (err == nil) - return - } +func init() { + errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) - nt.Valid = false - return fmt.Errorf("Can't convert %T to time.Time", value) -} + dsnPattern = regexp.MustCompile( + `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] + `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] + `\/(?P.*?)` + // /dbname + `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] -// Value implements the driver Valuer interface. -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil + tlsConfigRegister = make(map[string]*tls.Config) } -var tlsConfigMap map[string]*tls.Config - // Registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // @@ -103,38 +65,14 @@ var tlsConfigMap map[string]*tls.Config // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // func RegisterTLSConfig(key string, config *tls.Config) { - if tlsConfigMap == nil { - tlsConfigMap = make(map[string]*tls.Config) - } - tlsConfigMap[key] = config + tlsConfigRegister[key] = config } // Removes tls.Config associated with key. func DeregisterTLSConfig(key string) { - if tlsConfigMap == nil { - return - } - delete(tlsConfigMap, key) + delete(tlsConfigRegister, key) } -// Logger -var ( - errLog *log.Logger -) - -func init() { - errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) - - dsnPattern = regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] -} - -// Data Source Name Parser -var dsnPattern *regexp.Regexp - func parseDSN(dsn string) (cfg *config, err error) { cfg = new(config) cfg.params = make(map[string]string) @@ -192,8 +130,8 @@ func parseDSN(dsn string) (cfg *config, err error) { cfg.tls = &tls.Config{} } else if strings.ToLower(value) == "skip-verify" { cfg.tls = &tls.Config{InsecureSkipVerify: true} - // TODO: Check for Boolean false - } else if tlsConfig, ok := tlsConfigMap[value]; ok { + // TODO: Check for Boolean false + } else if tlsConfig, ok := tlsConfigRegister[value]; ok { cfg.tls = tlsConfig } @@ -253,6 +191,74 @@ func scramblePassword(scramble, password []byte) []byte { return scramble } +func readBool(value string) bool { + switch strings.ToLower(value) { + case "true": + return true + case "1": + return true + } + return false +} + +/****************************************************************************** +* Time related utils * +******************************************************************************/ + +// NullTime represents a time.Time that may be NULL. +// NullTime implements the Scanner interface so +// it can be used as a scan destination: +// +// var nt NullTime +// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) +// ... +// if nt.Valid { +// // use nt.Time +// } else { +// // NULL value +// } +// +// This NullTime implementation is not driver-specific +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +// The value type must be time.Time or string / []byte (formatted time-string), +// otherwise Scan fails. +func (nt *NullTime) Scan(value interface{}) (err error) { + if value == nil { + nt.Time, nt.Valid = time.Time{}, false + return + } + + switch v := value.(type) { + case time.Time: + nt.Time, nt.Valid = v, true + return + case []byte: + nt.Time, err = parseDateTime(string(v), time.UTC) + nt.Valid = (err == nil) + return + case string: + nt.Time, err = parseDateTime(v, time.UTC) + nt.Valid = (err == nil) + return + } + + nt.Valid = false + return fmt.Errorf("Can't convert %T to time.Time", value) +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { switch len(str) { case 10: // YYYY-MM-DD @@ -369,16 +375,6 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) { return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) } -func readBool(value string) bool { - switch strings.ToLower(value) { - case "true": - return true - case "1": - return true - } - return false -} - /****************************************************************************** * Convert from and to bytes * ******************************************************************************/ From c24056d3d5902e5bc521001142cc9ad2d80901b2 Mon Sep 17 00:00:00 2001 From: Julien Schmidt Date: Tue, 2 Jul 2013 00:45:30 +0200 Subject: [PATCH 2/2] strict tls.Config key check --- connection.go | 12 ++++++++-- driver_test.go | 2 +- utils.go | 64 +++++++++++++++++++++++++++++++++++--------------- 3 files changed, 56 insertions(+), 22 deletions(-) diff --git a/connection.go b/connection.go index a62db429a..df07b955e 100644 --- a/connection.go +++ b/connection.go @@ -68,11 +68,19 @@ func (mc *mysqlConn) handleParams() (err error) { // time.Time parsing case "parseTime": - mc.parseTime = readBool(val) + var isBool bool + mc.parseTime, isBool = readBool(val) + if !isBool { + return errors.New("Invalid Bool value: " + val) + } // Strict mode case "strict": - mc.strict = readBool(val) + var isBool bool + mc.strict, isBool = readBool(val) + if !isBool { + return errors.New("Invalid Bool value: " + val) + } // Compression case "compress": diff --git a/driver_test.go b/driver_test.go index 246e83ee5..adf7af431 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1053,7 +1053,7 @@ func TestStmtMultiRows(t *testing.T) { } func TestConcurrent(t *testing.T) { - if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true { + if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled { t.Skip("CONCURRENT env var not set") } diff --git a/utils.go b/utils.go index b8bb4b867..6658ef6f3 100644 --- a/utils.go +++ b/utils.go @@ -41,7 +41,7 @@ func init() { tlsConfigRegister = make(map[string]*tls.Config) } -// Registers a custom tls.Config to be used with sql.Open. +// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // // rootCertPool := x509.NewCertPool() @@ -64,11 +64,16 @@ func init() { // }) // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // -func RegisterTLSConfig(key string, config *tls.Config) { +func RegisterTLSConfig(key string, config *tls.Config) error { + if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { + return fmt.Errorf("Key '%s' is reserved", key) + } + tlsConfigRegister[key] = config + return nil } -// Removes tls.Config associated with key. +// DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) { delete(tlsConfigRegister, key) } @@ -104,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) { // Disable INFILE whitelist / enable all files case "allowAllFiles": - cfg.allowAllFiles = readBool(value) + var isBool bool + cfg.allowAllFiles, isBool = readBool(value) + if !isBool { + err = fmt.Errorf("Invalid Bool value: %s", value) + return + } // Switch "rowsAffected" mode case "clientFoundRows": - cfg.clientFoundRows = readBool(value) + var isBool bool + cfg.clientFoundRows, isBool = readBool(value) + if !isBool { + err = fmt.Errorf("Invalid Bool value: %s", value) + return + } // Time Location case "loc": @@ -126,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) { // TLS-Encryption case "tls": - if readBool(value) { - cfg.tls = &tls.Config{} - } else if strings.ToLower(value) == "skip-verify" { - cfg.tls = &tls.Config{InsecureSkipVerify: true} - // TODO: Check for Boolean false - } else if tlsConfig, ok := tlsConfigRegister[value]; ok { - cfg.tls = tlsConfig + boolValue, isBool := readBool(value) + if isBool { + if boolValue { + cfg.tls = &tls.Config{} + } + } else { + if strings.ToLower(value) == "skip-verify" { + cfg.tls = &tls.Config{InsecureSkipVerify: true} + } else if tlsConfig, ok := tlsConfigRegister[value]; ok { + cfg.tls = tlsConfig + } else { + err = fmt.Errorf("Invalid value / unknown config name: %s", value) + return + } } default: @@ -191,14 +213,18 @@ func scramblePassword(scramble, password []byte) []byte { return scramble } -func readBool(value string) bool { - switch strings.ToLower(value) { - case "true": - return true - case "1": - return true +// Returns the bool value of the input. +// The 2nd return value indicates if the input was a valid bool value +func readBool(input string) (value bool, valid bool) { + switch input { + case "1", "true", "TRUE", "True": + return true, true + case "0", "false", "FALSE", "False": + return false, true } - return false + + // Not a valid bool value + return } /******************************************************************************