1package jsoniter
2
3import (
4	"encoding"
5	"encoding/json"
6	"unsafe"
7
8	"github.com/modern-go/reflect2"
9)
10
11var marshalerType = reflect2.TypeOfPtr((*json.Marshaler)(nil)).Elem()
12var unmarshalerType = reflect2.TypeOfPtr((*json.Unmarshaler)(nil)).Elem()
13var textMarshalerType = reflect2.TypeOfPtr((*encoding.TextMarshaler)(nil)).Elem()
14var textUnmarshalerType = reflect2.TypeOfPtr((*encoding.TextUnmarshaler)(nil)).Elem()
15
16func createDecoderOfMarshaler(ctx *ctx, typ reflect2.Type) ValDecoder {
17	ptrType := reflect2.PtrTo(typ)
18	if ptrType.Implements(unmarshalerType) {
19		return &referenceDecoder{
20			&unmarshalerDecoder{ptrType},
21		}
22	}
23	if ptrType.Implements(textUnmarshalerType) {
24		return &referenceDecoder{
25			&textUnmarshalerDecoder{ptrType},
26		}
27	}
28	return nil
29}
30
31func createEncoderOfMarshaler(ctx *ctx, typ reflect2.Type) ValEncoder {
32	if typ == marshalerType {
33		checkIsEmpty := createCheckIsEmpty(ctx, typ)
34		var encoder ValEncoder = &directMarshalerEncoder{
35			checkIsEmpty: checkIsEmpty,
36		}
37		return encoder
38	}
39	if typ.Implements(marshalerType) {
40		checkIsEmpty := createCheckIsEmpty(ctx, typ)
41		var encoder ValEncoder = &marshalerEncoder{
42			valType:      typ,
43			checkIsEmpty: checkIsEmpty,
44		}
45		return encoder
46	}
47	ptrType := reflect2.PtrTo(typ)
48	if ctx.prefix != "" && ptrType.Implements(marshalerType) {
49		checkIsEmpty := createCheckIsEmpty(ctx, ptrType)
50		var encoder ValEncoder = &marshalerEncoder{
51			valType:      ptrType,
52			checkIsEmpty: checkIsEmpty,
53		}
54		return &referenceEncoder{encoder}
55	}
56	if typ == textMarshalerType {
57		checkIsEmpty := createCheckIsEmpty(ctx, typ)
58		var encoder ValEncoder = &directTextMarshalerEncoder{
59			checkIsEmpty:  checkIsEmpty,
60			stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
61		}
62		return encoder
63	}
64	if typ.Implements(textMarshalerType) {
65		checkIsEmpty := createCheckIsEmpty(ctx, typ)
66		var encoder ValEncoder = &textMarshalerEncoder{
67			valType:       typ,
68			stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
69			checkIsEmpty:  checkIsEmpty,
70		}
71		return encoder
72	}
73	// if prefix is empty, the type is the root type
74	if ctx.prefix != "" && ptrType.Implements(textMarshalerType) {
75		checkIsEmpty := createCheckIsEmpty(ctx, ptrType)
76		var encoder ValEncoder = &textMarshalerEncoder{
77			valType:       ptrType,
78			stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
79			checkIsEmpty:  checkIsEmpty,
80		}
81		return &referenceEncoder{encoder}
82	}
83	return nil
84}
85
86type marshalerEncoder struct {
87	checkIsEmpty checkIsEmpty
88	valType      reflect2.Type
89}
90
91func (encoder *marshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
92	obj := encoder.valType.UnsafeIndirect(ptr)
93	if encoder.valType.IsNullable() && reflect2.IsNil(obj) {
94		stream.WriteNil()
95		return
96	}
97	marshaler := obj.(json.Marshaler)
98	bytes, err := marshaler.MarshalJSON()
99	if err != nil {
100		stream.Error = err
101	} else {
102		// html escape was already done by jsoniter
103		// but the extra '\n' should be trimed
104		l := len(bytes)
105		if l > 0 && bytes[l-1] == '\n' {
106			bytes = bytes[:l-1]
107		}
108		stream.Write(bytes)
109	}
110}
111
112func (encoder *marshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
113	return encoder.checkIsEmpty.IsEmpty(ptr)
114}
115
116type directMarshalerEncoder struct {
117	checkIsEmpty checkIsEmpty
118}
119
120func (encoder *directMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
121	marshaler := *(*json.Marshaler)(ptr)
122	if marshaler == nil {
123		stream.WriteNil()
124		return
125	}
126	bytes, err := marshaler.MarshalJSON()
127	if err != nil {
128		stream.Error = err
129	} else {
130		stream.Write(bytes)
131	}
132}
133
134func (encoder *directMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
135	return encoder.checkIsEmpty.IsEmpty(ptr)
136}
137
138type textMarshalerEncoder struct {
139	valType       reflect2.Type
140	stringEncoder ValEncoder
141	checkIsEmpty  checkIsEmpty
142}
143
144func (encoder *textMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
145	obj := encoder.valType.UnsafeIndirect(ptr)
146	if encoder.valType.IsNullable() && reflect2.IsNil(obj) {
147		stream.WriteNil()
148		return
149	}
150	marshaler := (obj).(encoding.TextMarshaler)
151	bytes, err := marshaler.MarshalText()
152	if err != nil {
153		stream.Error = err
154	} else {
155		str := string(bytes)
156		encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
157	}
158}
159
160func (encoder *textMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
161	return encoder.checkIsEmpty.IsEmpty(ptr)
162}
163
164type directTextMarshalerEncoder struct {
165	stringEncoder ValEncoder
166	checkIsEmpty  checkIsEmpty
167}
168
169func (encoder *directTextMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
170	marshaler := *(*encoding.TextMarshaler)(ptr)
171	if marshaler == nil {
172		stream.WriteNil()
173		return
174	}
175	bytes, err := marshaler.MarshalText()
176	if err != nil {
177		stream.Error = err
178	} else {
179		str := string(bytes)
180		encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
181	}
182}
183
184func (encoder *directTextMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
185	return encoder.checkIsEmpty.IsEmpty(ptr)
186}
187
188type unmarshalerDecoder struct {
189	valType reflect2.Type
190}
191
192func (decoder *unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
193	valType := decoder.valType
194	obj := valType.UnsafeIndirect(ptr)
195	unmarshaler := obj.(json.Unmarshaler)
196	iter.nextToken()
197	iter.unreadByte() // skip spaces
198	bytes := iter.SkipAndReturnBytes()
199	err := unmarshaler.UnmarshalJSON(bytes)
200	if err != nil {
201		iter.ReportError("unmarshalerDecoder", err.Error())
202	}
203}
204
205type textUnmarshalerDecoder struct {
206	valType reflect2.Type
207}
208
209func (decoder *textUnmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
210	valType := decoder.valType
211	obj := valType.UnsafeIndirect(ptr)
212	if reflect2.IsNil(obj) {
213		ptrType := valType.(*reflect2.UnsafePtrType)
214		elemType := ptrType.Elem()
215		elem := elemType.UnsafeNew()
216		ptrType.UnsafeSet(ptr, unsafe.Pointer(&elem))
217		obj = valType.UnsafeIndirect(ptr)
218	}
219	unmarshaler := (obj).(encoding.TextUnmarshaler)
220	str := iter.ReadString()
221	err := unmarshaler.UnmarshalText([]byte(str))
222	if err != nil {
223		iter.ReportError("textUnmarshalerDecoder", err.Error())
224	}
225}
226