1 //============================================================================
2 // Copyright (c) Kitware, Inc.
3 // All rights reserved.
4 // See LICENSE.txt for details.
5 //
6 // This software is distributed WITHOUT ANY WARRANTY; without even
7 // the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
8 // PURPOSE. See the above copyright notice for more information.
9 //============================================================================
10
11 #include <vtkm/cont/ArrayCopy.h>
12 #include <vtkm/cont/testing/Testing.h>
13 #include <vtkm/worklet/connectivities/InnerJoin.h>
14
15
16 class TestInnerJoin
17 {
18 public:
TestJoinedValues(const vtkm::cont::ArrayHandle<vtkm::Id> & computedValuesArray,const vtkm::cont::ArrayHandle<vtkm::Id> & expectedValuesArray,const vtkm::cont::ArrayHandle<vtkm::Id> & originalKeysArray)19 static bool TestJoinedValues(const vtkm::cont::ArrayHandle<vtkm::Id>& computedValuesArray,
20 const vtkm::cont::ArrayHandle<vtkm::Id>& expectedValuesArray,
21 const vtkm::cont::ArrayHandle<vtkm::Id>& originalKeysArray)
22 {
23 auto computedValues = computedValuesArray.ReadPortal();
24 auto expectedValues = expectedValuesArray.ReadPortal();
25 auto originalKeys = originalKeysArray.ReadPortal();
26 if (computedValues.GetNumberOfValues() != expectedValues.GetNumberOfValues())
27 {
28 return false;
29 }
30
31 for (vtkm::Id valueIndex = 0; valueIndex < computedValues.GetNumberOfValues(); ++valueIndex)
32 {
33 vtkm::Id computed = computedValues.Get(valueIndex);
34 vtkm::Id expected = expectedValues.Get(valueIndex);
35
36 // The join algorithm uses some key/value sorts that are unstable. Thus, for keys
37 // that are repeated in the original input, the computed and expected values may be
38 // swapped in the results associated with those keys. To test correctly, the values
39 // we computed for are actually indices into the original keys array. Thus, if both
40 // computed and expected are different indices that point to the same original key,
41 // then the algorithm is still correct.
42 vtkm::Id computedKey = originalKeys.Get(computed);
43 vtkm::Id expectedKey = originalKeys.Get(expected);
44 if (computedKey != expectedKey)
45 {
46 return false;
47 }
48 }
49
50 return true;
51 }
52
TestTwoArrays() const53 void TestTwoArrays() const
54 {
55 vtkm::cont::ArrayHandle<vtkm::Id> keysAOriginal =
56 vtkm::cont::make_ArrayHandle<vtkm::Id>({ 8, 3, 6, 8, 9, 5, 12, 10, 14 });
57 vtkm::cont::ArrayHandle<vtkm::Id> keysBOriginal =
58 vtkm::cont::make_ArrayHandle<vtkm::Id>({ 7, 11, 9, 8, 5, 1, 0, 5 });
59
60 vtkm::cont::ArrayHandle<vtkm::Id> keysA;
61 vtkm::cont::ArrayHandle<vtkm::Id> keysB;
62 vtkm::cont::ArrayHandle<vtkm::Id> valuesA;
63 vtkm::cont::ArrayHandle<vtkm::Id> valuesB;
64
65 vtkm::cont::ArrayCopy(keysAOriginal, keysA);
66 vtkm::cont::ArrayCopy(keysBOriginal, keysB);
67 vtkm::cont::ArrayCopy(vtkm::cont::ArrayHandleIndex(keysA.GetNumberOfValues()), valuesA);
68 vtkm::cont::ArrayCopy(vtkm::cont::ArrayHandleIndex(keysB.GetNumberOfValues()), valuesB);
69
70 vtkm::cont::ArrayHandle<vtkm::Id> joinedIndex;
71 vtkm::cont::ArrayHandle<vtkm::Id> outA;
72 vtkm::cont::ArrayHandle<vtkm::Id> outB;
73
74 vtkm::worklet::connectivity::InnerJoin().Run(
75 keysA, valuesA, keysB, valuesB, joinedIndex, outA, outB);
76
77 vtkm::cont::ArrayHandle<vtkm::Id> expectedIndex =
78 vtkm::cont::make_ArrayHandle<vtkm::Id>({ 5, 5, 8, 8, 9 });
79 VTKM_TEST_ASSERT(test_equal_portals(joinedIndex.ReadPortal(), expectedIndex.ReadPortal()));
80
81 vtkm::cont::ArrayHandle<vtkm::Id> expectedOutA =
82 vtkm::cont::make_ArrayHandle<vtkm::Id>({ 5, 5, 0, 3, 4 });
83 VTKM_TEST_ASSERT(TestJoinedValues(outA, expectedOutA, keysAOriginal));
84
85 vtkm::cont::ArrayHandle<vtkm::Id> expectedOutB =
86 vtkm::cont::make_ArrayHandle<vtkm::Id>({ 4, 7, 3, 3, 2 });
87 VTKM_TEST_ASSERT(TestJoinedValues(outB, expectedOutB, keysBOriginal));
88 }
89
operator ()() const90 void operator()() const { this->TestTwoArrays(); }
91 };
92
UnitTestInnerJoin(int argc,char * argv[])93 int UnitTestInnerJoin(int argc, char* argv[])
94 {
95 return vtkm::cont::testing::Testing::Run(TestInnerJoin(), argc, argv);
96 }
97