1package pgtype
2
3import (
4	"database/sql/driver"
5	"encoding/json"
6	"errors"
7	"fmt"
8)
9
10type JSON struct {
11	Bytes  []byte
12	Status Status
13}
14
15func (dst *JSON) Set(src interface{}) error {
16	if src == nil {
17		*dst = JSON{Status: Null}
18		return nil
19	}
20
21	if value, ok := src.(interface{ Get() interface{} }); ok {
22		value2 := value.Get()
23		if value2 != value {
24			return dst.Set(value2)
25		}
26	}
27
28	switch value := src.(type) {
29	case string:
30		*dst = JSON{Bytes: []byte(value), Status: Present}
31	case *string:
32		if value == nil {
33			*dst = JSON{Status: Null}
34		} else {
35			*dst = JSON{Bytes: []byte(*value), Status: Present}
36		}
37	case []byte:
38		if value == nil {
39			*dst = JSON{Status: Null}
40		} else {
41			*dst = JSON{Bytes: value, Status: Present}
42		}
43	// Encode* methods are defined on *JSON. If JSON is passed directly then the
44	// struct itself would be encoded instead of Bytes. This is clearly a footgun
45	// so detect and return an error. See https://github.com/jackc/pgx/issues/350.
46	case JSON:
47		return errors.New("use pointer to pgtype.JSON instead of value")
48	// Same as above but for JSONB (because they share implementation)
49	case JSONB:
50		return errors.New("use pointer to pgtype.JSONB instead of value")
51
52	default:
53		buf, err := json.Marshal(value)
54		if err != nil {
55			return err
56		}
57		*dst = JSON{Bytes: buf, Status: Present}
58	}
59
60	return nil
61}
62
63func (dst JSON) Get() interface{} {
64	switch dst.Status {
65	case Present:
66		var i interface{}
67		err := json.Unmarshal(dst.Bytes, &i)
68		if err != nil {
69			return dst
70		}
71		return i
72	case Null:
73		return nil
74	default:
75		return dst.Status
76	}
77}
78
79func (src *JSON) AssignTo(dst interface{}) error {
80	switch v := dst.(type) {
81	case *string:
82		if src.Status == Present {
83			*v = string(src.Bytes)
84		} else {
85			return fmt.Errorf("cannot assign non-present status to %T", dst)
86		}
87	case **string:
88		if src.Status == Present {
89			s := string(src.Bytes)
90			*v = &s
91			return nil
92		} else {
93			*v = nil
94			return nil
95		}
96	case *[]byte:
97		if src.Status != Present {
98			*v = nil
99		} else {
100			buf := make([]byte, len(src.Bytes))
101			copy(buf, src.Bytes)
102			*v = buf
103		}
104	default:
105		data := src.Bytes
106		if data == nil || src.Status != Present {
107			data = []byte("null")
108		}
109
110		return json.Unmarshal(data, dst)
111	}
112
113	return nil
114}
115
116func (JSON) PreferredResultFormat() int16 {
117	return TextFormatCode
118}
119
120func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error {
121	if src == nil {
122		*dst = JSON{Status: Null}
123		return nil
124	}
125
126	*dst = JSON{Bytes: src, Status: Present}
127	return nil
128}
129
130func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error {
131	return dst.DecodeText(ci, src)
132}
133
134func (JSON) PreferredParamFormat() int16 {
135	return TextFormatCode
136}
137
138func (src JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
139	switch src.Status {
140	case Null:
141		return nil, nil
142	case Undefined:
143		return nil, errUndefined
144	}
145
146	return append(buf, src.Bytes...), nil
147}
148
149func (src JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
150	return src.EncodeText(ci, buf)
151}
152
153// Scan implements the database/sql Scanner interface.
154func (dst *JSON) Scan(src interface{}) error {
155	if src == nil {
156		*dst = JSON{Status: Null}
157		return nil
158	}
159
160	switch src := src.(type) {
161	case string:
162		return dst.DecodeText(nil, []byte(src))
163	case []byte:
164		srcCopy := make([]byte, len(src))
165		copy(srcCopy, src)
166		return dst.DecodeText(nil, srcCopy)
167	}
168
169	return fmt.Errorf("cannot scan %T", src)
170}
171
172// Value implements the database/sql/driver Valuer interface.
173func (src JSON) Value() (driver.Value, error) {
174	switch src.Status {
175	case Present:
176		return src.Bytes, nil
177	case Null:
178		return nil, nil
179	default:
180		return nil, errUndefined
181	}
182}
183
184func (src JSON) MarshalJSON() ([]byte, error) {
185	switch src.Status {
186	case Present:
187		return src.Bytes, nil
188	case Null:
189		return []byte("null"), nil
190	case Undefined:
191		return nil, errUndefined
192	}
193
194	return nil, errBadStatus
195}
196
197func (dst *JSON) UnmarshalJSON(b []byte) error {
198	if b == nil || string(b) == "null" {
199		*dst = JSON{Status: Null}
200	} else {
201		*dst = JSON{Bytes: b, Status: Present}
202	}
203	return nil
204
205}
206