1 /* 2 * Copyright (C) 2018-2021 Intel Corporation 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 */ 7 8 #pragma once 9 #include "shared/source/helpers/abort.h" 10 #include "shared/source/helpers/debug_helpers.h" 11 #include "shared/source/utilities/reference_tracked_object.h" 12 13 #include "opencl/source/api/dispatch.h" 14 15 #include "CL/cl.h" 16 17 #include <atomic> 18 #include <condition_variable> 19 #include <iostream> 20 #include <mutex> 21 #include <thread> 22 23 namespace NEO { 24 25 #if defined(__clang__) 26 #define NO_SANITIZE __attribute__((no_sanitize("undefined"))) 27 #else 28 #define NO_SANITIZE 29 #endif 30 template <typename Type> 31 struct OpenCLObjectMapper { 32 }; 33 34 template <typename T> 35 using DerivedType_t = typename OpenCLObjectMapper<T>::DerivedType; 36 37 template <typename DerivedType> castToObject(typename DerivedType::BaseType * object)38NO_SANITIZE inline DerivedType *castToObject(typename DerivedType::BaseType *object) { 39 if (object == nullptr) { 40 return nullptr; 41 } 42 43 auto derivedObject = static_cast<DerivedType *>(object); 44 if (((derivedObject->getMagic() & DerivedType::maskMagic) == DerivedType::objectMagic) && 45 (derivedObject->dispatch.icdDispatch == &icdGlobalDispatchTable)) { 46 return derivedObject; 47 } 48 49 return nullptr; 50 } 51 52 template <typename DerivedType> castToObjectOrAbort(typename DerivedType::BaseType * object)53inline DerivedType *castToObjectOrAbort(typename DerivedType::BaseType *object) { 54 auto derivedObject = castToObject<DerivedType>(object); 55 if (derivedObject == nullptr) { 56 abortExecution(); 57 } else { 58 return derivedObject; 59 } 60 } 61 62 template <typename DerivedType> castToObject(const typename DerivedType::BaseType * object)63inline const DerivedType *castToObject(const typename DerivedType::BaseType *object) { 64 return castToObject<DerivedType>(const_cast<typename DerivedType::BaseType *>(object)); 65 } 66 67 template <typename DerivedType> castToObject(const void * object)68inline DerivedType *castToObject(const void *object) { 69 cl_mem clMem = const_cast<cl_mem>(static_cast<const _cl_mem *>(object)); 70 return castToObject<DerivedType>(clMem); 71 } 72 73 extern std::thread::id invalidThreadID; 74 75 class ConditionVariableWithCounter { 76 public: ConditionVariableWithCounter()77 ConditionVariableWithCounter() { 78 waitersCount = 0; 79 } 80 template <typename... Args> wait(Args &&...args)81 void wait(Args &&...args) { 82 ++waitersCount; 83 cond.wait(std::forward<Args>(args)...); 84 --waitersCount; 85 } 86 notify_one()87 void notify_one() { // NOLINT 88 cond.notify_one(); 89 } 90 peekNumWaiters()91 uint32_t peekNumWaiters() { 92 return waitersCount.load(); 93 } 94 95 private: 96 std::atomic_uint waitersCount; 97 std::condition_variable cond; 98 }; 99 100 template <typename T> 101 class TakeOwnershipWrapper { 102 public: TakeOwnershipWrapper(T & obj)103 TakeOwnershipWrapper(T &obj) 104 : obj(obj) { 105 lock(); 106 } TakeOwnershipWrapper(T & obj,bool lockImmediately)107 TakeOwnershipWrapper(T &obj, bool lockImmediately) 108 : obj(obj) { 109 if (lockImmediately) { 110 lock(); 111 } 112 } ~TakeOwnershipWrapper()113 ~TakeOwnershipWrapper() { 114 unlock(); 115 } unlock()116 void unlock() { 117 if (locked) { 118 obj.releaseOwnership(); 119 locked = false; 120 } 121 } 122 lock()123 void lock() { 124 if (!locked) { 125 obj.takeOwnership(); 126 locked = true; 127 } 128 } 129 130 private: 131 T &obj; 132 bool locked = false; 133 }; 134 135 // This class should act as a base class for all CL objects. It will handle the 136 // MT safe and reference things for every CL object. 137 template <typename B> 138 class BaseObject : public B, public ReferenceTrackedObject<DerivedType_t<B>> { 139 public: 140 typedef BaseObject<B> ThisType; 141 typedef B BaseType; 142 typedef DerivedType_t<B> DerivedType; 143 144 const static cl_ulong maskMagic = 0xFFFFFFFFFFFFFFFFLL; 145 const static cl_ulong deadMagic = 0xFFFFFFFFFFFFFFFFLL; 146 147 BaseObject(const BaseObject &) = delete; 148 BaseObject &operator=(const BaseObject &) = delete; 149 150 protected: 151 cl_long magic; 152 153 mutable std::mutex mtx; 154 mutable ConditionVariableWithCounter cond; 155 mutable std::thread::id owner; 156 mutable uint32_t recursiveOwnageCounter = 0; 157 BaseObject()158 BaseObject() 159 : magic(DerivedType::objectMagic) { 160 this->incRefApi(); 161 } 162 ~BaseObject()163 ~BaseObject() override { 164 magic = deadMagic; 165 } 166 isValid()167 bool isValid() const { 168 return (magic & DerivedType::maskMagic) == DerivedType::objectMagic; 169 } 170 convertToInternalObject()171 void convertToInternalObject() { 172 this->incRefInternal(); 173 this->decRefApi(); 174 } 175 176 public: 177 NO_SANITIZE getMagic()178 cl_ulong getMagic() const { 179 return this->magic; 180 } 181 retain()182 virtual void retain() { 183 DEBUG_BREAK_IF(!isValid()); 184 this->incRefApi(); 185 } 186 release()187 virtual unique_ptr_if_unused<DerivedType> release() { 188 DEBUG_BREAK_IF(!isValid()); 189 return this->decRefApi(); 190 } 191 getReference()192 cl_int getReference() const { 193 DEBUG_BREAK_IF(!isValid()); 194 return this->getRefApiCount(); 195 } 196 takeOwnership()197 MOCKABLE_VIRTUAL void takeOwnership() const { 198 DEBUG_BREAK_IF(!isValid()); 199 200 std::unique_lock<std::mutex> theLock(mtx); 201 std::thread::id self = std::this_thread::get_id(); 202 203 if (owner == invalidThreadID) { 204 owner = self; 205 return; 206 } 207 208 if (owner == self) { 209 ++recursiveOwnageCounter; 210 return; 211 } 212 213 cond.wait(theLock, [&] { return owner == invalidThreadID; }); 214 owner = self; 215 recursiveOwnageCounter = 0; 216 } 217 releaseOwnership()218 MOCKABLE_VIRTUAL void releaseOwnership() const { 219 DEBUG_BREAK_IF(!isValid()); 220 221 std::unique_lock<std::mutex> theLock(mtx); 222 223 if (hasOwnership() == false) { 224 DEBUG_BREAK_IF(true); 225 return; 226 } 227 228 if (recursiveOwnageCounter > 0) { 229 --recursiveOwnageCounter; 230 return; 231 } 232 owner = invalidThreadID; 233 cond.notify_one(); 234 } 235 236 // checks whether current thread owns object mutex hasOwnership()237 bool hasOwnership() const { 238 DEBUG_BREAK_IF(!isValid()); 239 240 return (owner == std::this_thread::get_id()); 241 } 242 getCond()243 ConditionVariableWithCounter &getCond() { 244 return this->cond; 245 } 246 }; 247 248 } // namespace NEO 249