diff --git a/collations.go b/collations.go index 1cdf97b6..29b1aa43 100644 --- a/collations.go +++ b/collations.go @@ -8,7 +8,7 @@ package mysql -const defaultCollation = "utf8mb4_general_ci" +const defaultCollationID = 45 // utf8mb4_general_ci const binaryCollationID = 63 // A list of available collations mapped to the internal ID. diff --git a/connection.go b/connection.go index 2b19c927..ef6fc9e4 100644 --- a/connection.go +++ b/connection.go @@ -67,45 +67,20 @@ func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder for param, val := range mc.cfg.Params { - switch param { - // Charset: character_set_connection, character_set_client, character_set_results - case "charset": - charsets := strings.Split(val, ",") - for _, cs := range charsets { - // ignore errors here - a charset may not exist - if mc.cfg.Collation != "" { - err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation) - } else { - err = mc.exec("SET NAMES " + cs) - } - if err == nil { - break - } - } - if err != nil { - return - } - - // Other system vars accumulated in a single SET command - default: - if cmdSet.Len() == 0 { - // Heuristic: 29 chars for each other key=value to reduce reallocations - cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1)) - cmdSet.WriteString("SET ") - } else { - cmdSet.WriteString(", ") - } - cmdSet.WriteString(param) - cmdSet.WriteString(" = ") - cmdSet.WriteString(val) + if cmdSet.Len() == 0 { + // Heuristic: 29 chars for each other key=value to reduce reallocations + cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1)) + cmdSet.WriteString("SET ") + } else { + cmdSet.WriteString(", ") } + cmdSet.WriteString(param) + cmdSet.WriteString(" = ") + cmdSet.WriteString(val) } if cmdSet.Len() > 0 { err = mc.exec(cmdSet.String()) - if err != nil { - return - } } return diff --git a/connector.go b/connector.go index b6707759..62012dba 100644 --- a/connector.go +++ b/connector.go @@ -180,6 +180,25 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.maxWriteSize = mc.maxAllowedPacket } + // Charset: character_set_connection, character_set_client, character_set_results + if len(mc.cfg.charsets) > 0 { + for _, cs := range mc.cfg.charsets { + // ignore errors here - a charset may not exist + if mc.cfg.Collation != "" { + err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation) + } else { + err = mc.exec("SET NAMES " + cs) + } + if err == nil { + break + } + } + if err != nil { + mc.Close() + return nil, err + } + } + // Handle DSN Params err = mc.handleParams() if err != nil { diff --git a/dsn.go b/dsn.go index 65f5a024..3c7a6e21 100644 --- a/dsn.go +++ b/dsn.go @@ -44,7 +44,8 @@ type Config struct { DBName string // Database name Params map[string]string // Connection parameters ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs - Collation string // Connection collation + charsets []string // Connection charset. When set, this will be set in SET NAMES query + Collation string // Connection collation. When set, this will be set in SET NAMES COLLATE query Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed ServerPubKey string // Server public key name @@ -282,6 +283,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "clientFoundRows", "true") } + if charsets := cfg.charsets; len(charsets) > 0 { + writeDSNParam(&buf, &hasParam, "charset", strings.Join(charsets, ",")) + } + if col := cfg.Collation; col != "" { writeDSNParam(&buf, &hasParam, "collation", col) } @@ -501,6 +506,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // charset + case "charset": + cfg.charsets = strings.Split(value, ",") + // Collation case "collation": cfg.Collation = value diff --git a/dsn_test.go b/dsn_test.go index dd8cd935..863d1482 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -31,13 +31,13 @@ var testDSNs = []struct { &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, }, { "user@unix(/path/to/socket)/dbname?charset=utf8", - &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", charsets: []string{"utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", charsets: []string{"utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", charsets: []string{"utf8mb4", "utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, }, { "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, Logger: defaultLogger, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, diff --git a/packets.go b/packets.go index df850fd4..00101962 100644 --- a/packets.go +++ b/packets.go @@ -322,17 +322,15 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data[11] = 0x00 // Collation ID [1 byte] - cname := mc.cfg.Collation - if cname == "" { - cname = defaultCollation - } - var found bool - data[12], found = collations[cname] - if !found { - // Note possibility for false negatives: - // could be triggered although the collation is valid if the - // collations map does not contain entries the server supports. - return fmt.Errorf("unknown collation: %q", cname) + data[12] = defaultCollationID + if cname := mc.cfg.Collation; cname != "" { + colID, ok := collations[cname] + if ok { + data[12] = colID + } else if len(mc.cfg.charsets) > 0 { + // When cfg.charset is set, the collation is set by `SET NAMES COLLATE `. + return fmt.Errorf("unknown collation: %q", cname) + } } // Filler [23 bytes] (all 0x00)