1package pgtype
2
3import (
4	"database/sql/driver"
5	"encoding/binary"
6	"math"
7	"strconv"
8
9	"github.com/jackc/pgio"
10	errors "golang.org/x/xerrors"
11)
12
13type Int2 struct {
14	Int    int16
15	Status Status
16}
17
18func (dst *Int2) Set(src interface{}) error {
19	if src == nil {
20		*dst = Int2{Status: Null}
21		return nil
22	}
23
24	if value, ok := src.(interface{ Get() interface{} }); ok {
25		value2 := value.Get()
26		if value2 != value {
27			return dst.Set(value2)
28		}
29	}
30
31	switch value := src.(type) {
32	case int8:
33		*dst = Int2{Int: int16(value), Status: Present}
34	case uint8:
35		*dst = Int2{Int: int16(value), Status: Present}
36	case int16:
37		*dst = Int2{Int: int16(value), Status: Present}
38	case uint16:
39		if value > math.MaxInt16 {
40			return errors.Errorf("%d is greater than maximum value for Int2", value)
41		}
42		*dst = Int2{Int: int16(value), Status: Present}
43	case int32:
44		if value < math.MinInt16 {
45			return errors.Errorf("%d is greater than maximum value for Int2", value)
46		}
47		if value > math.MaxInt16 {
48			return errors.Errorf("%d is greater than maximum value for Int2", value)
49		}
50		*dst = Int2{Int: int16(value), Status: Present}
51	case uint32:
52		if value > math.MaxInt16 {
53			return errors.Errorf("%d is greater than maximum value for Int2", value)
54		}
55		*dst = Int2{Int: int16(value), Status: Present}
56	case int64:
57		if value < math.MinInt16 {
58			return errors.Errorf("%d is greater than maximum value for Int2", value)
59		}
60		if value > math.MaxInt16 {
61			return errors.Errorf("%d is greater than maximum value for Int2", value)
62		}
63		*dst = Int2{Int: int16(value), Status: Present}
64	case uint64:
65		if value > math.MaxInt16 {
66			return errors.Errorf("%d is greater than maximum value for Int2", value)
67		}
68		*dst = Int2{Int: int16(value), Status: Present}
69	case int:
70		if value < math.MinInt16 {
71			return errors.Errorf("%d is greater than maximum value for Int2", value)
72		}
73		if value > math.MaxInt16 {
74			return errors.Errorf("%d is greater than maximum value for Int2", value)
75		}
76		*dst = Int2{Int: int16(value), Status: Present}
77	case uint:
78		if value > math.MaxInt16 {
79			return errors.Errorf("%d is greater than maximum value for Int2", value)
80		}
81		*dst = Int2{Int: int16(value), Status: Present}
82	case string:
83		num, err := strconv.ParseInt(value, 10, 16)
84		if err != nil {
85			return err
86		}
87		*dst = Int2{Int: int16(num), Status: Present}
88	case *int8:
89		if value == nil {
90			*dst = Int2{Status: Null}
91		} else {
92			return dst.Set(*value)
93		}
94	case *uint8:
95		if value == nil {
96			*dst = Int2{Status: Null}
97		} else {
98			return dst.Set(*value)
99		}
100	case *int16:
101		if value == nil {
102			*dst = Int2{Status: Null}
103		} else {
104			return dst.Set(*value)
105		}
106	case *uint16:
107		if value == nil {
108			*dst = Int2{Status: Null}
109		} else {
110			return dst.Set(*value)
111		}
112	case *int32:
113		if value == nil {
114			*dst = Int2{Status: Null}
115		} else {
116			return dst.Set(*value)
117		}
118	case *uint32:
119		if value == nil {
120			*dst = Int2{Status: Null}
121		} else {
122			return dst.Set(*value)
123		}
124	case *int64:
125		if value == nil {
126			*dst = Int2{Status: Null}
127		} else {
128			return dst.Set(*value)
129		}
130	case *uint64:
131		if value == nil {
132			*dst = Int2{Status: Null}
133		} else {
134			return dst.Set(*value)
135		}
136	case *int:
137		if value == nil {
138			*dst = Int2{Status: Null}
139		} else {
140			return dst.Set(*value)
141		}
142	case *uint:
143		if value == nil {
144			*dst = Int2{Status: Null}
145		} else {
146			return dst.Set(*value)
147		}
148	case *string:
149		if value == nil {
150			*dst = Int2{Status: Null}
151		} else {
152			return dst.Set(*value)
153		}
154	default:
155		if originalSrc, ok := underlyingNumberType(src); ok {
156			return dst.Set(originalSrc)
157		}
158		return errors.Errorf("cannot convert %v to Int2", value)
159	}
160
161	return nil
162}
163
164func (dst Int2) Get() interface{} {
165	switch dst.Status {
166	case Present:
167		return dst.Int
168	case Null:
169		return nil
170	default:
171		return dst.Status
172	}
173}
174
175func (src *Int2) AssignTo(dst interface{}) error {
176	return int64AssignTo(int64(src.Int), src.Status, dst)
177}
178
179func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error {
180	if src == nil {
181		*dst = Int2{Status: Null}
182		return nil
183	}
184
185	n, err := strconv.ParseInt(string(src), 10, 16)
186	if err != nil {
187		return err
188	}
189
190	*dst = Int2{Int: int16(n), Status: Present}
191	return nil
192}
193
194func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error {
195	if src == nil {
196		*dst = Int2{Status: Null}
197		return nil
198	}
199
200	if len(src) != 2 {
201		return errors.Errorf("invalid length for int2: %v", len(src))
202	}
203
204	n := int16(binary.BigEndian.Uint16(src))
205	*dst = Int2{Int: n, Status: Present}
206	return nil
207}
208
209func (src Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
210	switch src.Status {
211	case Null:
212		return nil, nil
213	case Undefined:
214		return nil, errUndefined
215	}
216
217	return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil
218}
219
220func (src Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
221	switch src.Status {
222	case Null:
223		return nil, nil
224	case Undefined:
225		return nil, errUndefined
226	}
227
228	return pgio.AppendInt16(buf, src.Int), nil
229}
230
231// Scan implements the database/sql Scanner interface.
232func (dst *Int2) Scan(src interface{}) error {
233	if src == nil {
234		*dst = Int2{Status: Null}
235		return nil
236	}
237
238	switch src := src.(type) {
239	case int64:
240		if src < math.MinInt16 {
241			return errors.Errorf("%d is greater than maximum value for Int2", src)
242		}
243		if src > math.MaxInt16 {
244			return errors.Errorf("%d is greater than maximum value for Int2", src)
245		}
246		*dst = Int2{Int: int16(src), Status: Present}
247		return nil
248	case string:
249		return dst.DecodeText(nil, []byte(src))
250	case []byte:
251		srcCopy := make([]byte, len(src))
252		copy(srcCopy, src)
253		return dst.DecodeText(nil, srcCopy)
254	}
255
256	return errors.Errorf("cannot scan %T", src)
257}
258
259// Value implements the database/sql/driver Valuer interface.
260func (src Int2) Value() (driver.Value, error) {
261	switch src.Status {
262	case Present:
263		return int64(src.Int), nil
264	case Null:
265		return nil, nil
266	default:
267		return nil, errUndefined
268	}
269}
270
271func (src Int2) MarshalJSON() ([]byte, error) {
272	switch src.Status {
273	case Present:
274		return []byte(strconv.FormatInt(int64(src.Int), 10)), nil
275	case Null:
276		return []byte("null"), nil
277	case Undefined:
278		return nil, errUndefined
279	}
280
281	return nil, errBadStatus
282}
283