1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsonrw
8
9import (
10	"encoding/base64"
11	"errors"
12	"fmt"
13	"math"
14	"strconv"
15	"time"
16
17	"go.mongodb.org/mongo-driver/bson/bsontype"
18	"go.mongodb.org/mongo-driver/bson/primitive"
19)
20
21func wrapperKeyBSONType(key string) bsontype.Type {
22	switch string(key) {
23	case "$numberInt":
24		return bsontype.Int32
25	case "$numberLong":
26		return bsontype.Int64
27	case "$oid":
28		return bsontype.ObjectID
29	case "$symbol":
30		return bsontype.Symbol
31	case "$numberDouble":
32		return bsontype.Double
33	case "$numberDecimal":
34		return bsontype.Decimal128
35	case "$binary":
36		return bsontype.Binary
37	case "$code":
38		return bsontype.JavaScript
39	case "$scope":
40		return bsontype.CodeWithScope
41	case "$timestamp":
42		return bsontype.Timestamp
43	case "$regularExpression":
44		return bsontype.Regex
45	case "$dbPointer":
46		return bsontype.DBPointer
47	case "$date":
48		return bsontype.DateTime
49	case "$ref":
50		fallthrough
51	case "$id":
52		fallthrough
53	case "$db":
54		return bsontype.EmbeddedDocument // dbrefs aren't bson types
55	case "$minKey":
56		return bsontype.MinKey
57	case "$maxKey":
58		return bsontype.MaxKey
59	case "$undefined":
60		return bsontype.Undefined
61	}
62
63	return bsontype.EmbeddedDocument
64}
65
66func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
67	if ejv.t != bsontype.EmbeddedDocument {
68		return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t)
69	}
70
71	binObj := ejv.v.(*extJSONObject)
72	bFound := false
73	stFound := false
74
75	for i, key := range binObj.keys {
76		val := binObj.values[i]
77
78		switch key {
79		case "base64":
80			if bFound {
81				return nil, 0, errors.New("duplicate base64 key in $binary")
82			}
83
84			if val.t != bsontype.String {
85				return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t)
86			}
87
88			base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string))
89			if err != nil {
90				return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string))
91			}
92
93			b = base64Bytes
94			bFound = true
95		case "subType":
96			if stFound {
97				return nil, 0, errors.New("duplicate subType key in $binary")
98			}
99
100			if val.t != bsontype.String {
101				return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
102			}
103
104			i, err := strconv.ParseInt(val.v.(string), 16, 64)
105			if err != nil {
106				return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string))
107			}
108
109			subType = byte(i)
110			stFound = true
111		default:
112			return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key)
113		}
114	}
115
116	if !bFound {
117		return nil, 0, errors.New("missing base64 field in $binary object")
118	}
119
120	if !stFound {
121		return nil, 0, errors.New("missing subType field in $binary object")
122
123	}
124
125	return b, subType, nil
126}
127
128func (ejv *extJSONValue) parseDBPointer() (ns string, oid primitive.ObjectID, err error) {
129	if ejv.t != bsontype.EmbeddedDocument {
130		return "", primitive.NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t)
131	}
132
133	dbpObj := ejv.v.(*extJSONObject)
134	oidFound := false
135	nsFound := false
136
137	for i, key := range dbpObj.keys {
138		val := dbpObj.values[i]
139
140		switch key {
141		case "$ref":
142			if nsFound {
143				return "", primitive.NilObjectID, errors.New("duplicate $ref key in $dbPointer")
144			}
145
146			if val.t != bsontype.String {
147				return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t)
148			}
149
150			ns = val.v.(string)
151			nsFound = true
152		case "$id":
153			if oidFound {
154				return "", primitive.NilObjectID, errors.New("duplicate $id key in $dbPointer")
155			}
156
157			if val.t != bsontype.String {
158				return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t)
159			}
160
161			oid, err = primitive.ObjectIDFromHex(val.v.(string))
162			if err != nil {
163				return "", primitive.NilObjectID, err
164			}
165
166			oidFound = true
167		default:
168			return "", primitive.NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key)
169		}
170	}
171
172	if !nsFound {
173		return "", oid, errors.New("missing $ref field in $dbPointer object")
174	}
175
176	if !oidFound {
177		return "", oid, errors.New("missing $id field in $dbPointer object")
178	}
179
180	return ns, oid, nil
181}
182
183const (
184	rfc3339Milli = "2006-01-02T15:04:05.999Z07:00"
185)
186
187var (
188	timeFormats = []string{rfc3339Milli, "2006-01-02T15:04:05.999Z0700"}
189)
190
191func (ejv *extJSONValue) parseDateTime() (int64, error) {
192	switch ejv.t {
193	case bsontype.Int32:
194		return int64(ejv.v.(int32)), nil
195	case bsontype.Int64:
196		return ejv.v.(int64), nil
197	case bsontype.String:
198		return parseDatetimeString(ejv.v.(string))
199	case bsontype.EmbeddedDocument:
200		return parseDatetimeObject(ejv.v.(*extJSONObject))
201	default:
202		return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t)
203	}
204}
205
206func parseDatetimeString(data string) (int64, error) {
207	var t time.Time
208	var err error
209	// try acceptable time formats until one matches
210	for _, format := range timeFormats {
211		t, err = time.Parse(format, data)
212		if err == nil {
213			break
214		}
215	}
216	if err != nil {
217		return 0, fmt.Errorf("invalid $date value string: %s", data)
218	}
219
220	return int64(primitive.NewDateTimeFromTime(t)), nil
221}
222
223func parseDatetimeObject(data *extJSONObject) (d int64, err error) {
224	dFound := false
225
226	for i, key := range data.keys {
227		val := data.values[i]
228
229		switch key {
230		case "$numberLong":
231			if dFound {
232				return 0, errors.New("duplicate $numberLong key in $date")
233			}
234
235			if val.t != bsontype.String {
236				return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t)
237			}
238
239			d, err = val.parseInt64()
240			if err != nil {
241				return 0, err
242			}
243			dFound = true
244		default:
245			return 0, fmt.Errorf("invalid key in $date object: %s", key)
246		}
247	}
248
249	if !dFound {
250		return 0, errors.New("missing $numberLong field in $date object")
251	}
252
253	return d, nil
254}
255
256func (ejv *extJSONValue) parseDecimal128() (primitive.Decimal128, error) {
257	if ejv.t != bsontype.String {
258		return primitive.Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t)
259	}
260
261	d, err := primitive.ParseDecimal128(ejv.v.(string))
262	if err != nil {
263		return primitive.Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string))
264	}
265
266	return d, nil
267}
268
269func (ejv *extJSONValue) parseDouble() (float64, error) {
270	if ejv.t == bsontype.Double {
271		return ejv.v.(float64), nil
272	}
273
274	if ejv.t != bsontype.String {
275		return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t)
276	}
277
278	switch string(ejv.v.(string)) {
279	case "Infinity":
280		return math.Inf(1), nil
281	case "-Infinity":
282		return math.Inf(-1), nil
283	case "NaN":
284		return math.NaN(), nil
285	}
286
287	f, err := strconv.ParseFloat(ejv.v.(string), 64)
288	if err != nil {
289		return 0, err
290	}
291
292	return f, nil
293}
294
295func (ejv *extJSONValue) parseInt32() (int32, error) {
296	if ejv.t == bsontype.Int32 {
297		return ejv.v.(int32), nil
298	}
299
300	if ejv.t != bsontype.String {
301		return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t)
302	}
303
304	i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
305	if err != nil {
306		return 0, err
307	}
308
309	if i < math.MinInt32 || i > math.MaxInt32 {
310		return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i)
311	}
312
313	return int32(i), nil
314}
315
316func (ejv *extJSONValue) parseInt64() (int64, error) {
317	if ejv.t == bsontype.Int64 {
318		return ejv.v.(int64), nil
319	}
320
321	if ejv.t != bsontype.String {
322		return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t)
323	}
324
325	i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
326	if err != nil {
327		return 0, err
328	}
329
330	return i, nil
331}
332
333func (ejv *extJSONValue) parseJavascript() (code string, err error) {
334	if ejv.t != bsontype.String {
335		return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t)
336	}
337
338	return ejv.v.(string), nil
339}
340
341func (ejv *extJSONValue) parseMinMaxKey(minmax string) error {
342	if ejv.t != bsontype.Int32 {
343		return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t)
344	}
345
346	if ejv.v.(int32) != 1 {
347		return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32))
348	}
349
350	return nil
351}
352
353func (ejv *extJSONValue) parseObjectID() (primitive.ObjectID, error) {
354	if ejv.t != bsontype.String {
355		return primitive.NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t)
356	}
357
358	return primitive.ObjectIDFromHex(ejv.v.(string))
359}
360
361func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) {
362	if ejv.t != bsontype.EmbeddedDocument {
363		return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t)
364	}
365
366	regexObj := ejv.v.(*extJSONObject)
367	patFound := false
368	optFound := false
369
370	for i, key := range regexObj.keys {
371		val := regexObj.values[i]
372
373		switch string(key) {
374		case "pattern":
375			if patFound {
376				return "", "", errors.New("duplicate pattern key in $regularExpression")
377			}
378
379			if val.t != bsontype.String {
380				return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t)
381			}
382
383			pattern = val.v.(string)
384			patFound = true
385		case "options":
386			if optFound {
387				return "", "", errors.New("duplicate options key in $regularExpression")
388			}
389
390			if val.t != bsontype.String {
391				return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t)
392			}
393
394			options = val.v.(string)
395			optFound = true
396		default:
397			return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key)
398		}
399	}
400
401	if !patFound {
402		return "", "", errors.New("missing pattern field in $regularExpression object")
403	}
404
405	if !optFound {
406		return "", "", errors.New("missing options field in $regularExpression object")
407
408	}
409
410	return pattern, options, nil
411}
412
413func (ejv *extJSONValue) parseSymbol() (string, error) {
414	if ejv.t != bsontype.String {
415		return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t)
416	}
417
418	return ejv.v.(string), nil
419}
420
421func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) {
422	if ejv.t != bsontype.EmbeddedDocument {
423		return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t)
424	}
425
426	handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) {
427		if flag {
428			return 0, fmt.Errorf("duplicate %s key in $timestamp", key)
429		}
430
431		switch val.t {
432		case bsontype.Int32:
433			value := val.v.(int32)
434
435			if value < 0 {
436				return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
437			}
438
439			return uint32(value), nil
440		case bsontype.Int64:
441			value := val.v.(int64)
442			if value < 0 || value > int64(math.MaxUint32) {
443				return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
444			}
445
446			return uint32(value), nil
447		default:
448			return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t)
449		}
450	}
451
452	tsObj := ejv.v.(*extJSONObject)
453	tFound := false
454	iFound := false
455
456	for j, key := range tsObj.keys {
457		val := tsObj.values[j]
458
459		switch key {
460		case "t":
461			if t, err = handleKey(key, val, tFound); err != nil {
462				return 0, 0, err
463			}
464
465			tFound = true
466		case "i":
467			if i, err = handleKey(key, val, iFound); err != nil {
468				return 0, 0, err
469			}
470
471			iFound = true
472		default:
473			return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key)
474		}
475	}
476
477	if !tFound {
478		return 0, 0, errors.New("missing t field in $timestamp object")
479	}
480
481	if !iFound {
482		return 0, 0, errors.New("missing i field in $timestamp object")
483	}
484
485	return t, i, nil
486}
487
488func (ejv *extJSONValue) parseUndefined() error {
489	if ejv.t != bsontype.Boolean {
490		return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t)
491	}
492
493	if !ejv.v.(bool) {
494		return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool))
495	}
496
497	return nil
498}
499