1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package firestore
16
17import (
18	"bytes"
19	"fmt"
20	"math"
21	"sort"
22	"strings"
23
24	tspb "github.com/golang/protobuf/ptypes/timestamp"
25	pb "google.golang.org/genproto/googleapis/firestore/v1"
26)
27
28// Returns a negative number, zero, or a positive number depending on whether a is
29// less than, equal to, or greater than b according to Firestore's ordering of
30// values.
31func compareValues(a, b *pb.Value) int {
32	ta := typeOrder(a)
33	tb := typeOrder(b)
34	if ta != tb {
35		return compareInt64s(int64(ta), int64(tb))
36	}
37	switch a := a.ValueType.(type) {
38	case *pb.Value_NullValue:
39		return 0 // nulls are equal
40
41	case *pb.Value_BooleanValue:
42		av := a.BooleanValue
43		bv := b.GetBooleanValue()
44		switch {
45		case av && !bv:
46			return 1
47		case bv && !av:
48			return -1
49		default:
50			return 0
51		}
52
53	case *pb.Value_IntegerValue:
54		return compareNumbers(float64(a.IntegerValue), toFloat(b))
55
56	case *pb.Value_DoubleValue:
57		return compareNumbers(a.DoubleValue, toFloat(b))
58
59	case *pb.Value_TimestampValue:
60		return compareTimestamps(a.TimestampValue, b.GetTimestampValue())
61
62	case *pb.Value_StringValue:
63		return strings.Compare(a.StringValue, b.GetStringValue())
64
65	case *pb.Value_BytesValue:
66		return bytes.Compare(a.BytesValue, b.GetBytesValue())
67
68	case *pb.Value_ReferenceValue:
69		return compareReferences(a.ReferenceValue, b.GetReferenceValue())
70
71	case *pb.Value_GeoPointValue:
72		ag := a.GeoPointValue
73		bg := b.GetGeoPointValue()
74		if ag.Latitude != bg.Latitude {
75			return compareFloat64s(ag.Latitude, bg.Latitude)
76		}
77		return compareFloat64s(ag.Longitude, bg.Longitude)
78
79	case *pb.Value_ArrayValue:
80		return compareArrays(a.ArrayValue.Values, b.GetArrayValue().Values)
81
82	case *pb.Value_MapValue:
83		return compareMaps(a.MapValue.Fields, b.GetMapValue().Fields)
84
85	default:
86		panic(fmt.Sprintf("bad value type: %v", a))
87	}
88}
89
90// Treats NaN as less than any non-NaN.
91func compareNumbers(a, b float64) int {
92	switch {
93	case math.IsNaN(a):
94		if math.IsNaN(b) {
95			return 0
96		}
97		return -1
98	case math.IsNaN(b):
99		return 1
100	default:
101		return compareFloat64s(a, b)
102	}
103}
104
105// Return v as a float64, assuming it's an Integer or Double.
106func toFloat(v *pb.Value) float64 {
107	if x, ok := v.ValueType.(*pb.Value_IntegerValue); ok {
108		return float64(x.IntegerValue)
109	}
110	return v.GetDoubleValue()
111}
112
113func compareTimestamps(a, b *tspb.Timestamp) int {
114	if c := compareInt64s(a.Seconds, b.Seconds); c != 0 {
115		return c
116	}
117	return compareInt64s(int64(a.Nanos), int64(b.Nanos))
118}
119
120func compareReferences(a, b string) int {
121	// Compare path components lexicographically.
122	pa := strings.Split(a, "/")
123	pb := strings.Split(b, "/")
124	return compareSequences(len(pa), len(pb), func(i int) int {
125		return strings.Compare(pa[i], pb[i])
126	})
127}
128
129func compareArrays(a, b []*pb.Value) int {
130	return compareSequences(len(a), len(b), func(i int) int {
131		return compareValues(a[i], b[i])
132	})
133}
134
135func compareMaps(a, b map[string]*pb.Value) int {
136	sortedKeys := func(m map[string]*pb.Value) []string {
137		var ks []string
138		for k := range m {
139			ks = append(ks, k)
140		}
141		sort.Strings(ks)
142		return ks
143	}
144
145	aks := sortedKeys(a)
146	bks := sortedKeys(b)
147	return compareSequences(len(aks), len(bks), func(i int) int {
148		if c := strings.Compare(aks[i], bks[i]); c != 0 {
149			return c
150		}
151		k := aks[i]
152		return compareValues(a[k], b[k])
153	})
154}
155
156func compareSequences(len1, len2 int, compare func(int) int) int {
157	for i := 0; i < len1 && i < len2; i++ {
158		if c := compare(i); c != 0 {
159			return c
160		}
161	}
162	return compareInt64s(int64(len1), int64(len2))
163}
164
165func compareFloat64s(a, b float64) int {
166	switch {
167	case a < b:
168		return -1
169	case a > b:
170		return 1
171	default:
172		return 0
173	}
174}
175
176func compareInt64s(a, b int64) int {
177	switch {
178	case a < b:
179		return -1
180	case a > b:
181		return 1
182	default:
183		return 0
184	}
185}
186
187// Return an integer corresponding to the type of value stored in v, such that
188// comparing the resulting integers gives the Firestore ordering for types.
189func typeOrder(v *pb.Value) int {
190	switch v.ValueType.(type) {
191	case *pb.Value_NullValue:
192		return 0
193	case *pb.Value_BooleanValue:
194		return 1
195	case *pb.Value_IntegerValue:
196		return 2
197	case *pb.Value_DoubleValue:
198		return 2
199	case *pb.Value_TimestampValue:
200		return 3
201	case *pb.Value_StringValue:
202		return 4
203	case *pb.Value_BytesValue:
204		return 5
205	case *pb.Value_ReferenceValue:
206		return 6
207	case *pb.Value_GeoPointValue:
208		return 7
209	case *pb.Value_ArrayValue:
210		return 8
211	case *pb.Value_MapValue:
212		return 9
213	default:
214		panic(fmt.Sprintf("bad value type: %v", v))
215	}
216}
217