1 //============================================================================
2 //  Copyright (c) Kitware, Inc.
3 //  All rights reserved.
4 //  See LICENSE.txt for details.
5 //
6 //  This software is distributed WITHOUT ANY WARRANTY; without even
7 //  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
8 //  PURPOSE.  See the above copyright notice for more information.
9 //============================================================================
10 #ifndef vtk_m_worklet_internal_DispatcherBase_h
11 #define vtk_m_worklet_internal_DispatcherBase_h
12 
13 #include <vtkm/StaticAssert.h>
14 
15 #include <vtkm/internal/FunctionInterface.h>
16 #include <vtkm/internal/Invocation.h>
17 
18 #include <vtkm/cont/CastAndCall.h>
19 #include <vtkm/cont/ErrorBadType.h>
20 #include <vtkm/cont/Logging.h>
21 #include <vtkm/cont/TryExecute.h>
22 
23 #include <vtkm/cont/arg/ControlSignatureTagBase.h>
24 #include <vtkm/cont/arg/Transport.h>
25 #include <vtkm/cont/arg/TypeCheck.h>
26 
27 #include <vtkm/exec/arg/ExecutionSignatureTagBase.h>
28 
29 #include <vtkm/internal/brigand.hpp>
30 
31 #include <vtkm/internal/DecayHelpers.h>
32 
33 #include <vtkm/worklet/internal/WorkletBase.h>
34 
35 #include <vtkmstd/is_trivial.h>
36 
37 #include <sstream>
38 
39 namespace vtkm
40 {
41 namespace cont
42 {
43 
44 // Forward declaration.
45 template <typename CellSetList>
46 class DynamicCellSetBase;
47 }
48 }
49 
50 namespace vtkm
51 {
52 namespace worklet
53 {
54 namespace internal
55 {
56 
57 template <typename Domain>
58 inline auto SchedulingRange(const Domain& inputDomain) -> decltype(inputDomain.GetNumberOfValues())
59 {
60   return inputDomain.GetNumberOfValues();
61 }
62 
63 template <typename Domain>
64 inline auto SchedulingRange(const Domain* const inputDomain)
65   -> decltype(inputDomain->GetNumberOfValues())
66 {
67   return inputDomain->GetNumberOfValues();
68 }
69 
70 template <typename Domain, typename SchedulingRangeType>
71 inline auto SchedulingRange(const Domain& inputDomain, SchedulingRangeType type)
72   -> decltype(inputDomain.GetSchedulingRange(type))
73 {
74   return inputDomain.GetSchedulingRange(type);
75 }
76 
77 template <typename Domain, typename SchedulingRangeType>
78 inline auto SchedulingRange(const Domain* const inputDomain, SchedulingRangeType type)
79   -> decltype(inputDomain->GetSchedulingRange(type))
80 {
81   return inputDomain->GetSchedulingRange(type);
82 }
83 
84 namespace detail
85 {
86 
87 // This code is actually taking an error found at compile-time and not
88 // reporting it until run-time. This seems strange at first, but this
89 // behavior is actually important. With dynamic arrays and similar dynamic
90 // classes, there may be types that are technically possible (such as using a
91 // vector where a scalar is expected) but in reality never happen. Thus, for
92 // these unsupported combinations we just silently halt the compiler from
93 // attempting to create code for these errant conditions and throw a run-time
94 // error if one every tries to create one.
PrintFailureMessage(int index)95 inline void PrintFailureMessage(int index)
96 {
97   std::stringstream message;
98   message << "Encountered bad type for parameter " << index
99           << " when calling Invoke on a dispatcher.";
100   throw vtkm::cont::ErrorBadType(message.str());
101 }
102 
PrintNullPtrMessage(int index,int mode)103 inline void PrintNullPtrMessage(int index, int mode)
104 {
105   std::stringstream message;
106   if (mode == 0)
107   {
108     message << "Encountered nullptr for parameter " << index;
109   }
110   else
111   {
112     message << "Encountered nullptr for " << index << " from last parameter ";
113   }
114   message << " when calling Invoke on a dispatcher.";
115   throw vtkm::cont::ErrorBadValue(message.str());
116 }
117 
118 template <typename T>
119 inline void not_nullptr(T* ptr, int index, int mode = 0)
120 {
121   if (!ptr)
122   {
123     PrintNullPtrMessage(index, mode);
124   }
125 }
126 template <typename T>
127 inline void not_nullptr(T&&, int, int mode = 0)
128 {
129   (void)mode;
130 }
131 
132 template <typename T>
as_ref(T * ptr)133 inline T& as_ref(T* ptr)
134 {
135   return *ptr;
136 }
137 template <typename T>
as_ref(T && t)138 inline T&& as_ref(T&& t)
139 {
140   return std::forward<T>(t);
141 }
142 
143 
144 template <typename T, bool noError>
145 struct ReportTypeOnError;
146 template <typename T>
147 struct ReportTypeOnError<T, true> : std::true_type
148 {
149 };
150 
151 template <int Value, bool noError>
152 struct ReportValueOnError;
153 template <int Value>
154 struct ReportValueOnError<Value, true> : std::true_type
155 {
156 };
157 
158 // Is designed as a brigand fold operation.
159 template <typename Type, typename State>
160 struct DetermineIfHasDynamicParameter
161 {
162   using T = vtkm::internal::remove_pointer_and_decay<Type>;
163   using DynamicTag = typename vtkm::cont::internal::DynamicTransformTraits<T>::DynamicTag;
164   using isDynamic =
165     typename std::is_same<DynamicTag, vtkm::cont::internal::DynamicTransformTagCastAndCall>::type;
166 
167   using type = std::integral_constant<bool, (State::value || isDynamic::value)>;
168 };
169 
170 
171 // Is designed as a brigand fold operation.
172 template <typename WorkletType>
173 struct DetermineHasCorrectParameters
174 {
175   template <typename Type, typename State, typename SigTypes>
176   struct Functor
177   {
178     //T is the type of the Param at the current index
179     //State if the index to use to fetch the control signature tag
180     using ControlSignatureTag = typename brigand::at_c<SigTypes, State::value>;
181     using TypeCheckTag = typename ControlSignatureTag::TypeCheckTag;
182 
183     using T = typename std::remove_pointer<Type>::type;
184     static constexpr bool isCorrect = vtkm::cont::arg::TypeCheck<TypeCheckTag, T>::value;
185 
186     // If you get an error on the line below, that means that your code has called the
187     // Invoke method on a dispatcher, and one of the arguments of the Invoke is the wrong
188     // type. Each argument of Invoke corresponds to a tag in the arguments of the
189     // ControlSignature of the worklet. If there is a mismatch, then you get an error here
190     // (instead of where you called the dispatcher). For example, if the worklet has a
191     // control signature as ControlSignature(CellSetIn, ...) and the first argument passed
192     // to Invoke is an ArrayHandle, you will get an error here because you cannot use an
193     // ArrayHandle in place of a CellSetIn argument. (You need to use a CellSet.) See a few
194     // lines later for some diagnostics to help you trace where the error occurred.
195     VTKM_READ_THE_SOURCE_CODE_FOR_HELP(isCorrect);
196 
197     // If you are getting the error described above, the following lines will give you some
198     // diagnostics (in the form of compile errors). Each one will result in a compile error
199     // reporting an undefined type for ReportTypeOnError (or ReportValueOnError). What we are
200     // really reporting is the first template argument, which is one of the types or values that
201     // should help pinpoint where the error is. The comment for static_assert provides the
202     // type/value being reported. (Note that some compilers report better types than others. If
203     // your compiler is giving unhelpful types like "T" or "WorkletType", you may need to try a
204     // different compiler.)
205     static_assert(ReportTypeOnError<T, isCorrect>::value, "Type passed to Invoke");
206     static_assert(ReportTypeOnError<WorkletType, isCorrect>::value, "Worklet being invoked.");
207     static_assert(ReportValueOnError<State::value, isCorrect>::value, "Index of Invoke parameter");
208     static_assert(ReportTypeOnError<TypeCheckTag, isCorrect>::value, "Type check tag used");
209 
210     // This final static_assert gives a human-readable error message. Ideally, this would be
211     // placed first, but some compilers will suppress further errors when a static_assert
212     // fails, so you would not see the other diagnostic error messages.
213     static_assert(isCorrect,
214                   "The type of one of the arguments to the dispatcher's Invoke method is "
215                   "incompatible with the corresponding tag in the worklet's ControlSignature.");
216 
217     using type = std::integral_constant<std::size_t, State::value + 1>;
218   };
219 };
220 
221 // Checks that an argument in a ControlSignature is a valid control signature
222 // tag. Causes a compile error otherwise.
223 struct DispatcherBaseControlSignatureTagCheck
224 {
225   template <typename ControlSignatureTag, vtkm::IdComponent Index>
226   struct ReturnType
227   {
228     // If you get a compile error here, it means there is something that is
229     // not a valid control signature tag in a worklet's ControlSignature.
230     VTKM_IS_CONTROL_SIGNATURE_TAG(ControlSignatureTag);
231     using type = ControlSignatureTag;
232   };
233 };
234 
235 // Checks that an argument in a ExecutionSignature is a valid execution
236 // signature tag. Causes a compile error otherwise.
237 struct DispatcherBaseExecutionSignatureTagCheck
238 {
239   template <typename ExecutionSignatureTag, vtkm::IdComponent Index>
240   struct ReturnType
241   {
242     // If you get a compile error here, it means there is something that is not
243     // a valid execution signature tag in a worklet's ExecutionSignature.
244     VTKM_IS_EXECUTION_SIGNATURE_TAG(ExecutionSignatureTag);
245     using type = ExecutionSignatureTag;
246   };
247 };
248 
249 struct DispatcherBaseTryExecuteFunctor
250 {
251   template <typename Device, typename DispatcherBaseType, typename Invocation, typename RangeType>
252   VTKM_CONT bool operator()(Device device,
253                             const DispatcherBaseType* self,
254                             Invocation& invocation,
255                             const RangeType& dimensions)
256   {
257     auto outputRange = self->Scatter.GetOutputRange(dimensions);
258     self->InvokeTransportParameters(
259       invocation, dimensions, outputRange, self->Mask.GetThreadRange(outputRange), device);
260     return true;
261   }
262 };
263 
264 // A look up helper used by DispatcherBaseTransportFunctor to determine
265 //the types independent of the device we are templated on.
266 template <typename ControlInterface, vtkm::IdComponent Index>
267 struct DispatcherBaseTransportInvokeTypes
268 {
269   //Moved out of DispatcherBaseTransportFunctor to reduce code generation
270   using ControlSignatureTag = typename ControlInterface::template ParameterType<Index>::type;
271   using TransportTag = typename ControlSignatureTag::TransportTag;
272 };
273 
274 VTKM_CONT
275 inline vtkm::Id FlatRange(vtkm::Id range)
276 {
277   return range;
278 }
279 
280 VTKM_CONT
281 inline vtkm::Id FlatRange(const vtkm::Id3& range)
282 {
283   return range[0] * range[1] * range[2];
284 }
285 
286 // A functor used in a StaticCast of a FunctionInterface to transport arguments
287 // from the control environment to the execution environment.
288 template <typename ControlInterface, typename InputDomainType, typename Device>
289 struct DispatcherBaseTransportFunctor
290 {
291   const InputDomainType& InputDomain; // Warning: this is a reference
292   vtkm::Id InputRange;
293   vtkm::Id OutputRange;
294   vtkm::cont::Token& Token; // Warning: this is a reference
295 
296   // TODO: We need to think harder about how scheduling on 3D arrays works.
297   // Chances are we need to allow the transport for each argument to manage
298   // 3D indices (for example, allocate a 3D array instead of a 1D array).
299   // But for now, just treat all transports as 1D arrays.
300   template <typename InputRangeType, typename OutputRangeType>
301   VTKM_CONT DispatcherBaseTransportFunctor(const InputDomainType& inputDomain,
302                                            const InputRangeType& inputRange,
303                                            const OutputRangeType& outputRange,
304                                            vtkm::cont::Token& token)
305     : InputDomain(inputDomain)
306     , InputRange(FlatRange(inputRange))
307     , OutputRange(FlatRange(outputRange))
308     , Token(token)
309   {
310   }
311 
312 
313   template <typename ControlParameter, vtkm::IdComponent Index>
314   struct ReturnType
315   {
316     using TransportTag =
317       typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
318     using T = vtkm::internal::remove_pointer_and_decay<ControlParameter>;
319     using TransportType = typename vtkm::cont::arg::Transport<TransportTag, T, Device>;
320     using type = typename std::decay<typename TransportType::ExecObjectType>::type;
321 
322     // If you get a compile error here, it means that an execution object type is not
323     // trivially copyable. This is strictly disallowed. All execution objects must be
324     // trivially copyable so that the can be memcpy-ed between host and devices.
325     // Note that it is still legal for execution objects to have pointers or other
326     // references to resources on a particular device. It is up to the generating code
327     // to ensure that all referenced resources are valid on the target device.
328     VTKM_IS_TRIVIALLY_COPYABLE(type);
329   };
330 
331   template <typename ControlParameter, vtkm::IdComponent Index>
332   VTKM_CONT typename ReturnType<ControlParameter, Index>::type operator()(
333     ControlParameter&& invokeData,
334     vtkm::internal::IndexTag<Index>) const
335   {
336     using TransportTag =
337       typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
338     using T = vtkm::internal::remove_pointer_and_decay<ControlParameter>;
339     vtkm::cont::arg::Transport<TransportTag, T, Device> transport;
340 
341     not_nullptr(invokeData, Index);
342     return transport(as_ref(invokeData),
343                      as_ref(this->InputDomain),
344                      this->InputRange,
345                      this->OutputRange,
346                      this->Token);
347   }
348 
349 
350 
351 private:
352   void operator=(const DispatcherBaseTransportFunctor&) = delete;
353 };
354 
355 //forward declares
356 template <std::size_t LeftToProcess>
357 struct for_each_dynamic_arg;
358 
359 template <std::size_t LeftToProcess, typename TypeCheckTag>
360 struct convert_arg_wrapper
361 {
362   template <typename T, typename... Args>
363   void operator()(T&& t, Args&&... args) const
364   {
365     using Type = typename std::decay<T>::type;
366     using valid =
367       std::integral_constant<bool, vtkm::cont::arg::TypeCheck<TypeCheckTag, Type>::value>;
368     this->WillContinue(valid(), std::forward<T>(t), std::forward<Args>(args)...);
369   }
370   template <typename T, typename... Args>
371   void WillContinue(std::true_type, T&& t, Args&&... args) const
372   {
373     for_each_dynamic_arg<LeftToProcess - 1>()(std::forward<Args>(args)..., std::forward<T>(t));
374   }
375   template <typename... Args>
376   void WillContinue(std::false_type, Args&&...) const
377   {
378     vtkm::worklet::internal::detail::PrintFailureMessage(LeftToProcess);
379   }
380 };
381 
382 template <std::size_t LeftToProcess,
383           typename T,
384           typename ContParams,
385           typename Trampoline,
386           typename... Args>
387 inline void convert_arg(vtkm::cont::internal::DynamicTransformTagStatic,
388                         T&& t,
389                         const ContParams&,
390                         const Trampoline& trampoline,
391                         Args&&... args)
392 { //This is a static array, so just push it to the back
393   using popped_sig = brigand::pop_front<ContParams>;
394   for_each_dynamic_arg<LeftToProcess - 1>()(
395     trampoline, popped_sig(), std::forward<Args>(args)..., std::forward<T>(t));
396 }
397 
398 template <std::size_t LeftToProcess,
399           typename T,
400           typename ContParams,
401           typename Trampoline,
402           typename... Args>
403 inline void convert_arg(vtkm::cont::internal::DynamicTransformTagCastAndCall,
404                         T&& t,
405                         const ContParams&,
406                         const Trampoline& trampoline,
407                         Args&&... args)
408 { //This is something dynamic so cast and call
409   using tag_check = typename brigand::at_c<ContParams, 0>::TypeCheckTag;
410   using popped_sig = brigand::pop_front<ContParams>;
411 
412   not_nullptr(t, LeftToProcess, 1);
413   vtkm::cont::CastAndCall(as_ref(t),
414                           convert_arg_wrapper<LeftToProcess, tag_check>(),
415                           trampoline,
416                           popped_sig(),
417                           std::forward<Args>(args)...);
418 }
419 
420 template <std::size_t LeftToProcess>
421 struct for_each_dynamic_arg
422 {
423   template <typename Trampoline, typename ContParams, typename T, typename... Args>
424   void operator()(const Trampoline& trampoline, ContParams&& sig, T&& t, Args&&... args) const
425   {
426     //Determine that state of T when it is either a `cons&` or a `* const&`
427     using Type = vtkm::internal::remove_pointer_and_decay<T>;
428     using tag = typename vtkm::cont::internal::DynamicTransformTraits<Type>::DynamicTag;
429     //convert the first item to a known type
430     convert_arg<LeftToProcess>(
431       tag(), std::forward<T>(t), sig, trampoline, std::forward<Args>(args)...);
432   }
433 };
434 
435 template <>
436 struct for_each_dynamic_arg<0>
437 {
438   template <typename Trampoline, typename ContParams, typename... Args>
439   void operator()(const Trampoline& trampoline, ContParams&&, Args&&... args) const
440   {
441     trampoline.StartInvokeDynamic(std::false_type(), std::forward<Args>(args)...);
442   }
443 };
444 
445 template <typename Trampoline, typename ContParams, typename... Args>
446 inline void deduce(Trampoline&& trampoline, ContParams&& sig, Args&&... args)
447 {
448   for_each_dynamic_arg<sizeof...(Args)>()(std::forward<Trampoline>(trampoline), sig, args...);
449 }
450 
451 } // namespace detail
452 
453 /// This is a help struct to detect out of bound placeholders defined in the
454 /// execution signature at compile time
455 template <vtkm::IdComponent MaxIndexAllowed>
456 struct PlaceholderValidator
457 {
458   PlaceholderValidator() {}
459 
460   // An overload operator to detect possible out of bound placeholder
461   template <int N>
462   void operator()(brigand::type_<vtkm::placeholders::Arg<N>>) const
463   {
464     static_assert(N <= MaxIndexAllowed,
465                   "An argument in the execution signature"
466                   " (usually _2, _3, _4, etc.) refers to a control signature argument that"
467                   " does not exist. For example, you will get this error if you have _3 (or"
468                   " _4 or _5 or so on) as one of the execution signature arguments, but you"
469                   " have fewer than 3 (or 4 or 5 or so on) arguments in the control signature.");
470   }
471 
472   template <typename DerivedType>
473   void operator()(brigand::type_<DerivedType>) const
474   {
475   }
476 };
477 
478 /// Base class for all dispatcher classes. Every worklet type should have its
479 /// own dispatcher.
480 ///
481 template <typename DerivedClass, typename WorkletType, typename BaseWorkletType>
482 class DispatcherBase
483 {
484 private:
485   using MyType = DispatcherBase<DerivedClass, WorkletType, BaseWorkletType>;
486 
487   friend struct detail::for_each_dynamic_arg<0>;
488 
489 protected:
490   using ControlInterface =
491     vtkm::internal::FunctionInterface<typename WorkletType::ControlSignature>;
492 
493   // We go through the GetExecSig as that generates a default ExecutionSignature
494   // if one doesn't exist on the worklet
495   using ExecutionSignature =
496     typename vtkm::placeholders::GetExecSig<WorkletType>::ExecutionSignature;
497   using ExecutionInterface = vtkm::internal::FunctionInterface<ExecutionSignature>;
498 
499   static constexpr vtkm::IdComponent NUM_INVOKE_PARAMS = ControlInterface::ARITY;
500 
501 private:
502   // We don't really need these types, but declaring them checks the arguments
503   // of the control and execution signatures.
504   using ControlSignatureCheck = typename ControlInterface::template StaticTransformType<
505     detail::DispatcherBaseControlSignatureTagCheck>::type;
506   using ExecutionSignatureCheck = typename ExecutionInterface::template StaticTransformType<
507     detail::DispatcherBaseExecutionSignatureTagCheck>::type;
508 
509   template <typename... Args>
510   VTKM_CONT void StartInvoke(Args&&... args) const
511   {
512     using ParameterInterface =
513       vtkm::internal::FunctionInterface<void(vtkm::internal::remove_cvref<Args>...)>;
514 
515     VTKM_STATIC_ASSERT_MSG(ParameterInterface::ARITY == NUM_INVOKE_PARAMS,
516                            "Dispatcher Invoke called with wrong number of arguments.");
517 
518     static_assert(
519       std::is_base_of<BaseWorkletType, WorkletType>::value,
520       "The worklet being scheduled by this dispatcher doesn't match the type of the dispatcher");
521 
522     // Check if the placeholders defined in the execution environment exceed the max bound
523     // defined in the control environment by throwing a nice compile error.
524     using ComponentSig = typename ExecutionInterface::ComponentSig;
525     brigand::for_each<ComponentSig>(PlaceholderValidator<NUM_INVOKE_PARAMS>{});
526 
527     //We need to determine if we have the need to do any dynamic
528     //transforms. This is fairly simple of a query. We just need to check
529     //everything in the FunctionInterface and see if any of them have the
530     //proper dynamic trait. Doing this, allows us to generate zero dynamic
531     //check & convert code when we already know all the types. This results
532     //in smaller executables and libraries.
533     using ParamTypes = typename ParameterInterface::ParameterSig;
534     using HasDynamicTypes =
535       brigand::fold<ParamTypes,
536                     std::false_type,
537                     detail::DetermineIfHasDynamicParameter<brigand::_element, brigand::_state>>;
538 
539     this->StartInvokeDynamic(HasDynamicTypes(), std::forward<Args>(args)...);
540   }
541 
542   template <typename... Args>
543   VTKM_CONT void StartInvokeDynamic(std::true_type, Args&&... args) const
544   {
545     // As we do the dynamic transform, we are also going to check the static
546     // type against the TypeCheckTag in the ControlSignature tags. To do this,
547     // the check needs access to both the parameter (in the parameters
548     // argument) and the ControlSignature tags (in the ControlInterface type).
549     using ContParamsInfo =
550       vtkm::internal::detail::FunctionSigInfo<typename WorkletType::ControlSignature>;
551     typename ContParamsInfo::Parameters parameters;
552     detail::deduce(*this, parameters, std::forward<Args>(args)...);
553   }
554 
555   template <typename... Args>
556   VTKM_CONT void StartInvokeDynamic(std::false_type, Args&&... args) const
557   {
558     using ParameterInterface =
559       vtkm::internal::FunctionInterface<void(vtkm::internal::remove_cvref<Args>...)>;
560 
561     //Nothing requires a conversion from dynamic to static types, so
562     //next we need to verify that each argument's type is correct. If not
563     //we need to throw a nice compile time error
564     using ParamTypes = typename ParameterInterface::ParameterSig;
565     using ContSigTypes = typename vtkm::internal::detail::FunctionSigInfo<
566       typename WorkletType::ControlSignature>::Parameters;
567 
568     //isAllValid will throw a compile error if everything doesn't match
569     using isAllValid = brigand::fold<
570       ParamTypes,
571       std::integral_constant<std::size_t, 0>,
572       typename detail::DetermineHasCorrectParameters<WorkletType>::
573         template Functor<brigand::_element, brigand::_state, brigand::pin<ContSigTypes>>>;
574 
575     //this warning exists so that we don't get a warning from not using isAllValid
576     using expectedLen = std::integral_constant<std::size_t, sizeof...(Args)>;
577     static_assert(isAllValid::value == expectedLen::value,
578                   "All arguments failed the TypeCheck pass");
579 
580     auto fi =
581       vtkm::internal::make_FunctionInterface<void, vtkm::internal::remove_cvref<Args>...>(args...);
582     auto ivc = vtkm::internal::Invocation<ParameterInterface,
583                                           ControlInterface,
584                                           ExecutionInterface,
585                                           WorkletType::InputDomain::INDEX,
586                                           vtkm::internal::NullType,
587                                           vtkm::internal::NullType>(
588       fi, vtkm::internal::NullType{}, vtkm::internal::NullType{});
589     static_cast<const DerivedClass*>(this)->DoInvoke(ivc);
590   }
591 
592 public:
593   //@{
594   /// Setting the device ID will force the execute to happen on a particular device. If no device
595   /// is specified (or the device ID is set to any), then a device will automatically be chosen
596   /// based on the runtime device tracker.
597   ///
598   VTKM_CONT
599   void SetDevice(vtkm::cont::DeviceAdapterId device) { this->Device = device; }
600 
601   VTKM_CONT vtkm::cont::DeviceAdapterId GetDevice() const { return this->Device; }
602   //@}
603 
604   using ScatterType = typename WorkletType::ScatterType;
605   using MaskType = typename WorkletType::MaskType;
606 
607   template <typename... Args>
608   VTKM_CONT void Invoke(Args&&... args) const
609   {
610     VTKM_LOG_SCOPE(vtkm::cont::LogLevel::Perf,
611                    "Invoking Worklet: '%s'",
612                    vtkm::cont::TypeToString<DerivedClass>().c_str());
613     this->StartInvoke(std::forward<Args>(args)...);
614   }
615 
616 protected:
617   // If you get a compile error here about there being no appropriate constructor for ScatterType
618   // or MapType, then that probably means that the worklet you are trying to execute has defined a
619   // custom ScatterType or MaskType and that you need to create one (because there is no default
620   // way to construct the scatter or mask).
621   VTKM_CONT
622   DispatcherBase(const WorkletType& worklet = WorkletType(),
623                  const ScatterType& scatter = ScatterType(),
624                  const MaskType& mask = MaskType())
625     : Worklet(worklet)
626     , Scatter(scatter)
627     , Mask(mask)
628     , Device(vtkm::cont::DeviceAdapterTagAny())
629   {
630   }
631 
632   // If you get a compile error here about there being no appropriate constructor for MaskType,
633   // then that probably means that the worklet you are trying to execute has defined a custom
634   // MaskType and that you need to create one (because there is no default way to construct the
635   // mask).
636   VTKM_CONT
637   DispatcherBase(const ScatterType& scatter, const MaskType& mask = MaskType())
638     : Worklet(WorkletType())
639     , Scatter(scatter)
640     , Mask(mask)
641     , Device(vtkm::cont::DeviceAdapterTagAny())
642   {
643   }
644 
645   // If you get a compile error here about there being no appropriate constructor for ScatterType,
646   // then that probably means that the worklet you are trying to execute has defined a custom
647   // ScatterType and that you need to create one (because there is no default way to construct the
648   // scatter).
649   VTKM_CONT
650   DispatcherBase(const WorkletType& worklet,
651                  const MaskType& mask,
652                  const ScatterType& scatter = ScatterType())
653     : Worklet(worklet)
654     , Scatter(scatter)
655     , Mask(mask)
656     , Device(vtkm::cont::DeviceAdapterTagAny())
657   {
658   }
659 
660   // If you get a compile error here about there being no appropriate constructor for ScatterType,
661   // then that probably means that the worklet you are trying to execute has defined a custom
662   // ScatterType and that you need to create one (because there is no default way to construct the
663   // scatter).
664   VTKM_CONT
665   DispatcherBase(const MaskType& mask, const ScatterType& scatter = ScatterType())
666     : Worklet(WorkletType())
667     , Scatter(scatter)
668     , Mask(mask)
669     , Device(vtkm::cont::DeviceAdapterTagAny())
670   {
671   }
672 
673   friend struct internal::detail::DispatcherBaseTryExecuteFunctor;
674 
675   template <typename Invocation>
676   VTKM_CONT void BasicInvoke(Invocation& invocation, vtkm::Id numInstances) const
677   {
678     bool success =
679       vtkm::cont::TryExecuteOnDevice(this->Device,
680                                      internal::detail::DispatcherBaseTryExecuteFunctor(),
681                                      this,
682                                      invocation,
683                                      numInstances);
684     if (!success)
685     {
686       throw vtkm::cont::ErrorExecution("Failed to execute worklet on any device.");
687     }
688   }
689 
690   template <typename Invocation>
691   VTKM_CONT void BasicInvoke(Invocation& invocation, vtkm::Id2 dimensions) const
692   {
693     this->BasicInvoke(invocation, vtkm::Id3(dimensions[0], dimensions[1], 1));
694   }
695 
696   template <typename Invocation>
697   VTKM_CONT void BasicInvoke(Invocation& invocation, vtkm::Id3 dimensions) const
698   {
699     bool success =
700       vtkm::cont::TryExecuteOnDevice(this->Device,
701                                      internal::detail::DispatcherBaseTryExecuteFunctor(),
702                                      this,
703                                      invocation,
704                                      dimensions);
705     if (!success)
706     {
707       throw vtkm::cont::ErrorExecution("Failed to execute worklet on any device.");
708     }
709   }
710 
711   WorkletType Worklet;
712   ScatterType Scatter;
713   MaskType Mask;
714 
715 private:
716   // Dispatchers cannot be copied
717   DispatcherBase(const MyType&) = delete;
718   void operator=(const MyType&) = delete;
719 
720   vtkm::cont::DeviceAdapterId Device;
721 
722   template <typename Invocation,
723             typename InputRangeType,
724             typename OutputRangeType,
725             typename ThreadRangeType,
726             typename DeviceAdapter>
727   VTKM_CONT void InvokeTransportParameters(Invocation& invocation,
728                                            const InputRangeType& inputRange,
729                                            OutputRangeType&& outputRange,
730                                            ThreadRangeType&& threadRange,
731                                            DeviceAdapter device) const
732   {
733     // This token represents the scope of the execution objects. It should
734     // exist as long as things run on the device.
735     vtkm::cont::Token token;
736 
737     // The first step in invoking a worklet is to transport the arguments to
738     // the execution environment. The invocation object passed to this function
739     // contains the parameters passed to Invoke in the control environment. We
740     // will use the template magic in the FunctionInterface class to invoke the
741     // appropriate Transport class on each parameter and get a list of
742     // execution objects (corresponding to the arguments of the Invoke in the
743     // control environment) in a FunctionInterface. Specifically, we use a
744     // static transform of the FunctionInterface to call the transport on each
745     // argument and return the corresponding execution environment object.
746     using ParameterInterfaceType = typename Invocation::ParameterInterface;
747     ParameterInterfaceType& parameters = invocation.Parameters;
748 
749     using TransportFunctorType =
750       detail::DispatcherBaseTransportFunctor<typename Invocation::ControlInterface,
751                                              typename Invocation::InputDomainType,
752                                              DeviceAdapter>;
753     using ExecObjectParameters =
754       typename ParameterInterfaceType::template StaticTransformType<TransportFunctorType>::type;
755 
756     ExecObjectParameters execObjectParameters = parameters.StaticTransformCont(
757       TransportFunctorType(invocation.GetInputDomain(), inputRange, outputRange, token));
758 
759     // Get the arrays used for scattering input to output.
760     typename ScatterType::OutputToInputMapType outputToInputMap =
761       this->Scatter.GetOutputToInputMap(inputRange);
762     typename ScatterType::VisitArrayType visitArray = this->Scatter.GetVisitArray(inputRange);
763 
764     // Get the arrays used for masking output elements.
765     typename MaskType::ThreadToOutputMapType threadToOutputMap =
766       this->Mask.GetThreadToOutputMap(inputRange);
767 
768     // Replace the parameters in the invocation with the execution object and
769     // pass to next step of Invoke. Also add the scatter information.
770     vtkm::internal::Invocation<ExecObjectParameters,
771                                typename Invocation::ControlInterface,
772                                typename Invocation::ExecutionInterface,
773                                Invocation::InputDomainIndex,
774                                decltype(outputToInputMap.PrepareForInput(device, token)),
775                                decltype(visitArray.PrepareForInput(device, token)),
776                                decltype(threadToOutputMap.PrepareForInput(device, token)),
777                                DeviceAdapter>
778       changedInvocation(execObjectParameters,
779                         outputToInputMap.PrepareForInput(device, token),
780                         visitArray.PrepareForInput(device, token),
781                         threadToOutputMap.PrepareForInput(device, token));
782 
783     this->InvokeSchedule(changedInvocation, threadRange, device);
784   }
785 
786   template <typename Invocation, typename RangeType, typename DeviceAdapter>
787   VTKM_CONT void InvokeSchedule(const Invocation& invocation, RangeType range, DeviceAdapter) const
788   {
789     using Algorithm = vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
790     using TaskTypes = typename vtkm::cont::DeviceTaskTypes<DeviceAdapter>;
791 
792     // The TaskType class handles the magic of fetching values
793     // for each instance and calling the worklet's function.
794     // The TaskType will evaluate to one of the following classes:
795     //
796     // vtkm::exec::internal::TaskSingular
797     // vtkm::exec::internal::TaskTiling1D
798     // vtkm::exec::internal::TaskTiling3D
799     auto task = TaskTypes::MakeTask(this->Worklet, invocation, range);
800     Algorithm::ScheduleTask(task, range);
801   }
802 };
803 }
804 }
805 } // namespace vtkm::worklet::internal
806 
807 #endif //vtk_m_worklet_internal_DispatcherBase_h
808