Skip to content

Commit 54246cb

Browse files
committed
Merge branch 'master' of https://github.com/go-sql-driver/mysql into cleanup
2 parents c27f685 + 55a708b commit 54246cb

File tree

8 files changed

+198
-51
lines changed

8 files changed

+198
-51
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,14 @@ For Unix domain sockets the address is the absolute path to the MySQL-Server-soc
105105
***Parameters are case-sensitive!***
106106

107107
Possible Parameters are:
108-
* `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
109-
* `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
110108
* `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
111-
* `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
109+
* `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
110+
* `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
112111
* `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
112+
* `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
113113
* `strict`: Enable strict mode. MySQL warnings are treated as errors.
114+
* `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
115+
* `tls`: `true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side)
114116

115117
All other parameters are interpreted as system variables:
116118
* `autocommit`: *"SET autocommit=`value`"*

connection.go

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package mysql
1111

1212
import (
13+
"crypto/tls"
1314
"database/sql/driver"
1415
"errors"
1516
"net"
@@ -35,13 +36,15 @@ type mysqlConn struct {
3536
}
3637

3738
type config struct {
38-
user string
39-
passwd string
40-
net string
41-
addr string
42-
dbname string
43-
params map[string]string
44-
loc *time.Location
39+
user string
40+
passwd string
41+
net string
42+
addr string
43+
dbname string
44+
params map[string]string
45+
loc *time.Location
46+
timeout time.Duration
47+
tls *tls.Config
4548
}
4649

4750
// Handles parameters set in DSN
@@ -63,7 +66,7 @@ func (mc *mysqlConn) handleParams() (err error) {
6366
}
6467

6568
// handled elsewhere
66-
case "timeout", "allowAllFiles", "loc":
69+
case "allowAllFiles", "clientFoundRows":
6770
continue
6871

6972
// time.Time parsing
@@ -74,14 +77,10 @@ func (mc *mysqlConn) handleParams() (err error) {
7477
case "strict":
7578
mc.strict = readBool(val)
7679

77-
// TLS-Encryption
78-
case "tls":
79-
err = errors.New("TLS-Encryption not implemented yet")
80-
return
81-
8280
// Compression
8381
case "compress":
8482
err = errors.New("Compression not implemented yet")
83+
return
8584

8685
// System Vars
8786
default:

driver.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"database/sql"
1313
"database/sql/driver"
1414
"net"
15-
"time"
1615
)
1716

1817
type mysqlDriver struct{}
@@ -34,11 +33,9 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
3433
}
3534

3635
// Connect to Server
37-
if _, ok := mc.cfg.params["timeout"]; ok { // with timeout
38-
var timeout time.Duration
39-
timeout, err = time.ParseDuration(mc.cfg.params["timeout"])
36+
if mc.cfg.timeout > 0 { // with timeout
4037
if err == nil {
41-
mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, timeout)
38+
mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, mc.cfg.timeout)
4239
}
4340
} else { // no timeout
4441
mc.netConn, err = net.Dial(mc.cfg.net, mc.cfg.addr)

driver_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,42 @@ func TestStrict(t *testing.T) {
807807
})
808808
}
809809

810+
func TestTLS(t *testing.T) {
811+
runTests(t, "TestTLS", dsn+"&tls=skip-verify", func(dbt *DBTest) {
812+
/* TODO: GO 1.1 API */
813+
/*if err := dbt.db.Ping(); err != nil {
814+
if err == errNoTLS {
815+
dbt.Skip("Server does not support TLS. Skipping TestTLS")
816+
} else {
817+
dbt.Fatalf("Error on Ping: %s", err.Error())
818+
}
819+
}*/
820+
821+
/* GO 1.0 API */
822+
if _, err := dbt.db.Exec("DO 1"); err != nil {
823+
if err == errNoTLS {
824+
dbt.Log("Server does not support TLS. Skipping TestTLS")
825+
return
826+
} else {
827+
dbt.Fatalf("Error on Ping: %s", err.Error())
828+
}
829+
}
830+
831+
rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'")
832+
833+
var variable, value *sql.RawBytes
834+
for rows.Next() {
835+
if err := rows.Scan(&variable, &value); err != nil {
836+
dbt.Fatal(err.Error())
837+
}
838+
839+
if value == nil {
840+
dbt.Fatal("No Cipher")
841+
}
842+
}
843+
})
844+
}
845+
810846
// Special cases
811847

812848
func TestRowsClose(t *testing.T) {
@@ -1036,6 +1072,51 @@ func TestConcurrent(t *testing.T) {
10361072
})
10371073
}
10381074

1075+
func TestFoundRows(t *testing.T) {
1076+
runTests(t, "TestFoundRows1", dsn, func(dbt *DBTest) {
1077+
dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
1078+
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
1079+
1080+
res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
1081+
count, err := res.RowsAffected()
1082+
if err != nil {
1083+
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1084+
}
1085+
if count != 2 {
1086+
dbt.Fatalf("Expected 2 affected rows, got %d", count)
1087+
}
1088+
res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
1089+
count, err = res.RowsAffected()
1090+
if err != nil {
1091+
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1092+
}
1093+
if count != 2 {
1094+
dbt.Fatalf("Expected 2 affected rows, got %d", count)
1095+
}
1096+
})
1097+
runTests(t, "TestFoundRows2", dsn+"&clientFoundRows=true", func(dbt *DBTest) {
1098+
dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
1099+
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
1100+
1101+
res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
1102+
count, err := res.RowsAffected()
1103+
if err != nil {
1104+
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1105+
}
1106+
if count != 2 {
1107+
dbt.Fatalf("Expected 2 matched rows, got %d", count)
1108+
}
1109+
res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
1110+
count, err = res.RowsAffected()
1111+
if err != nil {
1112+
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1113+
}
1114+
if count != 3 {
1115+
dbt.Fatalf("Expected 3 matched rows, got %d", count)
1116+
}
1117+
})
1118+
}
1119+
10391120
// BENCHMARKS
10401121
var sample []byte
10411122

errors.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ import (
1818

1919
var (
2020
errMalformPkt = errors.New("Malformed Packet")
21+
errNoTLS = errors.New("TLS encryption requested but server does not support TLS")
22+
errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
23+
errOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
2124
errPktSync = errors.New("Commands out of sync. You can't run this command now")
2225
errPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
23-
errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
2426
errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
2527
)
2628

packets.go

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ package mysql
1111

1212
import (
1313
"bytes"
14+
"crypto/tls"
1415
"database/sql/driver"
1516
"encoding/binary"
16-
"errors"
1717
"fmt"
1818
"io"
1919
"math"
@@ -167,7 +167,10 @@ func (mc *mysqlConn) readInitPacket() (err error) {
167167
// capability flags (lower 2 bytes) [2 bytes]
168168
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
169169
if mc.flags&clientProtocol41 == 0 {
170-
err = errors.New("MySQL-Server does not support required Protocol 41+")
170+
err = errOldProtocol
171+
}
172+
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
173+
return errNoTLS
171174
}
172175
pos += 2
173176

@@ -205,15 +208,20 @@ func (mc *mysqlConn) readInitPacket() (err error) {
205208
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
206209
func (mc *mysqlConn) writeAuthPacket() error {
207210
// Adjust client flags based on server support
208-
clientFlags := uint32(
209-
clientProtocol41 |
210-
clientSecureConn |
211-
clientLongPassword |
212-
clientTransactions |
213-
clientLocalFiles,
214-
)
215-
if mc.flags&clientLongFlag > 0 {
216-
clientFlags |= uint32(clientLongFlag)
211+
clientFlags := clientProtocol41 |
212+
clientSecureConn |
213+
clientLongPassword |
214+
clientTransactions |
215+
clientLocalFiles |
216+
mc.flags&clientLongFlag
217+
218+
if _, ok := mc.cfg.params["clientFoundRows"]; ok {
219+
clientFlags |= clientFoundRows
220+
}
221+
222+
// To enable TLS / SSL
223+
if mc.cfg.tls != nil {
224+
clientFlags |= clientSSL
217225
}
218226

219227
// User Password
@@ -224,19 +232,13 @@ func (mc *mysqlConn) writeAuthPacket() error {
224232

225233
// To specify a db name
226234
if len(mc.cfg.dbname) > 0 {
227-
clientFlags |= uint32(clientConnectWithDB)
235+
clientFlags |= clientConnectWithDB
228236
pktLen += len(mc.cfg.dbname) + 1
229237
}
230238

231239
// Calculate packet length and make buffer with that size
232240
data := make([]byte, pktLen+4)
233241

234-
// Add the packet header [24bit length + 1 byte sequence]
235-
data[0] = byte(pktLen)
236-
data[1] = byte(pktLen >> 8)
237-
data[2] = byte(pktLen >> 16)
238-
data[3] = mc.sequence
239-
240242
// ClientFlags [32 bit]
241243
data[4] = byte(clientFlags)
242244
data[5] = byte(clientFlags >> 8)
@@ -252,6 +254,35 @@ func (mc *mysqlConn) writeAuthPacket() error {
252254
// Charset [1 byte]
253255
data[12] = mc.charset
254256

257+
// SSL Connection Request Packet
258+
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::SSLRequest
259+
if mc.cfg.tls != nil {
260+
// Packet header [24bit length + 1 byte sequence]
261+
data[0] = byte((4 + 4 + 1 + 23))
262+
data[1] = byte((4 + 4 + 1 + 23) >> 8)
263+
data[2] = byte((4 + 4 + 1 + 23) >> 16)
264+
data[3] = mc.sequence
265+
266+
// Send TLS / SSL request packet
267+
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
268+
return err
269+
}
270+
271+
// Switch to TLS
272+
tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
273+
if err := tlsConn.Handshake(); err != nil {
274+
return err
275+
}
276+
mc.netConn = tlsConn
277+
mc.buf.rd = tlsConn
278+
}
279+
280+
// Add the packet header [24bit length + 1 byte sequence]
281+
data[0] = byte(pktLen)
282+
data[1] = byte(pktLen >> 8)
283+
data[2] = byte(pktLen >> 16)
284+
data[3] = mc.sequence
285+
255286
// Filler [23 bytes] (all 0x00)
256287
pos := 13 + 23
257288

utils.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ package mysql
1111

1212
import (
1313
"crypto/sha1"
14+
"crypto/tls"
1415
"database/sql/driver"
1516
"encoding/binary"
1617
"fmt"
@@ -119,7 +120,35 @@ func parseDSN(dsn string) (cfg *config, err error) {
119120
if len(param) != 2 {
120121
continue
121122
}
122-
cfg.params[param[0]] = param[1]
123+
124+
// cfg params
125+
switch value := param[1]; param[0] {
126+
127+
// Time Location
128+
case "loc":
129+
cfg.loc, err = time.LoadLocation(value)
130+
if err != nil {
131+
return
132+
}
133+
134+
// Dial Timeout
135+
case "timeout":
136+
cfg.timeout, err = time.ParseDuration(value)
137+
if err != nil {
138+
return
139+
}
140+
141+
// TLS-Encryption
142+
case "tls":
143+
if readBool(value) {
144+
cfg.tls = &tls.Config{}
145+
} else if strings.ToLower(value) == "skip-verify" {
146+
cfg.tls = &tls.Config{InsecureSkipVerify: true}
147+
}
148+
149+
default:
150+
cfg.params[param[0]] = value
151+
}
123152
}
124153
}
125154
}
@@ -134,7 +163,10 @@ func parseDSN(dsn string) (cfg *config, err error) {
134163
cfg.addr = "127.0.0.1:3306"
135164
}
136165

137-
cfg.loc, err = time.LoadLocation(cfg.params["loc"])
166+
// Set default location if not set
167+
if cfg.loc == nil {
168+
cfg.loc = time.UTC
169+
}
138170

139171
return
140172
}

0 commit comments

Comments
 (0)