1 /* 2 * 3 * Copyright (C) 2019-2021 Intel Corporation 4 * 5 * SPDX-License-Identifier: MIT 6 * 7 */ 8 #include "ze_loader.h" 9 10 #include "driver_discovery.h" 11 12 namespace loader 13 { 14 /////////////////////////////////////////////////////////////////////////////// 15 context_t *context; 16 check_drivers(ze_init_flags_t flags)17 ze_result_t context_t::check_drivers(ze_init_flags_t flags) { 18 bool return_first_driver_result=false; 19 if(drivers.size()==1) { 20 return_first_driver_result=true; 21 } 22 23 for(auto it = drivers.begin(); it != drivers.end(); ) 24 { 25 ze_result_t result = init_driver(*it, flags); 26 if(result != ZE_RESULT_SUCCESS) { 27 FREE_DRIVER_LIBRARY(it->handle); 28 it = drivers.erase(it); 29 if(return_first_driver_result) 30 return result; 31 } 32 else { 33 it++; 34 } 35 } 36 37 if(drivers.size() == 0) 38 return ZE_RESULT_ERROR_UNINITIALIZED; 39 40 return ZE_RESULT_SUCCESS; 41 } 42 init_driver(driver_t driver,ze_init_flags_t flags)43 ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags) { 44 45 auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>( 46 GET_FUNCTION_PTR(driver.handle, "zeGetGlobalProcAddrTable")); 47 if(!getTable) { 48 return ZE_RESULT_ERROR_UNINITIALIZED; 49 } 50 51 ze_global_dditable_t global; 52 auto getTableResult = getTable(ZE_API_VERSION_CURRENT, &global); 53 if(getTableResult != ZE_RESULT_SUCCESS) { 54 return ZE_RESULT_ERROR_UNINITIALIZED; 55 } 56 57 if(nullptr == global.pfnInit) { 58 return ZE_RESULT_ERROR_UNINITIALIZED; 59 } 60 61 if(nullptr != validationLayer) { 62 auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>( 63 GET_FUNCTION_PTR(validationLayer, "zeGetGlobalProcAddrTable") ); 64 if(!getTable) 65 return ZE_RESULT_ERROR_UNINITIALIZED; 66 auto getTableResult = getTable( version, &global); 67 if(getTableResult != ZE_RESULT_SUCCESS) { 68 return ZE_RESULT_ERROR_UNINITIALIZED; 69 } 70 } 71 72 if(nullptr != tracingLayer) { 73 auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>( 74 GET_FUNCTION_PTR(tracingLayer, "zeGetGlobalProcAddrTable") ); 75 if(!getTable) 76 return ZE_RESULT_ERROR_UNINITIALIZED; 77 auto getTableResult = getTable( version, &global); 78 if(getTableResult != ZE_RESULT_SUCCESS) { 79 return ZE_RESULT_ERROR_UNINITIALIZED; 80 } 81 } 82 83 auto pfnInit = global.pfnInit; 84 if(nullptr == pfnInit) { 85 return ZE_RESULT_ERROR_UNINITIALIZED; 86 } 87 88 return pfnInit(flags); 89 } 90 91 /////////////////////////////////////////////////////////////////////////////// init()92 ze_result_t context_t::init() 93 { 94 auto discoveredDrivers = discoverEnabledDrivers(); 95 96 drivers.reserve( discoveredDrivers.size() + getenv_tobool( "ZE_ENABLE_NULL_DRIVER" ) ); 97 if( getenv_tobool( "ZE_ENABLE_NULL_DRIVER" ) ) 98 { 99 auto handle = LOAD_DRIVER_LIBRARY( MAKE_LIBRARY_NAME( "ze_null", L0_LOADER_VERSION ) ); 100 if( NULL != handle ) 101 { 102 drivers.emplace_back(); 103 drivers.rbegin()->handle = handle; 104 } 105 } 106 107 for( auto name : discoveredDrivers ) 108 { 109 auto handle = LOAD_DRIVER_LIBRARY( name.c_str() ); 110 if( NULL != handle ) 111 { 112 drivers.emplace_back(); 113 drivers.rbegin()->handle = handle; 114 } 115 } 116 117 if(drivers.size()==0) 118 return ZE_RESULT_ERROR_UNINITIALIZED; 119 120 add_loader_version(); 121 std::string loaderLibraryPath; 122 #ifdef _WIN32 123 loaderLibraryPath = readLevelZeroLoaderLibraryPath(); 124 #endif 125 typedef ze_result_t (ZE_APICALL *getVersion_t)(zel_component_version_t *version); 126 if( getenv_tobool( "ZE_ENABLE_VALIDATION_LAYER" ) ) 127 { 128 std::string validationLayerLibraryPath = create_library_path(MAKE_LAYER_NAME( "ze_validation_layer" ), loaderLibraryPath.c_str()); 129 validationLayer = LOAD_DRIVER_LIBRARY( validationLayerLibraryPath.c_str() ); 130 if(validationLayer) 131 { 132 auto getVersion = reinterpret_cast<getVersion_t>( 133 GET_FUNCTION_PTR(validationLayer, "zelLoaderGetVersion")); 134 zel_component_version_t version; 135 if(getVersion && ZE_RESULT_SUCCESS == getVersion(&version)) 136 { 137 compVersions.push_back(version); 138 } 139 } 140 } 141 142 if( getenv_tobool( "ZE_ENABLE_TRACING_LAYER" ) ) 143 { 144 std::string tracingLayerLibraryPath = create_library_path(MAKE_LAYER_NAME( "ze_tracing_layer" ), loaderLibraryPath.c_str()); 145 tracingLayer = LOAD_DRIVER_LIBRARY( tracingLayerLibraryPath.c_str() ); 146 if(tracingLayer) 147 { 148 auto getVersion = reinterpret_cast<getVersion_t>( 149 GET_FUNCTION_PTR(tracingLayer, "zelLoaderGetVersion")); 150 zel_component_version_t version; 151 if(getVersion && ZE_RESULT_SUCCESS == getVersion(&version)) 152 { 153 compVersions.push_back(version); 154 } 155 } 156 } 157 158 forceIntercept = getenv_tobool( "ZE_ENABLE_LOADER_INTERCEPT" ); 159 160 return ZE_RESULT_SUCCESS; 161 }; 162 163 /////////////////////////////////////////////////////////////////////////////// ~context_t()164 context_t::~context_t() 165 { 166 FREE_DRIVER_LIBRARY( validationLayer ); 167 FREE_DRIVER_LIBRARY( tracingLayer ); 168 169 for( auto& drv : drivers ) 170 { 171 FREE_DRIVER_LIBRARY( drv.handle ); 172 } 173 }; 174 add_loader_version()175 void context_t::add_loader_version(){ 176 zel_component_version_t version = {}; 177 strncpy(version.component_name, LOADER_COMP_NAME, ZEL_COMPONENT_STRING_SIZE); 178 version.spec_version = ZE_API_VERSION_CURRENT; 179 version.component_lib_version.major = LOADER_VERSION_MAJOR; 180 version.component_lib_version.minor = LOADER_VERSION_MINOR; 181 version.component_lib_version.patch = LOADER_VERSION_PATCH; 182 183 compVersions.push_back(version); 184 } 185 186 } 187