Skip to content

Commit 5e9c396

Browse files
committed
stmt: add json.RawMessage for converter and prepared statement
Following #1058, in order for the driver.Value to get as a json.RawMessage, the converter should accept it as a valid value, and handle it as bytes in case where interpolation is disabled
1 parent c4f1976 commit 5e9c396

File tree

4 files changed

+39
-10
lines changed

4 files changed

+39
-10
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Alex Snast <alexsn at fb.com>
1717
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
1818
Andrew Reid <andrew.reid at tixtrack.com>
1919
Arne Hormann <arnehormann at gmail.com>
20+
Ariel Mashraki <ariel at mashraki.co.il>
2021
Asta Xie <xiemengjun at gmail.com>
2122
Bulat Gaifullin <gaifullinbf at gmail.com>
2223
Carlos Nieto <jose.carlos at menteslibres.net>

packets.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"crypto/tls"
1414
"database/sql/driver"
1515
"encoding/binary"
16+
"encoding/json"
1617
"errors"
1718
"fmt"
1819
"io"
@@ -1063,19 +1064,26 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
10631064
paramValues = append(paramValues, 0x00)
10641065
}
10651066

1066-
case []byte:
1067+
case []byte, json.RawMessage:
1068+
var (
1069+
ok bool
1070+
b []byte
1071+
)
1072+
if b, ok = v.(json.RawMessage); !ok {
1073+
b = v.([]byte)
1074+
}
10671075
// Common case (non-nil value) first
1068-
if v != nil {
1076+
if b != nil {
10691077
paramTypes[i+i] = byte(fieldTypeString)
10701078
paramTypes[i+i+1] = 0x00
10711079

1072-
if len(v) < longDataSize {
1080+
if len(b) < longDataSize {
10731081
paramValues = appendLengthEncodedInteger(paramValues,
1074-
uint64(len(v)),
1082+
uint64(len(b)),
10751083
)
1076-
paramValues = append(paramValues, v...)
1084+
paramValues = append(paramValues, b...)
10771085
} else {
1078-
if err := stmt.writeCommandLongData(i, v); err != nil {
1086+
if err := stmt.writeCommandLongData(i, b); err != nil {
10791087
return err
10801088
}
10811089
}

statement.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"database/sql/driver"
13+
"encoding/json"
1314
"fmt"
1415
"io"
1516
"reflect"
@@ -129,6 +130,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
129130
return rows, err
130131
}
131132

133+
var jsonType = reflect.TypeOf(json.RawMessage{})
134+
132135
type converter struct{}
133136

134137
// ConvertValue mirrors the reference/default converter in database/sql/driver
@@ -151,7 +154,6 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
151154
}
152155
return sv, nil
153156
}
154-
155157
rv := reflect.ValueOf(v)
156158
switch rv.Kind() {
157159
case reflect.Ptr:
@@ -170,11 +172,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
170172
case reflect.Bool:
171173
return rv.Bool(), nil
172174
case reflect.Slice:
173-
ek := rv.Type().Elem().Kind()
174-
if ek == reflect.Uint8 {
175+
switch t := rv.Type(); {
176+
case t == jsonType:
177+
return v, nil
178+
case t.Elem().Kind() == reflect.Uint8:
175179
return rv.Bytes(), nil
180+
default:
181+
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
176182
}
177-
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
178183
case reflect.String:
179184
return rv.String(), nil
180185
}

statement_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"bytes"
13+
"encoding/json"
1314
"testing"
1415
)
1516

@@ -124,3 +125,17 @@ func TestConvertUnsignedIntegers(t *testing.T) {
124125
t.Fatalf("uint64 high-bit converted, got %#v %T", output, output)
125126
}
126127
}
128+
129+
func TestConvertJSON(t *testing.T) {
130+
raw := json.RawMessage("{}")
131+
132+
out, err := converter{}.ConvertValue(&raw)
133+
134+
if err != nil {
135+
t.Fatal("json.RawMessage was failed in convert", err)
136+
}
137+
138+
if _, ok := out.(json.RawMessage); !ok {
139+
t.Fatalf("json.RawMessage converted, got %#v %T", out, out)
140+
}
141+
}

0 commit comments

Comments
 (0)