1 /*******************************************************************************
2 * Copyright 2019-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef GPU_COMPUTE_KERNEL_CTX_HPP
18 #define GPU_COMPUTE_KERNEL_CTX_HPP
19
20 #include <cassert>
21 #ifdef DEBUG_PRINT
22 #include <iostream>
23 #endif
24 #include <map>
25 #include <set>
26 #include <sstream>
27 #include <string>
28 #include <type_traits>
29
30 #include "common/bit_cast.hpp"
31
32 namespace dnnl {
33 namespace impl {
34 namespace gpu {
35 namespace compute {
36
37 class kernel_ctx_t {
38 public:
kernel_ctx_t()39 kernel_ctx_t() { set_default_options(); }
40
options() const41 std::string options() const {
42 std::ostringstream oss;
43 for (auto &opt : option_set_)
44 oss << " " << opt;
45
46 for (auto &int_var : int_var_map_)
47 oss << " -D" << int_var.first << "=" << int_var.second;
48
49 for (auto &float_var : float_var_map_) {
50 oss << " -D" << float_var.first << "=as_float(0x" << std::hex
51 << utils::bit_cast<uint32_t>(float_var.second) << ")";
52 }
53 return oss.str();
54 }
55
define_int(const char * variable,int64_t value)56 void define_int(const char *variable, int64_t value) {
57 int_var_map_.insert({variable, value});
58 }
59
define_int(const std::string & variable,int64_t value)60 void define_int(const std::string &variable, int64_t value) {
61 define_int(variable.c_str(), value);
62 }
63
64 // TODO: should be removed, any float values should be passed in
65 // kernel parameters
define_float(const char * variable,float value)66 void define_float(const char *variable, float value) {
67 float_var_map_.insert({variable, value});
68 }
69
add_option(const char * option)70 void add_option(const char *option) { option_set_.insert(option); }
add_option(const std::string & option)71 void add_option(const std::string &option) { add_option(option.c_str()); }
72
has_macro(const char * name) const73 bool has_macro(const char *name) const {
74 std::string opt_start = std::string("-D") + name + "=";
75 for (auto &opt : option_set_)
76 if (opt.find(opt_start) != std::string::npos) return true;
77
78 return int_var_map_.count(name) != 0 || float_var_map_.count(name) != 0;
79 }
has_macro(const std::string & name) const80 bool has_macro(const std::string &name) const {
81 return has_macro(name.c_str());
82 }
83
set_data_type(data_type_t dt)84 void set_data_type(data_type_t dt) {
85 switch (dt) {
86 case data_type::bf16: define_int("DT_BF16", 1); break;
87 case data_type::f16: define_int("DT_F16", 1); break;
88 case data_type::f32: define_int("DT_F32", 1); break;
89 case data_type::s8: define_int("DT_S8", 1); break;
90 case data_type::u8: define_int("DT_U8", 1); break;
91 case data_type::s32: define_int("DT_S32", 1); break;
92 default: assert(!"unknown data type"); break;
93 }
94 }
95
print_options() const96 void print_options() const {
97 #ifdef DEBUG_PRINT
98 std::cout << "OPT:\n" << options() << std::endl;
99 #endif
100 }
101
102 template <typename T>
get_scalar(const std::string & s) const103 T get_scalar(const std::string &s) const {
104 UNUSED(s);
105 static_assert(!std::is_same<T, T>::value, "not expected");
106 return {};
107 }
108
data_type() const109 std::string data_type() const {
110 if (int_var_map_.count("DT_F16") != 0) return "f16";
111
112 if (int_var_map_.count("DT_F32") != 0) return "f32";
113
114 if (int_var_map_.count("DT_S8") != 0) return "s8";
115
116 return "";
117 }
118
119 private:
set_default_options()120 void set_default_options() {
121 // By default fp32 division and sqrt are not IEEE-compliant
122 add_option("-cl-fp32-correctly-rounded-divide-sqrt");
123 }
124
125 std::map<std::string, int64_t> int_var_map_;
126 std::map<std::string, float> float_var_map_;
127 std::set<std::string> option_set_;
128 };
129
130 template <>
get_scalar(const std::string & name) const131 inline int64_t kernel_ctx_t::get_scalar(const std::string &name) const {
132 assert(int_var_map_.count(name) != 0 && "not expected");
133 return int_var_map_.at(name);
134 }
135
136 } // namespace compute
137 } // namespace gpu
138 } // namespace impl
139 } // namespace dnnl
140
141 #endif // GPU_COMPUTE_KERNEL_CTX_HPP
142