1 //============================================================================
2 //  Copyright (c) Kitware, Inc.
3 //  All rights reserved.
4 //  See LICENSE.txt for details.
5 //
6 //  This software is distributed WITHOUT ANY WARRANTY; without even
7 //  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
8 //  PURPOSE.  See the above copyright notice for more information.
9 //============================================================================
10 
11 #include <vtkm/cont/RuntimeDeviceTracker.h>
12 
13 //include all backends
14 #include <vtkm/cont/cuda/DeviceAdapterCuda.h>
15 #include <vtkm/cont/openmp/DeviceAdapterOpenMP.h>
16 #include <vtkm/cont/serial/DeviceAdapterSerial.h>
17 #include <vtkm/cont/tbb/DeviceAdapterTBB.h>
18 
19 #include <vtkm/cont/testing/Testing.h>
20 
21 #include <algorithm>
22 #include <array>
23 
24 namespace
25 {
26 
27 template <typename DeviceAdapterTag>
verify_state(DeviceAdapterTag tag,std::array<bool,VTKM_MAX_DEVICE_ADAPTER_ID> & defaults)28 void verify_state(DeviceAdapterTag tag, std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& defaults)
29 {
30   auto& tracker = vtkm::cont::GetRuntimeDeviceTracker();
31   // presumable all other devices match the defaults
32   for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
33   {
34     const auto deviceId = vtkm::cont::make_DeviceAdapterId(i);
35     if (deviceId != tag)
36     {
37       VTKM_TEST_ASSERT(defaults[static_cast<std::size_t>(i)] == tracker.CanRunOn(deviceId),
38                        "ScopedRuntimeDeviceTracker didn't properly setup state correctly");
39     }
40   }
41 }
42 
43 template <typename DeviceAdapterTag>
verify_srdt_support(DeviceAdapterTag tag,std::array<bool,VTKM_MAX_DEVICE_ADAPTER_ID> & force,std::array<bool,VTKM_MAX_DEVICE_ADAPTER_ID> & enable,std::array<bool,VTKM_MAX_DEVICE_ADAPTER_ID> & disable)44 void verify_srdt_support(DeviceAdapterTag tag,
45                          std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& force,
46                          std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& enable,
47                          std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& disable)
48 {
49   vtkm::cont::RuntimeDeviceInformation runtime;
50   const bool haveSupport = runtime.Exists(tag);
51   if (haveSupport)
52   {
53     vtkm::cont::ScopedRuntimeDeviceTracker tracker(tag,
54                                                    vtkm::cont::RuntimeDeviceTrackerMode::Force);
55     VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == haveSupport, "");
56     verify_state(tag, force);
57   }
58 
59   if (haveSupport)
60   {
61     vtkm::cont::ScopedRuntimeDeviceTracker tracker(tag,
62                                                    vtkm::cont::RuntimeDeviceTrackerMode::Enable);
63     VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == haveSupport);
64     verify_state(tag, enable);
65   }
66 
67   {
68     vtkm::cont::ScopedRuntimeDeviceTracker tracker(tag,
69                                                    vtkm::cont::RuntimeDeviceTrackerMode::Disable);
70     VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == false, "");
71     verify_state(tag, disable);
72   }
73 }
74 
VerifyScopedRuntimeDeviceTracker()75 void VerifyScopedRuntimeDeviceTracker()
76 {
77   std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> all_off;
78   std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> all_on;
79   std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> defaults;
80 
81   all_off.fill(false);
82   vtkm::cont::RuntimeDeviceInformation runtime;
83   auto& tracker = vtkm::cont::GetRuntimeDeviceTracker();
84   for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
85   {
86     auto deviceId = vtkm::cont::make_DeviceAdapterId(i);
87     defaults[static_cast<std::size_t>(i)] = tracker.CanRunOn(deviceId);
88     all_on[static_cast<std::size_t>(i)] = runtime.Exists(deviceId);
89   }
90 
91   using SerialTag = ::vtkm::cont::DeviceAdapterTagSerial;
92   using OpenMPTag = ::vtkm::cont::DeviceAdapterTagOpenMP;
93   using TBBTag = ::vtkm::cont::DeviceAdapterTagTBB;
94   using CudaTag = ::vtkm::cont::DeviceAdapterTagCuda;
95   using KokkosTag = ::vtkm::cont::DeviceAdapterTagKokkos;
96   using AnyTag = ::vtkm::cont::DeviceAdapterTagAny;
97 
98   //Verify that for each device adapter we compile code for, that it
99   //has valid runtime support.
100   verify_srdt_support(SerialTag(), all_off, all_on, defaults);
101   verify_srdt_support(OpenMPTag(), all_off, all_on, defaults);
102   verify_srdt_support(CudaTag(), all_off, all_on, defaults);
103   verify_srdt_support(TBBTag(), all_off, all_on, defaults);
104   verify_srdt_support(KokkosTag(), all_off, all_on, defaults);
105 
106   // Verify that all the ScopedRuntimeDeviceTracker changes
107   // have been reverted
108   verify_state(AnyTag(), defaults);
109 
110 
111   verify_srdt_support(AnyTag(), all_on, all_on, all_off);
112 
113   // Verify that all the ScopedRuntimeDeviceTracker changes
114   // have been reverted
115   verify_state(AnyTag(), defaults);
116 }
117 
118 } // anonymous namespace
119 
UnitTestScopedRuntimeDeviceTracker(int argc,char * argv[])120 int UnitTestScopedRuntimeDeviceTracker(int argc, char* argv[])
121 {
122   return vtkm::cont::testing::Testing::Run(VerifyScopedRuntimeDeviceTracker, argc, argv);
123 }
124