1// Licensed to the Apache Software Foundation (ASF) under one 2// or more contributor license agreements. See the NOTICE file 3// distributed with this work for additional information 4// regarding copyright ownership. The ASF licenses this file 5// to you under the Apache License, Version 2.0 (the 6// "License"); you may not use this file except in compliance 7// with the License. You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, software 12// distributed under the License is distributed on an "AS IS" BASIS, 13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14// See the License for the specific language governing permissions and 15// limitations under the License. 16 17package array 18 19import ( 20 "math" 21 22 "github.com/apache/arrow/go/v6/arrow" 23 "github.com/apache/arrow/go/v6/arrow/float16" 24 "golang.org/x/xerrors" 25) 26 27// RecordEqual reports whether the two provided records are equal. 28func RecordEqual(left, right Record) bool { 29 switch { 30 case left.NumCols() != right.NumCols(): 31 return false 32 case left.NumRows() != right.NumRows(): 33 return false 34 } 35 36 for i := range left.Columns() { 37 lc := left.Column(i) 38 rc := right.Column(i) 39 if !ArrayEqual(lc, rc) { 40 return false 41 } 42 } 43 return true 44} 45 46// RecordApproxEqual reports whether the two provided records are approximately equal. 47// For non-floating point columns, it is equivalent to RecordEqual. 48func RecordApproxEqual(left, right Record, opts ...EqualOption) bool { 49 switch { 50 case left.NumCols() != right.NumCols(): 51 return false 52 case left.NumRows() != right.NumRows(): 53 return false 54 } 55 56 opt := newEqualOption(opts...) 57 58 for i := range left.Columns() { 59 lc := left.Column(i) 60 rc := right.Column(i) 61 if !arrayApproxEqual(lc, rc, opt) { 62 return false 63 } 64 } 65 return true 66} 67 68// helper function to evaluate a function on two chunked object having possibly different 69// chunk layouts. the function passed in will be called for each corresponding slice of the 70// two chunked arrays and if the function returns false it will end the loop early. 71func chunkedBinaryApply(left, right *Chunked, fn func(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool) { 72 var ( 73 pos int64 74 length int64 = int64(left.length) 75 leftIdx, rightIdx int 76 leftPos, rightPos int64 77 ) 78 79 for pos < length { 80 var cleft, cright Interface 81 for { 82 cleft, cright = left.Chunk(leftIdx), right.Chunk(rightIdx) 83 if leftPos == int64(cleft.Len()) { 84 leftPos = 0 85 leftIdx++ 86 continue 87 } 88 if rightPos == int64(cright.Len()) { 89 rightPos = 0 90 rightIdx++ 91 continue 92 } 93 break 94 } 95 96 sz := int64(min(cleft.Len()-int(leftPos), cright.Len()-int(rightPos))) 97 pos += sz 98 if !fn(cleft, leftPos, leftPos+sz, cright, rightPos, rightPos+sz) { 99 return 100 } 101 102 leftPos += sz 103 rightPos += sz 104 } 105} 106 107// ChunkedEqual reports whether two chunked arrays are equal regardless of their chunkings 108func ChunkedEqual(left, right *Chunked) bool { 109 switch { 110 case left == right: 111 return true 112 case left.length != right.length: 113 return false 114 case left.nulls != right.nulls: 115 return false 116 case !arrow.TypeEqual(left.dtype, right.dtype): 117 return false 118 } 119 120 var isequal bool 121 chunkedBinaryApply(left, right, func(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool { 122 isequal = ArraySliceEqual(left, lbeg, lend, right, rbeg, rend) 123 return isequal 124 }) 125 126 return isequal 127} 128 129// ChunkedApproxEqual reports whether two chunked arrays are approximately equal regardless of their chunkings 130// for non-floating point arrays, this is equivalent to ChunkedEqual 131func ChunkedApproxEqual(left, right *Chunked, opts ...EqualOption) bool { 132 switch { 133 case left == right: 134 return true 135 case left.length != right.length: 136 return false 137 case left.nulls != right.nulls: 138 return false 139 case !arrow.TypeEqual(left.dtype, right.dtype): 140 return false 141 } 142 143 var isequal bool 144 chunkedBinaryApply(left, right, func(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool { 145 isequal = ArraySliceApproxEqual(left, lbeg, lend, right, rbeg, rend, opts...) 146 return isequal 147 }) 148 149 return isequal 150} 151 152// TableEqual returns if the two tables have the same data in the same schema 153func TableEqual(left, right Table) bool { 154 switch { 155 case left.NumCols() != right.NumCols(): 156 return false 157 case left.NumRows() != right.NumRows(): 158 return false 159 } 160 161 for i := 0; int64(i) < left.NumCols(); i++ { 162 lc := left.Column(i) 163 rc := right.Column(i) 164 if !lc.field.Equal(rc.field) { 165 return false 166 } 167 168 if !ChunkedEqual(lc.data, rc.data) { 169 return false 170 } 171 } 172 return true 173} 174 175// TableEqual returns if the two tables have the approximately equal data in the same schema 176func TableApproxEqual(left, right Table, opts ...EqualOption) bool { 177 switch { 178 case left.NumCols() != right.NumCols(): 179 return false 180 case left.NumRows() != right.NumRows(): 181 return false 182 } 183 184 for i := 0; int64(i) < left.NumCols(); i++ { 185 lc := left.Column(i) 186 rc := right.Column(i) 187 if !lc.field.Equal(rc.field) { 188 return false 189 } 190 191 if !ChunkedApproxEqual(lc.data, rc.data, opts...) { 192 return false 193 } 194 } 195 return true 196} 197 198// ArrayEqual reports whether the two provided arrays are equal. 199func ArrayEqual(left, right Interface) bool { 200 switch { 201 case !baseArrayEqual(left, right): 202 return false 203 case left.Len() == 0: 204 return true 205 case left.NullN() == left.Len(): 206 return true 207 } 208 209 // at this point, we know both arrays have same type, same length, same number of nulls 210 // and nulls at the same place. 211 // compare the values. 212 213 switch l := left.(type) { 214 case *Null: 215 return true 216 case *Boolean: 217 r := right.(*Boolean) 218 return arrayEqualBoolean(l, r) 219 case *FixedSizeBinary: 220 r := right.(*FixedSizeBinary) 221 return arrayEqualFixedSizeBinary(l, r) 222 case *Binary: 223 r := right.(*Binary) 224 return arrayEqualBinary(l, r) 225 case *String: 226 r := right.(*String) 227 return arrayEqualString(l, r) 228 case *Int8: 229 r := right.(*Int8) 230 return arrayEqualInt8(l, r) 231 case *Int16: 232 r := right.(*Int16) 233 return arrayEqualInt16(l, r) 234 case *Int32: 235 r := right.(*Int32) 236 return arrayEqualInt32(l, r) 237 case *Int64: 238 r := right.(*Int64) 239 return arrayEqualInt64(l, r) 240 case *Uint8: 241 r := right.(*Uint8) 242 return arrayEqualUint8(l, r) 243 case *Uint16: 244 r := right.(*Uint16) 245 return arrayEqualUint16(l, r) 246 case *Uint32: 247 r := right.(*Uint32) 248 return arrayEqualUint32(l, r) 249 case *Uint64: 250 r := right.(*Uint64) 251 return arrayEqualUint64(l, r) 252 case *Float16: 253 r := right.(*Float16) 254 return arrayEqualFloat16(l, r) 255 case *Float32: 256 r := right.(*Float32) 257 return arrayEqualFloat32(l, r) 258 case *Float64: 259 r := right.(*Float64) 260 return arrayEqualFloat64(l, r) 261 case *Decimal128: 262 r := right.(*Decimal128) 263 return arrayEqualDecimal128(l, r) 264 case *Date32: 265 r := right.(*Date32) 266 return arrayEqualDate32(l, r) 267 case *Date64: 268 r := right.(*Date64) 269 return arrayEqualDate64(l, r) 270 case *Time32: 271 r := right.(*Time32) 272 return arrayEqualTime32(l, r) 273 case *Time64: 274 r := right.(*Time64) 275 return arrayEqualTime64(l, r) 276 case *Timestamp: 277 r := right.(*Timestamp) 278 return arrayEqualTimestamp(l, r) 279 case *List: 280 r := right.(*List) 281 return arrayEqualList(l, r) 282 case *FixedSizeList: 283 r := right.(*FixedSizeList) 284 return arrayEqualFixedSizeList(l, r) 285 case *Struct: 286 r := right.(*Struct) 287 return arrayEqualStruct(l, r) 288 case *MonthInterval: 289 r := right.(*MonthInterval) 290 return arrayEqualMonthInterval(l, r) 291 case *DayTimeInterval: 292 r := right.(*DayTimeInterval) 293 return arrayEqualDayTimeInterval(l, r) 294 case *MonthDayNanoInterval: 295 r := right.(*MonthDayNanoInterval) 296 return arrayEqualMonthDayNanoInterval(l, r) 297 case *Duration: 298 r := right.(*Duration) 299 return arrayEqualDuration(l, r) 300 case *Map: 301 r := right.(*Map) 302 return arrayEqualMap(l, r) 303 case ExtensionArray: 304 r := right.(ExtensionArray) 305 return arrayEqualExtension(l, r) 306 default: 307 panic(xerrors.Errorf("arrow/array: unknown array type %T", l)) 308 } 309} 310 311// ArraySliceEqual reports whether slices left[lbeg:lend] and right[rbeg:rend] are equal. 312func ArraySliceEqual(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64) bool { 313 l := NewSlice(left, lbeg, lend) 314 defer l.Release() 315 r := NewSlice(right, rbeg, rend) 316 defer r.Release() 317 318 return ArrayEqual(l, r) 319} 320 321// ArraySliceApproxEqual reports whether slices left[lbeg:lend] and right[rbeg:rend] are approximately equal. 322func ArraySliceApproxEqual(left Interface, lbeg, lend int64, right Interface, rbeg, rend int64, opts ...EqualOption) bool { 323 l := NewSlice(left, lbeg, lend) 324 defer l.Release() 325 r := NewSlice(right, rbeg, rend) 326 defer r.Release() 327 328 return ArrayApproxEqual(l, r, opts...) 329} 330 331const defaultAbsoluteTolerance = 1e-5 332 333type equalOption struct { 334 atol float64 // absolute tolerance 335 nansEq bool // whether NaNs are considered equal. 336} 337 338func (eq equalOption) f16(f1, f2 float16.Num) bool { 339 v1 := float64(f1.Float32()) 340 v2 := float64(f2.Float32()) 341 switch { 342 case eq.nansEq: 343 return math.Abs(v1-v2) <= eq.atol || (math.IsNaN(v1) && math.IsNaN(v2)) 344 default: 345 return math.Abs(v1-v2) <= eq.atol 346 } 347} 348 349func (eq equalOption) f32(f1, f2 float32) bool { 350 v1 := float64(f1) 351 v2 := float64(f2) 352 switch { 353 case eq.nansEq: 354 return math.Abs(v1-v2) <= eq.atol || (math.IsNaN(v1) && math.IsNaN(v2)) 355 default: 356 return math.Abs(v1-v2) <= eq.atol 357 } 358} 359 360func (eq equalOption) f64(v1, v2 float64) bool { 361 switch { 362 case eq.nansEq: 363 return math.Abs(v1-v2) <= eq.atol || (math.IsNaN(v1) && math.IsNaN(v2)) 364 default: 365 return math.Abs(v1-v2) <= eq.atol 366 } 367} 368 369func newEqualOption(opts ...EqualOption) equalOption { 370 eq := equalOption{ 371 atol: defaultAbsoluteTolerance, 372 nansEq: false, 373 } 374 for _, opt := range opts { 375 opt(&eq) 376 } 377 378 return eq 379} 380 381// EqualOption is a functional option type used to configure how Records and Arrays are compared. 382type EqualOption func(*equalOption) 383 384// WithNaNsEqual configures the comparison functions so that NaNs are considered equal. 385func WithNaNsEqual(v bool) EqualOption { 386 return func(o *equalOption) { 387 o.nansEq = v 388 } 389} 390 391// WithAbsTolerance configures the comparison functions so that 2 floating point values 392// v1 and v2 are considered equal if |v1-v2| <= atol. 393func WithAbsTolerance(atol float64) EqualOption { 394 return func(o *equalOption) { 395 o.atol = atol 396 } 397} 398 399// ArrayApproxEqual reports whether the two provided arrays are approximately equal. 400// For non-floating point arrays, it is equivalent to ArrayEqual. 401func ArrayApproxEqual(left, right Interface, opts ...EqualOption) bool { 402 opt := newEqualOption(opts...) 403 return arrayApproxEqual(left, right, opt) 404} 405 406func arrayApproxEqual(left, right Interface, opt equalOption) bool { 407 switch { 408 case !baseArrayEqual(left, right): 409 return false 410 case left.Len() == 0: 411 return true 412 case left.NullN() == left.Len(): 413 return true 414 } 415 416 // at this point, we know both arrays have same type, same length, same number of nulls 417 // and nulls at the same place. 418 // compare the values. 419 420 switch l := left.(type) { 421 case *Null: 422 return true 423 case *Boolean: 424 r := right.(*Boolean) 425 return arrayEqualBoolean(l, r) 426 case *FixedSizeBinary: 427 r := right.(*FixedSizeBinary) 428 return arrayEqualFixedSizeBinary(l, r) 429 case *Binary: 430 r := right.(*Binary) 431 return arrayEqualBinary(l, r) 432 case *String: 433 r := right.(*String) 434 return arrayEqualString(l, r) 435 case *Int8: 436 r := right.(*Int8) 437 return arrayEqualInt8(l, r) 438 case *Int16: 439 r := right.(*Int16) 440 return arrayEqualInt16(l, r) 441 case *Int32: 442 r := right.(*Int32) 443 return arrayEqualInt32(l, r) 444 case *Int64: 445 r := right.(*Int64) 446 return arrayEqualInt64(l, r) 447 case *Uint8: 448 r := right.(*Uint8) 449 return arrayEqualUint8(l, r) 450 case *Uint16: 451 r := right.(*Uint16) 452 return arrayEqualUint16(l, r) 453 case *Uint32: 454 r := right.(*Uint32) 455 return arrayEqualUint32(l, r) 456 case *Uint64: 457 r := right.(*Uint64) 458 return arrayEqualUint64(l, r) 459 case *Float16: 460 r := right.(*Float16) 461 return arrayApproxEqualFloat16(l, r, opt) 462 case *Float32: 463 r := right.(*Float32) 464 return arrayApproxEqualFloat32(l, r, opt) 465 case *Float64: 466 r := right.(*Float64) 467 return arrayApproxEqualFloat64(l, r, opt) 468 case *Decimal128: 469 r := right.(*Decimal128) 470 return arrayEqualDecimal128(l, r) 471 case *Date32: 472 r := right.(*Date32) 473 return arrayEqualDate32(l, r) 474 case *Date64: 475 r := right.(*Date64) 476 return arrayEqualDate64(l, r) 477 case *Time32: 478 r := right.(*Time32) 479 return arrayEqualTime32(l, r) 480 case *Time64: 481 r := right.(*Time64) 482 return arrayEqualTime64(l, r) 483 case *Timestamp: 484 r := right.(*Timestamp) 485 return arrayEqualTimestamp(l, r) 486 case *List: 487 r := right.(*List) 488 return arrayApproxEqualList(l, r, opt) 489 case *FixedSizeList: 490 r := right.(*FixedSizeList) 491 return arrayApproxEqualFixedSizeList(l, r, opt) 492 case *Struct: 493 r := right.(*Struct) 494 return arrayApproxEqualStruct(l, r, opt) 495 case *MonthInterval: 496 r := right.(*MonthInterval) 497 return arrayEqualMonthInterval(l, r) 498 case *DayTimeInterval: 499 r := right.(*DayTimeInterval) 500 return arrayEqualDayTimeInterval(l, r) 501 case *MonthDayNanoInterval: 502 r := right.(*MonthDayNanoInterval) 503 return arrayEqualMonthDayNanoInterval(l, r) 504 case *Duration: 505 r := right.(*Duration) 506 return arrayEqualDuration(l, r) 507 case *Map: 508 r := right.(*Map) 509 return arrayApproxEqualList(l.List, r.List, opt) 510 case ExtensionArray: 511 r := right.(ExtensionArray) 512 return arrayApproxEqualExtension(l, r, opt) 513 default: 514 panic(xerrors.Errorf("arrow/array: unknown array type %T", l)) 515 } 516 517 return false 518} 519 520func baseArrayEqual(left, right Interface) bool { 521 switch { 522 case left.Len() != right.Len(): 523 return false 524 case left.NullN() != right.NullN(): 525 return false 526 case !arrow.TypeEqual(left.DataType(), right.DataType()): // We do not check for metadata as in the C++ implementation. 527 return false 528 case !validityBitmapEqual(left, right): 529 return false 530 } 531 return true 532} 533 534func validityBitmapEqual(left, right Interface) bool { 535 // TODO(alexandreyc): make it faster by comparing byte slices of the validity bitmap? 536 n := left.Len() 537 if n != right.Len() { 538 return false 539 } 540 for i := 0; i < n; i++ { 541 if left.IsNull(i) != right.IsNull(i) { 542 return false 543 } 544 } 545 return true 546} 547 548func arrayApproxEqualFloat16(left, right *Float16, opt equalOption) bool { 549 for i := 0; i < left.Len(); i++ { 550 if left.IsNull(i) { 551 continue 552 } 553 if !opt.f16(left.Value(i), right.Value(i)) { 554 return false 555 } 556 } 557 return true 558} 559 560func arrayApproxEqualFloat32(left, right *Float32, opt equalOption) bool { 561 for i := 0; i < left.Len(); i++ { 562 if left.IsNull(i) { 563 continue 564 } 565 if !opt.f32(left.Value(i), right.Value(i)) { 566 return false 567 } 568 } 569 return true 570} 571 572func arrayApproxEqualFloat64(left, right *Float64, opt equalOption) bool { 573 for i := 0; i < left.Len(); i++ { 574 if left.IsNull(i) { 575 continue 576 } 577 if !opt.f64(left.Value(i), right.Value(i)) { 578 return false 579 } 580 } 581 return true 582} 583 584func arrayApproxEqualList(left, right *List, opt equalOption) bool { 585 for i := 0; i < left.Len(); i++ { 586 if left.IsNull(i) { 587 continue 588 } 589 o := func() bool { 590 l := left.newListValue(i) 591 defer l.Release() 592 r := right.newListValue(i) 593 defer r.Release() 594 return arrayApproxEqual(l, r, opt) 595 }() 596 if !o { 597 return false 598 } 599 } 600 return true 601} 602 603func arrayApproxEqualFixedSizeList(left, right *FixedSizeList, opt equalOption) bool { 604 for i := 0; i < left.Len(); i++ { 605 if left.IsNull(i) { 606 continue 607 } 608 o := func() bool { 609 l := left.newListValue(i) 610 defer l.Release() 611 r := right.newListValue(i) 612 defer r.Release() 613 return arrayApproxEqual(l, r, opt) 614 }() 615 if !o { 616 return false 617 } 618 } 619 return true 620} 621 622func arrayApproxEqualStruct(left, right *Struct, opt equalOption) bool { 623 for i, lf := range left.fields { 624 rf := right.fields[i] 625 if !arrayApproxEqual(lf, rf, opt) { 626 return false 627 } 628 } 629 return true 630} 631