1 //============================================================================
2 // Copyright (c) Kitware, Inc.
3 // All rights reserved.
4 // See LICENSE.txt for details.
5 // This software is distributed WITHOUT ANY WARRANTY; without even
6 // the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
7 // PURPOSE. See the above copyright notice for more information.
8 //
9 // Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
10 // Copyright 2014 UT-Battelle, LLC.
11 // Copyright 2014 Los Alamos National Security.
12 //
13 // Under the terms of Contract DE-NA0003525 with NTESS,
14 // the U.S. Government retains certain rights in this software.
15 //
16 // Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
17 // Laboratory (LANL), the U.S. Government retains certain rights in
18 // this software.
19 //============================================================================
20 #include <vtkm/cont/ArrayCopy.h>
21 #include <vtkm/cont/ArrayHandle.h>
22 #include <vtkm/cont/DynamicArrayHandle.h>
23
24 #include <vtkm/worklet/DispatcherMapField.h>
25 #include <vtkm/worklet/WorkletMapField.h>
26
27 #include <vtkm/cont/testing/Testing.h>
28
29 class TestMapFieldWorklet : public vtkm::worklet::WorkletMapField
30 {
31 public:
32 using ControlSignature = void(FieldIn<>, FieldOut<>, FieldInOut<>);
33 using ExecutionSignature = void(_1, _2, _3, WorkIndex);
34
35 template <typename T>
operator ()(const T & in,T & out,T & inout,vtkm::Id workIndex) const36 VTKM_EXEC void operator()(const T& in, T& out, T& inout, vtkm::Id workIndex) const
37 {
38 if (!test_equal(in, TestValue(workIndex, T()) + T(100)))
39 {
40 this->RaiseError("Got wrong input value.");
41 }
42 out = in - T(100);
43 if (!test_equal(inout, TestValue(workIndex, T()) + T(100)))
44 {
45 this->RaiseError("Got wrong in-out value.");
46 }
47 inout = inout - T(100);
48 }
49
50 template <typename T1, typename T2, typename T3>
operator ()(const T1 &,const T2 &,const T3 &,vtkm::Id) const51 VTKM_EXEC void operator()(const T1&, const T2&, const T3&, vtkm::Id) const
52 {
53 this->RaiseError("Cannot call this worklet with different types.");
54 }
55 };
56
57 class TestMapFieldWorkletLimitedTypes : public vtkm::worklet::WorkletMapField
58 {
59 public:
60 using ControlSignature = void(FieldIn<ScalarAll>, FieldOut<ScalarAll>, FieldInOut<ScalarAll>);
61 using ExecutionSignature = _2(_1, _3, WorkIndex);
62
63 template <typename T1, typename T3>
operator ()(const T1 & in,T3 & inout,vtkm::Id workIndex) const64 VTKM_EXEC T1 operator()(const T1& in, T3& inout, vtkm::Id workIndex) const
65 {
66 if (!test_equal(in, TestValue(workIndex, T1()) + T1(100)))
67 {
68 this->RaiseError("Got wrong input value.");
69 }
70
71 if (!test_equal(inout, TestValue(workIndex, T3()) + T3(100)))
72 {
73 this->RaiseError("Got wrong in-out value.");
74 }
75 inout = inout - T3(100);
76
77 return in - T1(100);
78 }
79 };
80
81 namespace mapfield
82 {
83 static constexpr vtkm::Id ARRAY_SIZE = 10;
84
85 template <typename WorkletType>
86 struct DoStaticTestWorklet
87 {
88 template <typename T>
operator ()mapfield::DoStaticTestWorklet89 VTKM_CONT void operator()(T) const
90 {
91 std::cout << "Set up data." << std::endl;
92 T inputArray[ARRAY_SIZE];
93
94 for (vtkm::Id index = 0; index < ARRAY_SIZE; index++)
95 {
96 inputArray[index] = TestValue(index, T()) + T(100);
97 }
98
99 vtkm::cont::ArrayHandle<T> inputHandle = vtkm::cont::make_ArrayHandle(inputArray, ARRAY_SIZE);
100 vtkm::cont::ArrayHandle<T> outputHandle, outputHandleAsPtr;
101 vtkm::cont::ArrayHandle<T> inoutHandle, inoutHandleAsPtr;
102
103 vtkm::cont::ArrayCopy(inputHandle, inoutHandle, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
104 vtkm::cont::ArrayCopy(inputHandle, inoutHandleAsPtr, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
105
106 std::cout << "Create and run dispatchers." << std::endl;
107 vtkm::worklet::DispatcherMapField<WorkletType> dispatcher;
108 dispatcher.Invoke(inputHandle, outputHandle, inoutHandle);
109 dispatcher.Invoke(&inputHandle, &outputHandleAsPtr, &inoutHandleAsPtr);
110
111 std::cout << "Check results." << std::endl;
112 CheckPortal(outputHandle.GetPortalConstControl());
113 CheckPortal(inoutHandle.GetPortalConstControl());
114 CheckPortal(outputHandleAsPtr.GetPortalConstControl());
115 CheckPortal(inoutHandleAsPtr.GetPortalConstControl());
116
117 std::cout << "Try to invoke with an input array of the wrong size." << std::endl;
118 inputHandle.Shrink(ARRAY_SIZE / 2);
119 bool exceptionThrown = false;
120 try
121 {
122 dispatcher.Invoke(inputHandle, outputHandle, inoutHandle);
123 }
124 catch (vtkm::cont::ErrorBadValue& error)
125 {
126 std::cout << " Caught expected error: " << error.GetMessage() << std::endl;
127 exceptionThrown = true;
128 }
129 VTKM_TEST_ASSERT(exceptionThrown, "Dispatcher did not throw expected exception.");
130 }
131 };
132
133 template <typename WorkletType>
134 struct DoDynamicTestWorklet
135 {
136 template <typename T>
operator ()mapfield::DoDynamicTestWorklet137 VTKM_CONT void operator()(T) const
138 {
139 std::cout << "Set up data." << std::endl;
140 T inputArray[ARRAY_SIZE];
141
142 for (vtkm::Id index = 0; index < ARRAY_SIZE; index++)
143 {
144 inputArray[index] = TestValue(index, T()) + T(100);
145 }
146
147 vtkm::cont::ArrayHandle<T> inputHandle = vtkm::cont::make_ArrayHandle(inputArray, ARRAY_SIZE);
148 vtkm::cont::ArrayHandle<T> outputHandle;
149 vtkm::cont::ArrayHandle<T> inoutHandle;
150
151
152 std::cout << "Create and run dispatcher with dynamic arrays." << std::endl;
153 vtkm::worklet::DispatcherMapField<WorkletType> dispatcher;
154
155 vtkm::cont::DynamicArrayHandle inputDynamic(inputHandle);
156
157 { //Verify we can pass by value
158 vtkm::cont::ArrayCopy(inputHandle, inoutHandle, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
159 vtkm::cont::DynamicArrayHandle outputDynamic(outputHandle);
160 vtkm::cont::DynamicArrayHandle inoutDynamic(inoutHandle);
161 dispatcher.Invoke(inputDynamic, outputDynamic, inoutDynamic);
162 CheckPortal(outputHandle.GetPortalConstControl());
163 CheckPortal(inoutHandle.GetPortalConstControl());
164 }
165
166 { //Verify we can pass by pointer
167 vtkm::cont::ArrayCopy(inputHandle, inoutHandle, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
168 vtkm::cont::DynamicArrayHandle outputDynamic(outputHandle);
169 vtkm::cont::DynamicArrayHandle inoutDynamic(inoutHandle);
170 dispatcher.Invoke(&inputDynamic, &outputDynamic, &inoutDynamic);
171 CheckPortal(outputHandle.GetPortalConstControl());
172 CheckPortal(inoutHandle.GetPortalConstControl());
173 }
174 }
175 };
176
177 template <typename WorkletType>
178 struct DoTestWorklet
179 {
180 template <typename T>
operator ()mapfield::DoTestWorklet181 VTKM_CONT void operator()(T t) const
182 {
183 DoStaticTestWorklet<WorkletType> sw;
184 sw(t);
185 DoDynamicTestWorklet<WorkletType> dw;
186 dw(t);
187 }
188 };
189
TestWorkletMapField()190 void TestWorkletMapField()
191 {
192 using DeviceAdapterTraits = vtkm::cont::DeviceAdapterTraits<VTKM_DEFAULT_DEVICE_ADAPTER_TAG>;
193 std::cout << "Testing Map Field on device adapter: " << DeviceAdapterTraits::GetName()
194 << std::endl;
195
196 std::cout << "--- Worklet accepting all types." << std::endl;
197 vtkm::testing::Testing::TryTypes(mapfield::DoTestWorklet<TestMapFieldWorklet>(),
198 vtkm::TypeListTagCommon());
199
200 std::cout << "--- Worklet accepting some types." << std::endl;
201 vtkm::testing::Testing::TryTypes(mapfield::DoTestWorklet<TestMapFieldWorkletLimitedTypes>(),
202 vtkm::TypeListTagFieldScalar());
203
204 std::cout << "--- Sending bad type to worklet." << std::endl;
205 try
206 {
207 //can only test with dynamic arrays, as static arrays will fail to compile
208 DoDynamicTestWorklet<TestMapFieldWorkletLimitedTypes> badWorkletTest;
209 badWorkletTest(vtkm::Vec<vtkm::Float32, 3>());
210 VTKM_TEST_FAIL("Did not throw expected error.");
211 }
212 catch (vtkm::cont::ErrorBadType& error)
213 {
214 std::cout << "Got expected error: " << error.GetMessage() << std::endl;
215 }
216 }
217
218 } // mapfield namespace
219
UnitTestWorkletMapField(int,char * [])220 int UnitTestWorkletMapField(int, char* [])
221 {
222 return vtkm::cont::testing::Testing::Run(mapfield::TestWorkletMapField);
223 }
224