1 #include "chainerx/testing/util.h"
2 
3 #include <atomic>
4 #include <string>
5 
6 #include <absl/types/optional.h>
7 #include <gtest/gtest.h>
8 
9 #include "chainerx/backend.h"
10 #include "chainerx/context.h"
11 #include "chainerx/error.h"
12 #include "chainerx/macro.h"
13 #include "chainerx/util.h"
14 
15 namespace chainerx {
16 namespace testing {
17 namespace {
18 
19 std::atomic<int> g_skipped_native_test_count{0};
20 std::atomic<int> g_skipped_cuda_test_count{0};
21 
GetNativeDeviceLimit(Backend & backend)22 int GetNativeDeviceLimit(Backend& backend) {
23     CHAINERX_ASSERT(backend.GetName() == "native");
24     static int limit = -1;
25     if (limit >= 0) {
26         return limit;
27     }
28     if (absl::optional<std::string> env = GetEnv("CHAINERX_TEST_NATIVE_DEVICE_LIMIT")) {
29         try {
30             limit = std::stoi(*env);
31         } catch (const std::exception&) {
32             limit = -1;
33         }
34         if (limit < 0) {
35             throw ChainerxError{"CHAINERX_TEST_NATIVE_DEVICE_LIMIT must be non-negative integer: ", *env};
36         }
37     } else {
38         limit = backend.GetDeviceCount();
39     }
40     return limit;
41 }
42 
GetCudaDeviceLimit(Backend & backend)43 int GetCudaDeviceLimit(Backend& backend) {
44     CHAINERX_ASSERT(backend.GetName() == "cuda");
45     static int limit = -1;
46     if (limit >= 0) {
47         return limit;
48     }
49     if (absl::optional<std::string> env = GetEnv("CHAINERX_TEST_CUDA_DEVICE_LIMIT")) {
50         try {
51             limit = std::stoi(*env);
52         } catch (const std::exception&) {
53             limit = -1;
54         }
55         if (limit < 0) {
56             throw ChainerxError{"CHAINERX_TEST_CUDA_DEVICE_LIMIT must be non-negative integer: ", *env};
57         }
58     } else {
59         limit = backend.GetDeviceCount();
60     }
61     return limit;
62 }
63 
64 }  // namespace
65 
66 namespace testing_internal {
67 
GetSkippedNativeTestCount()68 int GetSkippedNativeTestCount() { return g_skipped_native_test_count; }
69 
GetSkippedCudaTestCount()70 int GetSkippedCudaTestCount() { return g_skipped_cuda_test_count; }
71 
GetDeviceLimit(Backend & backend)72 int GetDeviceLimit(Backend& backend) {
73     if (backend.GetName() == "native") {
74         return GetNativeDeviceLimit(backend);
75     }
76     if (backend.GetName() == "cuda") {
77         return GetCudaDeviceLimit(backend);
78     }
79     throw BackendError{"invalid backend: ", backend.GetName()};
80 }
81 
SkipIfDeviceUnavailable(Backend & backend,int required_num)82 bool SkipIfDeviceUnavailable(Backend& backend, int required_num) {
83     if (GetDeviceLimit(backend) >= required_num) {
84         return false;
85     }
86     const ::testing::TestInfo* const test_info = ::testing::UnitTest::GetInstance()->current_test_info();
87     std::cout << "[     SKIP ] " << test_info->test_case_name() << "." << test_info->name() << std::endl;
88 
89     if (backend.GetName() == "native") {
90         ++g_skipped_native_test_count;
91     } else if (backend.GetName() == "cuda") {
92         ++g_skipped_cuda_test_count;
93     } else {
94         throw BackendError{"invalid backend: ", backend.GetName()};
95     }
96     return true;
97 }
98 
SkipIfDeviceUnavailable(const std::string & backend_name,int required_num)99 bool SkipIfDeviceUnavailable(const std::string& backend_name, int required_num) {
100     return SkipIfDeviceUnavailable(Context{}.GetBackend(backend_name), required_num);
101 }
102 
103 }  // namespace testing_internal
104 }  // namespace testing
105 }  // namespace chainerx
106