Skip to content

Improve stdlib compatibility #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 57 additions & 56 deletions feature_reflect_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@ type intCodec struct {
}

func (codec *intCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*int)(ptr)) = 0
return
if !iter.ReadNil() {
*((*int)(ptr)) = iter.ReadInt()
}
*((*int)(ptr)) = iter.ReadInt()
}

func (codec *intCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -55,11 +53,9 @@ type uintptrCodec struct {
}

func (codec *uintptrCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*uintptr)(ptr)) = 0
return
if !iter.ReadNil() {
*((*uintptr)(ptr)) = uintptr(iter.ReadUint64())
}
*((*uintptr)(ptr)) = uintptr(iter.ReadUint64())
}

func (codec *uintptrCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -78,11 +74,9 @@ type int8Codec struct {
}

func (codec *int8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*uint8)(ptr)) = 0
return
if !iter.ReadNil() {
*((*int8)(ptr)) = iter.ReadInt8()
}
*((*int8)(ptr)) = iter.ReadInt8()
}

func (codec *int8Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -101,11 +95,9 @@ type int16Codec struct {
}

func (codec *int16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*int16)(ptr)) = 0
return
if !iter.ReadNil() {
*((*int16)(ptr)) = iter.ReadInt16()
}
*((*int16)(ptr)) = iter.ReadInt16()
}

func (codec *int16Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -124,11 +116,9 @@ type int32Codec struct {
}

func (codec *int32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*int32)(ptr)) = 0
return
if !iter.ReadNil() {
*((*int32)(ptr)) = iter.ReadInt32()
}
*((*int32)(ptr)) = iter.ReadInt32()
}

func (codec *int32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -147,11 +137,9 @@ type int64Codec struct {
}

func (codec *int64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*int64)(ptr)) = 0
return
if !iter.ReadNil() {
*((*int64)(ptr)) = iter.ReadInt64()
}
*((*int64)(ptr)) = iter.ReadInt64()
}

func (codec *int64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -170,11 +158,10 @@ type uintCodec struct {
}

func (codec *uintCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*uint)(ptr)) = 0
if !iter.ReadNil() {
*((*uint)(ptr)) = iter.ReadUint()
return
}
*((*uint)(ptr)) = iter.ReadUint()
}

func (codec *uintCodec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -193,11 +180,9 @@ type uint8Codec struct {
}

func (codec *uint8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*uint8)(ptr)) = 0
return
if !iter.ReadNil() {
*((*uint8)(ptr)) = iter.ReadUint8()
}
*((*uint8)(ptr)) = iter.ReadUint8()
}

func (codec *uint8Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -216,11 +201,9 @@ type uint16Codec struct {
}

func (codec *uint16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*uint16)(ptr)) = 0
return
if !iter.ReadNil() {
*((*uint16)(ptr)) = iter.ReadUint16()
}
*((*uint16)(ptr)) = iter.ReadUint16()
}

func (codec *uint16Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -239,11 +222,9 @@ type uint32Codec struct {
}

func (codec *uint32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*uint32)(ptr)) = 0
return
if !iter.ReadNil() {
*((*uint32)(ptr)) = iter.ReadUint32()
}
*((*uint32)(ptr)) = iter.ReadUint32()
}

func (codec *uint32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -262,11 +243,9 @@ type uint64Codec struct {
}

func (codec *uint64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*uint64)(ptr)) = 0
return
if !iter.ReadNil() {
*((*uint64)(ptr)) = iter.ReadUint64()
}
*((*uint64)(ptr)) = iter.ReadUint64()
}

func (codec *uint64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -285,11 +264,9 @@ type float32Codec struct {
}

func (codec *float32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*float32)(ptr)) = 0
return
if !iter.ReadNil() {
*((*float32)(ptr)) = iter.ReadFloat32()
}
*((*float32)(ptr)) = iter.ReadFloat32()
}

func (codec *float32Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand All @@ -308,11 +285,9 @@ type float64Codec struct {
}

func (codec *float64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*float64)(ptr)) = 0
return
if !iter.ReadNil() {
*((*float64)(ptr)) = iter.ReadFloat64()
}
*((*float64)(ptr)) = iter.ReadFloat64()
}

func (codec *float64Codec) Encode(ptr unsafe.Pointer, stream *Stream) {
Expand Down Expand Up @@ -352,13 +327,39 @@ type emptyInterfaceCodec struct {
}

func (codec *emptyInterfaceCodec) Decode(ptr unsafe.Pointer, iter *Iterator) {
if iter.ReadNil() {
*((*interface{})(ptr)) = nil
existing := *((*interface{})(ptr))

// Checking for both typed and untyped nil pointers.
if existing != nil &&
reflect.TypeOf(existing).Kind() == reflect.Ptr &&
!reflect.ValueOf(existing).IsNil() {

var ptrToExisting interface{}
for {
elem := reflect.ValueOf(existing).Elem()
if elem.Kind() != reflect.Ptr || elem.IsNil() {
break
}
ptrToExisting = existing
existing = elem.Interface()
}

if iter.ReadNil() {
if ptrToExisting != nil {
nilPtr := reflect.Zero(reflect.TypeOf(ptrToExisting).Elem())
reflect.ValueOf(ptrToExisting).Elem().Set(nilPtr)
} else {
*((*interface{})(ptr)) = nil
}
} else {
iter.ReadVal(existing)
}

return
}
existing := *((*interface{})(ptr))
if existing != nil && reflect.TypeOf(existing).Kind() == reflect.Ptr {
iter.ReadVal(existing)

if iter.ReadNil() {
*((*interface{})(ptr)) = nil
} else {
*((*interface{})(ptr)) = iter.Read()
}
Expand Down
121 changes: 121 additions & 0 deletions jsoniter_interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,124 @@ func Test_marshal_nil_nonempty_interface(t *testing.T) {
should.NoError(err)
should.Equal(nil, obj.Field)
}

func Test_overwrite_interface_ptr_value_with_nil(t *testing.T) {
type Wrapper struct {
Payload interface{} `json:"payload,omitempty"`
}
type Payload struct {
Value int `json:"val,omitempty"`
}

should := require.New(t)

payload := &Payload{}
wrapper := &Wrapper{
Payload: &payload,
}

err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
should.Equal(nil, err)
should.Equal(&payload, wrapper.Payload)
should.Equal(42, (*(wrapper.Payload.(**Payload))).Value)

err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
should.Equal(nil, err)
should.Equal(&payload, wrapper.Payload)
should.Equal((*Payload)(nil), payload)

payload = &Payload{}
wrapper = &Wrapper{
Payload: &payload,
}

err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
should.Equal(nil, err)
should.Equal(&payload, wrapper.Payload)
should.Equal(42, (*(wrapper.Payload.(**Payload))).Value)

err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
should.Equal(nil, err)
should.Equal(&payload, wrapper.Payload)
should.Equal((*Payload)(nil), payload)
}

func Test_overwrite_interface_value_with_nil(t *testing.T) {
type Wrapper struct {
Payload interface{} `json:"payload,omitempty"`
}
type Payload struct {
Value int `json:"val,omitempty"`
}

should := require.New(t)

payload := &Payload{}
wrapper := &Wrapper{
Payload: payload,
}

err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
should.Equal(nil, err)
should.Equal(42, (*(wrapper.Payload.(*Payload))).Value)

err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
should.Equal(nil, err)
should.Equal(nil, wrapper.Payload)
should.Equal(42, payload.Value)

payload = &Payload{}
wrapper = &Wrapper{
Payload: payload,
}

err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
should.Equal(nil, err)
should.Equal(42, (*(wrapper.Payload.(*Payload))).Value)

err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
should.Equal(nil, err)
should.Equal(nil, wrapper.Payload)
should.Equal(42, payload.Value)
}

func Test_unmarshal_into_nil(t *testing.T) {
type Payload struct {
Value int `json:"val,omitempty"`
}
type Wrapper struct {
Payload interface{} `json:"payload,omitempty"`
}

should := require.New(t)

var payload *Payload
wrapper := &Wrapper{
Payload: payload,
}

err := json.Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
should.Nil(err)
should.NotNil(wrapper.Payload)
should.Nil(payload)

err = json.Unmarshal([]byte(`{"payload": null}`), &wrapper)
should.Nil(err)
should.Nil(wrapper.Payload)
should.Nil(payload)

payload = nil
wrapper = &Wrapper{
Payload: payload,
}

err = Unmarshal([]byte(`{"payload": {"val": 42}}`), &wrapper)
should.Nil(err)
should.NotNil(wrapper.Payload)
should.Nil(payload)

err = Unmarshal([]byte(`{"payload": null}`), &wrapper)
should.Nil(err)
should.Nil(wrapper.Payload)
should.Nil(payload)
}
33 changes: 32 additions & 1 deletion jsoniter_null_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package jsoniter
import (
"bytes"
"encoding/json"
"github.com/stretchr/testify/require"
"io"
"testing"

"github.com/stretchr/testify/require"
)

func Test_read_null(t *testing.T) {
Expand Down Expand Up @@ -135,3 +136,33 @@ func Test_encode_nil_array(t *testing.T) {
should.Nil(err)
should.Equal("null", string(output))
}

func Test_decode_nil_num(t *testing.T) {
type TestData struct {
Field int `json:"field"`
}
should := require.New(t)

data1 := []byte(`{"field": 42}`)
data2 := []byte(`{"field": null}`)

// Checking stdlib behavior as well
obj2 := TestData{}
err := json.Unmarshal(data1, &obj2)
should.Equal(nil, err)
should.Equal(42, obj2.Field)

err = json.Unmarshal(data2, &obj2)
should.Equal(nil, err)
should.Equal(42, obj2.Field)

obj := TestData{}

err = Unmarshal(data1, &obj)
should.Equal(nil, err)
should.Equal(42, obj.Field)

err = Unmarshal(data2, &obj)
should.Equal(nil, err)
should.Equal(42, obj.Field)
}