1 //============================================================================
2 //  Copyright (c) Kitware, Inc.
3 //  All rights reserved.
4 //  See LICENSE.txt for details.
5 //  This software is distributed WITHOUT ANY WARRANTY; without even
6 //  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
7 //  PURPOSE.  See the above copyright notice for more information.
8 //
9 //  Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
10 //  Copyright 2014 UT-Battelle, LLC.
11 //  Copyright 2014 Los Alamos National Security.
12 //
13 //  Under the terms of Contract DE-NA0003525 with NTESS,
14 //  the U.S. Government retains certain rights in this software.
15 //
16 //  Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
17 //  Laboratory (LANL), the U.S. Government retains certain rights in
18 //  this software.
19 //============================================================================
20 #include <vtkm/cont/ArrayCopy.h>
21 #include <vtkm/cont/ArrayHandle.h>
22 #include <vtkm/cont/DynamicArrayHandle.h>
23 
24 #include <vtkm/worklet/DispatcherMapField.h>
25 #include <vtkm/worklet/WorkletMapField.h>
26 
27 #include <vtkm/cont/testing/Testing.h>
28 
29 class TestMapFieldWorklet : public vtkm::worklet::WorkletMapField
30 {
31 public:
32   using ControlSignature = void(FieldIn<>, FieldOut<>, FieldInOut<>);
33   using ExecutionSignature = void(_1, _2, _3, WorkIndex);
34 
35   template <typename T>
operator ()(const T & in,T & out,T & inout,vtkm::Id workIndex) const36   VTKM_EXEC void operator()(const T& in, T& out, T& inout, vtkm::Id workIndex) const
37   {
38     if (!test_equal(in, TestValue(workIndex, T()) + T(100)))
39     {
40       this->RaiseError("Got wrong input value.");
41     }
42     out = in - T(100);
43     if (!test_equal(inout, TestValue(workIndex, T()) + T(100)))
44     {
45       this->RaiseError("Got wrong in-out value.");
46     }
47     inout = inout - T(100);
48   }
49 
50   template <typename T1, typename T2, typename T3>
operator ()(const T1 &,const T2 &,const T3 &,vtkm::Id) const51   VTKM_EXEC void operator()(const T1&, const T2&, const T3&, vtkm::Id) const
52   {
53     this->RaiseError("Cannot call this worklet with different types.");
54   }
55 };
56 
57 class TestMapFieldWorkletLimitedTypes : public vtkm::worklet::WorkletMapField
58 {
59 public:
60   using ControlSignature = void(FieldIn<ScalarAll>, FieldOut<ScalarAll>, FieldInOut<ScalarAll>);
61   using ExecutionSignature = _2(_1, _3, WorkIndex);
62 
63   template <typename T1, typename T3>
operator ()(const T1 & in,T3 & inout,vtkm::Id workIndex) const64   VTKM_EXEC T1 operator()(const T1& in, T3& inout, vtkm::Id workIndex) const
65   {
66     if (!test_equal(in, TestValue(workIndex, T1()) + T1(100)))
67     {
68       this->RaiseError("Got wrong input value.");
69     }
70 
71     if (!test_equal(inout, TestValue(workIndex, T3()) + T3(100)))
72     {
73       this->RaiseError("Got wrong in-out value.");
74     }
75     inout = inout - T3(100);
76 
77     return in - T1(100);
78   }
79 };
80 
81 namespace mapfield
82 {
83 static constexpr vtkm::Id ARRAY_SIZE = 10;
84 
85 template <typename WorkletType>
86 struct DoStaticTestWorklet
87 {
88   template <typename T>
operator ()mapfield::DoStaticTestWorklet89   VTKM_CONT void operator()(T) const
90   {
91     std::cout << "Set up data." << std::endl;
92     T inputArray[ARRAY_SIZE];
93 
94     for (vtkm::Id index = 0; index < ARRAY_SIZE; index++)
95     {
96       inputArray[index] = TestValue(index, T()) + T(100);
97     }
98 
99     vtkm::cont::ArrayHandle<T> inputHandle = vtkm::cont::make_ArrayHandle(inputArray, ARRAY_SIZE);
100     vtkm::cont::ArrayHandle<T> outputHandle, outputHandleAsPtr;
101     vtkm::cont::ArrayHandle<T> inoutHandle, inoutHandleAsPtr;
102 
103     vtkm::cont::ArrayCopy(inputHandle, inoutHandle, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
104     vtkm::cont::ArrayCopy(inputHandle, inoutHandleAsPtr, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
105 
106     std::cout << "Create and run dispatchers." << std::endl;
107     vtkm::worklet::DispatcherMapField<WorkletType> dispatcher;
108     dispatcher.Invoke(inputHandle, outputHandle, inoutHandle);
109     dispatcher.Invoke(&inputHandle, &outputHandleAsPtr, &inoutHandleAsPtr);
110 
111     std::cout << "Check results." << std::endl;
112     CheckPortal(outputHandle.GetPortalConstControl());
113     CheckPortal(inoutHandle.GetPortalConstControl());
114     CheckPortal(outputHandleAsPtr.GetPortalConstControl());
115     CheckPortal(inoutHandleAsPtr.GetPortalConstControl());
116 
117     std::cout << "Try to invoke with an input array of the wrong size." << std::endl;
118     inputHandle.Shrink(ARRAY_SIZE / 2);
119     bool exceptionThrown = false;
120     try
121     {
122       dispatcher.Invoke(inputHandle, outputHandle, inoutHandle);
123     }
124     catch (vtkm::cont::ErrorBadValue& error)
125     {
126       std::cout << "  Caught expected error: " << error.GetMessage() << std::endl;
127       exceptionThrown = true;
128     }
129     VTKM_TEST_ASSERT(exceptionThrown, "Dispatcher did not throw expected exception.");
130   }
131 };
132 
133 template <typename WorkletType>
134 struct DoDynamicTestWorklet
135 {
136   template <typename T>
operator ()mapfield::DoDynamicTestWorklet137   VTKM_CONT void operator()(T) const
138   {
139     std::cout << "Set up data." << std::endl;
140     T inputArray[ARRAY_SIZE];
141 
142     for (vtkm::Id index = 0; index < ARRAY_SIZE; index++)
143     {
144       inputArray[index] = TestValue(index, T()) + T(100);
145     }
146 
147     vtkm::cont::ArrayHandle<T> inputHandle = vtkm::cont::make_ArrayHandle(inputArray, ARRAY_SIZE);
148     vtkm::cont::ArrayHandle<T> outputHandle;
149     vtkm::cont::ArrayHandle<T> inoutHandle;
150 
151 
152     std::cout << "Create and run dispatcher with dynamic arrays." << std::endl;
153     vtkm::worklet::DispatcherMapField<WorkletType> dispatcher;
154 
155     vtkm::cont::DynamicArrayHandle inputDynamic(inputHandle);
156 
157     { //Verify we can pass by value
158       vtkm::cont::ArrayCopy(inputHandle, inoutHandle, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
159       vtkm::cont::DynamicArrayHandle outputDynamic(outputHandle);
160       vtkm::cont::DynamicArrayHandle inoutDynamic(inoutHandle);
161       dispatcher.Invoke(inputDynamic, outputDynamic, inoutDynamic);
162       CheckPortal(outputHandle.GetPortalConstControl());
163       CheckPortal(inoutHandle.GetPortalConstControl());
164     }
165 
166     { //Verify we can pass by pointer
167       vtkm::cont::ArrayCopy(inputHandle, inoutHandle, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
168       vtkm::cont::DynamicArrayHandle outputDynamic(outputHandle);
169       vtkm::cont::DynamicArrayHandle inoutDynamic(inoutHandle);
170       dispatcher.Invoke(&inputDynamic, &outputDynamic, &inoutDynamic);
171       CheckPortal(outputHandle.GetPortalConstControl());
172       CheckPortal(inoutHandle.GetPortalConstControl());
173     }
174   }
175 };
176 
177 template <typename WorkletType>
178 struct DoTestWorklet
179 {
180   template <typename T>
operator ()mapfield::DoTestWorklet181   VTKM_CONT void operator()(T t) const
182   {
183     DoStaticTestWorklet<WorkletType> sw;
184     sw(t);
185     DoDynamicTestWorklet<WorkletType> dw;
186     dw(t);
187   }
188 };
189 
TestWorkletMapField()190 void TestWorkletMapField()
191 {
192   using DeviceAdapterTraits = vtkm::cont::DeviceAdapterTraits<VTKM_DEFAULT_DEVICE_ADAPTER_TAG>;
193   std::cout << "Testing Map Field on device adapter: " << DeviceAdapterTraits::GetName()
194             << std::endl;
195 
196   std::cout << "--- Worklet accepting all types." << std::endl;
197   vtkm::testing::Testing::TryTypes(mapfield::DoTestWorklet<TestMapFieldWorklet>(),
198                                    vtkm::TypeListTagCommon());
199 
200   std::cout << "--- Worklet accepting some types." << std::endl;
201   vtkm::testing::Testing::TryTypes(mapfield::DoTestWorklet<TestMapFieldWorkletLimitedTypes>(),
202                                    vtkm::TypeListTagFieldScalar());
203 
204   std::cout << "--- Sending bad type to worklet." << std::endl;
205   try
206   {
207     //can only test with dynamic arrays, as static arrays will fail to compile
208     DoDynamicTestWorklet<TestMapFieldWorkletLimitedTypes> badWorkletTest;
209     badWorkletTest(vtkm::Vec<vtkm::Float32, 3>());
210     VTKM_TEST_FAIL("Did not throw expected error.");
211   }
212   catch (vtkm::cont::ErrorBadType& error)
213   {
214     std::cout << "Got expected error: " << error.GetMessage() << std::endl;
215   }
216 }
217 
218 } // mapfield namespace
219 
UnitTestWorkletMapField(int,char * [])220 int UnitTestWorkletMapField(int, char* [])
221 {
222   return vtkm::cont::testing::Testing::Run(mapfield::TestWorkletMapField);
223 }
224