1 //============================================================================
2 //  Copyright (c) Kitware, Inc.
3 //  All rights reserved.
4 //  See LICENSE.txt for details.
5 //
6 //  This software is distributed WITHOUT ANY WARRANTY; without even
7 //  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
8 //  PURPOSE.  See the above copyright notice for more information.
9 //============================================================================
10 
11 #include <vtkm/worklet/DispatcherMapField.h>
12 #include <vtkm/worklet/DotProduct.h>
13 
14 #include <vtkm/cont/testing/Testing.h>
15 
16 namespace
17 {
18 
19 template <typename T>
normalizedVector(T v)20 T normalizedVector(T v)
21 {
22   T vN = vtkm::Normal(v);
23   return vN;
24 }
25 
26 template <typename T>
createVectors(std::vector<vtkm::Vec<T,3>> & vecs1,std::vector<vtkm::Vec<T,3>> & vecs2,std::vector<T> & result)27 void createVectors(std::vector<vtkm::Vec<T, 3>>& vecs1,
28                    std::vector<vtkm::Vec<T, 3>>& vecs2,
29                    std::vector<T>& result)
30 {
31   vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
32   vecs2.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
33   result.push_back(1);
34 
35   vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
36   vecs2.push_back(normalizedVector(vtkm::make_Vec(T(-1), T(0), T(0))));
37   result.push_back(-1);
38 
39   vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
40   vecs2.push_back(normalizedVector(vtkm::make_Vec(T(0), T(1), T(0))));
41   result.push_back(0);
42 
43   vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
44   vecs2.push_back(normalizedVector(vtkm::make_Vec(T(0), T(-1), T(0))));
45   result.push_back(0);
46 
47   vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
48   vecs2.push_back(normalizedVector(vtkm::make_Vec(T(1), T(1), T(0))));
49   result.push_back(T(1.0 / vtkm::Sqrt(2.0)));
50 
51   vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(1), T(0))));
52   vecs2.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
53   result.push_back(T(1.0 / vtkm::Sqrt(2.0)));
54 
55   vecs1.push_back(normalizedVector(vtkm::make_Vec(T(-1), T(0), T(0))));
56   vecs2.push_back(normalizedVector(vtkm::make_Vec(T(1), T(1), T(0))));
57   result.push_back(-T(1.0 / vtkm::Sqrt(2.0)));
58 
59   vecs1.push_back(normalizedVector(vtkm::make_Vec(T(0), T(1), T(0))));
60   vecs2.push_back(normalizedVector(vtkm::make_Vec(T(1), T(1), T(0))));
61   result.push_back(T(1.0 / vtkm::Sqrt(2.0)));
62 }
63 
64 template <typename T>
TestDotProduct()65 void TestDotProduct()
66 {
67   std::vector<vtkm::Vec<T, 3>> inputVecs1, inputVecs2;
68   std::vector<T> answer;
69   createVectors(inputVecs1, inputVecs2, answer);
70 
71   vtkm::cont::ArrayHandle<vtkm::Vec<T, 3>> inputArray1, inputArray2;
72   vtkm::cont::ArrayHandle<T> outputArray;
73   inputArray1 = vtkm::cont::make_ArrayHandle(inputVecs1, vtkm::CopyFlag::Off);
74   inputArray2 = vtkm::cont::make_ArrayHandle(inputVecs2, vtkm::CopyFlag::Off);
75 
76   vtkm::worklet::DotProduct dotProductWorklet;
77   vtkm::worklet::DispatcherMapField<vtkm::worklet::DotProduct> dispatcherDotProduct(
78     dotProductWorklet);
79   dispatcherDotProduct.Invoke(inputArray1, inputArray2, outputArray);
80 
81   VTKM_TEST_ASSERT(outputArray.GetNumberOfValues() == inputArray1.GetNumberOfValues(),
82                    "Wrong number of results for DotProduct worklet");
83 
84   for (vtkm::Id i = 0; i < inputArray1.GetNumberOfValues(); i++)
85   {
86     vtkm::Vec<T, 3> v1 = inputArray1.ReadPortal().Get(i);
87     vtkm::Vec<T, 3> v2 = inputArray2.ReadPortal().Get(i);
88     T ans = answer[static_cast<std::size_t>(i)];
89 
90     VTKM_TEST_ASSERT(test_equal(ans, vtkm::Dot(v1, v2)), "Wrong result for dot product");
91   }
92 }
93 
TestDotProductWorklets()94 void TestDotProductWorklets()
95 {
96   std::cout << "Testing DotProduct Worklet" << std::endl;
97   TestDotProduct<vtkm::Float32>();
98   //  TestDotProduct<vtkm::Float64>();
99 }
100 }
101 
UnitTestDotProduct(int argc,char * argv[])102 int UnitTestDotProduct(int argc, char* argv[])
103 {
104   return vtkm::cont::testing::Testing::Run(TestDotProductWorklets, argc, argv);
105 }
106