1 //============================================================================
2 // Copyright (c) Kitware, Inc.
3 // All rights reserved.
4 // See LICENSE.txt for details.
5 //
6 // This software is distributed WITHOUT ANY WARRANTY; without even
7 // the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
8 // PURPOSE. See the above copyright notice for more information.
9 //============================================================================
10
11 #include <vtkm/cont/RuntimeDeviceTracker.h>
12
13 //include all backends
14 #include <vtkm/cont/cuda/DeviceAdapterCuda.h>
15 #include <vtkm/cont/openmp/DeviceAdapterOpenMP.h>
16 #include <vtkm/cont/serial/DeviceAdapterSerial.h>
17 #include <vtkm/cont/tbb/DeviceAdapterTBB.h>
18
19 #include <vtkm/cont/testing/Testing.h>
20
21 #include <algorithm>
22 #include <array>
23
24 namespace
25 {
26
27 template <typename DeviceAdapterTag>
verify_state(DeviceAdapterTag tag,std::array<bool,VTKM_MAX_DEVICE_ADAPTER_ID> & defaults)28 void verify_state(DeviceAdapterTag tag, std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& defaults)
29 {
30 auto& tracker = vtkm::cont::GetRuntimeDeviceTracker();
31 // presumable all other devices match the defaults
32 for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
33 {
34 const auto deviceId = vtkm::cont::make_DeviceAdapterId(i);
35 if (deviceId != tag)
36 {
37 VTKM_TEST_ASSERT(defaults[static_cast<std::size_t>(i)] == tracker.CanRunOn(deviceId),
38 "ScopedRuntimeDeviceTracker didn't properly setup state correctly");
39 }
40 }
41 }
42
43 template <typename DeviceAdapterTag>
verify_srdt_support(DeviceAdapterTag tag,std::array<bool,VTKM_MAX_DEVICE_ADAPTER_ID> & force,std::array<bool,VTKM_MAX_DEVICE_ADAPTER_ID> & enable,std::array<bool,VTKM_MAX_DEVICE_ADAPTER_ID> & disable)44 void verify_srdt_support(DeviceAdapterTag tag,
45 std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& force,
46 std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& enable,
47 std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& disable)
48 {
49 vtkm::cont::RuntimeDeviceInformation runtime;
50 const bool haveSupport = runtime.Exists(tag);
51 if (haveSupport)
52 {
53 vtkm::cont::ScopedRuntimeDeviceTracker tracker(tag,
54 vtkm::cont::RuntimeDeviceTrackerMode::Force);
55 VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == haveSupport, "");
56 verify_state(tag, force);
57 }
58
59 if (haveSupport)
60 {
61 vtkm::cont::ScopedRuntimeDeviceTracker tracker(tag,
62 vtkm::cont::RuntimeDeviceTrackerMode::Enable);
63 VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == haveSupport);
64 verify_state(tag, enable);
65 }
66
67 {
68 vtkm::cont::ScopedRuntimeDeviceTracker tracker(tag,
69 vtkm::cont::RuntimeDeviceTrackerMode::Disable);
70 VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == false, "");
71 verify_state(tag, disable);
72 }
73 }
74
VerifyScopedRuntimeDeviceTracker()75 void VerifyScopedRuntimeDeviceTracker()
76 {
77 std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> all_off;
78 std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> all_on;
79 std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> defaults;
80
81 all_off.fill(false);
82 vtkm::cont::RuntimeDeviceInformation runtime;
83 auto& tracker = vtkm::cont::GetRuntimeDeviceTracker();
84 for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
85 {
86 auto deviceId = vtkm::cont::make_DeviceAdapterId(i);
87 defaults[static_cast<std::size_t>(i)] = tracker.CanRunOn(deviceId);
88 all_on[static_cast<std::size_t>(i)] = runtime.Exists(deviceId);
89 }
90
91 using SerialTag = ::vtkm::cont::DeviceAdapterTagSerial;
92 using OpenMPTag = ::vtkm::cont::DeviceAdapterTagOpenMP;
93 using TBBTag = ::vtkm::cont::DeviceAdapterTagTBB;
94 using CudaTag = ::vtkm::cont::DeviceAdapterTagCuda;
95 using KokkosTag = ::vtkm::cont::DeviceAdapterTagKokkos;
96 using AnyTag = ::vtkm::cont::DeviceAdapterTagAny;
97
98 //Verify that for each device adapter we compile code for, that it
99 //has valid runtime support.
100 verify_srdt_support(SerialTag(), all_off, all_on, defaults);
101 verify_srdt_support(OpenMPTag(), all_off, all_on, defaults);
102 verify_srdt_support(CudaTag(), all_off, all_on, defaults);
103 verify_srdt_support(TBBTag(), all_off, all_on, defaults);
104 verify_srdt_support(KokkosTag(), all_off, all_on, defaults);
105
106 // Verify that all the ScopedRuntimeDeviceTracker changes
107 // have been reverted
108 verify_state(AnyTag(), defaults);
109
110
111 verify_srdt_support(AnyTag(), all_on, all_on, all_off);
112
113 // Verify that all the ScopedRuntimeDeviceTracker changes
114 // have been reverted
115 verify_state(AnyTag(), defaults);
116 }
117
118 } // anonymous namespace
119
UnitTestScopedRuntimeDeviceTracker(int argc,char * argv[])120 int UnitTestScopedRuntimeDeviceTracker(int argc, char* argv[])
121 {
122 return vtkm::cont::testing::Testing::Run(VerifyScopedRuntimeDeviceTracker, argc, argv);
123 }
124