1
2 // =================================================================================================
3 // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
4 // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
5 // width of 100 characters per line.
6 //
7 // Author(s):
8 // Cedric Nugteren <www.cedricnugteren.nl>
9 //
10 // This file implements the Database class (see the header for information about the class).
11 //
12 // =================================================================================================
13
14 #include <list>
15
16 #include "utilities/utilities.hpp"
17
18 #include "database/database.hpp"
19
20 #include "database/kernels/xaxpy/xaxpy.hpp"
21 #include "database/kernels/xdot/xdot.hpp"
22 #include "database/kernels/xgemv/xgemv.hpp"
23 #include "database/kernels/xgemv_fast/xgemv_fast.hpp"
24 #include "database/kernels/xgemv_fast_rot/xgemv_fast_rot.hpp"
25 #include "database/kernels/xger/xger.hpp"
26 #include "database/kernels/xgemm/xgemm.hpp"
27 #include "database/kernels/xgemm_direct/xgemm_direct.hpp"
28 #include "database/kernels/copy/copy.hpp"
29 #include "database/kernels/pad/pad.hpp"
30 #include "database/kernels/transpose/transpose.hpp"
31 #include "database/kernels/padtranspose/padtranspose.hpp"
32
33 #include "database/kernels/xtrsv.hpp"
34 #include "database/kernels/invert.hpp"
35 #include "database/apple_cpu_fallback.hpp"
36 #include "database/kernel_selection.hpp"
37
38 namespace clblast {
39 // =================================================================================================
40
41 // Initializes the databases
42 const std::vector<database::DatabaseEntry> Database::database = std::vector<database::DatabaseEntry>{
43 database::XaxpyHalf, database::XaxpySingle, database::XaxpyDouble, database::XaxpyComplexSingle, database::XaxpyComplexDouble,
44 database::XdotHalf, database::XdotSingle, database::XdotDouble, database::XdotComplexSingle, database::XdotComplexDouble,
45 database::XgemvHalf, database::XgemvSingle, database::XgemvDouble, database::XgemvComplexSingle, database::XgemvComplexDouble,
46 database::XgemvFastHalf, database::XgemvFastSingle, database::XgemvFastDouble, database::XgemvFastComplexSingle, database::XgemvFastComplexDouble,
47 database::XgemvFastRotHalf, database::XgemvFastRotSingle, database::XgemvFastRotDouble, database::XgemvFastRotComplexSingle, database::XgemvFastRotComplexDouble,
48 database::XgerHalf, database::XgerSingle, database::XgerDouble, database::XgerComplexSingle, database::XgerComplexDouble,
49 database::XtrsvHalf, database::XtrsvSingle, database::XtrsvDouble, database::XtrsvComplexSingle, database::XtrsvComplexDouble,
50 database::XgemmHalf, database::XgemmSingle, database::XgemmDouble, database::XgemmComplexSingle, database::XgemmComplexDouble,
51 database::XgemmDirectHalf, database::XgemmDirectSingle, database::XgemmDirectDouble, database::XgemmDirectComplexSingle, database::XgemmDirectComplexDouble,
52 database::CopyHalf, database::CopySingle, database::CopyDouble, database::CopyComplexSingle, database::CopyComplexDouble,
53 database::PadHalf, database::PadSingle, database::PadDouble, database::PadComplexSingle, database::PadComplexDouble,
54 database::TransposeHalf, database::TransposeSingle, database::TransposeDouble, database::TransposeComplexSingle, database::TransposeComplexDouble,
55 database::PadtransposeHalf, database::PadtransposeSingle, database::PadtransposeDouble, database::PadtransposeComplexSingle, database::PadtransposeComplexDouble,
56 database::InvertHalf, database::InvertSingle, database::InvertDouble, database::InvertComplexSingle, database::InvertComplexDouble,
57 database::KernelSelectionHalf, database::KernelSelectionSingle, database::KernelSelectionDouble, database::KernelSelectionComplexSingle, database::KernelSelectionComplexDouble
58 };
59 const std::vector<database::DatabaseEntry> Database::apple_cpu_fallback = std::vector<database::DatabaseEntry>{
60 database::XaxpyApple, database::XdotApple,
61 database::XgemvApple, database::XgemvFastApple, database::XgemvFastRotApple, database::XgerApple, database::XtrsvApple,
62 database::XgemmApple, database::XgemmDirectApple,
63 database::CopyApple, database::PadApple, database::TransposeApple, database::PadtransposeApple,
64 database::InvertApple
65 };
66
67 // The default values
68 const std::string Database::kDeviceVendorAll = "default";
69
70 // =================================================================================================
71
72 // Constructor, computing device properties and populating the parameter-vector from the database.
73 // This takes an optional overlay database in case of custom tuning or custom kernels.
Database(const Device & device,const std::string & kernel_name,const Precision precision,const std::vector<database::DatabaseEntry> & overlay)74 Database::Database(const Device &device, const std::string &kernel_name,
75 const Precision precision, const std::vector<database::DatabaseEntry> &overlay):
76 parameters_(std::make_shared<database::Parameters>()) {
77
78 // Finds device information
79 const auto device_type = GetDeviceType(device);
80 const auto device_vendor = GetDeviceVendor(device);
81 const auto device_architecture = GetDeviceArchitecture(device);
82 const auto device_name = GetDeviceName(device);
83
84 // Prints the obtained information in verbose mode
85 log_debug("Device type '" + device_type + "'; vendor '" + device_vendor + "'");
86 log_debug("Device name '" + device_name + "'; architecture '" + device_architecture + "'");
87
88 // Sets the databases to search through
89 auto databases = std::list<std::vector<database::DatabaseEntry>>{overlay, database};
90
91 // Special case: modifies the database if the device is a CPU with Apple OpenCL
92 #if defined(__APPLE__) || defined(__MACOSX)
93 if (device.Type() == "CPU") {
94 const auto extensions = device.Capabilities();
95 const auto is_apple = (extensions.find("cl_APPLE_SetMemObjectDestructor") == std::string::npos) ? false : true;
96 if (is_apple) {
97 databases.push_front(apple_cpu_fallback);
98 }
99 }
100 #endif
101
102 // Searches potentially multiple databases
103 auto search_result = database::Parameters();
104 for (auto &db: databases) {
105 search_result = Search(kernel_name, device_vendor, device_type,
106 device_name, device_architecture, precision, db);
107 if (search_result.size() != 0) {
108 parameters_->insert(search_result.begin(), search_result.end());
109 break;
110 }
111 }
112
113 if (search_result.size() == 0) { throw RuntimeErrorCode(StatusCode::kDatabaseError); }
114 }
115
116 // =================================================================================================
117
118 // Returns a list of OpenCL pre-processor defines in string form
GetDefines() const119 std::string Database::GetDefines() const {
120 std::string defines{};
121 for (auto ¶meter: *parameters_) {
122 defines += "#define "+parameter.first+" "+ToString(parameter.second)+"\n";
123 }
124 return defines;
125 }
126
127 // Retrieves the names of all the parameters
GetParameterNames() const128 std::vector<std::string> Database::GetParameterNames() const {
129 auto parameter_names = std::vector<std::string>();
130 for (auto ¶meter: *parameters_) {
131 parameter_names.push_back(parameter.first);
132 }
133 return parameter_names;
134 }
135
136 // =================================================================================================
137
138 // Searches a particular database for the right kernel and precision
Search(const std::string & this_kernel,const std::string & this_vendor,const std::string & this_type,const std::string & this_device,const std::string & this_architecture,const Precision this_precision,const std::vector<database::DatabaseEntry> & this_database) const139 database::Parameters Database::Search(const std::string &this_kernel,
140 const std::string &this_vendor, const std::string &this_type,
141 const std::string &this_device, const std::string &this_architecture,
142 const Precision this_precision,
143 const std::vector<database::DatabaseEntry> &this_database) const {
144
145 // Selects the right kernel
146 for (auto &db: this_database) {
147 if ((db.kernel == this_kernel) &&
148 (db.precision == this_precision || db.precision == Precision::kAny)) {
149
150 // Searches for the right vendor and device type, or selects the default if unavailable
151 const auto parameters = SearchVendorAndType(this_vendor, this_type, this_device, this_architecture,
152 db.vendors, db.parameter_names);
153 if (parameters.size() != 0) { return parameters; }
154 return SearchVendorAndType(kDeviceVendorAll, database::kDeviceTypeAll, this_device, this_architecture,
155 db.vendors, db.parameter_names);
156 }
157 }
158
159 // If we reached this point, the entry was not found in this database
160 return database::Parameters();
161 }
162
SearchVendorAndType(const std::string & target_vendor,const std::string & target_type,const std::string & this_device,const std::string & this_architecture,const std::vector<database::DatabaseVendor> & vendors,const std::vector<std::string> & parameter_names) const163 database::Parameters Database::SearchVendorAndType(const std::string &target_vendor, const std::string &target_type,
164 const std::string &this_device, const std::string &this_architecture,
165 const std::vector<database::DatabaseVendor> &vendors,
166 const std::vector<std::string> ¶meter_names) const {
167 for (auto &vendor: vendors) {
168 if ((vendor.name == target_vendor) && (vendor.type == target_type)) {
169 log_debug("Found architectures of vendor '" + target_vendor + "' and type '" + target_type + "'");
170
171 // Searches the architecture; if unavailable returns the vendor's default parameters
172 auto parameters = SearchArchitecture(this_architecture, this_device, vendor.architectures, parameter_names);
173 if (parameters.size() != 0) { return parameters; }
174 return SearchArchitecture("default", this_device, vendor.architectures, parameter_names);
175 }
176 }
177 return database::Parameters();
178 }
179
SearchArchitecture(const std::string & target_architecture,const std::string & this_device,const std::vector<database::DatabaseArchitecture> & architectures,const std::vector<std::string> & parameter_names) const180 database::Parameters Database::SearchArchitecture(const std::string &target_architecture,
181 const std::string &this_device,
182 const std::vector<database::DatabaseArchitecture> &architectures,
183 const std::vector<std::string> ¶meter_names) const {
184 for (auto &architecture: architectures) {
185 if (architecture.name == target_architecture) {
186 log_debug("Found devices of architecture type '" + target_architecture + "'");
187
188 // Searches the device; if unavailable returns the architecture's default parameters
189 auto parameters = SearchDevice(this_device, architecture.devices, parameter_names);
190 if (parameters.size() != 0) { return parameters; }
191 return SearchDevice("default", architecture.devices, parameter_names);
192 }
193 }
194 return database::Parameters();
195 }
196
SearchDevice(const std::string & target_device,const std::vector<database::DatabaseDevice> & devices,const std::vector<std::string> & parameter_names) const197 database::Parameters Database::SearchDevice(const std::string &target_device,
198 const std::vector<database::DatabaseDevice> &devices,
199 const std::vector<std::string> ¶meter_names) const {
200 for (auto &device: devices) {
201 const auto device_name = CharArrayToString(device.name);
202 // Cuts off 'target_device' string at 50 since the database cuts off as well
203 const auto target_device_cut_off = (target_device.length() > 50) ? target_device.substr(0, 50) : target_device;
204 if (device_name == target_device_cut_off) {
205 log_debug("Found parameters for device type '" + target_device_cut_off + "'");
206
207 // Sets the parameters accordingly
208 auto parameters = database::Parameters();
209 if (parameter_names.size() > device.parameters.size()) { return database::Parameters(); } // ERROR
210 for (auto i = size_t{0}; i < parameter_names.size(); ++i) {
211 parameters[parameter_names[i]] = static_cast<size_t>(device.parameters[i]);
212 }
213 return parameters;
214 }
215 }
216 return database::Parameters();
217 }
218
219 // Helper to convert from database format to proper types
CharArrayToString(const database::Name char_array) const220 std::string Database::CharArrayToString(const database::Name char_array) const {
221 auto result = std::string(char_array.data());
222 result.erase(result.find_last_not_of(" \t\n\r\f\v") + 1);
223 return result;
224 }
225
226 // =================================================================================================
227 } // namespace clblast
228