Skip to content

tls.Config: strict key check + remove data race #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 1, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,19 @@ func (mc *mysqlConn) handleParams() (err error) {

// time.Time parsing
case "parseTime":
mc.parseTime = readBool(val)
var isBool bool
mc.parseTime, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}

// Strict mode
case "strict":
mc.strict = readBool(val)
var isBool bool
mc.strict, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}

// Compression
case "compress":
Expand Down
2 changes: 1 addition & 1 deletion driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ func TestStmtMultiRows(t *testing.T) {
}

func TestConcurrent(t *testing.T) {
if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true {
if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
t.Skip("CONCURRENT env var not set")
}

Expand Down
218 changes: 120 additions & 98 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,63 +23,25 @@ import (
"time"
)

// NullTime represents a time.Time that may be NULL.
// NullTime implements the Scanner interface so
// it can be used as a scan destination:
//
// var nt NullTime
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
// ...
// if nt.Valid {
// // use nt.Time
// } else {
// // NULL value
// }
//
// This NullTime implementation is not driver-specific
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}

// Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails.
func (nt *NullTime) Scan(value interface{}) (err error) {
if value == nil {
nt.Time, nt.Valid = time.Time{}, false
return
}
var (
errLog *log.Logger // Error Logger
dsnPattern *regexp.Regexp // Data Source Name Parser
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
)

switch v := value.(type) {
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
}
func init() {
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)

nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value)
}
dsnPattern = regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]

// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
tlsConfigRegister = make(map[string]*tls.Config)
}

var tlsConfigMap map[string]*tls.Config

// Registers a custom tls.Config to be used with sql.Open.
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
// Use the key as a value in the DSN where tls=value.
//
// rootCertPool := x509.NewCertPool()
Expand All @@ -102,39 +64,20 @@ var tlsConfigMap map[string]*tls.Config
// })
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
//
func RegisterTLSConfig(key string, config *tls.Config) {
if tlsConfigMap == nil {
tlsConfigMap = make(map[string]*tls.Config)
func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
return fmt.Errorf("Key '%s' is reserved", key)
}
tlsConfigMap[key] = config
}

// Removes tls.Config associated with key.
func DeregisterTLSConfig(key string) {
if tlsConfigMap == nil {
return
}
delete(tlsConfigMap, key)
tlsConfigRegister[key] = config
return nil
}

// Logger
var (
errLog *log.Logger
)

func init() {
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)

dsnPattern = regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
// DeregisterTLSConfig removes the tls.Config associated with key.
func DeregisterTLSConfig(key string) {
delete(tlsConfigRegister, key)
}

// Data Source Name Parser
var dsnPattern *regexp.Regexp

func parseDSN(dsn string) (cfg *config, err error) {
cfg = new(config)
cfg.params = make(map[string]string)
Expand Down Expand Up @@ -166,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) {

// Disable INFILE whitelist / enable all files
case "allowAllFiles":
cfg.allowAllFiles = readBool(value)
var isBool bool
cfg.allowAllFiles, isBool = readBool(value)
if !isBool {
err = fmt.Errorf("Invalid Bool value: %s", value)
return
}

// Switch "rowsAffected" mode
case "clientFoundRows":
cfg.clientFoundRows = readBool(value)
var isBool bool
cfg.clientFoundRows, isBool = readBool(value)
if !isBool {
err = fmt.Errorf("Invalid Bool value: %s", value)
return
}

// Time Location
case "loc":
Expand All @@ -188,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) {

// TLS-Encryption
case "tls":
if readBool(value) {
cfg.tls = &tls.Config{}
} else if strings.ToLower(value) == "skip-verify" {
cfg.tls = &tls.Config{InsecureSkipVerify: true}
// TODO: Check for Boolean false
} else if tlsConfig, ok := tlsConfigMap[value]; ok {
cfg.tls = tlsConfig
boolValue, isBool := readBool(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer a switch case for all the stuff below... this if chaining is painful to read and each alternative terminates the chain.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally find the if-chaining much more readable. The case-switch would need a lot of redundant checks, e.g. !isBool && ... for the complete else branch below

if isBool {
if boolValue {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like the variable name, but that's a minor issue 😀

cfg.tls = &tls.Config{}
}
} else {
if strings.ToLower(value) == "skip-verify" {
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
cfg.tls = tlsConfig
} else {
err = fmt.Errorf("Invalid value / unknown config name: %s", value)
return
}
}

default:
Expand Down Expand Up @@ -253,6 +213,78 @@ func scramblePassword(scramble, password []byte) []byte {
return scramble
}

// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool, error instead of bool, bool as return value to get rid of some duplication?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, then we'd have to compare with the err in some cases which I'd like to avoid.

switch input {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch strings.ToLower(input)...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that this is necessary. I wouldn't consider "trUE" etc. as a valid bool formatting, neither does Go's stdlib. Also the current way is more efficient.

case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}

// Not a valid bool value
return
}

/******************************************************************************
* Time related utils *
******************************************************************************/

// NullTime represents a time.Time that may be NULL.
// NullTime implements the Scanner interface so
// it can be used as a scan destination:
//
// var nt NullTime
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
// ...
// if nt.Valid {
// // use nt.Time
// } else {
// // NULL value
// }
//
// This NullTime implementation is not driver-specific
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}

// Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails.
func (nt *NullTime) Scan(value interface{}) (err error) {
if value == nil {
nt.Time, nt.Valid = time.Time{}, false
return
}

switch v := value.(type) {
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
}

nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value)
}

// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}

func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
switch len(str) {
case 10: // YYYY-MM-DD
Expand Down Expand Up @@ -369,16 +401,6 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
}

func readBool(value string) bool {
switch strings.ToLower(value) {
case "true":
return true
case "1":
return true
}
return false
}

/******************************************************************************
* Convert from and to bytes *
******************************************************************************/
Expand Down