Skip to content

Commit 2f1342a

Browse files
committed
Merge pull request #106 from go-sql-driver/tls
tls.Config: strict key check + remove data race
2 parents 9bf94d4 + c24056d commit 2f1342a

File tree

3 files changed

+131
-101
lines changed

3 files changed

+131
-101
lines changed

connection.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,19 @@ func (mc *mysqlConn) handleParams() (err error) {
6868

6969
// time.Time parsing
7070
case "parseTime":
71-
mc.parseTime = readBool(val)
71+
var isBool bool
72+
mc.parseTime, isBool = readBool(val)
73+
if !isBool {
74+
return errors.New("Invalid Bool value: " + val)
75+
}
7276

7377
// Strict mode
7478
case "strict":
75-
mc.strict = readBool(val)
79+
var isBool bool
80+
mc.strict, isBool = readBool(val)
81+
if !isBool {
82+
return errors.New("Invalid Bool value: " + val)
83+
}
7684

7785
// Compression
7886
case "compress":

driver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,7 @@ func TestStmtMultiRows(t *testing.T) {
10531053
}
10541054

10551055
func TestConcurrent(t *testing.T) {
1056-
if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true {
1056+
if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
10571057
t.Skip("CONCURRENT env var not set")
10581058
}
10591059

utils.go

Lines changed: 120 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -23,63 +23,25 @@ import (
2323
"time"
2424
)
2525

26-
// NullTime represents a time.Time that may be NULL.
27-
// NullTime implements the Scanner interface so
28-
// it can be used as a scan destination:
29-
//
30-
// var nt NullTime
31-
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
32-
// ...
33-
// if nt.Valid {
34-
// // use nt.Time
35-
// } else {
36-
// // NULL value
37-
// }
38-
//
39-
// This NullTime implementation is not driver-specific
40-
type NullTime struct {
41-
Time time.Time
42-
Valid bool // Valid is true if Time is not NULL
43-
}
44-
45-
// Scan implements the Scanner interface.
46-
// The value type must be time.Time or string / []byte (formatted time-string),
47-
// otherwise Scan fails.
48-
func (nt *NullTime) Scan(value interface{}) (err error) {
49-
if value == nil {
50-
nt.Time, nt.Valid = time.Time{}, false
51-
return
52-
}
26+
var (
27+
errLog *log.Logger // Error Logger
28+
dsnPattern *regexp.Regexp // Data Source Name Parser
29+
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
30+
)
5331

54-
switch v := value.(type) {
55-
case time.Time:
56-
nt.Time, nt.Valid = v, true
57-
return
58-
case []byte:
59-
nt.Time, err = parseDateTime(string(v), time.UTC)
60-
nt.Valid = (err == nil)
61-
return
62-
case string:
63-
nt.Time, err = parseDateTime(v, time.UTC)
64-
nt.Valid = (err == nil)
65-
return
66-
}
32+
func init() {
33+
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
6734

68-
nt.Valid = false
69-
return fmt.Errorf("Can't convert %T to time.Time", value)
70-
}
35+
dsnPattern = regexp.MustCompile(
36+
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
37+
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
38+
`\/(?P<dbname>.*?)` + // /dbname
39+
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
7140

72-
// Value implements the driver Valuer interface.
73-
func (nt NullTime) Value() (driver.Value, error) {
74-
if !nt.Valid {
75-
return nil, nil
76-
}
77-
return nt.Time, nil
41+
tlsConfigRegister = make(map[string]*tls.Config)
7842
}
7943

80-
var tlsConfigMap map[string]*tls.Config
81-
82-
// Registers a custom tls.Config to be used with sql.Open.
44+
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
8345
// Use the key as a value in the DSN where tls=value.
8446
//
8547
// rootCertPool := x509.NewCertPool()
@@ -102,39 +64,20 @@ var tlsConfigMap map[string]*tls.Config
10264
// })
10365
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
10466
//
105-
func RegisterTLSConfig(key string, config *tls.Config) {
106-
if tlsConfigMap == nil {
107-
tlsConfigMap = make(map[string]*tls.Config)
67+
func RegisterTLSConfig(key string, config *tls.Config) error {
68+
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
69+
return fmt.Errorf("Key '%s' is reserved", key)
10870
}
109-
tlsConfigMap[key] = config
110-
}
11171

112-
// Removes tls.Config associated with key.
113-
func DeregisterTLSConfig(key string) {
114-
if tlsConfigMap == nil {
115-
return
116-
}
117-
delete(tlsConfigMap, key)
72+
tlsConfigRegister[key] = config
73+
return nil
11874
}
11975

120-
// Logger
121-
var (
122-
errLog *log.Logger
123-
)
124-
125-
func init() {
126-
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
127-
128-
dsnPattern = regexp.MustCompile(
129-
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
130-
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
131-
`\/(?P<dbname>.*?)` + // /dbname
132-
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
76+
// DeregisterTLSConfig removes the tls.Config associated with key.
77+
func DeregisterTLSConfig(key string) {
78+
delete(tlsConfigRegister, key)
13379
}
13480

135-
// Data Source Name Parser
136-
var dsnPattern *regexp.Regexp
137-
13881
func parseDSN(dsn string) (cfg *config, err error) {
13982
cfg = new(config)
14083
cfg.params = make(map[string]string)
@@ -166,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) {
166109

167110
// Disable INFILE whitelist / enable all files
168111
case "allowAllFiles":
169-
cfg.allowAllFiles = readBool(value)
112+
var isBool bool
113+
cfg.allowAllFiles, isBool = readBool(value)
114+
if !isBool {
115+
err = fmt.Errorf("Invalid Bool value: %s", value)
116+
return
117+
}
170118

171119
// Switch "rowsAffected" mode
172120
case "clientFoundRows":
173-
cfg.clientFoundRows = readBool(value)
121+
var isBool bool
122+
cfg.clientFoundRows, isBool = readBool(value)
123+
if !isBool {
124+
err = fmt.Errorf("Invalid Bool value: %s", value)
125+
return
126+
}
174127

175128
// Time Location
176129
case "loc":
@@ -188,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) {
188141

189142
// TLS-Encryption
190143
case "tls":
191-
if readBool(value) {
192-
cfg.tls = &tls.Config{}
193-
} else if strings.ToLower(value) == "skip-verify" {
194-
cfg.tls = &tls.Config{InsecureSkipVerify: true}
195-
// TODO: Check for Boolean false
196-
} else if tlsConfig, ok := tlsConfigMap[value]; ok {
197-
cfg.tls = tlsConfig
144+
boolValue, isBool := readBool(value)
145+
if isBool {
146+
if boolValue {
147+
cfg.tls = &tls.Config{}
148+
}
149+
} else {
150+
if strings.ToLower(value) == "skip-verify" {
151+
cfg.tls = &tls.Config{InsecureSkipVerify: true}
152+
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
153+
cfg.tls = tlsConfig
154+
} else {
155+
err = fmt.Errorf("Invalid value / unknown config name: %s", value)
156+
return
157+
}
198158
}
199159

200160
default:
@@ -253,6 +213,78 @@ func scramblePassword(scramble, password []byte) []byte {
253213
return scramble
254214
}
255215

216+
// Returns the bool value of the input.
217+
// The 2nd return value indicates if the input was a valid bool value
218+
func readBool(input string) (value bool, valid bool) {
219+
switch input {
220+
case "1", "true", "TRUE", "True":
221+
return true, true
222+
case "0", "false", "FALSE", "False":
223+
return false, true
224+
}
225+
226+
// Not a valid bool value
227+
return
228+
}
229+
230+
/******************************************************************************
231+
* Time related utils *
232+
******************************************************************************/
233+
234+
// NullTime represents a time.Time that may be NULL.
235+
// NullTime implements the Scanner interface so
236+
// it can be used as a scan destination:
237+
//
238+
// var nt NullTime
239+
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
240+
// ...
241+
// if nt.Valid {
242+
// // use nt.Time
243+
// } else {
244+
// // NULL value
245+
// }
246+
//
247+
// This NullTime implementation is not driver-specific
248+
type NullTime struct {
249+
Time time.Time
250+
Valid bool // Valid is true if Time is not NULL
251+
}
252+
253+
// Scan implements the Scanner interface.
254+
// The value type must be time.Time or string / []byte (formatted time-string),
255+
// otherwise Scan fails.
256+
func (nt *NullTime) Scan(value interface{}) (err error) {
257+
if value == nil {
258+
nt.Time, nt.Valid = time.Time{}, false
259+
return
260+
}
261+
262+
switch v := value.(type) {
263+
case time.Time:
264+
nt.Time, nt.Valid = v, true
265+
return
266+
case []byte:
267+
nt.Time, err = parseDateTime(string(v), time.UTC)
268+
nt.Valid = (err == nil)
269+
return
270+
case string:
271+
nt.Time, err = parseDateTime(v, time.UTC)
272+
nt.Valid = (err == nil)
273+
return
274+
}
275+
276+
nt.Valid = false
277+
return fmt.Errorf("Can't convert %T to time.Time", value)
278+
}
279+
280+
// Value implements the driver Valuer interface.
281+
func (nt NullTime) Value() (driver.Value, error) {
282+
if !nt.Valid {
283+
return nil, nil
284+
}
285+
return nt.Time, nil
286+
}
287+
256288
func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
257289
switch len(str) {
258290
case 10: // YYYY-MM-DD
@@ -369,16 +401,6 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
369401
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
370402
}
371403

372-
func readBool(value string) bool {
373-
switch strings.ToLower(value) {
374-
case "true":
375-
return true
376-
case "1":
377-
return true
378-
}
379-
return false
380-
}
381-
382404
/******************************************************************************
383405
* Convert from and to bytes *
384406
******************************************************************************/

0 commit comments

Comments
 (0)