Skip to content

support uint64 parameters with high bit set #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 3, 2015
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,49 @@ func TestNULL(t *testing.T) {
})
}

func TestUint64(t *testing.T) {
const (
u0 = uint64(0)
uall = ^u0
uhigh = uall >> 1
utop = ^uhigh
s0 = int64(0)
sall = ^s0
shigh = int64(uhigh)
stop = ^shigh
)
runTests(t, dsn, func(dbt *DBTest) {
stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`)
if err != nil {
dbt.Fatal(err)
}
defer stmt.Close()
row := stmt.QueryRow(
u0, uhigh, utop, uall,
s0, shigh, stop, sall,
)

var ua, ub, uc, ud uint64
var sa, sb, sc, sd int64

err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd)
if err != nil {
dbt.Fatal(err)
}
switch {
case ua != u0,
ub != uhigh,
uc != utop,
ud != uall,
sa != s0,
sb != shigh,
sc != stop,
sd != sall:
dbt.Fatal("Unexpected result value")
}
})
}

func TestLongData(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
var maxAllowedPacketSize int
Expand Down
48 changes: 48 additions & 0 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ package mysql

import (
"database/sql/driver"
"fmt"
"reflect"
)

type mysqlStmt struct {
Expand All @@ -34,6 +36,10 @@ func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}

func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
return converter{}
}

func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.netConn == nil {
errLog.Print(ErrInvalidConn)
Expand Down Expand Up @@ -110,3 +116,45 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {

return rows, err
}

type converter struct{}

func (converter) ConvertValue(v interface{}) (driver.Value, error) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if driver.IsValue(v) {
return v, nil
}

if svi, ok := v.(driver.Valuer); ok {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The database/sql package already checks this case: https://github.com/golang/go/blob/master/src/database/sql/convert.go#L48

sv, err := svi.Value()
if err != nil {
return nil, err
}
if !driver.IsValue(sv) {
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
}
return sv, nil
}

rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
}
return driver.DefaultParameterConverter.ConvertValue(rv.Elem().Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil
case reflect.Uint64:
val := rv.Uint()
if val < (^uint64(0) >> 1) {
return int64(val), nil
}
return fmt.Sprintf("%d", rv.Uint()), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}