1 //===--- amdgpu/impl/internal.h ----------------------------------- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef SRC_RUNTIME_INCLUDE_INTERNAL_H_
9 #define SRC_RUNTIME_INCLUDE_INTERNAL_H_
10 #include <inttypes.h>
11 #include <pthread.h>
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 
17 #include <cstring>
18 #include <map>
19 #include <queue>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "hsa_api.h"
25 
26 #include "atmi.h"
27 #include "atmi_runtime.h"
28 #include "rt.h"
29 
30 #define MAX_NUM_KERNELS (1024 * 16)
31 
32 typedef struct atmi_implicit_args_s {
33   unsigned long offset_x;
34   unsigned long offset_y;
35   unsigned long offset_z;
36   unsigned long hostcall_ptr;
37   char num_gpu_queues;
38   unsigned long gpu_queue_ptr;
39   char num_cpu_queues;
40   unsigned long cpu_worker_signals;
41   unsigned long cpu_queue_ptr;
42   unsigned long kernarg_template_ptr;
43 } atmi_implicit_args_t;
44 
45 extern "C" {
46 
47 #ifdef DEBUG
48 #define DEBUG_PRINT(fmt, ...)                                                  \
49   if (core::Runtime::getInstance().getDebugMode()) {                           \
50     fprintf(stderr, "[%s:%d] " fmt, __FILE__, __LINE__, ##__VA_ARGS__);        \
51   }
52 #else
53 #define DEBUG_PRINT(...)                                                       \
54   do {                                                                         \
55   } while (false)
56 #endif
57 
58 #ifndef HSA_RUNTIME_INC_HSA_H_
59 typedef struct hsa_signal_s {
60   uint64_t handle;
61 } hsa_signal_t;
62 #endif
63 
64 }
65 
66 /* ---------------------------------------------------------------------------------
67  * Simulated CPU Data Structures and API
68  * ---------------------------------------------------------------------------------
69  */
70 
71 #define ATMI_WAIT_STATE HSA_WAIT_STATE_BLOCKED
72 
73 // ---------------------- Kernel Start -------------
74 typedef struct atl_kernel_info_s {
75   uint64_t kernel_object;
76   uint32_t group_segment_size;
77   uint32_t private_segment_size;
78   uint32_t sgpr_count;
79   uint32_t vgpr_count;
80   uint32_t sgpr_spill_count;
81   uint32_t vgpr_spill_count;
82   uint32_t kernel_segment_size;
83   uint32_t num_args;
84   std::vector<uint64_t> arg_alignments;
85   std::vector<uint64_t> arg_offsets;
86   std::vector<uint64_t> arg_sizes;
87 } atl_kernel_info_t;
88 
89 typedef struct atl_symbol_info_s {
90   uint64_t addr;
91   uint32_t size;
92 } atl_symbol_info_t;
93 
94 // ---------------------- Kernel End -------------
95 
96 namespace core {
97 class TaskgroupImpl;
98 class TaskImpl;
99 class Kernel;
100 class KernelImpl;
101 } // namespace core
102 
103 struct SignalPoolT {
SignalPoolTSignalPoolT104   SignalPoolT() {}
105   SignalPoolT(const SignalPoolT &) = delete;
106   SignalPoolT(SignalPoolT &&) = delete;
~SignalPoolTSignalPoolT107   ~SignalPoolT() {
108     size_t N = state.size();
109     for (size_t i = 0; i < N; i++) {
110       hsa_signal_t signal = state.front();
111       state.pop();
112       hsa_status_t rc = hsa_signal_destroy(signal);
113       if (rc != HSA_STATUS_SUCCESS) {
114         DEBUG_PRINT("Signal pool destruction failed\n");
115       }
116     }
117   }
sizeSignalPoolT118   size_t size() {
119     lock l(&mutex);
120     return state.size();
121   }
pushSignalPoolT122   void push(hsa_signal_t s) {
123     lock l(&mutex);
124     state.push(s);
125   }
popSignalPoolT126   hsa_signal_t pop(void) {
127     lock l(&mutex);
128     if (!state.empty()) {
129       hsa_signal_t res = state.front();
130       state.pop();
131       return res;
132     }
133 
134     // Pool empty, attempt to create another signal
135     hsa_signal_t new_signal;
136     hsa_status_t err = hsa_signal_create(0, 0, NULL, &new_signal);
137     if (err == HSA_STATUS_SUCCESS) {
138       return new_signal;
139     }
140 
141     // Fail
142     return {0};
143   }
144 
145 private:
146   static pthread_mutex_t mutex;
147   std::queue<hsa_signal_t> state;
148   struct lock {
lockSignalPoolT::lock149     lock(pthread_mutex_t *m) : m(m) { pthread_mutex_lock(m); }
~lockSignalPoolT::lock150     ~lock() { pthread_mutex_unlock(m); }
151     pthread_mutex_t *m;
152   };
153 };
154 
155 namespace core {
156 hsa_status_t atl_init_gpu_context();
157 
158 hsa_status_t init_hsa();
159 hsa_status_t finalize_hsa();
160 /*
161  * Generic utils
162  */
alignDown(T value,size_t alignment)163 template <typename T> inline T alignDown(T value, size_t alignment) {
164   return (T)(value & ~(alignment - 1));
165 }
166 
alignDown(T * value,size_t alignment)167 template <typename T> inline T *alignDown(T *value, size_t alignment) {
168   return reinterpret_cast<T *>(alignDown((intptr_t)value, alignment));
169 }
170 
alignUp(T value,size_t alignment)171 template <typename T> inline T alignUp(T value, size_t alignment) {
172   return alignDown((T)(value + alignment - 1), alignment);
173 }
174 
alignUp(T * value,size_t alignment)175 template <typename T> inline T *alignUp(T *value, size_t alignment) {
176   return reinterpret_cast<T *>(
177       alignDown((intptr_t)(value + alignment - 1), alignment));
178 }
179 
180 extern bool atl_is_atmi_initialized();
181 
182 bool handle_group_signal(hsa_signal_value_t value, void *arg);
183 
184 hsa_status_t allow_access_to_all_gpu_agents(void *ptr);
185 } // namespace core
186 
187 const char *get_error_string(hsa_status_t err);
188 
189 #endif // SRC_RUNTIME_INCLUDE_INTERNAL_H_
190