1package jsoniter
2
3import (
4	"encoding"
5	"encoding/json"
6	"github.com/modern-go/reflect2"
7	"unsafe"
8)
9
10var marshalerType = reflect2.TypeOfPtr((*json.Marshaler)(nil)).Elem()
11var unmarshalerType = reflect2.TypeOfPtr((*json.Unmarshaler)(nil)).Elem()
12var textMarshalerType = reflect2.TypeOfPtr((*encoding.TextMarshaler)(nil)).Elem()
13var textUnmarshalerType = reflect2.TypeOfPtr((*encoding.TextUnmarshaler)(nil)).Elem()
14
15func createDecoderOfMarshaler(ctx *ctx, typ reflect2.Type) ValDecoder {
16	ptrType := reflect2.PtrTo(typ)
17	if ptrType.Implements(unmarshalerType) {
18		return &referenceDecoder{
19			&unmarshalerDecoder{ptrType},
20		}
21	}
22	if ptrType.Implements(textUnmarshalerType) {
23		return &referenceDecoder{
24			&textUnmarshalerDecoder{ptrType},
25		}
26	}
27	return nil
28}
29
30func createEncoderOfMarshaler(ctx *ctx, typ reflect2.Type) ValEncoder {
31	if typ == marshalerType {
32		checkIsEmpty := createCheckIsEmpty(ctx, typ)
33		var encoder ValEncoder = &directMarshalerEncoder{
34			checkIsEmpty: checkIsEmpty,
35		}
36		return encoder
37	}
38	if typ.Implements(marshalerType) {
39		checkIsEmpty := createCheckIsEmpty(ctx, typ)
40		var encoder ValEncoder = &marshalerEncoder{
41			valType:      typ,
42			checkIsEmpty: checkIsEmpty,
43		}
44		return encoder
45	}
46	ptrType := reflect2.PtrTo(typ)
47	if ctx.prefix != "" && ptrType.Implements(marshalerType) {
48		checkIsEmpty := createCheckIsEmpty(ctx, ptrType)
49		var encoder ValEncoder = &marshalerEncoder{
50			valType:      ptrType,
51			checkIsEmpty: checkIsEmpty,
52		}
53		return &referenceEncoder{encoder}
54	}
55	if typ == textMarshalerType {
56		checkIsEmpty := createCheckIsEmpty(ctx, typ)
57		var encoder ValEncoder = &directTextMarshalerEncoder{
58			checkIsEmpty:  checkIsEmpty,
59			stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
60		}
61		return encoder
62	}
63	if typ.Implements(textMarshalerType) {
64		checkIsEmpty := createCheckIsEmpty(ctx, typ)
65		var encoder ValEncoder = &textMarshalerEncoder{
66			valType:       typ,
67			stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
68			checkIsEmpty:  checkIsEmpty,
69		}
70		return encoder
71	}
72	// if prefix is empty, the type is the root type
73	if ctx.prefix != "" && ptrType.Implements(textMarshalerType) {
74		checkIsEmpty := createCheckIsEmpty(ctx, ptrType)
75		var encoder ValEncoder = &textMarshalerEncoder{
76			valType:       ptrType,
77			stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
78			checkIsEmpty:  checkIsEmpty,
79		}
80		return &referenceEncoder{encoder}
81	}
82	return nil
83}
84
85type marshalerEncoder struct {
86	checkIsEmpty checkIsEmpty
87	valType      reflect2.Type
88}
89
90func (encoder *marshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
91	obj := encoder.valType.UnsafeIndirect(ptr)
92	if encoder.valType.IsNullable() && reflect2.IsNil(obj) {
93		stream.WriteNil()
94		return
95	}
96	bytes, err := json.Marshal(obj)
97	if err != nil {
98		stream.Error = err
99	} else {
100		stream.Write(bytes)
101	}
102}
103
104func (encoder *marshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
105	return encoder.checkIsEmpty.IsEmpty(ptr)
106}
107
108type directMarshalerEncoder struct {
109	checkIsEmpty checkIsEmpty
110}
111
112func (encoder *directMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
113	marshaler := *(*json.Marshaler)(ptr)
114	if marshaler == nil {
115		stream.WriteNil()
116		return
117	}
118	bytes, err := marshaler.MarshalJSON()
119	if err != nil {
120		stream.Error = err
121	} else {
122		stream.Write(bytes)
123	}
124}
125
126func (encoder *directMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
127	return encoder.checkIsEmpty.IsEmpty(ptr)
128}
129
130type textMarshalerEncoder struct {
131	valType       reflect2.Type
132	stringEncoder ValEncoder
133	checkIsEmpty  checkIsEmpty
134}
135
136func (encoder *textMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
137	obj := encoder.valType.UnsafeIndirect(ptr)
138	if encoder.valType.IsNullable() && reflect2.IsNil(obj) {
139		stream.WriteNil()
140		return
141	}
142	marshaler := (obj).(encoding.TextMarshaler)
143	bytes, err := marshaler.MarshalText()
144	if err != nil {
145		stream.Error = err
146	} else {
147		str := string(bytes)
148		encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
149	}
150}
151
152func (encoder *textMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
153	return encoder.checkIsEmpty.IsEmpty(ptr)
154}
155
156type directTextMarshalerEncoder struct {
157	stringEncoder ValEncoder
158	checkIsEmpty  checkIsEmpty
159}
160
161func (encoder *directTextMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
162	marshaler := *(*encoding.TextMarshaler)(ptr)
163	if marshaler == nil {
164		stream.WriteNil()
165		return
166	}
167	bytes, err := marshaler.MarshalText()
168	if err != nil {
169		stream.Error = err
170	} else {
171		str := string(bytes)
172		encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
173	}
174}
175
176func (encoder *directTextMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
177	return encoder.checkIsEmpty.IsEmpty(ptr)
178}
179
180type unmarshalerDecoder struct {
181	valType reflect2.Type
182}
183
184func (decoder *unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
185	valType := decoder.valType
186	obj := valType.UnsafeIndirect(ptr)
187	unmarshaler := obj.(json.Unmarshaler)
188	iter.nextToken()
189	iter.unreadByte() // skip spaces
190	bytes := iter.SkipAndReturnBytes()
191	err := unmarshaler.UnmarshalJSON(bytes)
192	if err != nil {
193		iter.ReportError("unmarshalerDecoder", err.Error())
194	}
195}
196
197type textUnmarshalerDecoder struct {
198	valType reflect2.Type
199}
200
201func (decoder *textUnmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
202	valType := decoder.valType
203	obj := valType.UnsafeIndirect(ptr)
204	if reflect2.IsNil(obj) {
205		ptrType := valType.(*reflect2.UnsafePtrType)
206		elemType := ptrType.Elem()
207		elem := elemType.UnsafeNew()
208		ptrType.UnsafeSet(ptr, unsafe.Pointer(&elem))
209		obj = valType.UnsafeIndirect(ptr)
210	}
211	unmarshaler := (obj).(encoding.TextUnmarshaler)
212	str := iter.ReadString()
213	err := unmarshaler.UnmarshalText([]byte(str))
214	if err != nil {
215		iter.ReportError("textUnmarshalerDecoder", err.Error())
216	}
217}
218