1package jsoninfo
2
3import (
4	"encoding/json"
5	"fmt"
6	"reflect"
7)
8
9// UnmarshalStrictStruct function:
10//   * Unmarshals struct fields, ignoring UnmarshalJSON(...) and fields without 'json' tag.
11//   * Correctly handles StrictStruct
12func UnmarshalStrictStruct(data []byte, value StrictStruct) error {
13	decoder, err := NewObjectDecoder(data)
14	if err != nil {
15		return err
16	}
17	return value.DecodeWith(decoder, value)
18}
19
20type ObjectDecoder struct {
21	Data            []byte
22	remainingFields map[string]json.RawMessage
23}
24
25func NewObjectDecoder(data []byte) (*ObjectDecoder, error) {
26	var remainingFields map[string]json.RawMessage
27	if err := json.Unmarshal(data, &remainingFields); err != nil {
28		return nil, fmt.Errorf("failed to unmarshal extension properties: %v (%s)", err, data)
29	}
30	return &ObjectDecoder{
31		Data:            data,
32		remainingFields: remainingFields,
33	}, nil
34}
35
36// DecodeExtensionMap returns all properties that were not decoded previously.
37func (decoder *ObjectDecoder) DecodeExtensionMap() map[string]json.RawMessage {
38	return decoder.remainingFields
39}
40
41func (decoder *ObjectDecoder) DecodeStructFieldsAndExtensions(value interface{}) error {
42	reflection := reflect.ValueOf(value)
43	if reflection.Kind() != reflect.Ptr {
44		panic(fmt.Errorf("value %T is not a pointer", value))
45	}
46	if reflection.IsNil() {
47		panic(fmt.Errorf("value %T is nil", value))
48	}
49	reflection = reflection.Elem()
50	for (reflection.Kind() == reflect.Interface || reflection.Kind() == reflect.Ptr) && !reflection.IsNil() {
51		reflection = reflection.Elem()
52	}
53	reflectionType := reflection.Type()
54	if reflectionType.Kind() != reflect.Struct {
55		panic(fmt.Errorf("value %T is not a struct", value))
56	}
57	typeInfo := GetTypeInfo(reflectionType)
58
59	// Supported fields
60	fields := typeInfo.Fields
61	remainingFields := decoder.remainingFields
62	for fieldIndex, field := range fields {
63		// Fields without JSON tag are ignored
64		if !field.HasJSONTag {
65			continue
66		}
67
68		// Get data
69		fieldData, exists := remainingFields[field.JSONName]
70		if !exists {
71			continue
72		}
73
74		// Unmarshal
75		if field.TypeIsUnmarshaller {
76			fieldType := field.Type
77			isPtr := false
78			if fieldType.Kind() == reflect.Ptr {
79				fieldType = fieldType.Elem()
80				isPtr = true
81			}
82			fieldValue := reflect.New(fieldType)
83			if err := fieldValue.Interface().(json.Unmarshaler).UnmarshalJSON(fieldData); err != nil {
84				if field.MultipleFields {
85					i := fieldIndex + 1
86					if i < len(fields) && fields[i].JSONName == field.JSONName {
87						continue
88					}
89				}
90				return fmt.Errorf("failed to unmarshal property %q (%s): %v",
91					field.JSONName, fieldValue.Type().String(), err)
92			}
93			if !isPtr {
94				fieldValue = fieldValue.Elem()
95			}
96			reflection.FieldByIndex(field.Index).Set(fieldValue)
97
98			// Remove the field from remaining fields
99			delete(remainingFields, field.JSONName)
100		} else {
101			fieldPtr := reflection.FieldByIndex(field.Index)
102			if fieldPtr.Kind() != reflect.Ptr || fieldPtr.IsNil() {
103				fieldPtr = fieldPtr.Addr()
104			}
105			if err := json.Unmarshal(fieldData, fieldPtr.Interface()); err != nil {
106				if field.MultipleFields {
107					i := fieldIndex + 1
108					if i < len(fields) && fields[i].JSONName == field.JSONName {
109						continue
110					}
111				}
112				return fmt.Errorf("failed to unmarshal property %q (%s): %v",
113					field.JSONName, fieldPtr.Type().String(), err)
114			}
115
116			// Remove the field from remaining fields
117			delete(remainingFields, field.JSONName)
118		}
119	}
120	return nil
121}
122