-
Notifications
You must be signed in to change notification settings - Fork 2.3k
tls.Config: strict key check + remove data race #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,63 +23,25 @@ import ( | |
"time" | ||
) | ||
|
||
// NullTime represents a time.Time that may be NULL. | ||
// NullTime implements the Scanner interface so | ||
// it can be used as a scan destination: | ||
// | ||
// var nt NullTime | ||
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) | ||
// ... | ||
// if nt.Valid { | ||
// // use nt.Time | ||
// } else { | ||
// // NULL value | ||
// } | ||
// | ||
// This NullTime implementation is not driver-specific | ||
type NullTime struct { | ||
Time time.Time | ||
Valid bool // Valid is true if Time is not NULL | ||
} | ||
|
||
// Scan implements the Scanner interface. | ||
// The value type must be time.Time or string / []byte (formatted time-string), | ||
// otherwise Scan fails. | ||
func (nt *NullTime) Scan(value interface{}) (err error) { | ||
if value == nil { | ||
nt.Time, nt.Valid = time.Time{}, false | ||
return | ||
} | ||
var ( | ||
errLog *log.Logger // Error Logger | ||
dsnPattern *regexp.Regexp // Data Source Name Parser | ||
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs | ||
) | ||
|
||
switch v := value.(type) { | ||
case time.Time: | ||
nt.Time, nt.Valid = v, true | ||
return | ||
case []byte: | ||
nt.Time, err = parseDateTime(string(v), time.UTC) | ||
nt.Valid = (err == nil) | ||
return | ||
case string: | ||
nt.Time, err = parseDateTime(v, time.UTC) | ||
nt.Valid = (err == nil) | ||
return | ||
} | ||
func init() { | ||
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) | ||
|
||
nt.Valid = false | ||
return fmt.Errorf("Can't convert %T to time.Time", value) | ||
} | ||
dsnPattern = regexp.MustCompile( | ||
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@] | ||
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]] | ||
`\/(?P<dbname>.*?)` + // /dbname | ||
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN] | ||
|
||
// Value implements the driver Valuer interface. | ||
func (nt NullTime) Value() (driver.Value, error) { | ||
if !nt.Valid { | ||
return nil, nil | ||
} | ||
return nt.Time, nil | ||
tlsConfigRegister = make(map[string]*tls.Config) | ||
} | ||
|
||
var tlsConfigMap map[string]*tls.Config | ||
|
||
// Registers a custom tls.Config to be used with sql.Open. | ||
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. | ||
// Use the key as a value in the DSN where tls=value. | ||
// | ||
// rootCertPool := x509.NewCertPool() | ||
|
@@ -102,39 +64,20 @@ var tlsConfigMap map[string]*tls.Config | |
// }) | ||
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") | ||
// | ||
func RegisterTLSConfig(key string, config *tls.Config) { | ||
if tlsConfigMap == nil { | ||
tlsConfigMap = make(map[string]*tls.Config) | ||
func RegisterTLSConfig(key string, config *tls.Config) error { | ||
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { | ||
return fmt.Errorf("Key '%s' is reserved", key) | ||
} | ||
tlsConfigMap[key] = config | ||
} | ||
|
||
// Removes tls.Config associated with key. | ||
func DeregisterTLSConfig(key string) { | ||
if tlsConfigMap == nil { | ||
return | ||
} | ||
delete(tlsConfigMap, key) | ||
tlsConfigRegister[key] = config | ||
return nil | ||
} | ||
|
||
// Logger | ||
var ( | ||
errLog *log.Logger | ||
) | ||
|
||
func init() { | ||
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) | ||
|
||
dsnPattern = regexp.MustCompile( | ||
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@] | ||
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]] | ||
`\/(?P<dbname>.*?)` + // /dbname | ||
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN] | ||
// DeregisterTLSConfig removes the tls.Config associated with key. | ||
func DeregisterTLSConfig(key string) { | ||
delete(tlsConfigRegister, key) | ||
} | ||
|
||
// Data Source Name Parser | ||
var dsnPattern *regexp.Regexp | ||
|
||
func parseDSN(dsn string) (cfg *config, err error) { | ||
cfg = new(config) | ||
cfg.params = make(map[string]string) | ||
|
@@ -166,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) { | |
|
||
// Disable INFILE whitelist / enable all files | ||
case "allowAllFiles": | ||
cfg.allowAllFiles = readBool(value) | ||
var isBool bool | ||
cfg.allowAllFiles, isBool = readBool(value) | ||
if !isBool { | ||
err = fmt.Errorf("Invalid Bool value: %s", value) | ||
return | ||
} | ||
|
||
// Switch "rowsAffected" mode | ||
case "clientFoundRows": | ||
cfg.clientFoundRows = readBool(value) | ||
var isBool bool | ||
cfg.clientFoundRows, isBool = readBool(value) | ||
if !isBool { | ||
err = fmt.Errorf("Invalid Bool value: %s", value) | ||
return | ||
} | ||
|
||
// Time Location | ||
case "loc": | ||
|
@@ -188,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) { | |
|
||
// TLS-Encryption | ||
case "tls": | ||
if readBool(value) { | ||
cfg.tls = &tls.Config{} | ||
} else if strings.ToLower(value) == "skip-verify" { | ||
cfg.tls = &tls.Config{InsecureSkipVerify: true} | ||
// TODO: Check for Boolean false | ||
} else if tlsConfig, ok := tlsConfigMap[value]; ok { | ||
cfg.tls = tlsConfig | ||
boolValue, isBool := readBool(value) | ||
if isBool { | ||
if boolValue { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really like the variable name, but that's a minor issue 😀 |
||
cfg.tls = &tls.Config{} | ||
} | ||
} else { | ||
if strings.ToLower(value) == "skip-verify" { | ||
cfg.tls = &tls.Config{InsecureSkipVerify: true} | ||
} else if tlsConfig, ok := tlsConfigRegister[value]; ok { | ||
cfg.tls = tlsConfig | ||
} else { | ||
err = fmt.Errorf("Invalid value / unknown config name: %s", value) | ||
return | ||
} | ||
} | ||
|
||
default: | ||
|
@@ -253,6 +213,78 @@ func scramblePassword(scramble, password []byte) []byte { | |
return scramble | ||
} | ||
|
||
// Returns the bool value of the input. | ||
// The 2nd return value indicates if the input was a valid bool value | ||
func readBool(input string) (value bool, valid bool) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, then we'd have to compare with the err in some cases which I'd like to avoid. |
||
switch input { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think that this is necessary. I wouldn't consider "trUE" etc. as a valid bool formatting, neither does Go's stdlib. Also the current way is more efficient. |
||
case "1", "true", "TRUE", "True": | ||
return true, true | ||
case "0", "false", "FALSE", "False": | ||
return false, true | ||
} | ||
|
||
// Not a valid bool value | ||
return | ||
} | ||
|
||
/****************************************************************************** | ||
* Time related utils * | ||
******************************************************************************/ | ||
|
||
// NullTime represents a time.Time that may be NULL. | ||
// NullTime implements the Scanner interface so | ||
// it can be used as a scan destination: | ||
// | ||
// var nt NullTime | ||
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) | ||
// ... | ||
// if nt.Valid { | ||
// // use nt.Time | ||
// } else { | ||
// // NULL value | ||
// } | ||
// | ||
// This NullTime implementation is not driver-specific | ||
type NullTime struct { | ||
Time time.Time | ||
Valid bool // Valid is true if Time is not NULL | ||
} | ||
|
||
// Scan implements the Scanner interface. | ||
// The value type must be time.Time or string / []byte (formatted time-string), | ||
// otherwise Scan fails. | ||
func (nt *NullTime) Scan(value interface{}) (err error) { | ||
if value == nil { | ||
nt.Time, nt.Valid = time.Time{}, false | ||
return | ||
} | ||
|
||
switch v := value.(type) { | ||
case time.Time: | ||
nt.Time, nt.Valid = v, true | ||
return | ||
case []byte: | ||
nt.Time, err = parseDateTime(string(v), time.UTC) | ||
nt.Valid = (err == nil) | ||
return | ||
case string: | ||
nt.Time, err = parseDateTime(v, time.UTC) | ||
nt.Valid = (err == nil) | ||
return | ||
} | ||
|
||
nt.Valid = false | ||
return fmt.Errorf("Can't convert %T to time.Time", value) | ||
} | ||
|
||
// Value implements the driver Valuer interface. | ||
func (nt NullTime) Value() (driver.Value, error) { | ||
if !nt.Valid { | ||
return nil, nil | ||
} | ||
return nt.Time, nil | ||
} | ||
|
||
func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { | ||
switch len(str) { | ||
case 10: // YYYY-MM-DD | ||
|
@@ -369,16 +401,6 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) { | |
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) | ||
} | ||
|
||
func readBool(value string) bool { | ||
switch strings.ToLower(value) { | ||
case "true": | ||
return true | ||
case "1": | ||
return true | ||
} | ||
return false | ||
} | ||
|
||
/****************************************************************************** | ||
* Convert from and to bytes * | ||
******************************************************************************/ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer a switch case for all the stuff below... this if chaining is painful to read and each alternative terminates the chain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally find the if-chaining much more readable. The case-switch would need a lot of redundant checks, e.g.
!isBool && ...
for the complete else branch below