@@ -23,63 +23,25 @@ import (
23
23
"time"
24
24
)
25
25
26
- // NullTime represents a time.Time that may be NULL.
27
- // NullTime implements the Scanner interface so
28
- // it can be used as a scan destination:
29
- //
30
- // var nt NullTime
31
- // err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
32
- // ...
33
- // if nt.Valid {
34
- // // use nt.Time
35
- // } else {
36
- // // NULL value
37
- // }
38
- //
39
- // This NullTime implementation is not driver-specific
40
- type NullTime struct {
41
- Time time.Time
42
- Valid bool // Valid is true if Time is not NULL
43
- }
44
-
45
- // Scan implements the Scanner interface.
46
- // The value type must be time.Time or string / []byte (formatted time-string),
47
- // otherwise Scan fails.
48
- func (nt * NullTime ) Scan (value interface {}) (err error ) {
49
- if value == nil {
50
- nt .Time , nt .Valid = time.Time {}, false
51
- return
52
- }
26
+ var (
27
+ errLog * log.Logger // Error Logger
28
+ dsnPattern * regexp.Regexp // Data Source Name Parser
29
+ tlsConfigRegister map [string ]* tls.Config // Register for custom tls.Configs
30
+ )
53
31
54
- switch v := value .(type ) {
55
- case time.Time :
56
- nt .Time , nt .Valid = v , true
57
- return
58
- case []byte :
59
- nt .Time , err = parseDateTime (string (v ), time .UTC )
60
- nt .Valid = (err == nil )
61
- return
62
- case string :
63
- nt .Time , err = parseDateTime (v , time .UTC )
64
- nt .Valid = (err == nil )
65
- return
66
- }
32
+ func init () {
33
+ errLog = log .New (os .Stderr , "[MySQL] " , log .Ldate | log .Ltime | log .Lshortfile )
67
34
68
- nt .Valid = false
69
- return fmt .Errorf ("Can't convert %T to time.Time" , value )
70
- }
35
+ dsnPattern = regexp .MustCompile (
36
+ `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
37
+ `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
38
+ `\/(?P<dbname>.*?)` + // /dbname
39
+ `(?:\?(?P<params>[^\?]*))?$` ) // [?param1=value1¶mN=valueN]
71
40
72
- // Value implements the driver Valuer interface.
73
- func (nt NullTime ) Value () (driver.Value , error ) {
74
- if ! nt .Valid {
75
- return nil , nil
76
- }
77
- return nt .Time , nil
41
+ tlsConfigRegister = make (map [string ]* tls.Config )
78
42
}
79
43
80
- var tlsConfigMap map [string ]* tls.Config
81
-
82
- // Registers a custom tls.Config to be used with sql.Open.
44
+ // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
83
45
// Use the key as a value in the DSN where tls=value.
84
46
//
85
47
// rootCertPool := x509.NewCertPool()
@@ -102,39 +64,20 @@ var tlsConfigMap map[string]*tls.Config
102
64
// })
103
65
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
104
66
//
105
- func RegisterTLSConfig (key string , config * tls.Config ) {
106
- if tlsConfigMap == nil {
107
- tlsConfigMap = make ( map [ string ] * tls. Config )
67
+ func RegisterTLSConfig (key string , config * tls.Config ) error {
68
+ if _ , isBool := readBool ( key ); isBool || strings . ToLower ( key ) == "skip-verify" {
69
+ return fmt . Errorf ( "Key '%s' is reserved" , key )
108
70
}
109
- tlsConfigMap [key ] = config
110
- }
111
71
112
- // Removes tls.Config associated with key.
113
- func DeregisterTLSConfig (key string ) {
114
- if tlsConfigMap == nil {
115
- return
116
- }
117
- delete (tlsConfigMap , key )
72
+ tlsConfigRegister [key ] = config
73
+ return nil
118
74
}
119
75
120
- // Logger
121
- var (
122
- errLog * log.Logger
123
- )
124
-
125
- func init () {
126
- errLog = log .New (os .Stderr , "[MySQL] " , log .Ldate | log .Ltime | log .Lshortfile )
127
-
128
- dsnPattern = regexp .MustCompile (
129
- `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
130
- `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
131
- `\/(?P<dbname>.*?)` + // /dbname
132
- `(?:\?(?P<params>[^\?]*))?$` ) // [?param1=value1¶mN=valueN]
76
+ // DeregisterTLSConfig removes the tls.Config associated with key.
77
+ func DeregisterTLSConfig (key string ) {
78
+ delete (tlsConfigRegister , key )
133
79
}
134
80
135
- // Data Source Name Parser
136
- var dsnPattern * regexp.Regexp
137
-
138
81
func parseDSN (dsn string ) (cfg * config , err error ) {
139
82
cfg = new (config )
140
83
cfg .params = make (map [string ]string )
@@ -166,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) {
166
109
167
110
// Disable INFILE whitelist / enable all files
168
111
case "allowAllFiles" :
169
- cfg .allowAllFiles = readBool (value )
112
+ var isBool bool
113
+ cfg .allowAllFiles , isBool = readBool (value )
114
+ if ! isBool {
115
+ err = fmt .Errorf ("Invalid Bool value: %s" , value )
116
+ return
117
+ }
170
118
171
119
// Switch "rowsAffected" mode
172
120
case "clientFoundRows" :
173
- cfg .clientFoundRows = readBool (value )
121
+ var isBool bool
122
+ cfg .clientFoundRows , isBool = readBool (value )
123
+ if ! isBool {
124
+ err = fmt .Errorf ("Invalid Bool value: %s" , value )
125
+ return
126
+ }
174
127
175
128
// Time Location
176
129
case "loc" :
@@ -188,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) {
188
141
189
142
// TLS-Encryption
190
143
case "tls" :
191
- if readBool (value ) {
192
- cfg .tls = & tls.Config {}
193
- } else if strings .ToLower (value ) == "skip-verify" {
194
- cfg .tls = & tls.Config {InsecureSkipVerify : true }
195
- // TODO: Check for Boolean false
196
- } else if tlsConfig , ok := tlsConfigMap [value ]; ok {
197
- cfg .tls = tlsConfig
144
+ boolValue , isBool := readBool (value )
145
+ if isBool {
146
+ if boolValue {
147
+ cfg .tls = & tls.Config {}
148
+ }
149
+ } else {
150
+ if strings .ToLower (value ) == "skip-verify" {
151
+ cfg .tls = & tls.Config {InsecureSkipVerify : true }
152
+ } else if tlsConfig , ok := tlsConfigRegister [value ]; ok {
153
+ cfg .tls = tlsConfig
154
+ } else {
155
+ err = fmt .Errorf ("Invalid value / unknown config name: %s" , value )
156
+ return
157
+ }
198
158
}
199
159
200
160
default :
@@ -253,6 +213,78 @@ func scramblePassword(scramble, password []byte) []byte {
253
213
return scramble
254
214
}
255
215
216
+ // Returns the bool value of the input.
217
+ // The 2nd return value indicates if the input was a valid bool value
218
+ func readBool (input string ) (value bool , valid bool ) {
219
+ switch input {
220
+ case "1" , "true" , "TRUE" , "True" :
221
+ return true , true
222
+ case "0" , "false" , "FALSE" , "False" :
223
+ return false , true
224
+ }
225
+
226
+ // Not a valid bool value
227
+ return
228
+ }
229
+
230
+ /******************************************************************************
231
+ * Time related utils *
232
+ ******************************************************************************/
233
+
234
+ // NullTime represents a time.Time that may be NULL.
235
+ // NullTime implements the Scanner interface so
236
+ // it can be used as a scan destination:
237
+ //
238
+ // var nt NullTime
239
+ // err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
240
+ // ...
241
+ // if nt.Valid {
242
+ // // use nt.Time
243
+ // } else {
244
+ // // NULL value
245
+ // }
246
+ //
247
+ // This NullTime implementation is not driver-specific
248
+ type NullTime struct {
249
+ Time time.Time
250
+ Valid bool // Valid is true if Time is not NULL
251
+ }
252
+
253
+ // Scan implements the Scanner interface.
254
+ // The value type must be time.Time or string / []byte (formatted time-string),
255
+ // otherwise Scan fails.
256
+ func (nt * NullTime ) Scan (value interface {}) (err error ) {
257
+ if value == nil {
258
+ nt .Time , nt .Valid = time.Time {}, false
259
+ return
260
+ }
261
+
262
+ switch v := value .(type ) {
263
+ case time.Time :
264
+ nt .Time , nt .Valid = v , true
265
+ return
266
+ case []byte :
267
+ nt .Time , err = parseDateTime (string (v ), time .UTC )
268
+ nt .Valid = (err == nil )
269
+ return
270
+ case string :
271
+ nt .Time , err = parseDateTime (v , time .UTC )
272
+ nt .Valid = (err == nil )
273
+ return
274
+ }
275
+
276
+ nt .Valid = false
277
+ return fmt .Errorf ("Can't convert %T to time.Time" , value )
278
+ }
279
+
280
+ // Value implements the driver Valuer interface.
281
+ func (nt NullTime ) Value () (driver.Value , error ) {
282
+ if ! nt .Valid {
283
+ return nil , nil
284
+ }
285
+ return nt .Time , nil
286
+ }
287
+
256
288
func parseDateTime (str string , loc * time.Location ) (t time.Time , err error ) {
257
289
switch len (str ) {
258
290
case 10 : // YYYY-MM-DD
@@ -369,16 +401,6 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
369
401
return nil , fmt .Errorf ("Invalid DATETIME-packet length %d" , num )
370
402
}
371
403
372
- func readBool (value string ) bool {
373
- switch strings .ToLower (value ) {
374
- case "true" :
375
- return true
376
- case "1" :
377
- return true
378
- }
379
- return false
380
- }
381
-
382
404
/******************************************************************************
383
405
* Convert from and to bytes *
384
406
******************************************************************************/
0 commit comments