1package pgtype
2
3import (
4	"database/sql/driver"
5
6	"github.com/jackc/pgx/pgio"
7	"github.com/pkg/errors"
8)
9
10type Numrange struct {
11	Lower     Numeric
12	Upper     Numeric
13	LowerType BoundType
14	UpperType BoundType
15	Status    Status
16}
17
18func (dst *Numrange) Set(src interface{}) error {
19	return errors.Errorf("cannot convert %v to Numrange", src)
20}
21
22func (dst *Numrange) Get() interface{} {
23	switch dst.Status {
24	case Present:
25		return dst
26	case Null:
27		return nil
28	default:
29		return dst.Status
30	}
31}
32
33func (src *Numrange) AssignTo(dst interface{}) error {
34	return errors.Errorf("cannot assign %v to %T", src, dst)
35}
36
37func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error {
38	if src == nil {
39		*dst = Numrange{Status: Null}
40		return nil
41	}
42
43	utr, err := ParseUntypedTextRange(string(src))
44	if err != nil {
45		return err
46	}
47
48	*dst = Numrange{Status: Present}
49
50	dst.LowerType = utr.LowerType
51	dst.UpperType = utr.UpperType
52
53	if dst.LowerType == Empty {
54		return nil
55	}
56
57	if dst.LowerType == Inclusive || dst.LowerType == Exclusive {
58		if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil {
59			return err
60		}
61	}
62
63	if dst.UpperType == Inclusive || dst.UpperType == Exclusive {
64		if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil {
65			return err
66		}
67	}
68
69	return nil
70}
71
72func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error {
73	if src == nil {
74		*dst = Numrange{Status: Null}
75		return nil
76	}
77
78	ubr, err := ParseUntypedBinaryRange(src)
79	if err != nil {
80		return err
81	}
82
83	*dst = Numrange{Status: Present}
84
85	dst.LowerType = ubr.LowerType
86	dst.UpperType = ubr.UpperType
87
88	if dst.LowerType == Empty {
89		return nil
90	}
91
92	if dst.LowerType == Inclusive || dst.LowerType == Exclusive {
93		if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil {
94			return err
95		}
96	}
97
98	if dst.UpperType == Inclusive || dst.UpperType == Exclusive {
99		if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil {
100			return err
101		}
102	}
103
104	return nil
105}
106
107func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
108	switch src.Status {
109	case Null:
110		return nil, nil
111	case Undefined:
112		return nil, errUndefined
113	}
114
115	switch src.LowerType {
116	case Exclusive, Unbounded:
117		buf = append(buf, '(')
118	case Inclusive:
119		buf = append(buf, '[')
120	case Empty:
121		return append(buf, "empty"...), nil
122	default:
123		return nil, errors.Errorf("unknown lower bound type %v", src.LowerType)
124	}
125
126	var err error
127
128	if src.LowerType != Unbounded {
129		buf, err = src.Lower.EncodeText(ci, buf)
130		if err != nil {
131			return nil, err
132		} else if buf == nil {
133			return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded")
134		}
135	}
136
137	buf = append(buf, ',')
138
139	if src.UpperType != Unbounded {
140		buf, err = src.Upper.EncodeText(ci, buf)
141		if err != nil {
142			return nil, err
143		} else if buf == nil {
144			return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded")
145		}
146	}
147
148	switch src.UpperType {
149	case Exclusive, Unbounded:
150		buf = append(buf, ')')
151	case Inclusive:
152		buf = append(buf, ']')
153	default:
154		return nil, errors.Errorf("unknown upper bound type %v", src.UpperType)
155	}
156
157	return buf, nil
158}
159
160func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
161	switch src.Status {
162	case Null:
163		return nil, nil
164	case Undefined:
165		return nil, errUndefined
166	}
167
168	var rangeType byte
169	switch src.LowerType {
170	case Inclusive:
171		rangeType |= lowerInclusiveMask
172	case Unbounded:
173		rangeType |= lowerUnboundedMask
174	case Exclusive:
175	case Empty:
176		return append(buf, emptyMask), nil
177	default:
178		return nil, errors.Errorf("unknown LowerType: %v", src.LowerType)
179	}
180
181	switch src.UpperType {
182	case Inclusive:
183		rangeType |= upperInclusiveMask
184	case Unbounded:
185		rangeType |= upperUnboundedMask
186	case Exclusive:
187	default:
188		return nil, errors.Errorf("unknown UpperType: %v", src.UpperType)
189	}
190
191	buf = append(buf, rangeType)
192
193	var err error
194
195	if src.LowerType != Unbounded {
196		sp := len(buf)
197		buf = pgio.AppendInt32(buf, -1)
198
199		buf, err = src.Lower.EncodeBinary(ci, buf)
200		if err != nil {
201			return nil, err
202		}
203		if buf == nil {
204			return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded")
205		}
206
207		pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
208	}
209
210	if src.UpperType != Unbounded {
211		sp := len(buf)
212		buf = pgio.AppendInt32(buf, -1)
213
214		buf, err = src.Upper.EncodeBinary(ci, buf)
215		if err != nil {
216			return nil, err
217		}
218		if buf == nil {
219			return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded")
220		}
221
222		pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
223	}
224
225	return buf, nil
226}
227
228// Scan implements the database/sql Scanner interface.
229func (dst *Numrange) Scan(src interface{}) error {
230	if src == nil {
231		*dst = Numrange{Status: Null}
232		return nil
233	}
234
235	switch src := src.(type) {
236	case string:
237		return dst.DecodeText(nil, []byte(src))
238	case []byte:
239		srcCopy := make([]byte, len(src))
240		copy(srcCopy, src)
241		return dst.DecodeText(nil, srcCopy)
242	}
243
244	return errors.Errorf("cannot scan %T", src)
245}
246
247// Value implements the database/sql/driver Valuer interface.
248func (src Numrange) Value() (driver.Value, error) {
249	return EncodeValueText(src)
250}
251