1 /*
2  * Copyright (C) 2018-2021 Intel Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  */
7 
8 #pragma once
9 #include "shared/source/helpers/abort.h"
10 #include "shared/source/helpers/debug_helpers.h"
11 #include "shared/source/utilities/reference_tracked_object.h"
12 
13 #include "opencl/source/api/dispatch.h"
14 
15 #include "CL/cl.h"
16 
17 #include <atomic>
18 #include <condition_variable>
19 #include <iostream>
20 #include <mutex>
21 #include <thread>
22 
23 namespace NEO {
24 
25 #if defined(__clang__)
26 #define NO_SANITIZE __attribute__((no_sanitize("undefined")))
27 #else
28 #define NO_SANITIZE
29 #endif
30 template <typename Type>
31 struct OpenCLObjectMapper {
32 };
33 
34 template <typename T>
35 using DerivedType_t = typename OpenCLObjectMapper<T>::DerivedType;
36 
37 template <typename DerivedType>
castToObject(typename DerivedType::BaseType * object)38 NO_SANITIZE inline DerivedType *castToObject(typename DerivedType::BaseType *object) {
39     if (object == nullptr) {
40         return nullptr;
41     }
42 
43     auto derivedObject = static_cast<DerivedType *>(object);
44     if (((derivedObject->getMagic() & DerivedType::maskMagic) == DerivedType::objectMagic) &&
45         (derivedObject->dispatch.icdDispatch == &icdGlobalDispatchTable)) {
46         return derivedObject;
47     }
48 
49     return nullptr;
50 }
51 
52 template <typename DerivedType>
castToObjectOrAbort(typename DerivedType::BaseType * object)53 inline DerivedType *castToObjectOrAbort(typename DerivedType::BaseType *object) {
54     auto derivedObject = castToObject<DerivedType>(object);
55     if (derivedObject == nullptr) {
56         abortExecution();
57     } else {
58         return derivedObject;
59     }
60 }
61 
62 template <typename DerivedType>
castToObject(const typename DerivedType::BaseType * object)63 inline const DerivedType *castToObject(const typename DerivedType::BaseType *object) {
64     return castToObject<DerivedType>(const_cast<typename DerivedType::BaseType *>(object));
65 }
66 
67 template <typename DerivedType>
castToObject(const void * object)68 inline DerivedType *castToObject(const void *object) {
69     cl_mem clMem = const_cast<cl_mem>(static_cast<const _cl_mem *>(object));
70     return castToObject<DerivedType>(clMem);
71 }
72 
73 extern std::thread::id invalidThreadID;
74 
75 class ConditionVariableWithCounter {
76   public:
ConditionVariableWithCounter()77     ConditionVariableWithCounter() {
78         waitersCount = 0;
79     }
80     template <typename... Args>
wait(Args &&...args)81     void wait(Args &&...args) {
82         ++waitersCount;
83         cond.wait(std::forward<Args>(args)...);
84         --waitersCount;
85     }
86 
notify_one()87     void notify_one() { // NOLINT
88         cond.notify_one();
89     }
90 
peekNumWaiters()91     uint32_t peekNumWaiters() {
92         return waitersCount.load();
93     }
94 
95   private:
96     std::atomic_uint waitersCount;
97     std::condition_variable cond;
98 };
99 
100 template <typename T>
101 class TakeOwnershipWrapper {
102   public:
TakeOwnershipWrapper(T & obj)103     TakeOwnershipWrapper(T &obj)
104         : obj(obj) {
105         lock();
106     }
TakeOwnershipWrapper(T & obj,bool lockImmediately)107     TakeOwnershipWrapper(T &obj, bool lockImmediately)
108         : obj(obj) {
109         if (lockImmediately) {
110             lock();
111         }
112     }
~TakeOwnershipWrapper()113     ~TakeOwnershipWrapper() {
114         unlock();
115     }
unlock()116     void unlock() {
117         if (locked) {
118             obj.releaseOwnership();
119             locked = false;
120         }
121     }
122 
lock()123     void lock() {
124         if (!locked) {
125             obj.takeOwnership();
126             locked = true;
127         }
128     }
129 
130   private:
131     T &obj;
132     bool locked = false;
133 };
134 
135 // This class should act as a base class for all CL objects. It will handle the
136 // MT safe and reference things for every CL object.
137 template <typename B>
138 class BaseObject : public B, public ReferenceTrackedObject<DerivedType_t<B>> {
139   public:
140     typedef BaseObject<B> ThisType;
141     typedef B BaseType;
142     typedef DerivedType_t<B> DerivedType;
143 
144     const static cl_ulong maskMagic = 0xFFFFFFFFFFFFFFFFLL;
145     const static cl_ulong deadMagic = 0xFFFFFFFFFFFFFFFFLL;
146 
147     BaseObject(const BaseObject &) = delete;
148     BaseObject &operator=(const BaseObject &) = delete;
149 
150   protected:
151     cl_long magic;
152 
153     mutable std::mutex mtx;
154     mutable ConditionVariableWithCounter cond;
155     mutable std::thread::id owner;
156     mutable uint32_t recursiveOwnageCounter = 0;
157 
BaseObject()158     BaseObject()
159         : magic(DerivedType::objectMagic) {
160         this->incRefApi();
161     }
162 
~BaseObject()163     ~BaseObject() override {
164         magic = deadMagic;
165     }
166 
isValid()167     bool isValid() const {
168         return (magic & DerivedType::maskMagic) == DerivedType::objectMagic;
169     }
170 
convertToInternalObject()171     void convertToInternalObject() {
172         this->incRefInternal();
173         this->decRefApi();
174     }
175 
176   public:
177     NO_SANITIZE
getMagic()178     cl_ulong getMagic() const {
179         return this->magic;
180     }
181 
retain()182     virtual void retain() {
183         DEBUG_BREAK_IF(!isValid());
184         this->incRefApi();
185     }
186 
release()187     virtual unique_ptr_if_unused<DerivedType> release() {
188         DEBUG_BREAK_IF(!isValid());
189         return this->decRefApi();
190     }
191 
getReference()192     cl_int getReference() const {
193         DEBUG_BREAK_IF(!isValid());
194         return this->getRefApiCount();
195     }
196 
takeOwnership()197     MOCKABLE_VIRTUAL void takeOwnership() const {
198         DEBUG_BREAK_IF(!isValid());
199 
200         std::unique_lock<std::mutex> theLock(mtx);
201         std::thread::id self = std::this_thread::get_id();
202 
203         if (owner == invalidThreadID) {
204             owner = self;
205             return;
206         }
207 
208         if (owner == self) {
209             ++recursiveOwnageCounter;
210             return;
211         }
212 
213         cond.wait(theLock, [&] { return owner == invalidThreadID; });
214         owner = self;
215         recursiveOwnageCounter = 0;
216     }
217 
releaseOwnership()218     MOCKABLE_VIRTUAL void releaseOwnership() const {
219         DEBUG_BREAK_IF(!isValid());
220 
221         std::unique_lock<std::mutex> theLock(mtx);
222 
223         if (hasOwnership() == false) {
224             DEBUG_BREAK_IF(true);
225             return;
226         }
227 
228         if (recursiveOwnageCounter > 0) {
229             --recursiveOwnageCounter;
230             return;
231         }
232         owner = invalidThreadID;
233         cond.notify_one();
234     }
235 
236     // checks whether current thread owns object mutex
hasOwnership()237     bool hasOwnership() const {
238         DEBUG_BREAK_IF(!isValid());
239 
240         return (owner == std::this_thread::get_id());
241     }
242 
getCond()243     ConditionVariableWithCounter &getCond() {
244         return this->cond;
245     }
246 };
247 
248 } // namespace NEO
249