1 /*
2  * Copyright (C) 2021 Intel Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  */
7 
8 #pragma once
9 #include "opencl/source/kernel/kernel.h"
10 
11 namespace NEO {
12 template <>
13 struct OpenCLObjectMapper<_cl_kernel> {
14     typedef class MultiDeviceKernel DerivedType;
15 };
16 
17 using KernelVectorType = StackVec<Kernel *, 4>;
18 using KernelInfoContainer = StackVec<const KernelInfo *, 4>;
19 
20 class MultiDeviceKernel : public BaseObject<_cl_kernel> {
21   public:
22     static const cl_ulong objectMagic = 0x3284ADC8EA0AFE25LL;
23 
24     ~MultiDeviceKernel() override;
25     MultiDeviceKernel(KernelVectorType kernelVector, const KernelInfoContainer kernelInfosArg);
26 
27     Kernel *getKernel(uint32_t rootDeviceIndex) const { return kernels[rootDeviceIndex]; }
28     Kernel *getDefaultKernel() const { return defaultKernel; }
29 
30     template <typename kernel_t = Kernel, typename program_t = Program, typename multi_device_kernel_t = MultiDeviceKernel>
31     static multi_device_kernel_t *create(program_t *program, const KernelInfoContainer &kernelInfos, cl_int *errcodeRet) {
32         KernelVectorType kernels{};
33         kernels.resize(program->getMaxRootDeviceIndex() + 1);
34 
35         for (auto &pDevice : program->getDevicesInProgram()) {
36             auto rootDeviceIndex = pDevice->getRootDeviceIndex();
37             if (kernels[rootDeviceIndex]) {
38                 continue;
39             }
40             kernels[rootDeviceIndex] = Kernel::create<kernel_t, program_t>(program, *kernelInfos[rootDeviceIndex], *pDevice, errcodeRet);
41             if (!kernels[rootDeviceIndex]) {
42                 return nullptr;
43             }
44         }
45         auto pMultiDeviceKernel = new multi_device_kernel_t(std::move(kernels), kernelInfos);
46 
47         return pMultiDeviceKernel;
48     }
49 
50     cl_int cloneKernel(MultiDeviceKernel *pSourceMultiDeviceKernel);
51     const std::vector<Kernel::SimpleKernelArgInfo> &getKernelArguments() const;
52     cl_int checkCorrectImageAccessQualifier(cl_uint argIndex, size_t argSize, const void *argValue) const;
53     void unsetArg(uint32_t argIndex);
54     cl_int setArg(uint32_t argIndex, size_t argSize, const void *argVal);
55     cl_int getInfo(cl_kernel_info paramName, size_t paramValueSize, void *paramValue, size_t *paramValueSizeRet) const;
56     cl_int getArgInfo(cl_uint argIndx, cl_kernel_arg_info paramName, size_t paramValueSize, void *paramValue, size_t *paramValueSizeRet) const;
57     const ClDeviceVector &getDevices() const;
58     size_t getKernelArgsNumber() const;
59     Context &getContext() const;
60     cl_int setArgSvmAlloc(uint32_t argIndex, void *svmPtr, MultiGraphicsAllocation *svmAllocs);
61     bool getHasIndirectAccess() const;
62     void setUnifiedMemoryProperty(cl_kernel_exec_info infoType, bool infoValue);
63     void setSvmKernelExecInfo(const MultiGraphicsAllocation &argValue);
64     void clearSvmKernelExecInfo();
65     void setUnifiedMemoryExecInfo(const MultiGraphicsAllocation &argValue);
66     void clearUnifiedMemoryExecInfo();
67     int setKernelThreadArbitrationPolicy(uint32_t propertyValue);
68     cl_int setKernelExecutionType(cl_execution_info_kernel_type_intel executionType);
69     int32_t setAdditionalKernelExecInfoWithParam(uint32_t paramName, size_t paramValueSize, const void *paramValue);
70     Program *getProgram() const { return program; }
71     const KernelInfoContainer &getKernelInfos() const { return kernelInfos; }
72 
73   protected:
74     template <typename FuncType, typename... Args>
75     cl_int getResultFromEachKernel(FuncType function, Args &&...args) const {
76         cl_int retVal = CL_INVALID_VALUE;
77 
78         for (auto &pKernel : kernels) {
79             if (pKernel) {
80                 retVal = (pKernel->*function)(std::forward<Args>(args)...);
81                 if (CL_SUCCESS != retVal) {
82                     break;
83                 }
84             }
85         }
86         return retVal;
87     }
88     template <typename FuncType, typename... Args>
89     void callOnEachKernel(FuncType function, Args &&...args) {
90         for (auto &pKernel : kernels) {
91             if (pKernel) {
92                 (pKernel->*function)(std::forward<Args>(args)...);
93             }
94         }
95     }
96     static Kernel *determineDefaultKernel(KernelVectorType &kernelVector);
97     KernelVectorType kernels;
98     Kernel *defaultKernel = nullptr;
99     Program *program = nullptr;
100     const KernelInfoContainer kernelInfos;
101 };
102 
103 } // namespace NEO
104