Skip to content

Commit abd1799

Browse files
committed
Remove possible data race in tls.Config map init
also resorted code
1 parent a9ece60 commit abd1799

File tree

1 file changed

+85
-89
lines changed

1 file changed

+85
-89
lines changed

utils.go

Lines changed: 85 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -23,62 +23,24 @@ import (
2323
"time"
2424
)
2525

26-
// NullTime represents a time.Time that may be NULL.
27-
// NullTime implements the Scanner interface so
28-
// it can be used as a scan destination:
29-
//
30-
// var nt NullTime
31-
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
32-
// ...
33-
// if nt.Valid {
34-
// // use nt.Time
35-
// } else {
36-
// // NULL value
37-
// }
38-
//
39-
// This NullTime implementation is not driver-specific
40-
type NullTime struct {
41-
Time time.Time
42-
Valid bool // Valid is true if Time is not NULL
43-
}
44-
45-
// Scan implements the Scanner interface.
46-
// The value type must be time.Time or string / []byte (formatted time-string),
47-
// otherwise Scan fails.
48-
func (nt *NullTime) Scan(value interface{}) (err error) {
49-
if value == nil {
50-
nt.Time, nt.Valid = time.Time{}, false
51-
return
52-
}
26+
var (
27+
errLog *log.Logger // Error Logger
28+
dsnPattern *regexp.Regexp // Data Source Name Parser
29+
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
30+
)
5331

54-
switch v := value.(type) {
55-
case time.Time:
56-
nt.Time, nt.Valid = v, true
57-
return
58-
case []byte:
59-
nt.Time, err = parseDateTime(string(v), time.UTC)
60-
nt.Valid = (err == nil)
61-
return
62-
case string:
63-
nt.Time, err = parseDateTime(v, time.UTC)
64-
nt.Valid = (err == nil)
65-
return
66-
}
32+
func init() {
33+
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
6734

68-
nt.Valid = false
69-
return fmt.Errorf("Can't convert %T to time.Time", value)
70-
}
35+
dsnPattern = regexp.MustCompile(
36+
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
37+
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
38+
`\/(?P<dbname>.*?)` + // /dbname
39+
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
7140

72-
// Value implements the driver Valuer interface.
73-
func (nt NullTime) Value() (driver.Value, error) {
74-
if !nt.Valid {
75-
return nil, nil
76-
}
77-
return nt.Time, nil
41+
tlsConfigRegister = make(map[string]*tls.Config)
7842
}
7943

80-
var tlsConfigMap map[string]*tls.Config
81-
8244
// Registers a custom tls.Config to be used with sql.Open.
8345
// Use the key as a value in the DSN where tls=value.
8446
//
@@ -103,38 +65,14 @@ var tlsConfigMap map[string]*tls.Config
10365
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
10466
//
10567
func RegisterTLSConfig(key string, config *tls.Config) {
106-
if tlsConfigMap == nil {
107-
tlsConfigMap = make(map[string]*tls.Config)
108-
}
109-
tlsConfigMap[key] = config
68+
tlsConfigRegister[key] = config
11069
}
11170

11271
// Removes tls.Config associated with key.
11372
func DeregisterTLSConfig(key string) {
114-
if tlsConfigMap == nil {
115-
return
116-
}
117-
delete(tlsConfigMap, key)
73+
delete(tlsConfigRegister, key)
11874
}
11975

120-
// Logger
121-
var (
122-
errLog *log.Logger
123-
)
124-
125-
func init() {
126-
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
127-
128-
dsnPattern = regexp.MustCompile(
129-
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
130-
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
131-
`\/(?P<dbname>.*?)` + // /dbname
132-
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
133-
}
134-
135-
// Data Source Name Parser
136-
var dsnPattern *regexp.Regexp
137-
13876
func parseDSN(dsn string) (cfg *config, err error) {
13977
cfg = new(config)
14078
cfg.params = make(map[string]string)
@@ -192,8 +130,8 @@ func parseDSN(dsn string) (cfg *config, err error) {
192130
cfg.tls = &tls.Config{}
193131
} else if strings.ToLower(value) == "skip-verify" {
194132
cfg.tls = &tls.Config{InsecureSkipVerify: true}
195-
// TODO: Check for Boolean false
196-
} else if tlsConfig, ok := tlsConfigMap[value]; ok {
133+
// TODO: Check for Boolean false
134+
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
197135
cfg.tls = tlsConfig
198136
}
199137

@@ -253,6 +191,74 @@ func scramblePassword(scramble, password []byte) []byte {
253191
return scramble
254192
}
255193

194+
func readBool(value string) bool {
195+
switch strings.ToLower(value) {
196+
case "true":
197+
return true
198+
case "1":
199+
return true
200+
}
201+
return false
202+
}
203+
204+
/******************************************************************************
205+
* Time related utils *
206+
******************************************************************************/
207+
208+
// NullTime represents a time.Time that may be NULL.
209+
// NullTime implements the Scanner interface so
210+
// it can be used as a scan destination:
211+
//
212+
// var nt NullTime
213+
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
214+
// ...
215+
// if nt.Valid {
216+
// // use nt.Time
217+
// } else {
218+
// // NULL value
219+
// }
220+
//
221+
// This NullTime implementation is not driver-specific
222+
type NullTime struct {
223+
Time time.Time
224+
Valid bool // Valid is true if Time is not NULL
225+
}
226+
227+
// Scan implements the Scanner interface.
228+
// The value type must be time.Time or string / []byte (formatted time-string),
229+
// otherwise Scan fails.
230+
func (nt *NullTime) Scan(value interface{}) (err error) {
231+
if value == nil {
232+
nt.Time, nt.Valid = time.Time{}, false
233+
return
234+
}
235+
236+
switch v := value.(type) {
237+
case time.Time:
238+
nt.Time, nt.Valid = v, true
239+
return
240+
case []byte:
241+
nt.Time, err = parseDateTime(string(v), time.UTC)
242+
nt.Valid = (err == nil)
243+
return
244+
case string:
245+
nt.Time, err = parseDateTime(v, time.UTC)
246+
nt.Valid = (err == nil)
247+
return
248+
}
249+
250+
nt.Valid = false
251+
return fmt.Errorf("Can't convert %T to time.Time", value)
252+
}
253+
254+
// Value implements the driver Valuer interface.
255+
func (nt NullTime) Value() (driver.Value, error) {
256+
if !nt.Valid {
257+
return nil, nil
258+
}
259+
return nt.Time, nil
260+
}
261+
256262
func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
257263
switch len(str) {
258264
case 10: // YYYY-MM-DD
@@ -369,16 +375,6 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
369375
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
370376
}
371377

372-
func readBool(value string) bool {
373-
switch strings.ToLower(value) {
374-
case "true":
375-
return true
376-
case "1":
377-
return true
378-
}
379-
return false
380-
}
381-
382378
/******************************************************************************
383379
* Convert from and to bytes *
384380
******************************************************************************/

0 commit comments

Comments
 (0)