1 #include "chainerx/testing/util.h" 2 3 #include <atomic> 4 #include <string> 5 6 #include <absl/types/optional.h> 7 #include <gtest/gtest.h> 8 9 #include "chainerx/backend.h" 10 #include "chainerx/context.h" 11 #include "chainerx/error.h" 12 #include "chainerx/macro.h" 13 #include "chainerx/util.h" 14 15 namespace chainerx { 16 namespace testing { 17 namespace { 18 19 std::atomic<int> g_skipped_native_test_count{0}; 20 std::atomic<int> g_skipped_cuda_test_count{0}; 21 GetNativeDeviceLimit(Backend & backend)22int GetNativeDeviceLimit(Backend& backend) { 23 CHAINERX_ASSERT(backend.GetName() == "native"); 24 static int limit = -1; 25 if (limit >= 0) { 26 return limit; 27 } 28 if (absl::optional<std::string> env = GetEnv("CHAINERX_TEST_NATIVE_DEVICE_LIMIT")) { 29 try { 30 limit = std::stoi(*env); 31 } catch (const std::exception&) { 32 limit = -1; 33 } 34 if (limit < 0) { 35 throw ChainerxError{"CHAINERX_TEST_NATIVE_DEVICE_LIMIT must be non-negative integer: ", *env}; 36 } 37 } else { 38 limit = backend.GetDeviceCount(); 39 } 40 return limit; 41 } 42 GetCudaDeviceLimit(Backend & backend)43int GetCudaDeviceLimit(Backend& backend) { 44 CHAINERX_ASSERT(backend.GetName() == "cuda"); 45 static int limit = -1; 46 if (limit >= 0) { 47 return limit; 48 } 49 if (absl::optional<std::string> env = GetEnv("CHAINERX_TEST_CUDA_DEVICE_LIMIT")) { 50 try { 51 limit = std::stoi(*env); 52 } catch (const std::exception&) { 53 limit = -1; 54 } 55 if (limit < 0) { 56 throw ChainerxError{"CHAINERX_TEST_CUDA_DEVICE_LIMIT must be non-negative integer: ", *env}; 57 } 58 } else { 59 limit = backend.GetDeviceCount(); 60 } 61 return limit; 62 } 63 64 } // namespace 65 66 namespace testing_internal { 67 GetSkippedNativeTestCount()68int GetSkippedNativeTestCount() { return g_skipped_native_test_count; } 69 GetSkippedCudaTestCount()70int GetSkippedCudaTestCount() { return g_skipped_cuda_test_count; } 71 GetDeviceLimit(Backend & backend)72int GetDeviceLimit(Backend& backend) { 73 if (backend.GetName() == "native") { 74 return GetNativeDeviceLimit(backend); 75 } 76 if (backend.GetName() == "cuda") { 77 return GetCudaDeviceLimit(backend); 78 } 79 throw BackendError{"invalid backend: ", backend.GetName()}; 80 } 81 SkipIfDeviceUnavailable(Backend & backend,int required_num)82bool SkipIfDeviceUnavailable(Backend& backend, int required_num) { 83 if (GetDeviceLimit(backend) >= required_num) { 84 return false; 85 } 86 const ::testing::TestInfo* const test_info = ::testing::UnitTest::GetInstance()->current_test_info(); 87 std::cout << "[ SKIP ] " << test_info->test_case_name() << "." << test_info->name() << std::endl; 88 89 if (backend.GetName() == "native") { 90 ++g_skipped_native_test_count; 91 } else if (backend.GetName() == "cuda") { 92 ++g_skipped_cuda_test_count; 93 } else { 94 throw BackendError{"invalid backend: ", backend.GetName()}; 95 } 96 return true; 97 } 98 SkipIfDeviceUnavailable(const std::string & backend_name,int required_num)99bool SkipIfDeviceUnavailable(const std::string& backend_name, int required_num) { 100 return SkipIfDeviceUnavailable(Context{}.GetBackend(backend_name), required_num); 101 } 102 103 } // namespace testing_internal 104 } // namespace testing 105 } // namespace chainerx 106