1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17package array
18
19import (
20	"math"
21
22	"github.com/apache/arrow/go/v6/arrow"
23	"github.com/apache/arrow/go/v6/arrow/float16"
24	"golang.org/x/xerrors"
25)
26
27// RecordEqual reports whether the two provided records are equal.
28func RecordEqual(left, right Record) bool {
29	switch {
30	case left.NumCols() != right.NumCols():
31		return false
32	case left.NumRows() != right.NumRows():
33		return false
34	}
35
36	for i := range left.Columns() {
37		lc := left.Column(i)
38		rc := right.Column(i)
39		if !ArrayEqual(lc, rc) {
40			return false
41		}
42	}
43	return true
44}
45
46// RecordApproxEqual reports whether the two provided records are approximately equal.
47// For non-floating point columns, it is equivalent to RecordEqual.
48func RecordApproxEqual(left, right Record, opts ...EqualOption) bool {
49	switch {
50	case left.NumCols() != right.NumCols():
51		return false
52	case left.NumRows() != right.NumRows():
53		return false
54	}
55
56	opt := newEqualOption(opts...)
57
58	for i := range left.Columns() {
59		lc := left.Column(i)
60		rc := right.Column(i)
61		if !arrayApproxEqual(lc, rc, opt) {
62			return false
63		}
64	}
65	return true
66}
67
68// helper function to evaluate a function on two chunked object having possibly different
69// chunk layouts. the function passed in will be called for each corresponding slice of the
70// two chunked arrays and if the function returns false it will end the loop early.
71func chunkedBinaryApply(left, right *Chunked, fn func(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool) {
72	var (
73		pos               int64
74		length            int64 = int64(left.length)
75		leftIdx, rightIdx int
76		leftPos, rightPos int64
77	)
78
79	for pos < length {
80		var cleft, cright Interface
81		for {
82			cleft, cright = left.Chunk(leftIdx), right.Chunk(rightIdx)
83			if leftPos == int64(cleft.Len()) {
84				leftPos = 0
85				leftIdx++
86				continue
87			}
88			if rightPos == int64(cright.Len()) {
89				rightPos = 0
90				rightIdx++
91				continue
92			}
93			break
94		}
95
96		sz := int64(min(cleft.Len()-int(leftPos), cright.Len()-int(rightPos)))
97		pos += sz
98		if !fn(cleft, leftPos, leftPos+sz, cright, rightPos, rightPos+sz) {
99			return
100		}
101
102		leftPos += sz
103		rightPos += sz
104	}
105}
106
107// ChunkedEqual reports whether two chunked arrays are equal regardless of their chunkings
108func ChunkedEqual(left, right *Chunked) bool {
109	switch {
110	case left == right:
111		return true
112	case left.length != right.length:
113		return false
114	case left.nulls != right.nulls:
115		return false
116	case !arrow.TypeEqual(left.dtype, right.dtype):
117		return false
118	}
119
120	var isequal bool
121	chunkedBinaryApply(left, right, func(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool {
122		isequal = ArraySliceEqual(left, lbeg, lend, right, rbeg, rend)
123		return isequal
124	})
125
126	return isequal
127}
128
129// ChunkedApproxEqual reports whether two chunked arrays are approximately equal regardless of their chunkings
130// for non-floating point arrays, this is equivalent to ChunkedEqual
131func ChunkedApproxEqual(left, right *Chunked, opts ...EqualOption) bool {
132	switch {
133	case left == right:
134		return true
135	case left.length != right.length:
136		return false
137	case left.nulls != right.nulls:
138		return false
139	case !arrow.TypeEqual(left.dtype, right.dtype):
140		return false
141	}
142
143	var isequal bool
144	chunkedBinaryApply(left, right, func(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool {
145		isequal = ArraySliceApproxEqual(left, lbeg, lend, right, rbeg, rend, opts...)
146		return isequal
147	})
148
149	return isequal
150}
151
152// TableEqual returns if the two tables have the same data in the same schema
153func TableEqual(left, right Table) bool {
154	switch {
155	case left.NumCols() != right.NumCols():
156		return false
157	case left.NumRows() != right.NumRows():
158		return false
159	}
160
161	for i := 0; int64(i) < left.NumCols(); i++ {
162		lc := left.Column(i)
163		rc := right.Column(i)
164		if !lc.field.Equal(rc.field) {
165			return false
166		}
167
168		if !ChunkedEqual(lc.data, rc.data) {
169			return false
170		}
171	}
172	return true
173}
174
175// TableEqual returns if the two tables have the approximately equal data in the same schema
176func TableApproxEqual(left, right Table, opts ...EqualOption) bool {
177	switch {
178	case left.NumCols() != right.NumCols():
179		return false
180	case left.NumRows() != right.NumRows():
181		return false
182	}
183
184	for i := 0; int64(i) < left.NumCols(); i++ {
185		lc := left.Column(i)
186		rc := right.Column(i)
187		if !lc.field.Equal(rc.field) {
188			return false
189		}
190
191		if !ChunkedApproxEqual(lc.data, rc.data, opts...) {
192			return false
193		}
194	}
195	return true
196}
197
198// ArrayEqual reports whether the two provided arrays are equal.
199func ArrayEqual(left, right Interface) bool {
200	switch {
201	case !baseArrayEqual(left, right):
202		return false
203	case left.Len() == 0:
204		return true
205	case left.NullN() == left.Len():
206		return true
207	}
208
209	// at this point, we know both arrays have same type, same length, same number of nulls
210	// and nulls at the same place.
211	// compare the values.
212
213	switch l := left.(type) {
214	case *Null:
215		return true
216	case *Boolean:
217		r := right.(*Boolean)
218		return arrayEqualBoolean(l, r)
219	case *FixedSizeBinary:
220		r := right.(*FixedSizeBinary)
221		return arrayEqualFixedSizeBinary(l, r)
222	case *Binary:
223		r := right.(*Binary)
224		return arrayEqualBinary(l, r)
225	case *String:
226		r := right.(*String)
227		return arrayEqualString(l, r)
228	case *Int8:
229		r := right.(*Int8)
230		return arrayEqualInt8(l, r)
231	case *Int16:
232		r := right.(*Int16)
233		return arrayEqualInt16(l, r)
234	case *Int32:
235		r := right.(*Int32)
236		return arrayEqualInt32(l, r)
237	case *Int64:
238		r := right.(*Int64)
239		return arrayEqualInt64(l, r)
240	case *Uint8:
241		r := right.(*Uint8)
242		return arrayEqualUint8(l, r)
243	case *Uint16:
244		r := right.(*Uint16)
245		return arrayEqualUint16(l, r)
246	case *Uint32:
247		r := right.(*Uint32)
248		return arrayEqualUint32(l, r)
249	case *Uint64:
250		r := right.(*Uint64)
251		return arrayEqualUint64(l, r)
252	case *Float16:
253		r := right.(*Float16)
254		return arrayEqualFloat16(l, r)
255	case *Float32:
256		r := right.(*Float32)
257		return arrayEqualFloat32(l, r)
258	case *Float64:
259		r := right.(*Float64)
260		return arrayEqualFloat64(l, r)
261	case *Decimal128:
262		r := right.(*Decimal128)
263		return arrayEqualDecimal128(l, r)
264	case *Date32:
265		r := right.(*Date32)
266		return arrayEqualDate32(l, r)
267	case *Date64:
268		r := right.(*Date64)
269		return arrayEqualDate64(l, r)
270	case *Time32:
271		r := right.(*Time32)
272		return arrayEqualTime32(l, r)
273	case *Time64:
274		r := right.(*Time64)
275		return arrayEqualTime64(l, r)
276	case *Timestamp:
277		r := right.(*Timestamp)
278		return arrayEqualTimestamp(l, r)
279	case *List:
280		r := right.(*List)
281		return arrayEqualList(l, r)
282	case *FixedSizeList:
283		r := right.(*FixedSizeList)
284		return arrayEqualFixedSizeList(l, r)
285	case *Struct:
286		r := right.(*Struct)
287		return arrayEqualStruct(l, r)
288	case *MonthInterval:
289		r := right.(*MonthInterval)
290		return arrayEqualMonthInterval(l, r)
291	case *DayTimeInterval:
292		r := right.(*DayTimeInterval)
293		return arrayEqualDayTimeInterval(l, r)
294	case *MonthDayNanoInterval:
295		r := right.(*MonthDayNanoInterval)
296		return arrayEqualMonthDayNanoInterval(l, r)
297	case *Duration:
298		r := right.(*Duration)
299		return arrayEqualDuration(l, r)
300	case *Map:
301		r := right.(*Map)
302		return arrayEqualMap(l, r)
303	case ExtensionArray:
304		r := right.(ExtensionArray)
305		return arrayEqualExtension(l, r)
306	default:
307		panic(xerrors.Errorf("arrow/array: unknown array type %T", l))
308	}
309}
310
311// ArraySliceEqual reports whether slices left[lbeg:lend] and right[rbeg:rend] are equal.
312func ArraySliceEqual(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool {
313	l := NewSlice(left, lbeg, lend)
314	defer l.Release()
315	r := NewSlice(right, rbeg, rend)
316	defer r.Release()
317
318	return ArrayEqual(l, r)
319}
320
321// ArraySliceApproxEqual reports whether slices left[lbeg:lend] and right[rbeg:rend] are approximately equal.
322func ArraySliceApproxEqual(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64, opts ...EqualOption) bool {
323	l := NewSlice(left, lbeg, lend)
324	defer l.Release()
325	r := NewSlice(right, rbeg, rend)
326	defer r.Release()
327
328	return ArrayApproxEqual(l, r, opts...)
329}
330
331const defaultAbsoluteTolerance = 1e-5
332
333type equalOption struct {
334	atol   float64 // absolute tolerance
335	nansEq bool    // whether NaNs are considered equal.
336}
337
338func (eq equalOption) f16(f1, f2 float16.Num) bool {
339	v1 := float64(f1.Float32())
340	v2 := float64(f2.Float32())
341	switch {
342	case eq.nansEq:
343		return math.Abs(v1-v2) <= eq.atol || (math.IsNaN(v1) && math.IsNaN(v2))
344	default:
345		return math.Abs(v1-v2) <= eq.atol
346	}
347}
348
349func (eq equalOption) f32(f1, f2 float32) bool {
350	v1 := float64(f1)
351	v2 := float64(f2)
352	switch {
353	case eq.nansEq:
354		return math.Abs(v1-v2) <= eq.atol || (math.IsNaN(v1) && math.IsNaN(v2))
355	default:
356		return math.Abs(v1-v2) <= eq.atol
357	}
358}
359
360func (eq equalOption) f64(v1, v2 float64) bool {
361	switch {
362	case eq.nansEq:
363		return math.Abs(v1-v2) <= eq.atol || (math.IsNaN(v1) && math.IsNaN(v2))
364	default:
365		return math.Abs(v1-v2) <= eq.atol
366	}
367}
368
369func newEqualOption(opts ...EqualOption) equalOption {
370	eq := equalOption{
371		atol:   defaultAbsoluteTolerance,
372		nansEq: false,
373	}
374	for _, opt := range opts {
375		opt(&eq)
376	}
377
378	return eq
379}
380
381// EqualOption is a functional option type used to configure how Records and Arrays are compared.
382type EqualOption func(*equalOption)
383
384// WithNaNsEqual configures the comparison functions so that NaNs are considered equal.
385func WithNaNsEqual(v bool) EqualOption {
386	return func(o *equalOption) {
387		o.nansEq = v
388	}
389}
390
391// WithAbsTolerance configures the comparison functions so that 2 floating point values
392// v1 and v2 are considered equal if |v1-v2| <= atol.
393func WithAbsTolerance(atol float64) EqualOption {
394	return func(o *equalOption) {
395		o.atol = atol
396	}
397}
398
399// ArrayApproxEqual reports whether the two provided arrays are approximately equal.
400// For non-floating point arrays, it is equivalent to ArrayEqual.
401func ArrayApproxEqual(left, right Interface, opts ...EqualOption) bool {
402	opt := newEqualOption(opts...)
403	return arrayApproxEqual(left, right, opt)
404}
405
406func arrayApproxEqual(left, right Interface, opt equalOption) bool {
407	switch {
408	case !baseArrayEqual(left, right):
409		return false
410	case left.Len() == 0:
411		return true
412	case left.NullN() == left.Len():
413		return true
414	}
415
416	// at this point, we know both arrays have same type, same length, same number of nulls
417	// and nulls at the same place.
418	// compare the values.
419
420	switch l := left.(type) {
421	case *Null:
422		return true
423	case *Boolean:
424		r := right.(*Boolean)
425		return arrayEqualBoolean(l, r)
426	case *FixedSizeBinary:
427		r := right.(*FixedSizeBinary)
428		return arrayEqualFixedSizeBinary(l, r)
429	case *Binary:
430		r := right.(*Binary)
431		return arrayEqualBinary(l, r)
432	case *String:
433		r := right.(*String)
434		return arrayEqualString(l, r)
435	case *Int8:
436		r := right.(*Int8)
437		return arrayEqualInt8(l, r)
438	case *Int16:
439		r := right.(*Int16)
440		return arrayEqualInt16(l, r)
441	case *Int32:
442		r := right.(*Int32)
443		return arrayEqualInt32(l, r)
444	case *Int64:
445		r := right.(*Int64)
446		return arrayEqualInt64(l, r)
447	case *Uint8:
448		r := right.(*Uint8)
449		return arrayEqualUint8(l, r)
450	case *Uint16:
451		r := right.(*Uint16)
452		return arrayEqualUint16(l, r)
453	case *Uint32:
454		r := right.(*Uint32)
455		return arrayEqualUint32(l, r)
456	case *Uint64:
457		r := right.(*Uint64)
458		return arrayEqualUint64(l, r)
459	case *Float16:
460		r := right.(*Float16)
461		return arrayApproxEqualFloat16(l, r, opt)
462	case *Float32:
463		r := right.(*Float32)
464		return arrayApproxEqualFloat32(l, r, opt)
465	case *Float64:
466		r := right.(*Float64)
467		return arrayApproxEqualFloat64(l, r, opt)
468	case *Decimal128:
469		r := right.(*Decimal128)
470		return arrayEqualDecimal128(l, r)
471	case *Date32:
472		r := right.(*Date32)
473		return arrayEqualDate32(l, r)
474	case *Date64:
475		r := right.(*Date64)
476		return arrayEqualDate64(l, r)
477	case *Time32:
478		r := right.(*Time32)
479		return arrayEqualTime32(l, r)
480	case *Time64:
481		r := right.(*Time64)
482		return arrayEqualTime64(l, r)
483	case *Timestamp:
484		r := right.(*Timestamp)
485		return arrayEqualTimestamp(l, r)
486	case *List:
487		r := right.(*List)
488		return arrayApproxEqualList(l, r, opt)
489	case *FixedSizeList:
490		r := right.(*FixedSizeList)
491		return arrayApproxEqualFixedSizeList(l, r, opt)
492	case *Struct:
493		r := right.(*Struct)
494		return arrayApproxEqualStruct(l, r, opt)
495	case *MonthInterval:
496		r := right.(*MonthInterval)
497		return arrayEqualMonthInterval(l, r)
498	case *DayTimeInterval:
499		r := right.(*DayTimeInterval)
500		return arrayEqualDayTimeInterval(l, r)
501	case *MonthDayNanoInterval:
502		r := right.(*MonthDayNanoInterval)
503		return arrayEqualMonthDayNanoInterval(l, r)
504	case *Duration:
505		r := right.(*Duration)
506		return arrayEqualDuration(l, r)
507	case *Map:
508		r := right.(*Map)
509		return arrayApproxEqualList(l.List, r.List, opt)
510	case ExtensionArray:
511		r := right.(ExtensionArray)
512		return arrayApproxEqualExtension(l, r, opt)
513	default:
514		panic(xerrors.Errorf("arrow/array: unknown array type %T", l))
515	}
516
517	return false
518}
519
520func baseArrayEqual(left, right Interface) bool {
521	switch {
522	case left.Len() != right.Len():
523		return false
524	case left.NullN() != right.NullN():
525		return false
526	case !arrow.TypeEqual(left.DataType(), right.DataType()): // We do not check for metadata as in the C++ implementation.
527		return false
528	case !validityBitmapEqual(left, right):
529		return false
530	}
531	return true
532}
533
534func validityBitmapEqual(left, right Interface) bool {
535	// TODO(alexandreyc): make it faster by comparing byte slices of the validity bitmap?
536	n := left.Len()
537	if n != right.Len() {
538		return false
539	}
540	for i := 0; i < n; i++ {
541		if left.IsNull(i) != right.IsNull(i) {
542			return false
543		}
544	}
545	return true
546}
547
548func arrayApproxEqualFloat16(left, right *Float16, opt equalOption) bool {
549	for i := 0; i < left.Len(); i++ {
550		if left.IsNull(i) {
551			continue
552		}
553		if !opt.f16(left.Value(i), right.Value(i)) {
554			return false
555		}
556	}
557	return true
558}
559
560func arrayApproxEqualFloat32(left, right *Float32, opt equalOption) bool {
561	for i := 0; i < left.Len(); i++ {
562		if left.IsNull(i) {
563			continue
564		}
565		if !opt.f32(left.Value(i), right.Value(i)) {
566			return false
567		}
568	}
569	return true
570}
571
572func arrayApproxEqualFloat64(left, right *Float64, opt equalOption) bool {
573	for i := 0; i < left.Len(); i++ {
574		if left.IsNull(i) {
575			continue
576		}
577		if !opt.f64(left.Value(i), right.Value(i)) {
578			return false
579		}
580	}
581	return true
582}
583
584func arrayApproxEqualList(left, right *List, opt equalOption) bool {
585	for i := 0; i < left.Len(); i++ {
586		if left.IsNull(i) {
587			continue
588		}
589		o := func() bool {
590			l := left.newListValue(i)
591			defer l.Release()
592			r := right.newListValue(i)
593			defer r.Release()
594			return arrayApproxEqual(l, r, opt)
595		}()
596		if !o {
597			return false
598		}
599	}
600	return true
601}
602
603func arrayApproxEqualFixedSizeList(left, right *FixedSizeList, opt equalOption) bool {
604	for i := 0; i < left.Len(); i++ {
605		if left.IsNull(i) {
606			continue
607		}
608		o := func() bool {
609			l := left.newListValue(i)
610			defer l.Release()
611			r := right.newListValue(i)
612			defer r.Release()
613			return arrayApproxEqual(l, r, opt)
614		}()
615		if !o {
616			return false
617		}
618	}
619	return true
620}
621
622func arrayApproxEqualStruct(left, right *Struct, opt equalOption) bool {
623	for i, lf := range left.fields {
624		rf := right.fields[i]
625		if !arrayApproxEqual(lf, rf, opt) {
626			return false
627		}
628	}
629	return true
630}
631