1 //============================================================================
2 //  Copyright (c) Kitware, Inc.
3 //  All rights reserved.
4 //  See LICENSE.txt for details.
5 //  This software is distributed WITHOUT ANY WARRANTY; without even
6 //  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
7 //  PURPOSE.  See the above copyright notice for more information.
8 //
9 //  Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
10 //  Copyright 2014 UT-Battelle, LLC.
11 //  Copyright 2014 Los Alamos National Security.
12 //
13 //  Under the terms of Contract DE-NA0003525 with NTESS,
14 //  the U.S. Government retains certain rights in this software.
15 //
16 //  Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
17 //  Laboratory (LANL), the U.S. Government retains certain rights in
18 //  this software.
19 //============================================================================
20 
21 #ifndef vtk_m_worklet_KdTree3DConstruction_h
22 #define vtk_m_worklet_KdTree3DConstruction_h
23 
24 #include <vtkm/Math.h>
25 #include <vtkm/cont/ArrayHandle.h>
26 #include <vtkm/cont/ArrayHandleCounting.h>
27 #include <vtkm/cont/ArrayHandleReverse.h>
28 #include <vtkm/cont/DeviceAdapter.h>
29 #include <vtkm/cont/DeviceAdapterAlgorithm.h>
30 #include <vtkm/cont/arg/ControlSignatureTagBase.h>
31 #include <vtkm/cont/serial/DeviceAdapterSerial.h>
32 #include <vtkm/cont/testing/Testing.h>
33 #include <vtkm/worklet/DispatcherMapField.h>
34 #include <vtkm/worklet/WorkletMapField.h>
35 #include <vtkm/worklet/internal/DispatcherBase.h>
36 #include <vtkm/worklet/internal/WorkletBase.h>
37 
38 namespace vtkm
39 {
40 namespace worklet
41 {
42 namespace spatialstructure
43 {
44 
45 class KdTree3DConstruction
46 {
47 public:
48   ////////// General WORKLET for Kd-tree  //////
49   class ComputeFlag : public vtkm::worklet::WorkletMapField
50   {
51   public:
52     using ControlSignature = void(FieldIn<> rank, FieldIn<> pointCountInSeg, FieldOut<> flag);
53     using ExecutionSignature = void(_1, _2, _3);
54 
55     VTKM_CONT
ComputeFlag()56     ComputeFlag() {}
57 
58     template <typename T>
operator()59     VTKM_EXEC void operator()(const T& rank, const T& pointCountInSeg, T& flag) const
60     {
61       if (static_cast<float>(rank) >= static_cast<float>(pointCountInSeg) / 2.0f)
62         flag = 1; //right subtree
63       else
64         flag = 0; //left subtree
65     }
66   };
67 
68   class InverseArray : public vtkm::worklet::WorkletMapField
69   { //only for 0/1 array
70   public:
71     using ControlSignature = void(FieldIn<> in, FieldOut<> out);
72     using ExecutionSignature = void(_1, _2);
73 
74     VTKM_CONT
InverseArray()75     InverseArray() {}
76 
77     template <typename T>
operator()78     VTKM_EXEC void operator()(const T& in, T& out) const
79     {
80       if (in == 0)
81         out = 1;
82       else
83         out = 0;
84     }
85   };
86 
87   class SegmentedSplitTransform : public vtkm::worklet::WorkletMapField
88   {
89   public:
90     using ControlSignature =
91       void(FieldIn<> B, FieldIn<> D, FieldIn<> F, FieldIn<> G, FieldIn<> H, FieldOut<> I);
92     using ExecutionSignature = void(_1, _2, _3, _4, _5, _6);
93 
94     VTKM_CONT
SegmentedSplitTransform()95     SegmentedSplitTransform() {}
96 
97     template <typename T>
operator()98     VTKM_EXEC void operator()(const T& B, const T& D, const T& F, const T& G, const T& H, T& I)
99       const
100     {
101       if (B == 1)
102       {
103         I = F + H + D;
104       }
105       else
106       {
107         I = F + G - 1;
108       }
109     }
110   };
111 
112   class ScatterArray : public vtkm::worklet::WorkletMapField
113   {
114   public:
115     using ControlSignature = void(FieldIn<> in, FieldIn<> index, WholeArrayOut<> out);
116     using ExecutionSignature = void(_1, _2, _3);
117 
118     VTKM_CONT
ScatterArray()119     ScatterArray() {}
120 
121     template <typename T, typename OutputArrayPortalType>
operator()122     VTKM_EXEC void operator()(const T& in,
123                               const T& index,
124                               const OutputArrayPortalType& outputPortal) const
125     {
126       outputPortal.Set(index, in);
127     }
128   };
129 
130   class NewSegmentId : public vtkm::worklet::WorkletMapField
131   {
132   public:
133     using ControlSignature = void(FieldIn<> inSegmentId, FieldIn<> flag, FieldOut<> outSegmentId);
134     using ExecutionSignature = void(_1, _2, _3);
135 
136     VTKM_CONT
NewSegmentId()137     NewSegmentId() {}
138 
139     template <typename T>
operator()140     VTKM_EXEC void operator()(const T& oldSegId, const T& flag, T& newSegId) const
141     {
142       if (flag == 0)
143         newSegId = oldSegId * 2;
144       else
145         newSegId = oldSegId * 2 + 1;
146     }
147   };
148 
149   class SaveSplitPointId : public vtkm::worklet::WorkletMapField
150   {
151   public:
152     using ControlSignature = void(FieldIn<> pointId,
153                                   FieldIn<> flag,
154                                   FieldIn<> oldSplitPointId,
155                                   FieldOut<> newSplitPointId);
156     using ExecutionSignature = void(_1, _2, _3, _4);
157 
158     VTKM_CONT
SaveSplitPointId()159     SaveSplitPointId() {}
160 
161     template <typename T>
operator()162     VTKM_EXEC void operator()(const T& pointId,
163                               const T& flag,
164                               const T& oldSplitPointId,
165                               T& newSplitPointId) const
166     {
167       if (flag == 0) //do not change
168         newSplitPointId = oldSplitPointId;
169       else //split point id
170         newSplitPointId = pointId;
171     }
172   };
173 
174   class FindSplitPointId : public vtkm::worklet::WorkletMapField
175   {
176   public:
177     using ControlSignature = void(FieldIn<> pointId, FieldIn<> rank, FieldOut<> splitIdInsegment);
178     using ExecutionSignature = void(_1, _2, _3);
179 
180     VTKM_CONT
FindSplitPointId()181     FindSplitPointId() {}
182 
183     template <typename T>
operator()184     VTKM_EXEC void operator()(const T& pointId, const T& rank, T& splitIdInsegment) const
185     {
186       if (rank == 0) //do not change
187         splitIdInsegment = pointId;
188       else                     //split point id
189         splitIdInsegment = -1; //indicate this is not split point
190     }
191   };
192 
193   class ArrayAdd : public vtkm::worklet::WorkletMapField
194   {
195   public:
196     using ControlSignature = void(FieldIn<> inArray0, FieldIn<> inArray1, FieldOut<> outArray);
197     using ExecutionSignature = void(_1, _2, _3);
198 
199     VTKM_CONT
ArrayAdd()200     ArrayAdd() {}
201 
202     template <typename T>
operator()203     VTKM_EXEC void operator()(const T& in0, const T& in1, T& out) const
204     {
205       out = in0 + in1;
206     }
207   };
208 
209   class SeprateVec3AryHandle : public vtkm::worklet::WorkletMapField
210   {
211   public:
212     using ControlSignature = void(FieldIn<> inVec3,
213                                   FieldOut<> out0,
214                                   FieldOut<> out1,
215                                   FieldOut<> out2);
216     using ExecutionSignature = void(_1, _2, _3, _4);
217 
218     VTKM_CONT
SeprateVec3AryHandle()219     SeprateVec3AryHandle() {}
220 
221     template <typename T>
operator()222     VTKM_EXEC void operator()(const Vec<T, 3>& inVec3, T& out0, T& out1, T& out2) const
223     {
224       out0 = inVec3[0];
225       out1 = inVec3[1];
226       out2 = inVec3[2];
227     }
228   };
229 
230   ////////// General worklet WRAPPER for Kd-tree //////
231   template <typename T, class BinaryFunctor, typename DeviceAdapter>
ReverseScanInclusiveByKey(vtkm::cont::ArrayHandle<T> & keyHandle,vtkm::cont::ArrayHandle<T> & dataHandle,BinaryFunctor binary_functor,DeviceAdapter vtkmNotUsed (device))232   vtkm::cont::ArrayHandle<T> ReverseScanInclusiveByKey(vtkm::cont::ArrayHandle<T>& keyHandle,
233                                                        vtkm::cont::ArrayHandle<T>& dataHandle,
234                                                        BinaryFunctor binary_functor,
235                                                        DeviceAdapter vtkmNotUsed(device))
236   {
237     using Algorithm = typename vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
238 
239     vtkm::cont::ArrayHandle<T> resultHandle;
240 
241     auto reversedResultHandle = vtkm::cont::make_ArrayHandleReverse(resultHandle);
242 
243     Algorithm::ScanInclusiveByKey(vtkm::cont::make_ArrayHandleReverse(keyHandle),
244                                   vtkm::cont::make_ArrayHandleReverse(dataHandle),
245                                   reversedResultHandle,
246                                   binary_functor);
247 
248     return resultHandle;
249   }
250 
251   template <typename T, typename DeviceAdapter>
Inverse01ArrayWrapper(vtkm::cont::ArrayHandle<T> & inputHandle,DeviceAdapter vtkmNotUsed (device))252   vtkm::cont::ArrayHandle<T> Inverse01ArrayWrapper(vtkm::cont::ArrayHandle<T>& inputHandle,
253                                                    DeviceAdapter vtkmNotUsed(device))
254   {
255     vtkm::cont::ArrayHandle<T> InverseHandle;
256     InverseArray invWorklet;
257     vtkm::worklet::DispatcherMapField<InverseArray> inverseArrayDispatcher(invWorklet);
258     inverseArrayDispatcher.SetDevice(DeviceAdapter());
259     inverseArrayDispatcher.Invoke(inputHandle, InverseHandle);
260     return InverseHandle;
261   }
262 
263   template <typename T, typename DeviceAdapter>
ScatterArrayWrapper(vtkm::cont::ArrayHandle<T> & inputHandle,vtkm::cont::ArrayHandle<T> & indexHandle,DeviceAdapter vtkmNotUsed (device))264   vtkm::cont::ArrayHandle<T> ScatterArrayWrapper(vtkm::cont::ArrayHandle<T>& inputHandle,
265                                                  vtkm::cont::ArrayHandle<T>& indexHandle,
266                                                  DeviceAdapter vtkmNotUsed(device))
267   {
268     vtkm::cont::ArrayHandle<T> outputHandle;
269     outputHandle.Allocate(inputHandle.GetNumberOfValues());
270     ScatterArray scatterWorklet;
271     vtkm::worklet::DispatcherMapField<ScatterArray> scatterArrayDispatcher(scatterWorklet);
272     scatterArrayDispatcher.SetDevice(DeviceAdapter());
273     scatterArrayDispatcher.Invoke(inputHandle, indexHandle, outputHandle);
274     return outputHandle;
275   }
276 
277   template <typename T, typename DeviceAdapter>
NewKeyWrapper(vtkm::cont::ArrayHandle<T> & oldSegIdHandle,vtkm::cont::ArrayHandle<T> & flagHandle,DeviceAdapter vtkmNotUsed (device))278   vtkm::cont::ArrayHandle<T> NewKeyWrapper(vtkm::cont::ArrayHandle<T>& oldSegIdHandle,
279                                            vtkm::cont::ArrayHandle<T>& flagHandle,
280                                            DeviceAdapter vtkmNotUsed(device))
281   {
282     vtkm::cont::ArrayHandle<T> newSegIdHandle;
283     NewSegmentId newsegidWorklet;
284     vtkm::worklet::DispatcherMapField<NewSegmentId> newSegIdDispatcher(newsegidWorklet);
285     newSegIdDispatcher.SetDevice(DeviceAdapter());
286     newSegIdDispatcher.Invoke(oldSegIdHandle, flagHandle, newSegIdHandle);
287     return newSegIdHandle;
288   }
289 
290   template <typename T, typename DeviceAdapter>
SaveSplitPointIdWrapper(vtkm::cont::ArrayHandle<T> & pointIdHandle,vtkm::cont::ArrayHandle<T> & flagHandle,vtkm::cont::ArrayHandle<T> & rankHandle,vtkm::cont::ArrayHandle<T> & oldSplitIdHandle,DeviceAdapter device)291   vtkm::cont::ArrayHandle<T> SaveSplitPointIdWrapper(vtkm::cont::ArrayHandle<T>& pointIdHandle,
292                                                      vtkm::cont::ArrayHandle<T>& flagHandle,
293                                                      vtkm::cont::ArrayHandle<T>& rankHandle,
294                                                      vtkm::cont::ArrayHandle<T>& oldSplitIdHandle,
295                                                      DeviceAdapter device)
296   {
297     vtkm::cont::ArrayHandle<T> splitIdInSegmentHandle;
298     FindSplitPointId findSplitPointIdWorklet;
299     vtkm::worklet::DispatcherMapField<FindSplitPointId> findSplitPointIdWorkletDispatcher(
300       findSplitPointIdWorklet);
301     findSplitPointIdWorkletDispatcher.SetDevice(DeviceAdapter());
302     findSplitPointIdWorkletDispatcher.Invoke(pointIdHandle, rankHandle, splitIdInSegmentHandle);
303 
304     vtkm::cont::ArrayHandle<T> splitIdInSegmentByScanHandle =
305       ReverseScanInclusiveByKey(flagHandle, splitIdInSegmentHandle, vtkm::Maximum(), device);
306 
307     vtkm::cont::ArrayHandle<T> splitIdHandle;
308     SaveSplitPointId saveSplitPointIdWorklet;
309     vtkm::worklet::DispatcherMapField<SaveSplitPointId> saveSplitPointIdWorkletDispatcher(
310       saveSplitPointIdWorklet);
311     saveSplitPointIdWorkletDispatcher.SetDevice(DeviceAdapter());
312     saveSplitPointIdWorkletDispatcher.Invoke(
313       splitIdInSegmentByScanHandle, flagHandle, oldSplitIdHandle, splitIdHandle);
314 
315     return splitIdHandle;
316   }
317 
318   template <typename T, typename DeviceAdapter>
ArrayAddWrapper(vtkm::cont::ArrayHandle<T> & array0Handle,vtkm::cont::ArrayHandle<T> & array1Handle,DeviceAdapter vtkmNotUsed (device))319   vtkm::cont::ArrayHandle<T> ArrayAddWrapper(vtkm::cont::ArrayHandle<T>& array0Handle,
320                                              vtkm::cont::ArrayHandle<T>& array1Handle,
321                                              DeviceAdapter vtkmNotUsed(device))
322   {
323     vtkm::cont::ArrayHandle<T> resultHandle;
324     ArrayAdd arrayAddWorklet;
325     vtkm::worklet::DispatcherMapField<ArrayAdd> arrayAddDispatcher(arrayAddWorklet);
326     arrayAddDispatcher.SetDevice(DeviceAdapter());
327     arrayAddDispatcher.Invoke(array0Handle, array1Handle, resultHandle);
328     return resultHandle;
329   }
330 
331   ///////////////////////////////////////////////////
332   ////////General Kd tree function //////////////////
333   ///////////////////////////////////////////////////
334   template <typename T, typename DeviceAdapter>
ComputeFlagProcedure(vtkm::cont::ArrayHandle<T> & rankHandle,vtkm::cont::ArrayHandle<T> & segIdHandle,DeviceAdapter device)335   vtkm::cont::ArrayHandle<T> ComputeFlagProcedure(vtkm::cont::ArrayHandle<T>& rankHandle,
336                                                   vtkm::cont::ArrayHandle<T>& segIdHandle,
337                                                   DeviceAdapter device)
338   {
339     using Algorithm = typename vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
340 
341     vtkm::cont::ArrayHandle<T> segCountAryHandle;
342     {
343       vtkm::cont::ArrayHandle<T> tmpAryHandle;
344       vtkm::cont::ArrayHandleConstant<T> constHandle(1, rankHandle.GetNumberOfValues());
345       Algorithm::ScanInclusiveByKey(
346         segIdHandle, constHandle, tmpAryHandle, vtkm::Add()); //compute ttl segs in segment
347 
348       segCountAryHandle =
349         ReverseScanInclusiveByKey(segIdHandle, tmpAryHandle, vtkm::Maximum(), device);
350     }
351 
352     vtkm::cont::ArrayHandle<T> flagHandle;
353     vtkm::worklet::DispatcherMapField<ComputeFlag> computeFlagDispatcher;
354     computeFlagDispatcher.SetDevice(DeviceAdapter());
355     computeFlagDispatcher.Invoke(rankHandle, segCountAryHandle, flagHandle);
356 
357     return flagHandle;
358   }
359 
360   template <typename T, typename DeviceAdapter>
SegmentedSplitProcedure(vtkm::cont::ArrayHandle<T> & A_Handle,vtkm::cont::ArrayHandle<T> & B_Handle,vtkm::cont::ArrayHandle<T> & C_Handle,DeviceAdapter device)361   vtkm::cont::ArrayHandle<T> SegmentedSplitProcedure(vtkm::cont::ArrayHandle<T>& A_Handle,
362                                                      vtkm::cont::ArrayHandle<T>& B_Handle,
363                                                      vtkm::cont::ArrayHandle<T>& C_Handle,
364                                                      DeviceAdapter device)
365   {
366     using Algorithm = typename vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
367 
368     vtkm::cont::ArrayHandle<T> D_Handle;
369     T initValue = 0;
370     Algorithm::ScanExclusiveByKey(C_Handle, B_Handle, D_Handle, initValue, vtkm::Add());
371 
372     vtkm::cont::ArrayHandleCounting<T> Ecouting_Handle(0, 1, A_Handle.GetNumberOfValues());
373     vtkm::cont::ArrayHandle<T> E_Handle;
374     Algorithm::Copy(Ecouting_Handle, E_Handle);
375 
376     vtkm::cont::ArrayHandle<T> F_Handle;
377     Algorithm::ScanInclusiveByKey(C_Handle, E_Handle, F_Handle, vtkm::Minimum());
378 
379     vtkm::cont::ArrayHandle<T> InvB_Handle = Inverse01ArrayWrapper(B_Handle, device);
380     vtkm::cont::ArrayHandle<T> G_Handle;
381     Algorithm::ScanInclusiveByKey(C_Handle, InvB_Handle, G_Handle, vtkm::Add());
382 
383     vtkm::cont::ArrayHandle<T> H_Handle =
384       ReverseScanInclusiveByKey(C_Handle, G_Handle, vtkm::Maximum(), device);
385 
386     vtkm::cont::ArrayHandle<T> I_Handle;
387     SegmentedSplitTransform sstWorklet;
388     vtkm::worklet::DispatcherMapField<SegmentedSplitTransform> segmentedSplitTransformDispatcher(
389       sstWorklet);
390     segmentedSplitTransformDispatcher.SetDevice(DeviceAdapter());
391     segmentedSplitTransformDispatcher.Invoke(
392       B_Handle, D_Handle, F_Handle, G_Handle, H_Handle, I_Handle);
393 
394     return ScatterArrayWrapper(A_Handle, I_Handle, device);
395   }
396 
397   template <typename T, typename DeviceAdapter>
RenumberRanksProcedure(vtkm::cont::ArrayHandle<T> & A_Handle,vtkm::cont::ArrayHandle<T> & B_Handle,vtkm::cont::ArrayHandle<T> & C_Handle,vtkm::cont::ArrayHandle<T> & D_Handle,DeviceAdapter device)398   void RenumberRanksProcedure(vtkm::cont::ArrayHandle<T>& A_Handle,
399                               vtkm::cont::ArrayHandle<T>& B_Handle,
400                               vtkm::cont::ArrayHandle<T>& C_Handle,
401                               vtkm::cont::ArrayHandle<T>& D_Handle,
402                               DeviceAdapter device)
403   {
404     using Algorithm = typename vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
405 
406     vtkm::Id nPoints = A_Handle.GetNumberOfValues();
407 
408     vtkm::cont::ArrayHandleCounting<T> Ecouting_Handle(0, 1, nPoints);
409     vtkm::cont::ArrayHandle<T> E_Handle;
410     Algorithm::Copy(Ecouting_Handle, E_Handle);
411 
412     vtkm::cont::ArrayHandle<T> F_Handle;
413     Algorithm::ScanInclusiveByKey(D_Handle, E_Handle, F_Handle, vtkm::Minimum());
414 
415     vtkm::cont::ArrayHandle<T> G_Handle;
416     G_Handle = ArrayAddWrapper(A_Handle, F_Handle, device);
417 
418     vtkm::cont::ArrayHandleConstant<T> HConstant_Handle(1, nPoints);
419     vtkm::cont::ArrayHandle<T> H_Handle;
420     Algorithm::Copy(HConstant_Handle, H_Handle);
421 
422     vtkm::cont::ArrayHandle<T> I_Handle;
423     T initValue = 0;
424     Algorithm::ScanExclusiveByKey(C_Handle, H_Handle, I_Handle, initValue, vtkm::Add());
425 
426     vtkm::cont::ArrayHandle<T> J_Handle;
427     J_Handle = ScatterArrayWrapper(I_Handle, G_Handle, device);
428 
429     vtkm::cont::ArrayHandle<T> K_Handle;
430     K_Handle = ScatterArrayWrapper(B_Handle, G_Handle, device);
431 
432     vtkm::cont::ArrayHandle<T> L_Handle;
433     L_Handle = SegmentedSplitProcedure(J_Handle, K_Handle, D_Handle, device);
434 
435     vtkm::cont::ArrayHandle<T> M_Handle;
436     Algorithm::ScanInclusiveByKey(C_Handle, E_Handle, M_Handle, vtkm::Minimum());
437 
438     vtkm::cont::ArrayHandle<T> N_Handle;
439     N_Handle = ArrayAddWrapper(L_Handle, M_Handle, device);
440 
441     A_Handle = ScatterArrayWrapper(I_Handle, N_Handle, device);
442   }
443 
444   /////////////3D construction      /////////////////////
445   /// \brief Segmented split for 3D x, y, z coordinates
446   ///
447   /// Split \c pointId_Handle, \c X_Handle, \c Y_Handle and \c Z_Handle within each segment
448   /// as indicated by \c segId_Handle according to flags in \c flag_Handle.
449   ///
450   /// \tparam T
451   /// \tparam DeviceAdapter
452   /// \param pointId_Handle
453   /// \param flag_Handle
454   /// \param segId_Handle
455   /// \param X_Handle
456   /// \param Y_Handle
457   /// \param Z_Handle
458   /// \param device
459   template <typename T, typename DeviceAdapter>
SegmentedSplitProcedure3D(vtkm::cont::ArrayHandle<T> & pointId_Handle,vtkm::cont::ArrayHandle<T> & flag_Handle,vtkm::cont::ArrayHandle<T> & segId_Handle,vtkm::cont::ArrayHandle<T> & X_Handle,vtkm::cont::ArrayHandle<T> & Y_Handle,vtkm::cont::ArrayHandle<T> & Z_Handle,DeviceAdapter device)460   void SegmentedSplitProcedure3D(vtkm::cont::ArrayHandle<T>& pointId_Handle,
461                                  vtkm::cont::ArrayHandle<T>& flag_Handle,
462                                  vtkm::cont::ArrayHandle<T>& segId_Handle,
463                                  vtkm::cont::ArrayHandle<T>& X_Handle,
464                                  vtkm::cont::ArrayHandle<T>& Y_Handle,
465                                  vtkm::cont::ArrayHandle<T>& Z_Handle,
466                                  DeviceAdapter device)
467   {
468     using Algorithm = typename vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
469 
470     vtkm::cont::ArrayHandle<T> D_Handle;
471     T initValue = 0;
472     Algorithm::ScanExclusiveByKey(segId_Handle, flag_Handle, D_Handle, initValue, vtkm::Add());
473 
474     vtkm::cont::ArrayHandleCounting<T> Ecouting_Handle(0, 1, pointId_Handle.GetNumberOfValues());
475     vtkm::cont::ArrayHandle<T> E_Handle;
476     Algorithm::Copy(Ecouting_Handle, E_Handle);
477 
478     vtkm::cont::ArrayHandle<T> F_Handle;
479     Algorithm::ScanInclusiveByKey(segId_Handle, E_Handle, F_Handle, vtkm::Minimum());
480 
481     vtkm::cont::ArrayHandle<T> InvB_Handle = Inverse01ArrayWrapper(flag_Handle, device);
482     vtkm::cont::ArrayHandle<T> G_Handle;
483     Algorithm::ScanInclusiveByKey(segId_Handle, InvB_Handle, G_Handle, vtkm::Add());
484 
485     vtkm::cont::ArrayHandle<T> H_Handle =
486       ReverseScanInclusiveByKey(segId_Handle, G_Handle, vtkm::Maximum(), device);
487 
488     vtkm::cont::ArrayHandle<T> I_Handle;
489     SegmentedSplitTransform sstWorklet;
490     vtkm::worklet::DispatcherMapField<SegmentedSplitTransform> segmentedSplitTransformDispatcher(
491       sstWorklet);
492     segmentedSplitTransformDispatcher.SetDevice(DeviceAdapter());
493     segmentedSplitTransformDispatcher.Invoke(
494       flag_Handle, D_Handle, F_Handle, G_Handle, H_Handle, I_Handle);
495 
496     pointId_Handle = ScatterArrayWrapper(pointId_Handle, I_Handle, device);
497 
498     flag_Handle = ScatterArrayWrapper(flag_Handle, I_Handle, device);
499 
500     X_Handle = ScatterArrayWrapper(X_Handle, I_Handle, device);
501 
502     Y_Handle = ScatterArrayWrapper(Y_Handle, I_Handle, device);
503 
504     Z_Handle = ScatterArrayWrapper(Z_Handle, I_Handle, device);
505   }
506 
507   /// \brief Perform one level of KD-Tree construction
508   ///
509   /// Construct a level of KD-Tree by segemeted splits (partitioning) of \c pointId_Handle,
510   /// \c xrank_Handle, \c yrank_Handle and \c zrank_Handle according to the medium element
511   /// in each segment as indicated by \c segId_Handle alone the axis determined by \c level.
512   /// The split point of each segment will be updated in \c splitId_Handle.
513   template <typename T, typename DeviceAdapter>
OneLevelSplit3D(vtkm::cont::ArrayHandle<T> & pointId_Handle,vtkm::cont::ArrayHandle<T> & xrank_Handle,vtkm::cont::ArrayHandle<T> & yrank_Handle,vtkm::cont::ArrayHandle<T> & zrank_Handle,vtkm::cont::ArrayHandle<T> & segId_Handle,vtkm::cont::ArrayHandle<T> & splitId_Handle,vtkm::Int32 level,DeviceAdapter device)514   void OneLevelSplit3D(vtkm::cont::ArrayHandle<T>& pointId_Handle,
515                        vtkm::cont::ArrayHandle<T>& xrank_Handle,
516                        vtkm::cont::ArrayHandle<T>& yrank_Handle,
517                        vtkm::cont::ArrayHandle<T>& zrank_Handle,
518                        vtkm::cont::ArrayHandle<T>& segId_Handle,
519                        vtkm::cont::ArrayHandle<T>& splitId_Handle,
520                        vtkm::Int32 level,
521                        DeviceAdapter device)
522   {
523     using Algorithm = typename vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
524 
525     vtkm::cont::ArrayHandle<T> flag_Handle;
526     if (level % 3 == 0)
527     {
528       flag_Handle = ComputeFlagProcedure(xrank_Handle, segId_Handle, device);
529     }
530     else if (level % 3 == 1)
531     {
532       flag_Handle = ComputeFlagProcedure(yrank_Handle, segId_Handle, device);
533     }
534     else
535     {
536       flag_Handle = ComputeFlagProcedure(zrank_Handle, segId_Handle, device);
537     }
538 
539     SegmentedSplitProcedure3D(
540       pointId_Handle, flag_Handle, segId_Handle, xrank_Handle, yrank_Handle, zrank_Handle, device);
541 
542     vtkm::cont::ArrayHandle<T> segIdOld_Handle;
543     Algorithm::Copy(segId_Handle, segIdOld_Handle);
544     segId_Handle = NewKeyWrapper(segIdOld_Handle, flag_Handle, device);
545 
546     RenumberRanksProcedure(xrank_Handle, flag_Handle, segId_Handle, segIdOld_Handle, device);
547     RenumberRanksProcedure(yrank_Handle, flag_Handle, segId_Handle, segIdOld_Handle, device);
548     RenumberRanksProcedure(zrank_Handle, flag_Handle, segId_Handle, segIdOld_Handle, device);
549 
550     if (level % 3 == 0)
551     {
552       splitId_Handle =
553         SaveSplitPointIdWrapper(pointId_Handle, flag_Handle, xrank_Handle, splitId_Handle, device);
554     }
555     else if (level % 3 == 1)
556     {
557       splitId_Handle =
558         SaveSplitPointIdWrapper(pointId_Handle, flag_Handle, yrank_Handle, splitId_Handle, device);
559     }
560     else
561     {
562       splitId_Handle =
563         SaveSplitPointIdWrapper(pointId_Handle, flag_Handle, zrank_Handle, splitId_Handle, device);
564     }
565   }
566 
567   /// \brief Construct KdTree from x y z coordinate vector.
568   ///
569   /// This method constructs an array based KD-Tree from x, y, z coordinates of points in \c
570   /// coordi_Handle. The method rotates between x, y and z axis and splits input points into
571   /// equal halves with respect to the split axis at each level of construction. The indices to
572   /// the leaf nodes are returned in \c pointId_Handle and indices to internal nodes (splits)
573   /// are returned in splitId_handle.
574   ///
575   /// \param coordi_Handle (in) x, y, z coordinates of input points
576   /// \param pointId_Handle (out) returns indices to leaf nodes of the KD-tree
577   /// \param splitId_Handle (out) returns indices to internal nodes of the KD-tree
578   /// \param device the device to run the construction on
579   // Leaf Node vector and internal node (split) vectpr
580   template <typename CoordType, typename CoordStorageTag, typename DeviceAdapter>
Run(const vtkm::cont::ArrayHandle<vtkm::Vec<CoordType,3>,CoordStorageTag> & coordi_Handle,vtkm::cont::ArrayHandle<vtkm::Id> & pointId_Handle,vtkm::cont::ArrayHandle<vtkm::Id> & splitId_Handle,DeviceAdapter device)581   void Run(const vtkm::cont::ArrayHandle<vtkm::Vec<CoordType, 3>, CoordStorageTag>& coordi_Handle,
582            vtkm::cont::ArrayHandle<vtkm::Id>& pointId_Handle,
583            vtkm::cont::ArrayHandle<vtkm::Id>& splitId_Handle,
584            DeviceAdapter device)
585   {
586     using Algorithm = typename vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
587 
588     vtkm::Id nTrainingPoints = coordi_Handle.GetNumberOfValues();
589     vtkm::cont::ArrayHandleCounting<vtkm::Id> counting_Handle(0, 1, nTrainingPoints);
590     Algorithm::Copy(counting_Handle, pointId_Handle);
591     vtkm::cont::ArrayHandle<vtkm::Id> xorder_Handle;
592     Algorithm::Copy(counting_Handle, xorder_Handle);
593     vtkm::cont::ArrayHandle<vtkm::Id> yorder_Handle;
594     Algorithm::Copy(counting_Handle, yorder_Handle);
595     vtkm::cont::ArrayHandle<vtkm::Id> zorder_Handle;
596     Algorithm::Copy(counting_Handle, zorder_Handle);
597 
598     splitId_Handle.Allocate(nTrainingPoints);
599 
600     vtkm::cont::ArrayHandle<CoordType> xcoordi_Handle;
601     vtkm::cont::ArrayHandle<CoordType> ycoordi_Handle;
602     vtkm::cont::ArrayHandle<CoordType> zcoordi_Handle;
603 
604     SeprateVec3AryHandle sepVec3Worklet;
605     vtkm::worklet::DispatcherMapField<SeprateVec3AryHandle> sepVec3Dispatcher(sepVec3Worklet);
606     sepVec3Dispatcher.SetDevice(DeviceAdapter());
607     sepVec3Dispatcher.Invoke(coordi_Handle, xcoordi_Handle, ycoordi_Handle, zcoordi_Handle);
608 
609     Algorithm::SortByKey(xcoordi_Handle, xorder_Handle);
610     vtkm::cont::ArrayHandle<vtkm::Id> xrank_Handle =
611       ScatterArrayWrapper(pointId_Handle, xorder_Handle, device);
612 
613     Algorithm::SortByKey(ycoordi_Handle, yorder_Handle);
614     vtkm::cont::ArrayHandle<vtkm::Id> yrank_Handle =
615       ScatterArrayWrapper(pointId_Handle, yorder_Handle, device);
616 
617     Algorithm::SortByKey(zcoordi_Handle, zorder_Handle);
618     vtkm::cont::ArrayHandle<vtkm::Id> zrank_Handle =
619       ScatterArrayWrapper(pointId_Handle, zorder_Handle, device);
620 
621     vtkm::cont::ArrayHandle<vtkm::Id> segId_Handle;
622     vtkm::cont::ArrayHandleConstant<vtkm::Id> constHandle(0, nTrainingPoints);
623     Algorithm::Copy(constHandle, segId_Handle);
624 
625     ///// build kd tree /////
626     vtkm::Int32 maxLevel = static_cast<vtkm::Int32>(ceil(vtkm::Log2(nTrainingPoints) + 1));
627     for (vtkm::Int32 i = 0; i < maxLevel - 1; i++)
628     {
629       OneLevelSplit3D(pointId_Handle,
630                       xrank_Handle,
631                       yrank_Handle,
632                       zrank_Handle,
633                       segId_Handle,
634                       splitId_Handle,
635                       i,
636                       device);
637     }
638   }
639 };
640 }
641 }
642 } // namespace vtkm::worklet
643 
644 #endif // vtk_m_worklet_KdTree3DConstruction_h
645