1 /*
2  * Copyright (C) 2020-2021 Intel Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  */
7 
8 #include <level_zero/ze_api.h>
9 
10 #include <cstring>
11 #include <fstream>
12 #include <iostream>
13 #include <limits>
14 #include <memory>
15 #include <string>
16 #include <vector>
17 
18 extern bool verbose;
19 
20 template <bool TerminateOnFailure, typename ResulT>
validate(ResulT result,const char * message)21 inline void validate(ResulT result, const char *message) {
22     if (result == ZE_RESULT_SUCCESS) {
23         if (verbose) {
24             std::cerr << "SUCCESS : " << message << std::endl;
25         }
26         return;
27     }
28 
29     if (verbose) {
30         std::cerr << (TerminateOnFailure ? "ERROR : " : "WARNING : ") << message << " : " << result
31                   << std::endl;
32     }
33 
34     if (TerminateOnFailure) {
35         std::terminate();
36     }
37 }
38 
39 #define SUCCESS_OR_TERMINATE(CALL) validate<true>(CALL, #CALL)
40 #define SUCCESS_OR_TERMINATE_BOOL(FLAG) validate<true>(!(FLAG), #FLAG)
41 #define SUCCESS_OR_WARNING(CALL) validate<false>(CALL, #CALL)
42 #define SUCCESS_OR_WARNING_BOOL(FLAG) validate<false>(!(FLAG), #FLAG)
43 
isParamEnabled(int argc,char * argv[],const char * shortName,const char * longName)44 inline bool isParamEnabled(int argc, char *argv[], const char *shortName, const char *longName) {
45     char **arg = &argv[1];
46     char **argE = &argv[argc];
47 
48     for (; arg != argE; ++arg) {
49         if ((0 == strcmp(*arg, shortName)) || (0 == strcmp(*arg, longName))) {
50             return true;
51         }
52     }
53 
54     return false;
55 }
56 
getParamValue(int argc,char * argv[],const char * shortName,const char * longName,int defaultValue)57 inline int getParamValue(int argc, char *argv[], const char *shortName, const char *longName, int defaultValue) {
58     char **arg = &argv[1];
59     char **argE = &argv[argc];
60 
61     for (; arg != argE; ++arg) {
62         if ((0 == strcmp(*arg, shortName)) || (0 == strcmp(*arg, longName))) {
63             arg++;
64             return atoi(*arg);
65         }
66     }
67 
68     return defaultValue;
69 }
70 
isVerbose(int argc,char * argv[])71 inline bool isVerbose(int argc, char *argv[]) {
72     bool enabled = isParamEnabled(argc, argv, "-v", "--verbose");
73     if (enabled == false) {
74         return false;
75     }
76 
77     std::cerr << "Verbose mode detected" << std::endl;
78 
79     return true;
80 }
81 
isSyncQueueEnabled(int argc,char * argv[])82 inline bool isSyncQueueEnabled(int argc, char *argv[]) {
83     bool enabled = isParamEnabled(argc, argv, "-s", "--sync");
84     if (enabled == false) {
85         std::cerr << "Async Queue detected" << std::endl;
86         return false;
87     }
88 
89     std::cerr << "Sync Queue detected" << std::endl;
90 
91     return true;
92 }
93 
getCommandQueueOrdinal(ze_device_handle_t & device)94 uint32_t getCommandQueueOrdinal(ze_device_handle_t &device) {
95     uint32_t numQueueGroups = 0;
96     SUCCESS_OR_TERMINATE(zeDeviceGetCommandQueueGroupProperties(device, &numQueueGroups, nullptr));
97     if (numQueueGroups == 0) {
98         std::cout << "No queue groups found!\n";
99         std::terminate();
100     }
101     std::vector<ze_command_queue_group_properties_t> queueProperties(numQueueGroups);
102     SUCCESS_OR_TERMINATE(zeDeviceGetCommandQueueGroupProperties(device, &numQueueGroups,
103                                                                 queueProperties.data()));
104     uint32_t computeQueueGroupOrdinal = numQueueGroups;
105     for (uint32_t i = 0; i < numQueueGroups; i++) {
106         if (queueProperties[i].flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE) {
107             computeQueueGroupOrdinal = i;
108             break;
109         }
110     }
111     return computeQueueGroupOrdinal;
112 }
113 
getCopyOnlyCommandQueueOrdinal(ze_device_handle_t & device)114 uint32_t getCopyOnlyCommandQueueOrdinal(ze_device_handle_t &device) {
115     uint32_t numQueueGroups = 0;
116     SUCCESS_OR_TERMINATE(zeDeviceGetCommandQueueGroupProperties(device, &numQueueGroups, nullptr));
117     if (numQueueGroups == 0) {
118         std::cout << "No queue groups found!\n";
119         std::terminate();
120     }
121     std::vector<ze_command_queue_group_properties_t> queueProperties(numQueueGroups);
122     SUCCESS_OR_TERMINATE(zeDeviceGetCommandQueueGroupProperties(device, &numQueueGroups,
123                                                                 queueProperties.data()));
124     uint32_t copyOnlyQueueGroupOrdinal = 0;
125     for (uint32_t i = 0; i < numQueueGroups; i++) {
126         if (!(queueProperties[i].flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE) && (queueProperties[i].flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COPY)) {
127             copyOnlyQueueGroupOrdinal = i;
128             break;
129         }
130     }
131     return copyOnlyQueueGroupOrdinal;
132 }
133 
createCommandQueue(ze_context_handle_t & context,ze_device_handle_t & device,uint32_t * ordinal)134 ze_command_queue_handle_t createCommandQueue(ze_context_handle_t &context, ze_device_handle_t &device, uint32_t *ordinal) {
135     ze_command_queue_handle_t cmdQueue;
136     ze_command_queue_desc_t descriptor = {};
137     descriptor.stype = ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC;
138 
139     descriptor.pNext = nullptr;
140     descriptor.flags = 0;
141     descriptor.mode = ZE_COMMAND_QUEUE_MODE_DEFAULT;
142     descriptor.priority = ZE_COMMAND_QUEUE_PRIORITY_NORMAL;
143 
144     descriptor.ordinal = getCommandQueueOrdinal(device);
145     descriptor.index = 0;
146     SUCCESS_OR_TERMINATE(zeCommandQueueCreate(context, device, &descriptor, &cmdQueue));
147     if (ordinal != nullptr) {
148         *ordinal = descriptor.ordinal;
149     }
150     return cmdQueue;
151 }
152 
createCommandList(ze_context_handle_t & context,ze_device_handle_t & device,ze_command_list_handle_t & cmdList)153 ze_result_t createCommandList(ze_context_handle_t &context, ze_device_handle_t &device, ze_command_list_handle_t &cmdList) {
154     ze_command_list_desc_t descriptor = {};
155     descriptor.stype = ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC;
156 
157     descriptor.pNext = nullptr;
158     descriptor.flags = 0;
159     descriptor.commandQueueGroupOrdinal = getCommandQueueOrdinal(device);
160 
161     return zeCommandListCreate(context, device, &descriptor, &cmdList);
162 }
163 
createEventPoolAndEvents(ze_context_handle_t & context,ze_device_handle_t & device,ze_event_pool_handle_t & eventPool,ze_event_pool_flag_t poolFlag,uint32_t poolSize,ze_event_handle_t * events,ze_event_scope_flag_t signalScope,ze_event_scope_flag_t waitScope)164 void createEventPoolAndEvents(ze_context_handle_t &context,
165                               ze_device_handle_t &device,
166                               ze_event_pool_handle_t &eventPool,
167                               ze_event_pool_flag_t poolFlag,
168                               uint32_t poolSize,
169                               ze_event_handle_t *events,
170                               ze_event_scope_flag_t signalScope,
171                               ze_event_scope_flag_t waitScope) {
172     ze_event_pool_desc_t eventPoolDesc{ZE_STRUCTURE_TYPE_EVENT_POOL_DESC};
173     ze_event_desc_t eventDesc = {ZE_STRUCTURE_TYPE_EVENT_DESC};
174     eventPoolDesc.count = poolSize;
175     eventPoolDesc.flags = poolFlag;
176     SUCCESS_OR_TERMINATE(zeEventPoolCreate(context, &eventPoolDesc, 1, &device, &eventPool));
177 
178     for (uint32_t i = 0; i < poolSize; i++) {
179         eventDesc.index = i;
180         eventDesc.signal = signalScope;
181         eventDesc.wait = waitScope;
182         SUCCESS_OR_TERMINATE(zeEventCreate(eventPool, &eventDesc, events + i));
183     }
184 }
185 
zelloInitContextAndGetDevices(ze_context_handle_t & context,ze_driver_handle_t & driverHandle)186 std::vector<ze_device_handle_t> zelloInitContextAndGetDevices(ze_context_handle_t &context, ze_driver_handle_t &driverHandle) {
187     SUCCESS_OR_TERMINATE(zeInit(ZE_INIT_FLAG_GPU_ONLY));
188 
189     uint32_t driverCount = 0;
190     SUCCESS_OR_TERMINATE(zeDriverGet(&driverCount, nullptr));
191     if (driverCount == 0) {
192         std::cout << "No driver handle found!\n";
193         std::terminate();
194     }
195 
196     SUCCESS_OR_TERMINATE(zeDriverGet(&driverCount, &driverHandle));
197     ze_context_desc_t context_desc = {};
198     context_desc.stype = ZE_STRUCTURE_TYPE_CONTEXT_DESC;
199     SUCCESS_OR_TERMINATE(zeContextCreate(driverHandle, &context_desc, &context));
200 
201     uint32_t deviceCount = 0;
202     SUCCESS_OR_TERMINATE(zeDeviceGet(driverHandle, &deviceCount, nullptr));
203     if (deviceCount == 0) {
204         std::cout << "No device found!\n";
205         std::terminate();
206     }
207     std::vector<ze_device_handle_t> devices(deviceCount, nullptr);
208     SUCCESS_OR_TERMINATE(zeDeviceGet(driverHandle, &deviceCount, devices.data()));
209     return devices;
210 }
211 
zelloInitContextAndGetDevices(ze_context_handle_t & context)212 std::vector<ze_device_handle_t> zelloInitContextAndGetDevices(ze_context_handle_t &context) {
213     ze_driver_handle_t driverHandle;
214     return zelloInitContextAndGetDevices(context, driverHandle);
215 }
216 
initialize(ze_driver_handle_t & driver,ze_context_handle_t & context,ze_device_handle_t & device,ze_command_queue_handle_t & cmdQueue,uint32_t & ordinal)217 void initialize(ze_driver_handle_t &driver, ze_context_handle_t &context, ze_device_handle_t &device, ze_command_queue_handle_t &cmdQueue, uint32_t &ordinal) {
218     std::vector<ze_device_handle_t> devices;
219 
220     devices = zelloInitContextAndGetDevices(context, driver);
221     device = devices[0];
222     cmdQueue = createCommandQueue(context, device, &ordinal);
223 }
224 
teardown(ze_context_handle_t context,ze_command_queue_handle_t cmdQueue)225 static inline void teardown(ze_context_handle_t context, ze_command_queue_handle_t cmdQueue) {
226     SUCCESS_OR_TERMINATE(zeCommandQueueDestroy(cmdQueue));
227     SUCCESS_OR_TERMINATE(zeContextDestroy(context));
228 }
229