Skip to content

Commit 2e04012

Browse files
committed
Make logger configurable per connection
1 parent d83ecdc commit 2e04012

14 files changed

+247
-38
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ Stan Putrya <root.vagner at gmail.com>
9595
Stanley Gunawan <gunawan.stanley at gmail.com>
9696
Steven Hartland <steven.hartland at multiplay.co.uk>
9797
Tan Jinhua <312841925 at qq.com>
98+
Tetsuro Aoki <t.aoki1130 at gmail.com>
9899
Thomas Wodarek <wodarekwebpage at gmail.com>
99100
Tim Ruffles <timruffles at gmail.com>
100101
Tom Jenkinson <tom at tjenkinson.me>

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,16 @@ Note that this sets the location for time.Time values but does not change MySQL'
279279

280280
Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
281281

282+
##### `logging`
283+
284+
```
285+
Type: bool / string
286+
Valid Values: true, false, <name>
287+
Default: true
288+
```
289+
290+
`logging=false` disables logging. You can use a custom logger after registering it with [`mysql.RegisterLogger`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterLogger).
291+
282292
##### `maxAllowedPacket`
283293
```
284294
Type: decimal number

auth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
291291
return enc, err
292292

293293
default:
294-
errLog.Print("unknown auth plugin:", plugin)
294+
mc.errLog().Print("unknown auth plugin:", plugin)
295295
return nil, ErrUnknownPlugin
296296
}
297297
}

connection.go

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
105105

106106
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
107107
if mc.closed.Load() {
108-
errLog.Print(ErrInvalidConn)
108+
mc.errLog().Print(ErrInvalidConn)
109109
return nil, driver.ErrBadConn
110110
}
111111
var q string
@@ -147,7 +147,7 @@ func (mc *mysqlConn) cleanup() {
147147
return
148148
}
149149
if err := mc.netConn.Close(); err != nil {
150-
errLog.Print(err)
150+
mc.errLog().Print(err)
151151
}
152152
}
153153

@@ -163,14 +163,14 @@ func (mc *mysqlConn) error() error {
163163

164164
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
165165
if mc.closed.Load() {
166-
errLog.Print(ErrInvalidConn)
166+
mc.errLog().Print(ErrInvalidConn)
167167
return nil, driver.ErrBadConn
168168
}
169169
// Send command
170170
err := mc.writeCommandPacketStr(comStmtPrepare, query)
171171
if err != nil {
172172
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
173-
errLog.Print(err)
173+
mc.errLog().Print(err)
174174
return nil, driver.ErrBadConn
175175
}
176176

@@ -204,7 +204,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
204204
buf, err := mc.buf.takeCompleteBuffer()
205205
if err != nil {
206206
// can not take the buffer. Something must be wrong with the connection
207-
errLog.Print(err)
207+
mc.errLog().Print(err)
208208
return "", ErrInvalidConn
209209
}
210210
buf = buf[:0]
@@ -296,7 +296,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
296296

297297
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
298298
if mc.closed.Load() {
299-
errLog.Print(ErrInvalidConn)
299+
mc.errLog().Print(ErrInvalidConn)
300300
return nil, driver.ErrBadConn
301301
}
302302
if len(args) != 0 {
@@ -357,7 +357,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
357357

358358
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
359359
if mc.closed.Load() {
360-
errLog.Print(ErrInvalidConn)
360+
mc.errLog().Print(ErrInvalidConn)
361361
return nil, driver.ErrBadConn
362362
}
363363
if len(args) != 0 {
@@ -451,7 +451,7 @@ func (mc *mysqlConn) finish() {
451451
// Ping implements driver.Pinger interface
452452
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
453453
if mc.closed.Load() {
454-
errLog.Print(ErrInvalidConn)
454+
mc.errLog().Print(ErrInvalidConn)
455455
return driver.ErrBadConn
456456
}
457457

@@ -648,3 +648,10 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
648648
func (mc *mysqlConn) IsValid() bool {
649649
return !mc.closed.Load()
650650
}
651+
652+
func (mc *mysqlConn) errLog() Logger {
653+
if mc.cfg.Logger != nil {
654+
return mc.cfg.Logger
655+
}
656+
return defaultLogger
657+
}

connection_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ func TestPingErrInvalidConn(t *testing.T) {
179179
buf: newBuffer(nc),
180180
maxAllowedPacket: defaultMaxAllowedPacket,
181181
closech: make(chan struct{}),
182+
cfg: NewConfig(),
182183
}
183184

184185
err := ms.Ping(context.Background())

connector.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
9292
authResp, err := mc.auth(authData, plugin)
9393
if err != nil {
9494
// try the default auth plugin, if using the requested plugin failed
95-
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
95+
c.errLog().Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
9696
plugin = defaultAuthPlugin
9797
authResp, err = mc.auth(authData, plugin)
9898
if err != nil {
@@ -144,3 +144,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
144144
func (c *connector) Driver() driver.Driver {
145145
return &MySQLDriver{}
146146
}
147+
148+
func (c *connector) errLog() Logger {
149+
if c.cfg.Logger != nil {
150+
return c.cfg.Logger
151+
}
152+
return defaultLogger
153+
}

driver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1995,7 +1995,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) {
19951995
func TestUnixSocketAuthFail(t *testing.T) {
19961996
runTests(t, dsn, func(dbt *DBTest) {
19971997
// Save the current logger so we can restore it.
1998-
oldLogger := errLog
1998+
oldLogger := defaultLogger
19991999

20002000
// Set a new logger so we can capture its output.
20012001
buffer := bytes.NewBuffer(make([]byte, 0, 64))

dsn.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ type Config struct {
5050
Timeout time.Duration // Dial timeout
5151
ReadTimeout time.Duration // I/O read timeout
5252
WriteTimeout time.Duration // I/O write timeout
53+
LoggingConfig string // Logging configuration
54+
Logger Logger // Logger
5355

5456
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
5557
AllowCleartextPasswords bool // Allows the cleartext client side plugin
@@ -153,6 +155,20 @@ func (cfg *Config) normalize() error {
153155
}
154156
}
155157

158+
if cfg.Logger == nil {
159+
switch cfg.LoggingConfig {
160+
case "true", "":
161+
// use default logger
162+
case "false":
163+
cfg.Logger = defaultNopLogger
164+
default:
165+
cfg.Logger = getLogger(cfg.LoggingConfig)
166+
if cfg.Logger == nil {
167+
return errors.New("invalid value / unknown logger name: " + cfg.LoggingConfig)
168+
}
169+
}
170+
}
171+
156172
return nil
157173
}
158174

@@ -282,6 +298,10 @@ func (cfg *Config) FormatDSN() string {
282298
writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket))
283299
}
284300

301+
if len(cfg.LoggingConfig) > 0 {
302+
writeDSNParam(&buf, &hasParam, "logging", url.QueryEscape(cfg.LoggingConfig))
303+
}
304+
285305
// other params
286306
if cfg.Params != nil {
287307
var params []string
@@ -554,6 +574,23 @@ func parseDSNParams(cfg *Config, params string) (err error) {
554574
if err != nil {
555575
return
556576
}
577+
578+
case "logging":
579+
boolValue, isBool := readBool(value)
580+
if isBool {
581+
if boolValue {
582+
cfg.LoggingConfig = "true"
583+
} else {
584+
cfg.LoggingConfig = "false"
585+
}
586+
} else {
587+
name, err := url.QueryUnescape(value)
588+
if err != nil {
589+
return fmt.Errorf("invalid value for logger name: %v", err)
590+
}
591+
cfg.LoggingConfig = name
592+
}
593+
557594
default:
558595
// lazy init
559596
if cfg.Params == nil {

dsn_test.go

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ package mysql
1111
import (
1212
"crypto/tls"
1313
"fmt"
14+
"log"
1415
"net/url"
16+
"os"
1517
"reflect"
1618
"testing"
1719
"time"
@@ -33,14 +35,14 @@ var testDSNs = []struct {
3335
"user@unix(/path/to/socket)/dbname?charset=utf8",
3436
&Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
3537
}, {
36-
"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true",
37-
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"},
38+
"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true&logging=false",
39+
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true", LoggingConfig: "false"},
3840
}, {
39-
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
40-
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"},
41+
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify&logging=testLogger",
42+
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify", LoggingConfig: "testLogger"},
4143
}, {
42-
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true",
43-
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true},
44+
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true&logging=true",
45+
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, LoggingConfig: "true", AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true},
4446
}, {
4547
"user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0&allowFallbackToPlaintext=true",
4648
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowFallbackToPlaintext: true, AllowNativePasswords: false, CheckConnLiveness: false},
@@ -75,6 +77,10 @@ var testDSNs = []struct {
7577
}
7678

7779
func TestDSNParser(t *testing.T) {
80+
logger := log.New(os.Stderr, "", log.LstdFlags)
81+
RegisterLogger("testLogger", logger)
82+
defer DeregisterLogger("testLogger")
83+
7884
for i, tst := range testDSNs {
7985
cfg, err := ParseDSN(tst.in)
8086
if err != nil {
@@ -83,6 +89,7 @@ func TestDSNParser(t *testing.T) {
8389

8490
// pointer not static
8591
cfg.TLS = nil
92+
cfg.Logger = nil
8693

8794
if !reflect.DeepEqual(cfg, tst.out) {
8895
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
@@ -112,6 +119,10 @@ func TestDSNParserInvalid(t *testing.T) {
112119
}
113120

114121
func TestDSNReformat(t *testing.T) {
122+
logger := log.New(os.Stderr, "", log.LstdFlags)
123+
RegisterLogger("testLogger", logger)
124+
defer DeregisterLogger("testLogger")
125+
115126
for i, tst := range testDSNs {
116127
dsn1 := tst.in
117128
cfg1, err := ParseDSN(dsn1)
@@ -268,6 +279,93 @@ func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
268279
}
269280
}
270281

282+
func TestDSNWithCustomLogger(t *testing.T) {
283+
baseDSN := "User:password@tcp(localhost:5555)/dbname?logging="
284+
285+
t.Run("custom logger is registered", func(tt *testing.T) {
286+
logger := log.New(os.Stderr, "", log.LstdFlags)
287+
288+
RegisterLogger("testKey", logger)
289+
defer DeregisterLogger("testKey")
290+
291+
tst := baseDSN + "testKey"
292+
293+
cfg, err := ParseDSN(tst)
294+
if err != nil {
295+
tt.Fatal(err.Error())
296+
}
297+
298+
if cfg.LoggingConfig != "testKey" {
299+
tt.Errorf("unexpected cfg.LoggingConfig value: %q", cfg.LoggingConfig)
300+
}
301+
if cfg.Logger != logger {
302+
tt.Error("logger pointer doesn't match")
303+
}
304+
})
305+
306+
t.Run("custom logger is missing", func(tt *testing.T) {
307+
tst := baseDSN + "invalid_name"
308+
309+
cfg, err := ParseDSN(tst)
310+
if err == nil {
311+
tt.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg)
312+
}
313+
})
314+
}
315+
316+
func TestDSNLoggingConfig(t *testing.T) {
317+
t.Run("logging=true", func(tt *testing.T) {
318+
dsn := "User:password@tcp(localhost:5555)/dbname?logging=true"
319+
320+
cfg, err := ParseDSN(dsn)
321+
if err != nil {
322+
tt.Fatal(err.Error())
323+
}
324+
325+
if cfg.LoggingConfig != "true" {
326+
tt.Errorf("unexpected cfg.LoggingConfig value: %q", cfg.LoggingConfig)
327+
}
328+
if cfg.Logger != nil {
329+
tt.Error("cfg.Logger should be nil")
330+
}
331+
})
332+
333+
t.Run("logging=false", func(tt *testing.T) {
334+
dsn := "User:password@tcp(localhost:5555)/dbname?logging=false"
335+
336+
cfg, err := ParseDSN(dsn)
337+
if err != nil {
338+
tt.Fatal(err.Error())
339+
}
340+
341+
if cfg.LoggingConfig != "false" {
342+
tt.Errorf("unexpected cfg.LoggingConfig value: %q", cfg.LoggingConfig)
343+
}
344+
if cfg.Logger != defaultNopLogger {
345+
tt.Error("logger pointer doesn't match")
346+
}
347+
})
348+
}
349+
350+
func TestDSNWithCustomLoggerQueryEscape(t *testing.T) {
351+
const name = "&%!:"
352+
dsn := "User:password@tcp(localhost:5555)/dbname?logging=" + url.QueryEscape(name)
353+
354+
logger := log.New(os.Stderr, "", log.LstdFlags)
355+
356+
RegisterLogger(name, logger)
357+
defer DeregisterTLSConfig(name)
358+
359+
cfg, err := ParseDSN(dsn)
360+
if err != nil {
361+
t.Fatal(err.Error())
362+
}
363+
364+
if cfg.Logger != logger {
365+
t.Error("logger pointer doesn't match")
366+
}
367+
}
368+
271369
func TestDSNUnsafeCollation(t *testing.T) {
272370
_, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
273371
if err != errInvalidDSNUnsafeCollation {

0 commit comments

Comments
 (0)