Skip to content

Commit a9d6280

Browse files
committed
stmt: add json.RawMessage for converter and prepared statement
Following #1058, in order for the driver.Value to get as a json.RawMessage, the converter should accept it as a valid value, and handle it as bytes in case where interpolation is disabled
1 parent c4f1976 commit a9d6280

File tree

4 files changed

+29
-4
lines changed

4 files changed

+29
-4
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Alex Snast <alexsn at fb.com>
1717
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
1818
Andrew Reid <andrew.reid at tixtrack.com>
1919
Arne Hormann <arnehormann at gmail.com>
20+
Ariel Mashraki <ariel at mashraki.co.il>
2021
Asta Xie <xiemengjun at gmail.com>
2122
Bulat Gaifullin <gaifullinbf at gmail.com>
2223
Carlos Nieto <jose.carlos at menteslibres.net>

packets.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"crypto/tls"
1414
"database/sql/driver"
1515
"encoding/binary"
16+
"encoding/json"
1617
"errors"
1718
"fmt"
1819
"io"
@@ -1003,6 +1004,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
10031004
continue
10041005
}
10051006

1007+
if v, ok := arg.(json.RawMessage); ok {
1008+
arg = []byte(v)
1009+
}
10061010
// cache types and values
10071011
switch v := arg.(type) {
10081012
case int64:

statement.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"database/sql/driver"
13+
"encoding/json"
1314
"fmt"
1415
"io"
1516
"reflect"
@@ -129,6 +130,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
129130
return rows, err
130131
}
131132

133+
var jsonType = reflect.TypeOf(json.RawMessage{})
134+
132135
type converter struct{}
133136

134137
// ConvertValue mirrors the reference/default converter in database/sql/driver
@@ -151,7 +154,6 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
151154
}
152155
return sv, nil
153156
}
154-
155157
rv := reflect.ValueOf(v)
156158
switch rv.Kind() {
157159
case reflect.Ptr:
@@ -170,11 +172,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
170172
case reflect.Bool:
171173
return rv.Bool(), nil
172174
case reflect.Slice:
173-
ek := rv.Type().Elem().Kind()
174-
if ek == reflect.Uint8 {
175+
switch t := rv.Type(); {
176+
case t == jsonType:
177+
return v, nil
178+
case t.Elem().Kind() == reflect.Uint8:
175179
return rv.Bytes(), nil
180+
default:
181+
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
176182
}
177-
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
178183
case reflect.String:
179184
return rv.String(), nil
180185
}

statement_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"bytes"
13+
"encoding/json"
1314
"testing"
1415
)
1516

@@ -124,3 +125,17 @@ func TestConvertUnsignedIntegers(t *testing.T) {
124125
t.Fatalf("uint64 high-bit converted, got %#v %T", output, output)
125126
}
126127
}
128+
129+
func TestConvertJSON(t *testing.T) {
130+
raw := json.RawMessage("{}")
131+
132+
out, err := converter{}.ConvertValue(&raw)
133+
134+
if err != nil {
135+
t.Fatal("json.RawMessage was failed in convert", err)
136+
}
137+
138+
if _, ok := out.(json.RawMessage); !ok {
139+
t.Fatalf("json.RawMessage converted, got %#v %T", out, out)
140+
}
141+
}

0 commit comments

Comments
 (0)