1 //============================================================================
2 // Copyright (c) Kitware, Inc.
3 // All rights reserved.
4 // See LICENSE.txt for details.
5 // This software is distributed WITHOUT ANY WARRANTY; without even
6 // the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
7 // PURPOSE. See the above copyright notice for more information.
8 //
9 // Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
10 // Copyright 2014 UT-Battelle, LLC.
11 // Copyright 2014 Los Alamos National Security.
12 //
13 // Under the terms of Contract DE-NA0003525 with NTESS,
14 // the U.S. Government retains certain rights in this software.
15 //
16 // Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
17 // Laboratory (LANL), the U.S. Government retains certain rights in
18 // this software.
19 //============================================================================
20
21 #include <vtkm/worklet/internal/DispatcherBase.h>
22
23 #include <vtkm/worklet/internal/WorkletBase.h>
24
25 #include <vtkm/cont/testing/Testing.h>
26
27 namespace
28 {
29
30 static constexpr vtkm::Id ARRAY_SIZE = 10;
31
32 struct TestExecObjectIn
33 {
34 VTKM_EXEC_CONT
TestExecObjectIn__anonc10a45f30111::TestExecObjectIn35 TestExecObjectIn()
36 : Array(nullptr)
37 {
38 }
39
40 VTKM_EXEC_CONT
TestExecObjectIn__anonc10a45f30111::TestExecObjectIn41 TestExecObjectIn(const vtkm::Id* array)
42 : Array(array)
43 {
44 }
45
46 const vtkm::Id* Array;
47 };
48
49 struct TestExecObjectOut
50 {
51 VTKM_EXEC_CONT
TestExecObjectOut__anonc10a45f30111::TestExecObjectOut52 TestExecObjectOut()
53 : Array(nullptr)
54 {
55 }
56
57 VTKM_EXEC_CONT
TestExecObjectOut__anonc10a45f30111::TestExecObjectOut58 TestExecObjectOut(vtkm::Id* array)
59 : Array(array)
60 {
61 }
62
63 vtkm::Id* Array;
64 };
65
66 template <typename Device>
67 struct ExecutionObject
68 {
69 vtkm::Id Value;
70 };
71
72 struct TestExecObjectType : vtkm::cont::ExecutionObjectBase
73 {
74 template <typename Functor, typename... Args>
CastAndCall__anonc10a45f30111::TestExecObjectType75 void CastAndCall(Functor f, Args&&... args) const
76 {
77 f(*this, std::forward<Args>(args)...);
78 }
79 template <typename Device>
PrepareForExecution__anonc10a45f30111::TestExecObjectType80 VTKM_CONT ExecutionObject<Device> PrepareForExecution(Device) const
81 {
82 ExecutionObject<Device> object;
83 object.Value = this->Value;
84 return object;
85 }
86 vtkm::Id Value;
87 };
88
89 struct TestExecObjectTypeBad
90 { //this will fail as it doesn't inherit from vtkm::cont::ExecutionObjectBase
91 template <typename Functor, typename... Args>
CastAndCall__anonc10a45f30111::TestExecObjectTypeBad92 void CastAndCall(Functor f, Args&&... args) const
93 {
94 f(*this, std::forward<Args>(args)...);
95 }
96 };
97
98 struct TestTypeCheckTag
99 {
100 };
101 struct TestTransportTagIn
102 {
103 };
104 struct TestTransportTagOut
105 {
106 };
107 struct TestFetchTagInput
108 {
109 };
110 struct TestFetchTagOutput
111 {
112 };
113
114 } // anonymous namespace
115
116 namespace vtkm
117 {
118 namespace cont
119 {
120 namespace arg
121 {
122
123 template <>
124 struct TypeCheck<TestTypeCheckTag, std::vector<vtkm::Id>>
125 {
126 static constexpr bool value = true;
127 };
128
129 template <typename Device>
130 struct Transport<TestTransportTagIn, std::vector<vtkm::Id>, Device>
131 {
132 using ExecObjectType = TestExecObjectIn;
133
134 VTKM_CONT
operator ()vtkm::cont::arg::Transport135 ExecObjectType operator()(const std::vector<vtkm::Id>& contData,
136 const std::vector<vtkm::Id>&,
137 vtkm::Id inputRange,
138 vtkm::Id outputRange) const
139 {
140 VTKM_TEST_ASSERT(inputRange == ARRAY_SIZE, "Got unexpected size in test transport.");
141 VTKM_TEST_ASSERT(outputRange == ARRAY_SIZE, "Got unexpected size in test transport.");
142 return ExecObjectType(contData.data());
143 }
144 };
145
146 template <typename Device>
147 struct Transport<TestTransportTagOut, std::vector<vtkm::Id>, Device>
148 {
149 using ExecObjectType = TestExecObjectOut;
150
151 VTKM_CONT
operator ()vtkm::cont::arg::Transport152 ExecObjectType operator()(const std::vector<vtkm::Id>& contData,
153 const std::vector<vtkm::Id>&,
154 vtkm::Id inputRange,
155 vtkm::Id outputRange) const
156 {
157 VTKM_TEST_ASSERT(inputRange == ARRAY_SIZE, "Got unexpected size in test transport.");
158 VTKM_TEST_ASSERT(outputRange == ARRAY_SIZE, "Got unexpected size in test transport.");
159 auto ptr = const_cast<vtkm::Id*>(contData.data());
160 return ExecObjectType(ptr);
161 }
162 };
163 }
164 }
165 } // namespace vtkm::cont::arg
166
167 namespace vtkm
168 {
169 namespace cont
170 {
171 namespace internal
172 {
173
174 template <>
175 struct DynamicTransformTraits<TestExecObjectType>
176 {
177 using DynamicTag = vtkm::cont::internal::DynamicTransformTagCastAndCall;
178 };
179 template <>
180 struct DynamicTransformTraits<TestExecObjectTypeBad>
181 {
182 using DynamicTag = vtkm::cont::internal::DynamicTransformTagCastAndCall;
183 };
184 }
185 }
186 } // namespace vtkm::cont::internal
187
188 namespace vtkm
189 {
190 namespace exec
191 {
192 namespace arg
193 {
194
195 template <>
196 struct Fetch<TestFetchTagInput,
197 vtkm::exec::arg::AspectTagDefault,
198 vtkm::exec::arg::ThreadIndicesBasic,
199 TestExecObjectIn>
200 {
201 using ValueType = vtkm::Id;
202
203 VTKM_EXEC
Loadvtkm::exec::arg::Fetch204 ValueType Load(const vtkm::exec::arg::ThreadIndicesBasic indices,
205 const TestExecObjectIn& execObject) const
206 {
207 return execObject.Array[indices.GetInputIndex()];
208 }
209
210 VTKM_EXEC
Storevtkm::exec::arg::Fetch211 void Store(const vtkm::exec::arg::ThreadIndicesBasic, const TestExecObjectIn&, ValueType) const
212 {
213 // No-op
214 }
215 };
216
217 template <>
218 struct Fetch<TestFetchTagOutput,
219 vtkm::exec::arg::AspectTagDefault,
220 vtkm::exec::arg::ThreadIndicesBasic,
221 TestExecObjectOut>
222 {
223 using ValueType = vtkm::Id;
224
225 VTKM_EXEC
Loadvtkm::exec::arg::Fetch226 ValueType Load(const vtkm::exec::arg::ThreadIndicesBasic&, const TestExecObjectOut&) const
227 {
228 // No-op
229 return ValueType();
230 }
231
232 VTKM_EXEC
Storevtkm::exec::arg::Fetch233 void Store(const vtkm::exec::arg::ThreadIndicesBasic& indices,
234 const TestExecObjectOut& execObject,
235 ValueType value) const
236 {
237 execObject.Array[indices.GetOutputIndex()] = value;
238 }
239 };
240 }
241 }
242 } // vtkm::exec::arg
243
244 namespace
245 {
246
247 static constexpr vtkm::Id EXPECTED_EXEC_OBJECT_VALUE = 123;
248
249 class TestWorkletBase : public vtkm::worklet::internal::WorkletBase
250 {
251 public:
252 struct TestIn : vtkm::cont::arg::ControlSignatureTagBase
253 {
254 using TypeCheckTag = TestTypeCheckTag;
255 using TransportTag = TestTransportTagIn;
256 using FetchTag = TestFetchTagInput;
257 };
258 struct TestOut : vtkm::cont::arg::ControlSignatureTagBase
259 {
260 using TypeCheckTag = TestTypeCheckTag;
261 using TransportTag = TestTransportTagOut;
262 using FetchTag = TestFetchTagOutput;
263 };
264 };
265
266 class TestWorklet : public TestWorkletBase
267 {
268 public:
269 using ControlSignature = void(TestIn, ExecObject, TestOut);
270 using ExecutionSignature = _3(_1, _2, WorkIndex);
271
272 template <typename ExecObjectType>
operator ()(vtkm::Id value,ExecObjectType execObject,vtkm::Id index) const273 VTKM_EXEC vtkm::Id operator()(vtkm::Id value, ExecObjectType execObject, vtkm::Id index) const
274 {
275 VTKM_TEST_ASSERT(value == TestValue(index, vtkm::Id()), "Got bad value in worklet.");
276 VTKM_TEST_ASSERT(execObject.Value == EXPECTED_EXEC_OBJECT_VALUE,
277 "Got bad exec object in worklet.");
278 return TestValue(index, vtkm::Id()) + 1000;
279 }
280 };
281
282 #define ERROR_MESSAGE "Expected worklet error."
283
284 class TestErrorWorklet : public TestWorkletBase
285 {
286 public:
287 using ControlSignature = void(TestIn, ExecObject, TestOut);
288 using ExecutionSignature = void(_1, _2, _3);
289
290 template <typename ExecObjectType>
operator ()(vtkm::Id,ExecObjectType,vtkm::Id) const291 VTKM_EXEC void operator()(vtkm::Id, ExecObjectType, vtkm::Id) const
292 {
293 this->RaiseError(ERROR_MESSAGE);
294 }
295 };
296
297 template <typename WorkletType>
298 class TestDispatcher : public vtkm::worklet::internal::DispatcherBase<TestDispatcher<WorkletType>,
299 WorkletType,
300 TestWorkletBase>
301 {
302 using Superclass = vtkm::worklet::internal::DispatcherBase<TestDispatcher<WorkletType>,
303 WorkletType,
304 TestWorkletBase>;
305 using ScatterType = typename Superclass::ScatterType;
306
307 public:
308 VTKM_CONT
TestDispatcher(const WorkletType & worklet=WorkletType (),const ScatterType & scatter=ScatterType ())309 TestDispatcher(const WorkletType& worklet = WorkletType(),
310 const ScatterType& scatter = ScatterType())
311 : Superclass(worklet, scatter)
312 {
313 }
314
315 VTKM_CONT
316 template <typename Invocation>
DoInvoke(Invocation && invocation) const317 void DoInvoke(Invocation&& invocation) const
318 {
319 std::cout << "In TestDispatcher::DoInvoke()" << std::endl;
320 this->BasicInvoke(invocation, ARRAY_SIZE);
321 }
322
323 private:
324 WorkletType Worklet;
325 };
326
TestBasicInvoke()327 void TestBasicInvoke()
328 {
329 std::cout << "Test basic invoke" << std::endl;
330 std::cout << " Set up data." << std::endl;
331 std::vector<vtkm::Id> inputArray(ARRAY_SIZE);
332 std::vector<vtkm::Id> outputArray(ARRAY_SIZE);
333 TestExecObjectType execObject;
334 execObject.Value = EXPECTED_EXEC_OBJECT_VALUE;
335
336 std::size_t i = 0;
337 for (vtkm::Id index = 0; index < ARRAY_SIZE; index++, i++)
338 {
339 inputArray[i] = TestValue(index, vtkm::Id());
340 outputArray[i] = static_cast<vtkm::Id>(0xDEADDEAD);
341 }
342
343 std::cout << " Create and run dispatcher." << std::endl;
344 TestDispatcher<TestWorklet> dispatcher;
345 dispatcher.Invoke(inputArray, execObject, &outputArray);
346
347 std::cout << " Check output of invoke." << std::endl;
348 i = 0;
349 for (vtkm::Id index = 0; index < ARRAY_SIZE; index++, i++)
350 {
351 VTKM_TEST_ASSERT(outputArray[i] == TestValue(index, vtkm::Id()) + 1000,
352 "Got bad value from testing.");
353 }
354 }
355
TestInvokeWithError()356 void TestInvokeWithError()
357 {
358 std::cout << "Test invoke with error raised" << std::endl;
359 std::cout << " Set up data." << std::endl;
360 std::vector<vtkm::Id> inputArray(ARRAY_SIZE);
361 std::vector<vtkm::Id> outputArray(ARRAY_SIZE);
362 TestExecObjectType execObject;
363 execObject.Value = EXPECTED_EXEC_OBJECT_VALUE;
364
365 std::size_t i = 0;
366 for (vtkm::Id index = 0; index < ARRAY_SIZE; index++, ++i)
367 {
368 inputArray[i] = TestValue(index, vtkm::Id());
369 outputArray[i] = static_cast<vtkm::Id>(0xDEADDEAD);
370 }
371
372 try
373 {
374 std::cout << " Create and run dispatcher that raises error." << std::endl;
375 TestDispatcher<TestErrorWorklet> dispatcher;
376 dispatcher.Invoke(&inputArray, execObject, outputArray);
377 VTKM_TEST_FAIL("Exception not thrown.");
378 }
379 catch (vtkm::cont::ErrorExecution& error)
380 {
381 std::cout << " Got expected exception." << std::endl;
382 std::cout << " Exception message: " << error.GetMessage() << std::endl;
383 VTKM_TEST_ASSERT(error.GetMessage() == ERROR_MESSAGE, "Got unexpected error message.");
384 }
385 }
386
TestInvokeWithBadDynamicType()387 void TestInvokeWithBadDynamicType()
388 {
389 std::cout << "Test invoke with bad type" << std::endl;
390
391 std::vector<vtkm::Id> inputArray(ARRAY_SIZE);
392 std::vector<vtkm::Id> outputArray(ARRAY_SIZE);
393 TestExecObjectTypeBad execObject;
394 TestDispatcher<TestWorklet> dispatcher;
395
396 try
397 {
398 std::cout << " Second argument bad." << std::endl;
399 dispatcher.Invoke(inputArray, execObject, outputArray);
400 VTKM_TEST_FAIL("Dispatcher did not throw expected error.");
401 }
402 catch (vtkm::cont::ErrorBadType& error)
403 {
404 std::cout << " Got expected exception." << std::endl;
405 std::cout << " " << error.GetMessage() << std::endl;
406 VTKM_TEST_ASSERT(error.GetMessage().find(" 2 ") != std::string::npos,
407 "Parameter index not named in error message.");
408 }
409 }
410
TestDispatcherBase()411 void TestDispatcherBase()
412 {
413 TestBasicInvoke();
414 TestInvokeWithError();
415 TestInvokeWithBadDynamicType();
416 }
417
418 } // anonymous namespace
419
UnitTestDispatcherBase(int,char * [])420 int UnitTestDispatcherBase(int, char* [])
421 {
422 return vtkm::cont::testing::Testing::Run(TestDispatcherBase);
423 }
424