1 /*
2  *
3  * Copyright (C) 2019-2021 Intel Corporation
4  *
5  * SPDX-License-Identifier: MIT
6  *
7  */
8 #include "ze_loader.h"
9 
10 #include "driver_discovery.h"
11 
12 namespace loader
13 {
14     ///////////////////////////////////////////////////////////////////////////////
15     context_t *context;
16 
check_drivers(ze_init_flags_t flags)17     ze_result_t context_t::check_drivers(ze_init_flags_t flags) {
18         bool return_first_driver_result=false;
19         if(drivers.size()==1) {
20             return_first_driver_result=true;
21         }
22 
23         for(auto it = drivers.begin(); it != drivers.end(); )
24         {
25             ze_result_t result = init_driver(*it, flags);
26             if(result != ZE_RESULT_SUCCESS) {
27                 FREE_DRIVER_LIBRARY(it->handle);
28                 it = drivers.erase(it);
29                 if(return_first_driver_result)
30                     return result;
31             }
32             else {
33                 it++;
34             }
35         }
36 
37         if(drivers.size() == 0)
38             return ZE_RESULT_ERROR_UNINITIALIZED;
39 
40         return ZE_RESULT_SUCCESS;
41     }
42 
init_driver(driver_t driver,ze_init_flags_t flags)43     ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags) {
44 
45         auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
46             GET_FUNCTION_PTR(driver.handle, "zeGetGlobalProcAddrTable"));
47         if(!getTable) {
48             return ZE_RESULT_ERROR_UNINITIALIZED;
49         }
50 
51         ze_global_dditable_t global;
52         auto getTableResult = getTable(ZE_API_VERSION_CURRENT, &global);
53         if(getTableResult != ZE_RESULT_SUCCESS) {
54             return ZE_RESULT_ERROR_UNINITIALIZED;
55         }
56 
57         if(nullptr == global.pfnInit) {
58             return ZE_RESULT_ERROR_UNINITIALIZED;
59         }
60 
61         if(nullptr != validationLayer) {
62             auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
63                 GET_FUNCTION_PTR(validationLayer, "zeGetGlobalProcAddrTable") );
64             if(!getTable)
65                 return ZE_RESULT_ERROR_UNINITIALIZED;
66             auto getTableResult = getTable( version, &global);
67             if(getTableResult != ZE_RESULT_SUCCESS) {
68                 return ZE_RESULT_ERROR_UNINITIALIZED;
69             }
70         }
71 
72         if(nullptr != tracingLayer) {
73             auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
74                 GET_FUNCTION_PTR(tracingLayer, "zeGetGlobalProcAddrTable") );
75             if(!getTable)
76                 return ZE_RESULT_ERROR_UNINITIALIZED;
77             auto getTableResult = getTable( version, &global);
78             if(getTableResult != ZE_RESULT_SUCCESS) {
79                 return ZE_RESULT_ERROR_UNINITIALIZED;
80             }
81         }
82 
83         auto pfnInit = global.pfnInit;
84         if(nullptr == pfnInit) {
85             return ZE_RESULT_ERROR_UNINITIALIZED;
86         }
87 
88         return pfnInit(flags);
89     }
90 
91     ///////////////////////////////////////////////////////////////////////////////
init()92     ze_result_t context_t::init()
93     {
94         auto discoveredDrivers = discoverEnabledDrivers();
95 
96         drivers.reserve( discoveredDrivers.size() + getenv_tobool( "ZE_ENABLE_NULL_DRIVER" ) );
97         if( getenv_tobool( "ZE_ENABLE_NULL_DRIVER" ) )
98         {
99             auto handle = LOAD_DRIVER_LIBRARY( MAKE_LIBRARY_NAME( "ze_null", L0_LOADER_VERSION ) );
100             if( NULL != handle )
101             {
102                 drivers.emplace_back();
103                 drivers.rbegin()->handle = handle;
104             }
105         }
106 
107         for( auto name : discoveredDrivers )
108         {
109             auto handle = LOAD_DRIVER_LIBRARY( name.c_str() );
110             if( NULL != handle )
111             {
112                 drivers.emplace_back();
113                 drivers.rbegin()->handle = handle;
114             }
115         }
116 
117         if(drivers.size()==0)
118             return ZE_RESULT_ERROR_UNINITIALIZED;
119 
120         add_loader_version();
121         std::string loaderLibraryPath;
122 #ifdef _WIN32
123         loaderLibraryPath = readLevelZeroLoaderLibraryPath();
124 #endif
125         typedef ze_result_t (ZE_APICALL *getVersion_t)(zel_component_version_t *version);
126         if( getenv_tobool( "ZE_ENABLE_VALIDATION_LAYER" ) )
127         {
128             std::string validationLayerLibraryPath = create_library_path(MAKE_LAYER_NAME( "ze_validation_layer" ), loaderLibraryPath.c_str());
129             validationLayer = LOAD_DRIVER_LIBRARY( validationLayerLibraryPath.c_str() );
130             if(validationLayer)
131             {
132                 auto getVersion = reinterpret_cast<getVersion_t>(
133                     GET_FUNCTION_PTR(validationLayer, "zelLoaderGetVersion"));
134                 zel_component_version_t version;
135                 if(getVersion && ZE_RESULT_SUCCESS == getVersion(&version))
136                 {
137                     compVersions.push_back(version);
138                 }
139             }
140         }
141 
142         if( getenv_tobool( "ZE_ENABLE_TRACING_LAYER" ) )
143         {
144             std::string tracingLayerLibraryPath = create_library_path(MAKE_LAYER_NAME( "ze_tracing_layer" ), loaderLibraryPath.c_str());
145             tracingLayer = LOAD_DRIVER_LIBRARY( tracingLayerLibraryPath.c_str() );
146             if(tracingLayer)
147             {
148                 auto getVersion = reinterpret_cast<getVersion_t>(
149                     GET_FUNCTION_PTR(tracingLayer, "zelLoaderGetVersion"));
150                 zel_component_version_t version;
151                 if(getVersion && ZE_RESULT_SUCCESS == getVersion(&version))
152                 {
153                     compVersions.push_back(version);
154                 }
155             }
156         }
157 
158         forceIntercept = getenv_tobool( "ZE_ENABLE_LOADER_INTERCEPT" );
159 
160         return ZE_RESULT_SUCCESS;
161     };
162 
163     ///////////////////////////////////////////////////////////////////////////////
~context_t()164     context_t::~context_t()
165     {
166         FREE_DRIVER_LIBRARY( validationLayer );
167         FREE_DRIVER_LIBRARY( tracingLayer );
168 
169         for( auto& drv : drivers )
170         {
171             FREE_DRIVER_LIBRARY( drv.handle );
172         }
173     };
174 
add_loader_version()175     void context_t::add_loader_version(){
176         zel_component_version_t version = {};
177         strncpy(version.component_name, LOADER_COMP_NAME, ZEL_COMPONENT_STRING_SIZE);
178         version.spec_version = ZE_API_VERSION_CURRENT;
179         version.component_lib_version.major = LOADER_VERSION_MAJOR;
180         version.component_lib_version.minor = LOADER_VERSION_MINOR;
181         version.component_lib_version.patch = LOADER_VERSION_PATCH;
182 
183         compVersions.push_back(version);
184     }
185 
186 }
187