1package jsonw
2
3import (
4	"bufio"
5	"bytes"
6	"encoding/json"
7	"fmt"
8	"io"
9	"reflect"
10	"strconv"
11	"strings"
12)
13
14type Wrapper struct {
15	dat    interface{}
16	err    *Error
17	access []string
18}
19
20type Error struct {
21	msg string
22}
23
24type DepthError struct {
25	msg string
26}
27
28func (d DepthError) Error() string {
29	return fmt.Sprintf("DepthError: %s", d.msg)
30}
31
32const defaultMaxDepth int = 50
33
34func (w *Wrapper) Marshal() ([]byte, error) {
35	return json.Marshal(w.dat)
36}
37
38// MarshalJSON makes Wrapper satisfy the encoding/json.Marshaler interface.
39func (w *Wrapper) MarshalJSON() ([]byte, error) {
40	return w.Marshal()
41}
42
43func (w *Wrapper) MarshalPretty() string {
44	encoded, err := json.MarshalIndent(w.dat, "", "    ")
45	if err != nil {
46		return fmt.Sprintf("<bad JSON structure: %s>", err.Error())
47	} else {
48		return string(encoded)
49	}
50}
51
52func (w *Wrapper) MarshalToDebug() string {
53	buf, err := w.Marshal()
54	if err != nil {
55		return fmt.Sprintf("<bad JSON structure: %s>", err.Error())
56	} else {
57		return string(buf)
58	}
59}
60
61func Unmarshal(unsafeRaw []byte) (*Wrapper, error) {
62	return UnmarshalWithMaxDepth(unsafeRaw, defaultMaxDepth)
63}
64
65func UnmarshalWithMaxDepth(unsafeRaw []byte, maxDepth int) (*Wrapper, error) {
66	err := EnsureMaxDepth(bufio.NewReader(bytes.NewReader(unsafeRaw)), maxDepth)
67	if err != nil {
68		return nil, err
69	}
70	raw := unsafeRaw
71
72	var iface interface{}
73	dec := json.NewDecoder(bytes.NewReader(raw))
74	dec.UseNumber()
75	err = dec.Decode(&iface)
76	var ret *Wrapper
77	if err == nil {
78		ret = NewWrapper(iface)
79	}
80	return ret, err
81}
82
83func WrapperFromObject(obj interface{}) (*Wrapper, error) {
84	// Round tripping through []byte isn't very efficient. Is there a smarter way?
85	encoded, err := json.Marshal(obj)
86	if err != nil {
87		return nil, err
88	}
89	return Unmarshal(encoded)
90}
91
92func (e Error) Error() string { return e.msg }
93
94func (w *Wrapper) NewError(format string, a ...interface{}) *Error {
95	m1 := fmt.Sprintf(format, a...)
96	p := w.AccessPath()
97	m2 := fmt.Sprintf("%s: %s", p, m1)
98	return &Error{m2}
99}
100
101func (w *Wrapper) wrongType(want string, got reflect.Kind) *Error {
102	return w.NewError("type error: wanted %s, got %s", want, got)
103}
104
105func (i *Wrapper) getData() interface{} { return i.dat }
106func (i *Wrapper) IsOk() bool           { return i.Error() == nil }
107
108func (i *Wrapper) GetData() (dat interface{}, err error) {
109	if i.err != nil {
110		err = *i.err
111	} else {
112		dat = i.dat
113	}
114	return
115}
116
117func (i *Wrapper) GetDataVoid(dp *interface{}, ep *error) {
118	d, e := i.GetData()
119	if e == nil {
120		*dp = d
121	} else if e != nil && ep != nil && *ep == nil {
122		*ep = e
123	}
124
125}
126
127func (i *Wrapper) Error() (e error) {
128	if i.err != nil {
129		e = *i.err
130	}
131	return
132}
133
134func (i *Wrapper) GetDataOrNil() interface{} { return i.getData() }
135
136func NewWrapper(i interface{}) (rd *Wrapper) {
137	rd = new(Wrapper)
138	rd.dat = i
139	rd.access = make([]string, 1, 1)
140	rd.access[0] = "<root>"
141	return rd
142}
143
144// NewObjectWrapper takes a Go object that has JSON field struct annotations
145// and inserts it into a JSON wrapper. The serialization happens eagerly,
146// and the object is copied into the wrapper, so that subsequent updates to the
147// object will not be reflected in the Wrapper.
148func NewObjectWrapper(i interface{}) (*Wrapper, error) {
149	rd := NewDictionary()
150	err := (NewWrapper(i)).UnmarshalAgain(&rd.dat)
151	if err != nil {
152		rd = nil
153	}
154	return rd, err
155}
156
157func NewDictionary() *Wrapper {
158	m := make(map[string]interface{})
159	return NewWrapper(m)
160}
161
162func NewArray(l int) *Wrapper {
163	m := make([]interface{}, l)
164	return NewWrapper(m)
165}
166
167func NewNil() *Wrapper {
168	return NewWrapper(nil)
169}
170
171func NewInt(i int) *Wrapper {
172	return NewWrapper(i)
173}
174
175func NewInt64(i int64) *Wrapper {
176	return NewWrapper(i)
177}
178
179func NewFloat64(f float64) *Wrapper {
180	return NewWrapper(f)
181}
182
183func NewUint64(u uint64) *Wrapper {
184	return NewWrapper(u)
185}
186
187func NewString(s string) *Wrapper {
188	return NewWrapper(s)
189}
190
191func NewBool(b bool) *Wrapper {
192	return NewWrapper(b)
193}
194
195func isInt(v reflect.Value) bool {
196	k := v.Kind()
197	return k == reflect.Int || k == reflect.Int8 ||
198		k == reflect.Int16 || k == reflect.Int32 ||
199		k == reflect.Int64
200}
201
202func isUint(v reflect.Value) bool {
203	k := v.Kind()
204	return k == reflect.Uint || k == reflect.Uint8 ||
205		k == reflect.Uint16 || k == reflect.Uint32 ||
206		k == reflect.Uint64
207}
208
209func isFloat(v reflect.Value) bool {
210	k := v.Kind()
211	return k == reflect.Float32 || k == reflect.Float64
212}
213
214func (i *Wrapper) sameType(w *Wrapper) bool {
215	return reflect.ValueOf(i.dat).Kind() == reflect.ValueOf(w.dat).Kind()
216}
217
218func (i *Wrapper) AccessPath() string {
219	return strings.Join(i.access, "")
220}
221
222func (rd *Wrapper) GetFloat() (ret float64, err error) {
223	if rd.err != nil {
224		err = rd.err
225	} else if n, ok := rd.dat.(json.Number); ok {
226		ret, err = n.Float64()
227	} else if v := reflect.ValueOf(rd.dat); isFloat(v) {
228		ret = float64(v.Float())
229	} else if isInt(v) {
230		ret = float64(v.Int())
231	} else if isUint(v) {
232		ret = float64(v.Uint())
233	} else {
234		err = rd.wrongType("float-like", v.Kind())
235	}
236	return
237}
238
239func (w *Wrapper) GetFloatVoid(fp *float64, errp *error) {
240	f, e := w.GetFloat()
241	if e == nil {
242		*fp = f
243	} else if e != nil && errp != nil && *errp == nil {
244		*errp = e
245	}
246}
247
248func (rd *Wrapper) GetInt64() (ret int64, err error) {
249	if rd.err != nil {
250		err = rd.err
251	} else if n, ok := rd.dat.(json.Number); ok {
252		ret, err = n.Int64()
253	} else if v := reflect.ValueOf(rd.dat); isInt(v) {
254		ret = v.Int()
255	} else if isFloat(v) {
256		ret = int64(v.Float())
257	} else if !isUint(v) {
258		err = rd.wrongType("int", v.Kind())
259	} else if v.Uint() <= (1<<63 - 1) {
260		ret = int64(v.Uint())
261	} else {
262		err = rd.NewError("Signed int64 overflow error")
263	}
264	return
265}
266
267func (w *Wrapper) GetInt64Void(ip *int64, errp *error) {
268	i, e := w.GetInt64()
269	if e == nil {
270		*ip = i
271	} else if e != nil && errp != nil && *errp == nil {
272		*errp = e
273	}
274}
275
276func (rd *Wrapper) GetInt() (i int, err error) {
277	i64, e := rd.GetInt64()
278	return int(i64), e
279}
280
281func (w *Wrapper) GetIntVoid(ip *int, errp *error) {
282	i, e := w.GetInt()
283	if e == nil {
284		*ip = i
285	} else if e != nil && errp != nil && *errp == nil {
286		*errp = e
287	}
288}
289
290func (rd *Wrapper) GetUint() (u uint, err error) {
291	u64, e := rd.GetUint64()
292	return uint(u64), e
293}
294
295func (w *Wrapper) GetUintVoid(ip *uint, errp *error) {
296	i, e := w.GetUint()
297	if e == nil {
298		*ip = i
299	} else if e != nil && errp != nil && *errp == nil {
300		*errp = e
301	}
302}
303
304func (rd *Wrapper) GetUint64() (ret uint64, err error) {
305	underflow := false
306	if rd.err != nil {
307		err = rd.err
308	} else if n, ok := rd.dat.(json.Number); ok {
309		var tmp int64
310		if tmp, err = n.Int64(); err == nil && tmp < 0 {
311			underflow = true
312		} else if err == nil {
313			ret = uint64(tmp)
314		}
315	} else if v := reflect.ValueOf(rd.dat); isUint(v) {
316		ret = v.Uint()
317	} else if isFloat(v) {
318		if v.Float() < 0 {
319			underflow = true
320		} else {
321			ret = uint64(v.Float())
322		}
323	} else if !isInt(v) {
324		err = rd.wrongType("uint", v.Kind())
325	} else if v.Int() >= 0 {
326		ret = uint64(v.Int())
327	} else {
328		underflow = true
329	}
330
331	if underflow {
332		err = rd.NewError("Unsigned uint64 underflow error")
333
334	}
335	return
336}
337
338func (w *Wrapper) GetUint64Void(ip *uint64, errp *error) {
339	i, e := w.GetUint64()
340	if e == nil {
341		*ip = i
342	} else if e != nil && errp != nil && *errp == nil {
343		*errp = e
344	}
345}
346
347func (rd *Wrapper) GetInterface() (v interface{}, err error) {
348	if rd.err != nil {
349		err = rd.err
350	} else {
351		v = rd.dat
352	}
353	return v, err
354}
355
356func (rd *Wrapper) GetBool() (ret bool, err error) {
357	if rd.err != nil {
358		err = rd.err
359	} else {
360		v := reflect.ValueOf(rd.dat)
361		k := v.Kind()
362		if k == reflect.Bool {
363			ret = v.Bool()
364		} else {
365			err = rd.wrongType("bool", k)
366		}
367	}
368	return
369}
370
371func (w *Wrapper) GetBoolVoid(bp *bool, errp *error) {
372	b, e := w.GetBool()
373	if e == nil {
374		*bp = b
375	} else if e != nil && errp != nil && *errp == nil {
376		*errp = e
377	}
378}
379
380func (rd *Wrapper) GetString() (ret string, err error) {
381	if rd.err != nil {
382		err = rd.err
383	} else if v := reflect.ValueOf(rd.dat); v.Kind() == reflect.String {
384		ret = v.String()
385	} else if b, ok := rd.dat.([]uint8); ok {
386		ret = string(b)
387	} else if b, ok := rd.dat.([]byte); ok {
388		ret = string(b)
389	} else {
390		err = rd.wrongType("string", v.Kind())
391	}
392	return
393}
394
395func (rd *Wrapper) GetBytes() (ret []byte, err error) {
396	if rd.err != nil {
397		err = rd.err
398	} else if b, ok := rd.dat.([]byte); ok {
399		ret = b
400	} else {
401		err = rd.wrongType("[]byte", reflect.ValueOf(rd.dat).Kind())
402	}
403	return
404}
405
406func (w *Wrapper) GetBytesVoid(bp *[]byte, errp *error) {
407	b, e := w.GetBytes()
408	if e == nil {
409		*bp = b
410	} else if e != nil && errp != nil && *errp == nil {
411		*errp = e
412	}
413}
414
415func (w *Wrapper) GetStringVoid(sp *string, errp *error) {
416	s, e := w.GetString()
417	if e == nil {
418		*sp = s
419	} else if e != nil && errp != nil && *errp == nil {
420		*errp = e
421	}
422}
423
424func (rd *Wrapper) AtIndex(i int) *Wrapper {
425	ret, v := rd.asArray()
426	if v == nil {
427
428	} else if i < 0 {
429		ret.err = rd.NewError("index out of bounds %d < 0", i)
430	} else if len(v) <= i {
431		ret.err = rd.NewError("index out of bounds %d >= %d", i, len(v))
432	} else {
433		ret.dat = v[i]
434	}
435	ret.access = append(ret.access, fmt.Sprintf("[%d]", i))
436	return ret
437}
438
439func (rd *Wrapper) Len() (ret int, err error) {
440	tmp, v := rd.asArray()
441	if v == nil {
442		err = tmp.err
443	} else {
444		ret = len(v)
445	}
446	return
447}
448
449func (i *Wrapper) Keys() (v []string, err error) {
450	tmp, d := i.asDictionary()
451	if d == nil {
452		err = tmp.err
453	} else {
454		v = make([]string, len(d))
455		var i int = 0
456		for k, _ := range d {
457			v[i] = k
458			i++
459		}
460	}
461	return
462}
463
464func (i *Wrapper) asArray() (ret *Wrapper, v []interface{}) {
465	if i.err != nil {
466		ret = i
467	} else {
468		var ok bool
469		v, ok = (i.dat).([]interface{})
470		ret = new(Wrapper)
471		ret.access = i.access
472		if !ok {
473			ret.err = i.wrongType("array", reflect.ValueOf(i.dat).Kind())
474		}
475	}
476	return
477}
478
479func (rd *Wrapper) IsNil() bool {
480	return rd.dat == nil
481}
482
483func (rd *Wrapper) AtKey(s string) *Wrapper {
484	ret, d := rd.asDictionary()
485
486	ret.access = append(ret.access, fmt.Sprintf(".%s", s))
487	if d != nil {
488		val, found := d[s]
489		if found {
490			ret.dat = val
491		} else {
492			ret.dat = nil
493			ret.err = ret.NewError("no such key: %s", s)
494		}
495	}
496	return ret
497}
498
499func (rd *Wrapper) ToDictionary() (out *Wrapper, e error) {
500	tmp, _ := rd.asDictionary()
501	if tmp.err != nil {
502		e = tmp.err
503	} else {
504		out = rd
505	}
506	return
507}
508
509func (rd *Wrapper) ToArray() (out *Wrapper, e error) {
510	tmp, _ := rd.asArray()
511	if tmp.err != nil {
512		e = tmp.err
513	} else {
514		out = rd
515	}
516	return
517}
518
519func (w *Wrapper) SetKey(s string, val *Wrapper) error {
520	b, d := w.asDictionary()
521	if d != nil {
522		d[s] = val.getData()
523	}
524	return b.Error()
525}
526
527func (w *Wrapper) DeleteKey(s string) error {
528	b, d := w.asDictionary()
529	if d != nil {
530		delete(d, s)
531	}
532	return b.Error()
533}
534
535func (w *Wrapper) SetIndex(i int, val *Wrapper) error {
536	b, d := w.asArray()
537	if d != nil {
538		d[i] = val.getData()
539	}
540	return b.Error()
541
542}
543
544func (i *Wrapper) asDictionary() (ret *Wrapper, d map[string]interface{}) {
545	if i.err != nil {
546		ret = i
547	} else {
548		var ok bool
549		d, ok = (i.dat).(map[string]interface{})
550		ret = new(Wrapper)
551		ret.access = i.access
552		if !ok {
553			ret.err = i.wrongType("dict", reflect.ValueOf(i.dat).Kind())
554		}
555	}
556	return
557}
558
559func tryInt(bit string) (ret int, isInt bool) {
560	ret = 0
561	isInt = false
562	if len(bit) > 0 && (bit[0] >= '0' && bit[0] <= '9') {
563		// this is probably an int, use AtIndex instead
564		var e error
565		ret, e = strconv.Atoi(bit)
566		isInt = (e == nil)
567	}
568	return
569}
570
571func (w *Wrapper) AtPath(path string) (ret *Wrapper) {
572	bits := strings.Split(path, ".")
573	ret = w
574	for _, bit := range bits {
575		if val, isInt := tryInt(bit); isInt {
576			ret = ret.AtIndex(val)
577		} else if len(bit) > 0 {
578			ret = ret.AtKey(bit)
579		} else {
580			break
581		}
582
583		if ret.dat == nil || ret.err != nil {
584			break
585		}
586	}
587	return ret
588}
589
590func (w *Wrapper) AtPathGetInt(path string) (ret int, ok bool) {
591	tmp := w.AtPath(path)
592	if tmp != nil {
593		var err error
594		ret, err = tmp.GetInt()
595		ok = (err == nil)
596	} else {
597		ok = false
598	}
599	return
600}
601
602func (w *Wrapper) SetValueAtPath(path string, value *Wrapper) error {
603	bits := strings.Split(path, ".")
604	currW := w
605	var err error
606	for i, bit := range bits {
607		// at each key, create an empty dictionary if one doesn't exist yet
608		var nextVal, d *Wrapper
609		// if the next bit is an integer, and it's not the last key
610		// in the path, then the next value should be an array
611		if i == len(bits)-1 {
612			nextVal = value
613		} else if nextInt, nextIsInt := tryInt(bits[i+1]); nextIsInt {
614			// Default size of the array is just big enough to fit the next
615			// value.
616			nextVal = NewArray(nextInt + 1)
617		} else {
618			nextVal = NewDictionary()
619		}
620
621		// If we're looking at an index, treat like an array
622		if val, is_int := tryInt(bit); is_int {
623			d = currW.AtIndex(val)
624		} else {
625			d = currW.AtKey(bit)
626		}
627
628		// if we've hit nil or a wrong type of node, or the last bit,
629		// write in the correct value
630		if d.IsNil() || !d.sameType(nextVal) || i == len(bits)-1 {
631			d = nextVal
632			if val, is_int := tryInt(bit); is_int {
633				// TODO: resize array if it's not big enough?
634				err = currW.SetIndex(val, d)
635			} else {
636				err = currW.SetKey(bit, d)
637			}
638			if err != nil {
639				return err
640			}
641		}
642
643		currW = d
644	}
645
646	return err
647}
648
649func (w *Wrapper) DeleteValueAtPath(path string) error {
650	bits := strings.Split(path, ".")
651	currW := w
652	var err error
653	for _, bit := range bits[:len(bits)-1] {
654		//  if the any key on the path doesn't exist yet, we're done
655		// If we're looking at an index, treat like an array
656		var d *Wrapper
657		if val, is_int := tryInt(bit); is_int {
658			d = currW.AtIndex(val)
659		} else {
660			d = currW.AtKey(bit)
661		}
662
663		if d.IsNil() {
664			return nil
665		}
666
667		currW = d
668	}
669
670	lastBit := bits[len(bits)-1]
671	if val, is_int := tryInt(lastBit); is_int {
672		// can't do much for arrays besides just make it nil
673		err = currW.SetIndex(val, NewNil())
674	} else {
675		err = currW.DeleteKey(lastBit)
676	}
677	return err
678}
679
680func (w *Wrapper) UnmarshalAgain(i interface{}) (err error) {
681	var tmp []byte
682	if tmp, err = w.Marshal(); err != nil {
683		return
684	}
685	err = json.Unmarshal(tmp, i)
686	return
687}
688
689func Canonicalize(in []byte) ([]byte, error) {
690	if v, err := Unmarshal(in); err != nil {
691		return nil, err
692	} else if ret, err := v.Marshal(); err != nil {
693		return nil, err
694	} else {
695		return ret, nil
696	}
697}
698
699func (w *Wrapper) AssertEqAtPath(path string, obj *Wrapper, errp *error) {
700	v := w.AtPath(path)
701	if b1, err := v.Marshal(); err != nil {
702		*errp = err
703	} else if b2, err := w.Marshal(); err != nil {
704		*errp = err
705	} else if !bytes.Equal(b1, b2) {
706		err = fmt.Errorf("Equality assertion failed at %s: %s != %s",
707			path, string(b1), string(b2))
708	}
709	return
710}
711
712const JSONEscape = byte('\\')
713const JSONDoubleQuotationMark = byte('"')
714const JSONLeftSquareBracket = byte('[')
715const JSONLeftCurlyBracket = byte('{')
716const JSONRightSquareBracket = byte(']')
717const JSONRightCurlyBracket = byte('}')
718
719// ensureMaxDepth returns an error if raw represents a valid JSON string whose
720// deserialization's maximum depth exceeds maxDepth.
721// If raw represents an invalid JSON string with a prefix that is a valid JSON prefix
722// whose depth exceeds maxDepth, an error will be returned as well).
723// See https://github.com/golang/go/blob/master/src/encoding/json/decode.go#L96.
724// Otherwise, behavior is undefined and an error may or may not be returned.
725// See the spec at https://json.org.
726func EnsureMaxDepth(unsafeRawReader *bufio.Reader, maxDepth int) error {
727	depth := 1
728	inString := false
729	lastByteWasEscape := false
730	for {
731		b, err := unsafeRawReader.ReadByte()
732		switch err {
733		case io.EOF:
734			return nil
735		case nil:
736		default:
737			return err
738		}
739		if depth >= maxDepth {
740			return DepthError{fmt.Sprintf("Invalid JSON or exceeds max depth %d.", maxDepth)}
741		}
742		if inString {
743			if lastByteWasEscape {
744				// i.e., if the last byte was an escape, we are no longer in an
745				// escape sequence. This is not strictly true: JSON unicode codepoint
746				// escape sequences are of the form \uXXXX where X is a hexadecimal
747				// character. However since X cannot be JSONEscape or JSONDoubleQuotationMark
748				// in valid JSON, there is no problem: later there will be an
749				// error parsing the JSON and this will occur before maxDepth
750				// is reached in the JSON parser.
751				lastByteWasEscape = false
752			} else if b == JSONEscape {
753				lastByteWasEscape = true
754			} else if b == JSONDoubleQuotationMark {
755				inString = false
756			}
757		} else {
758			switch b {
759			case JSONDoubleQuotationMark:
760				inString = true
761			case JSONLeftSquareBracket, JSONLeftCurlyBracket:
762				depth += 1
763			case JSONRightSquareBracket, JSONRightCurlyBracket:
764				depth -= 1
765			}
766		}
767	}
768}
769
770func EnsureMaxDepthDefault(unsafeRawReader *bufio.Reader) error {
771	return EnsureMaxDepth(unsafeRawReader, defaultMaxDepth)
772}
773
774func EnsureMaxDepthBytesDefault(unsafeRaw []byte) error {
775	return EnsureMaxDepthDefault(bufio.NewReader(bytes.NewReader(unsafeRaw)))
776}
777