1 //============================================================================
2 // Copyright (c) Kitware, Inc.
3 // All rights reserved.
4 // See LICENSE.txt for details.
5 // This software is distributed WITHOUT ANY WARRANTY; without even
6 // the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
7 // PURPOSE. See the above copyright notice for more information.
8 //
9 // Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
10 // Copyright 2014 UT-Battelle, LLC.
11 // Copyright 2014 Los Alamos National Security.
12 //
13 // Under the terms of Contract DE-NA0003525 with NTESS,
14 // the U.S. Government retains certain rights in this software.
15 //
16 // Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
17 // Laboratory (LANL), the U.S. Government retains certain rights in
18 // this software.
19 //============================================================================
20 #ifndef vtk_m_worklet_internal_DispatcherBase_h
21 #define vtk_m_worklet_internal_DispatcherBase_h
22
23 #include <vtkm/StaticAssert.h>
24
25 #include <vtkm/internal/FunctionInterface.h>
26 #include <vtkm/internal/Invocation.h>
27
28 #include <vtkm/cont/DeviceAdapter.h>
29 #include <vtkm/cont/ErrorBadType.h>
30 #include <vtkm/cont/Logging.h>
31 #include <vtkm/cont/TryExecute.h>
32
33 #include <vtkm/cont/arg/ControlSignatureTagBase.h>
34 #include <vtkm/cont/arg/Transport.h>
35 #include <vtkm/cont/arg/TypeCheck.h>
36 #include <vtkm/cont/internal/DynamicTransform.h>
37
38 #include <vtkm/exec/arg/ExecutionSignatureTagBase.h>
39
40 #include <vtkm/internal/brigand.hpp>
41
42 #include <vtkm/worklet/internal/WorkletBase.h>
43
44 #include <sstream>
45
46 namespace vtkm
47 {
48 namespace cont
49 {
50
51 // Forward declaration.
52 template <typename CellSetList>
53 class DynamicCellSetBase;
54 }
55 }
56
57 namespace vtkm
58 {
59 namespace worklet
60 {
61 namespace internal
62 {
63
64 template <typename Domain>
65 inline auto scheduling_range(const Domain& inputDomain) -> decltype(inputDomain.GetNumberOfValues())
66 {
67 return inputDomain.GetNumberOfValues();
68 }
69
70 template <typename Domain>
71 inline auto scheduling_range(const Domain* const inputDomain)
72 -> decltype(inputDomain->GetNumberOfValues())
73 {
74 return inputDomain->GetNumberOfValues();
75 }
76
77 template <typename Domain, typename SchedulingRangeType>
78 inline auto scheduling_range(const Domain& inputDomain, SchedulingRangeType type)
79 -> decltype(inputDomain.GetSchedulingRange(type))
80 {
81 return inputDomain.GetSchedulingRange(type);
82 }
83
84 template <typename Domain, typename SchedulingRangeType>
85 inline auto scheduling_range(const Domain* const inputDomain, SchedulingRangeType type)
86 -> decltype(inputDomain->GetSchedulingRange(type))
87 {
88 return inputDomain->GetSchedulingRange(type);
89 }
90
91 namespace detail
92 {
93
94 // This code is actually taking an error found at compile-time and not
95 // reporting it until run-time. This seems strange at first, but this
96 // behavior is actually important. With dynamic arrays and similar dynamic
97 // classes, there may be types that are technically possible (such as using a
98 // vector where a scalar is expected) but in reality never happen. Thus, for
99 // these unsupported combinations we just silently halt the compiler from
100 // attempting to create code for these errant conditions and throw a run-time
101 // error if one every tries to create one.
PrintFailureMessage(int index)102 inline void PrintFailureMessage(int index)
103 {
104 std::stringstream message;
105 message << "Encountered bad type for parameter " << index
106 << " when calling Invoke on a dispatcher.";
107 throw vtkm::cont::ErrorBadType(message.str());
108 }
109
PrintNullPtrMessage(int index,int mode)110 inline void PrintNullPtrMessage(int index, int mode)
111 {
112 std::stringstream message;
113 if (mode == 0)
114 {
115 message << "Encountered nullptr for parameter " << index;
116 }
117 else
118 {
119 message << "Encountered nullptr for " << index << " from last parameter ";
120 }
121 message << " when calling Invoke on a dispatcher.";
122 throw vtkm::cont::ErrorBadValue(message.str());
123 }
124
125 template <typename T>
126 inline void not_nullptr(T* ptr, int index, int mode = 0)
127 {
128 if (!ptr)
129 {
130 PrintNullPtrMessage(index, mode);
131 }
132 }
133 template <typename T>
134 inline void not_nullptr(T&&, int, int mode = 0)
135 {
136 (void)mode;
137 }
138
139 template <typename T>
as_ref(T * ptr)140 inline T& as_ref(T* ptr)
141 {
142 return *ptr;
143 }
144 template <typename T>
as_ref(T && t)145 inline T&& as_ref(T&& t)
146 {
147 return std::forward<T>(t);
148 }
149
150
151 template <typename T, bool noError>
152 struct ReportTypeOnError;
153 template <typename T>
154 struct ReportTypeOnError<T, true> : std::true_type
155 {
156 };
157
158 template <int Value, bool noError>
159 struct ReportValueOnError;
160 template <int Value>
161 struct ReportValueOnError<Value, true> : std::true_type
162 {
163 };
164
165 template <typename T>
166 struct remove_pointer_and_decay : std::remove_pointer<typename std::decay<T>::type>
167 {
168 };
169
170 // Is designed as a brigand fold operation.
171 template <typename Type, typename State>
172 struct DetermineIfHasDynamicParameter
173 {
174 using T = typename std::remove_pointer<Type>::type;
175 using DynamicTag = typename vtkm::cont::internal::DynamicTransformTraits<T>::DynamicTag;
176 using isDynamic =
177 typename std::is_same<DynamicTag, vtkm::cont::internal::DynamicTransformTagCastAndCall>::type;
178
179 using type = std::integral_constant<bool, (State::value || isDynamic::value)>;
180 };
181
182
183 // Is designed as a brigand fold operation.
184 template <typename WorkletType>
185 struct DetermineHasCorrectParameters
186 {
187 template <typename Type, typename State, typename SigTypes>
188 struct Functor
189 {
190 //T is the type of the Param at the current index
191 //State if the index to use to fetch the control signature tag
192 using ControlSignatureTag = typename brigand::at_c<SigTypes, State::value>;
193 using TypeCheckTag = typename ControlSignatureTag::TypeCheckTag;
194
195 using T = typename std::remove_pointer<Type>::type;
196 static constexpr bool isCorrect = vtkm::cont::arg::TypeCheck<TypeCheckTag, T>::value;
197
198 // If you get an error on the line below, that means that your code has called the
199 // Invoke method on a dispatcher, and one of the arguments of the Invoke is the wrong
200 // type. Each argument of Invoke corresponds to a tag in the arguments of the
201 // ControlSignature of the worklet. If there is a mismatch, then you get an error here
202 // (instead of where you called the dispatcher). For example, if the worklet has a
203 // control signature as ControlSignature(CellSetIn, ...) and the first argument passed
204 // to Invoke is an ArrayHandle, you will get an error here because you cannot use an
205 // ArrayHandle in place of a CellSetIn argument. (You need to use a CellSet.) See a few
206 // lines later for some diagnostics to help you trace where the error occurred.
207 VTKM_READ_THE_SOURCE_CODE_FOR_HELP(isCorrect);
208
209 // If you are getting the error described above, the following lines will give you some
210 // diagnostics (in the form of compile errors). Each one will result in a compile error
211 // reporting an undefined type for ReportTypeOnError (or ReportValueOnError). What we are
212 // really reporting is the first template argument, which is one of the types or values that
213 // should help pinpoint where the error is. The comment for static_assert provides the
214 // type/value being reported. (Note that some compilers report better types than others. If
215 // your compiler is giving unhelpful types like "T" or "WorkletType", you may need to try a
216 // different compiler.)
217 static_assert(ReportTypeOnError<T, isCorrect>::value, "Type passed to Invoke");
218 static_assert(ReportTypeOnError<WorkletType, isCorrect>::value, "Worklet being invoked.");
219 static_assert(ReportValueOnError<State::value, isCorrect>::value, "Index of Invoke parameter");
220 static_assert(ReportTypeOnError<TypeCheckTag, isCorrect>::value, "Type check tag used");
221
222 // This final static_assert gives a human-readable error message. Ideally, this would be
223 // placed first, but some compilers will suppress further errors when a static_assert
224 // fails, so you would not see the other diagnostic error messages.
225 static_assert(isCorrect,
226 "The type of one of the arguments to the dispatcher's Invoke method is "
227 "incompatible with the corresponding tag in the worklet's ControlSignature.");
228
229 using type = std::integral_constant<std::size_t, State::value + 1>;
230 };
231 };
232
233 // Checks that an argument in a ControlSignature is a valid control signature
234 // tag. Causes a compile error otherwise.
235 struct DispatcherBaseControlSignatureTagCheck
236 {
237 template <typename ControlSignatureTag, vtkm::IdComponent Index>
238 struct ReturnType
239 {
240 // If you get a compile error here, it means there is something that is
241 // not a valid control signature tag in a worklet's ControlSignature.
242 VTKM_IS_CONTROL_SIGNATURE_TAG(ControlSignatureTag);
243 using type = ControlSignatureTag;
244 };
245 };
246
247 // Checks that an argument in a ExecutionSignature is a valid execution
248 // signature tag. Causes a compile error otherwise.
249 struct DispatcherBaseExecutionSignatureTagCheck
250 {
251 template <typename ExecutionSignatureTag, vtkm::IdComponent Index>
252 struct ReturnType
253 {
254 // If you get a compile error here, it means there is something that is not
255 // a valid execution signature tag in a worklet's ExecutionSignature.
256 VTKM_IS_EXECUTION_SIGNATURE_TAG(ExecutionSignatureTag);
257 using type = ExecutionSignatureTag;
258 };
259 };
260
261 struct DispatcherBaseTryExecuteFunctor
262 {
263 template <typename Device, typename DispatcherBaseType, typename Invocation, typename RangeType>
264 VTKM_CONT bool operator()(Device device,
265 const DispatcherBaseType* self,
266 Invocation& invocation,
267 const RangeType& dimensions)
268 {
269 self->InvokeTransportParameters(
270 invocation, dimensions, self->Scatter.GetOutputRange(dimensions), device);
271 return true;
272 }
273 };
274
275 // A look up helper used by DispatcherBaseTransportFunctor to determine
276 //the types independent of the device we are templated on.
277 template <typename ControlInterface, vtkm::IdComponent Index>
278 struct DispatcherBaseTransportInvokeTypes
279 {
280 //Moved out of DispatcherBaseTransportFunctor to reduce code generation
281 using ControlSignatureTag = typename ControlInterface::template ParameterType<Index>::type;
282 using TransportTag = typename ControlSignatureTag::TransportTag;
283 };
284
285 VTKM_CONT
286 inline vtkm::Id FlatRange(vtkm::Id range)
287 {
288 return range;
289 }
290
291 VTKM_CONT
292 inline vtkm::Id FlatRange(const vtkm::Id3& range)
293 {
294 return range[0] * range[1] * range[2];
295 }
296
297 // A functor used in a StaticCast of a FunctionInterface to transport arguments
298 // from the control environment to the execution environment.
299 template <typename ControlInterface, typename InputDomainType, typename Device>
300 struct DispatcherBaseTransportFunctor
301 {
302 const InputDomainType& InputDomain; // Warning: this is a reference
303 vtkm::Id InputRange;
304 vtkm::Id OutputRange;
305
306 // TODO: We need to think harder about how scheduling on 3D arrays works.
307 // Chances are we need to allow the transport for each argument to manage
308 // 3D indices (for example, allocate a 3D array instead of a 1D array).
309 // But for now, just treat all transports as 1D arrays.
310 template <typename InputRangeType, typename OutputRangeType>
311 VTKM_CONT DispatcherBaseTransportFunctor(const InputDomainType& inputDomain,
312 const InputRangeType& inputRange,
313 const OutputRangeType& outputRange)
314 : InputDomain(inputDomain)
315 , InputRange(FlatRange(inputRange))
316 , OutputRange(FlatRange(outputRange))
317 {
318 }
319
320
321 template <typename ControlParameter, vtkm::IdComponent Index>
322 struct ReturnType
323 {
324 using TransportTag =
325 typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
326 using T = typename remove_pointer_and_decay<ControlParameter>::type;
327 using TransportType = typename vtkm::cont::arg::Transport<TransportTag, T, Device>;
328 using type = typename TransportType::ExecObjectType;
329 };
330
331 // template<typename ControlParameter, vtkm::IdComponent Index>
332 // VTKM_CONT typename ReturnType<ControlParameter, Index>::type operator()(
333 // ControlParameter const& invokeData,
334 // vtkm::internal::IndexTag<Index>) const
335 // {
336 // using TransportTag =
337 // typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
338 // using T = typename remove_pointer_and_decay<ControlParameter>::type;
339 // vtkm::cont::arg::Transport<TransportTag, T, Device> transport;
340 // return transport(invokeData, as_ref(this->InputDomain), this->InputRange, this->OutputRange);
341 // }
342
343 template <typename ControlParameter, vtkm::IdComponent Index>
344 VTKM_CONT typename ReturnType<ControlParameter, Index>::type operator()(
345 ControlParameter&& invokeData,
346 vtkm::internal::IndexTag<Index>) const
347 {
348 using TransportTag =
349 typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
350 using T = typename remove_pointer_and_decay<ControlParameter>::type;
351 vtkm::cont::arg::Transport<TransportTag, T, Device> transport;
352
353 not_nullptr(invokeData, Index);
354 return transport(
355 as_ref(invokeData), as_ref(this->InputDomain), this->InputRange, this->OutputRange);
356 }
357
358
359
360 private:
361 void operator=(const DispatcherBaseTransportFunctor&) = delete;
362 };
363
364 //forward declares
365 template <std::size_t LeftToProcess>
366 struct for_each_dynamic_arg;
367
368 template <std::size_t LeftToProcess, typename TypeCheckTag>
369 struct convert_arg_wrapper
370 {
371 template <typename T, typename... Args>
372 void operator()(T&& t, Args&&... args) const
373 {
374 using Type = typename std::decay<T>::type;
375 using valid =
376 std::integral_constant<bool, vtkm::cont::arg::TypeCheck<TypeCheckTag, Type>::value>;
377 this->WillContinue(valid(), std::forward<T>(t), std::forward<Args>(args)...);
378 }
379 template <typename T, typename... Args>
380 void WillContinue(std::true_type, T&& t, Args&&... args) const
381 {
382 for_each_dynamic_arg<LeftToProcess - 1>()(std::forward<Args>(args)..., std::forward<T>(t));
383 }
384 template <typename... Args>
385 void WillContinue(std::false_type, Args&&...) const
386 {
387 vtkm::worklet::internal::detail::PrintFailureMessage(LeftToProcess);
388 }
389 };
390
391 template <std::size_t LeftToProcess,
392 typename T,
393 typename ContParams,
394 typename Trampoline,
395 typename... Args>
396 inline void convert_arg(vtkm::cont::internal::DynamicTransformTagStatic,
397 T&& t,
398 const ContParams&,
399 const Trampoline& trampoline,
400 Args&&... args)
401 { //This is a static array, so just push it to the back
402 using popped_sig = brigand::pop_front<ContParams>;
403 for_each_dynamic_arg<LeftToProcess - 1>()(
404 trampoline, popped_sig(), std::forward<Args>(args)..., std::forward<T>(t));
405 }
406
407 template <std::size_t LeftToProcess,
408 typename T,
409 typename ContParams,
410 typename Trampoline,
411 typename... Args>
412 inline void convert_arg(vtkm::cont::internal::DynamicTransformTagCastAndCall,
413 T&& t,
414 const ContParams&,
415 const Trampoline& trampoline,
416 Args&&... args)
417 { //This is something dynamic so cast and call
418 using tag_check = typename brigand::at_c<ContParams, 0>::TypeCheckTag;
419 using popped_sig = brigand::pop_front<ContParams>;
420
421 not_nullptr(t, LeftToProcess, 1);
422 vtkm::cont::CastAndCall(as_ref(t),
423 convert_arg_wrapper<LeftToProcess, tag_check>(),
424 trampoline,
425 popped_sig(),
426 std::forward<Args>(args)...);
427 }
428
429 template <std::size_t LeftToProcess>
430 struct for_each_dynamic_arg
431 {
432 template <typename Trampoline, typename ContParams, typename T, typename... Args>
433 void operator()(const Trampoline& trampoline, ContParams&& sig, T&& t, Args&&... args) const
434 {
435 //Determine that state of T when it is either a `cons&` or a `* const&`
436 using Type = typename std::remove_pointer<typename std::decay<T>::type>::type;
437 using tag = typename vtkm::cont::internal::DynamicTransformTraits<Type>::DynamicTag;
438 //convert the first item to a known type
439 convert_arg<LeftToProcess>(
440 tag(), std::forward<T>(t), sig, trampoline, std::forward<Args>(args)...);
441 }
442 };
443
444 template <>
445 struct for_each_dynamic_arg<0>
446 {
447 template <typename Trampoline, typename ContParams, typename... Args>
448 void operator()(const Trampoline& trampoline, ContParams&&, Args&&... args) const
449 {
450 trampoline.StartInvokeDynamic(std::false_type(), std::forward<Args>(args)...);
451 }
452 };
453
454 template <typename Trampoline, typename ContParams, typename... Args>
455 inline void deduce(Trampoline&& trampoline, ContParams&& sig, Args&&... args)
456 {
457 for_each_dynamic_arg<sizeof...(Args)>()(std::forward<Trampoline>(trampoline), sig, args...);
458 }
459
460
461 #if defined(VTKM_MSVC)
462 #pragma warning(push)
463 #pragma warning(disable : 4068) //unknown pragma
464 #endif
465 #if defined(__NVCC__) && defined(__CUDACC_VER_MAJOR__)
466 // Disable warning "calling a __host__ function from a __host__ __device__"
467 // In some cases nv_exec_check_disable doesn't work and therefore you need
468 // to use the following suppressions
469 #pragma push
470
471 #if (__CUDACC_VER_MAJOR__ < 8)
472 #pragma diag_suppress 2670
473 #pragma diag_suppress 2668
474 #endif
475
476 #if (__CUDACC_VER_MAJOR__ >= 8)
477 #pragma diag_suppress 2735
478 #pragma diag_suppress 2737
479 #pragma diag_suppress 2739
480 #endif
481
482 #if (__CUDACC_VER_MAJOR__ >= 9)
483 #pragma diag_suppress 2828
484 #pragma diag_suppress 2864
485 #pragma diag_suppress 2867
486 #pragma diag_suppress 2885
487 #endif
488
489 #if (__CUDACC_VER_MAJOR__ >= 10)
490 #pragma diag_suppress 2905
491 #endif
492
493 #endif
494 //This is a separate function as the pragma guards can cause nvcc
495 //to have an internal compiler error (codegen #3028)
496 template <typename... Args>
497 inline auto make_funcIFace(Args&&... args) -> decltype(
498 vtkm::internal::make_FunctionInterface<void, typename std::decay<Args>::type...>(args...))
499 {
500 return vtkm::internal::make_FunctionInterface<void, typename std::decay<Args>::type...>(args...);
501 }
502 #if defined(__NVCC__) && defined(__CUDACC_VER_MAJOR__)
503 #pragma pop
504 #endif
505 #if defined(VTKM_MSVC)
506 #pragma warning(pop)
507 #endif
508
509
510 } // namespace detail
511
512 /// This is a help struct to detect out of bound placeholders defined in the
513 /// execution signature at compile time
514 template <vtkm::IdComponent MaxIndexAllowed>
515 struct PlaceholderValidator
516 {
517 PlaceholderValidator() {}
518
519 // An overload operator to detect possible out of bound placeholder
520 template <int N>
521 void operator()(brigand::type_<vtkm::placeholders::Arg<N>>) const
522 {
523 static_assert(N <= MaxIndexAllowed,
524 "An argument in the execution signature"
525 " (usually _2, _3, _4, etc.) refers to a control signature argument that"
526 " does not exist. For example, you will get this error if you have _3 (or"
527 " _4 or _5 or so on) as one of the execution signature arguments, but you"
528 " have fewer than 3 (or 4 or 5 or so on) arguments in the control signature.");
529 }
530
531 template <typename DerivedType>
532 void operator()(brigand::type_<DerivedType>) const
533 {
534 }
535 };
536
537 /// Base class for all dispatcher classes. Every worklet type should have its
538 /// own dispatcher.
539 ///
540 template <typename DerivedClass, typename WorkletType, typename BaseWorkletType>
541 class DispatcherBase
542 {
543 private:
544 using MyType = DispatcherBase<DerivedClass, WorkletType, BaseWorkletType>;
545
546 friend struct detail::for_each_dynamic_arg<0>;
547
548 protected:
549 using ControlInterface =
550 vtkm::internal::FunctionInterface<typename WorkletType::ControlSignature>;
551 using ExecutionInterface =
552 vtkm::internal::FunctionInterface<typename WorkletType::ExecutionSignature>;
553
554 static constexpr vtkm::IdComponent NUM_INVOKE_PARAMS = ControlInterface::ARITY;
555
556 private:
557 // We don't really need these types, but declaring them checks the arguments
558 // of the control and execution signatures.
559 using ControlSignatureCheck = typename ControlInterface::template StaticTransformType<
560 detail::DispatcherBaseControlSignatureTagCheck>::type;
561 using ExecutionSignatureCheck = typename ExecutionInterface::template StaticTransformType<
562 detail::DispatcherBaseExecutionSignatureTagCheck>::type;
563
564 template <typename... Args>
565 VTKM_CONT void StartInvoke(Args&&... args) const
566 {
567 using ParameterInterface =
568 vtkm::internal::FunctionInterface<void(typename std::decay<Args>::type...)>;
569
570 VTKM_STATIC_ASSERT_MSG(ParameterInterface::ARITY == NUM_INVOKE_PARAMS,
571 "Dispatcher Invoke called with wrong number of arguments.");
572
573 static_assert(
574 std::is_base_of<BaseWorkletType, WorkletType>::value,
575 "The worklet being scheduled by this dispatcher doesn't match the type of the dispatcher");
576
577 // Check if the placeholders defined in the execution environment exceed the max bound
578 // defined in the control environment by throwing a nice compile error.
579 using ComponentSig = typename ExecutionInterface::ComponentSig;
580 brigand::for_each<ComponentSig>(PlaceholderValidator<NUM_INVOKE_PARAMS>{});
581
582 //We need to determine if we have the need to do any dynamic
583 //transforms. This is fairly simple of a query. We just need to check
584 //everything in the FunctionInterface and see if any of them have the
585 //proper dynamic trait. Doing this, allows us to generate zero dynamic
586 //check & convert code when we already know all the types. This results
587 //in smaller executables and libraries.
588 using ParamTypes = typename ParameterInterface::ParameterSig;
589 using HasDynamicTypes =
590 brigand::fold<ParamTypes,
591 std::false_type,
592 detail::DetermineIfHasDynamicParameter<brigand::_element, brigand::_state>>;
593
594 this->StartInvokeDynamic(HasDynamicTypes(), std::forward<Args>(args)...);
595 }
596
597 template <typename... Args>
598 VTKM_CONT void StartInvokeDynamic(std::true_type, Args&&... args) const
599 {
600 // As we do the dynamic transform, we are also going to check the static
601 // type against the TypeCheckTag in the ControlSignature tags. To do this,
602 // the check needs access to both the parameter (in the parameters
603 // argument) and the ControlSignature tags (in the ControlInterface type).
604 using ContParamsInfo =
605 vtkm::internal::detail::FunctionSigInfo<typename WorkletType::ControlSignature>;
606 typename ContParamsInfo::Parameters parameters;
607 detail::deduce(*this, parameters, std::forward<Args>(args)...);
608 }
609
610 template <typename... Args>
611 VTKM_CONT void StartInvokeDynamic(std::false_type, Args&&... args) const
612 {
613 using ParameterInterface =
614 vtkm::internal::FunctionInterface<void(typename std::decay<Args>::type...)>;
615
616 //Nothing requires a conversion from dynamic to static types, so
617 //next we need to verify that each argument's type is correct. If not
618 //we need to throw a nice compile time error
619 using ParamTypes = typename ParameterInterface::ParameterSig;
620 using ContSigTypes = typename vtkm::internal::detail::FunctionSigInfo<
621 typename WorkletType::ControlSignature>::Parameters;
622
623 //isAllValid will throw a compile error if everything doesn't match
624 using isAllValid = brigand::fold<
625 ParamTypes,
626 std::integral_constant<std::size_t, 0>,
627 typename detail::DetermineHasCorrectParameters<WorkletType>::
628 template Functor<brigand::_element, brigand::_state, brigand::pin<ContSigTypes>>>;
629
630 //this warning exists so that we don't get a warning from not using isAllValid
631 using expectedLen = std::integral_constant<std::size_t, sizeof...(Args)>;
632 static_assert(isAllValid::value == expectedLen::value,
633 "All arguments failed the TypeCheck pass");
634
635 //This is a separate function as the pragma guards can cause nvcc
636 //to have an internal compiler error (codegen #3028)
637 auto fi = detail::make_funcIFace(std::forward<Args>(args)...);
638
639 auto ivc = vtkm::internal::Invocation<ParameterInterface,
640 ControlInterface,
641 ExecutionInterface,
642 WorkletType::InputDomain::INDEX,
643 vtkm::internal::NullType,
644 vtkm::internal::NullType>(
645 fi, vtkm::internal::NullType{}, vtkm::internal::NullType{});
646 static_cast<const DerivedClass*>(this)->DoInvoke(ivc);
647 }
648
649 public:
650 //@{
651 /// Setting the device ID will force the execute to happen on a particular device. If no device
652 /// is specified (or the device ID is set to any), then a device will automatically be chosen
653 /// based on the runtime device tracker.
654 ///
655 VTKM_CONT
656 void SetDevice(vtkm::cont::DeviceAdapterId device) { this->Device = device; }
657
658 VTKM_CONT vtkm::cont::DeviceAdapterId GetDevice() const { return this->Device; }
659 //@}
660
661 using ScatterType = typename WorkletType::ScatterType;
662
663 template <typename... Args>
664 VTKM_CONT void Invoke(Args&&... args) const
665 {
666 VTKM_LOG_SCOPE(vtkm::cont::LogLevel::Perf,
667 "Invoking Worklet: '%s'",
668 vtkm::cont::TypeName<WorkletType>().c_str());
669 this->StartInvoke(std::forward<Args>(args)...);
670 }
671
672 protected:
673 VTKM_CONT
674 DispatcherBase(const WorkletType& worklet, const ScatterType& scatter)
675 : Worklet(worklet)
676 , Scatter(scatter)
677 , Device(vtkm::cont::DeviceAdapterTagAny())
678 {
679 }
680
681 friend struct internal::detail::DispatcherBaseTryExecuteFunctor;
682
683 template <typename Invocation>
684 VTKM_CONT void BasicInvoke(Invocation& invocation, vtkm::Id numInstances) const
685 {
686 bool success =
687 vtkm::cont::TryExecuteOnDevice(this->Device,
688 internal::detail::DispatcherBaseTryExecuteFunctor(),
689 this,
690 invocation,
691 numInstances);
692 if (!success)
693 {
694 throw vtkm::cont::ErrorExecution("Failed to execute worklet on any device.");
695 }
696 }
697
698 template <typename Invocation>
699 VTKM_CONT void BasicInvoke(Invocation& invocation, vtkm::Id2 dimensions) const
700 {
701 this->BasicInvoke(invocation, vtkm::Id3(dimensions[0], dimensions[1], 1));
702 }
703
704 template <typename Invocation>
705 VTKM_CONT void BasicInvoke(Invocation& invocation, vtkm::Id3 dimensions) const
706 {
707 bool success =
708 vtkm::cont::TryExecuteOnDevice(this->Device,
709 internal::detail::DispatcherBaseTryExecuteFunctor(),
710 this,
711 invocation,
712 dimensions);
713 if (!success)
714 {
715 throw vtkm::cont::ErrorExecution("Failed to execute worklet on any device.");
716 }
717 }
718
719 WorkletType Worklet;
720 ScatterType Scatter;
721
722 private:
723 // Dispatchers cannot be copied
724 DispatcherBase(const MyType&) = delete;
725 void operator=(const MyType&) = delete;
726
727 vtkm::cont::DeviceAdapterId Device;
728
729 template <typename Invocation,
730 typename InputRangeType,
731 typename OutputRangeType,
732 typename DeviceAdapter>
733 VTKM_CONT void InvokeTransportParameters(Invocation& invocation,
734 const InputRangeType& inputRange,
735 OutputRangeType&& outputRange,
736 DeviceAdapter device) const
737 {
738 // The first step in invoking a worklet is to transport the arguments to
739 // the execution environment. The invocation object passed to this function
740 // contains the parameters passed to Invoke in the control environment. We
741 // will use the template magic in the FunctionInterface class to invoke the
742 // appropriate Transport class on each parameter and get a list of
743 // execution objects (corresponding to the arguments of the Invoke in the
744 // control environment) in a FunctionInterface. Specifically, we use a
745 // static transform of the FunctionInterface to call the transport on each
746 // argument and return the corresponding execution environment object.
747 using ParameterInterfaceType = typename Invocation::ParameterInterface;
748 ParameterInterfaceType& parameters = invocation.Parameters;
749
750 using TransportFunctorType =
751 detail::DispatcherBaseTransportFunctor<typename Invocation::ControlInterface,
752 typename Invocation::InputDomainType,
753 DeviceAdapter>;
754 using ExecObjectParameters =
755 typename ParameterInterfaceType::template StaticTransformType<TransportFunctorType>::type;
756
757 ExecObjectParameters execObjectParameters = parameters.StaticTransformCont(
758 TransportFunctorType(invocation.GetInputDomain(), inputRange, outputRange));
759
760 // Get the arrays used for scattering input to output.
761 typename WorkletType::ScatterType::OutputToInputMapType outputToInputMap =
762 this->Scatter.GetOutputToInputMap(inputRange);
763 typename WorkletType::ScatterType::VisitArrayType visitArray =
764 this->Scatter.GetVisitArray(inputRange);
765
766 // Replace the parameters in the invocation with the execution object and
767 // pass to next step of Invoke. Also add the scatter information.
768 this->InvokeSchedule(invocation.ChangeParameters(execObjectParameters)
769 .ChangeOutputToInputMap(outputToInputMap.PrepareForInput(device))
770 .ChangeVisitArray(visitArray.PrepareForInput(device)),
771 outputRange,
772 device);
773 }
774
775 template <typename Invocation, typename RangeType, typename DeviceAdapter>
776 VTKM_CONT void InvokeSchedule(const Invocation& invocation, RangeType range, DeviceAdapter) const
777 {
778 using Algorithm = vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
779 using TaskTypes = typename vtkm::cont::DeviceTaskTypes<DeviceAdapter>;
780
781 // The TaskType class handles the magic of fetching values
782 // for each instance and calling the worklet's function.
783 // The TaskType will evaluate to one of the following classes:
784 //
785 // vtkm::exec::internal::TaskSingular
786 // vtkm::exec::internal::TaskTiling1D
787 // vtkm::exec::internal::TaskTiling3D
788 auto task = TaskTypes::MakeTask(this->Worklet, invocation, range);
789 Algorithm::ScheduleTask(task, range);
790 }
791 };
792 }
793 }
794 } // namespace vtkm::worklet::internal
795
796 #endif //vtk_m_worklet_internal_DispatcherBase_h
797