1 /*******************************************************************************
2 * Copyright 2019-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_COMPUTE_KERNEL_CTX_HPP
18 #define GPU_COMPUTE_KERNEL_CTX_HPP
19 
20 #include <cassert>
21 #ifdef DEBUG_PRINT
22 #include <iostream>
23 #endif
24 #include <map>
25 #include <set>
26 #include <sstream>
27 #include <string>
28 #include <type_traits>
29 
30 #include "common/bit_cast.hpp"
31 
32 namespace dnnl {
33 namespace impl {
34 namespace gpu {
35 namespace compute {
36 
37 class kernel_ctx_t {
38 public:
kernel_ctx_t()39     kernel_ctx_t() { set_default_options(); }
40 
options() const41     std::string options() const {
42         std::ostringstream oss;
43         for (auto &opt : option_set_)
44             oss << " " << opt;
45 
46         for (auto &int_var : int_var_map_)
47             oss << " -D" << int_var.first << "=" << int_var.second;
48 
49         for (auto &float_var : float_var_map_) {
50             oss << " -D" << float_var.first << "=as_float(0x" << std::hex
51                 << utils::bit_cast<uint32_t>(float_var.second) << ")";
52         }
53         return oss.str();
54     }
55 
define_int(const char * variable,int64_t value)56     void define_int(const char *variable, int64_t value) {
57         int_var_map_.insert({variable, value});
58     }
59 
define_int(const std::string & variable,int64_t value)60     void define_int(const std::string &variable, int64_t value) {
61         define_int(variable.c_str(), value);
62     }
63 
64     // TODO: should be removed, any float values should be passed in
65     // kernel parameters
define_float(const char * variable,float value)66     void define_float(const char *variable, float value) {
67         float_var_map_.insert({variable, value});
68     }
69 
add_option(const char * option)70     void add_option(const char *option) { option_set_.insert(option); }
add_option(const std::string & option)71     void add_option(const std::string &option) { add_option(option.c_str()); }
72 
has_macro(const char * name) const73     bool has_macro(const char *name) const {
74         std::string opt_start = std::string("-D") + name + "=";
75         for (auto &opt : option_set_)
76             if (opt.find(opt_start) != std::string::npos) return true;
77 
78         return int_var_map_.count(name) != 0 || float_var_map_.count(name) != 0;
79     }
has_macro(const std::string & name) const80     bool has_macro(const std::string &name) const {
81         return has_macro(name.c_str());
82     }
83 
set_data_type(data_type_t dt)84     void set_data_type(data_type_t dt) {
85         switch (dt) {
86             case data_type::bf16: define_int("DT_BF16", 1); break;
87             case data_type::f16: define_int("DT_F16", 1); break;
88             case data_type::f32: define_int("DT_F32", 1); break;
89             case data_type::s8: define_int("DT_S8", 1); break;
90             case data_type::u8: define_int("DT_U8", 1); break;
91             case data_type::s32: define_int("DT_S32", 1); break;
92             default: assert(!"unknown data type"); break;
93         }
94     }
95 
print_options() const96     void print_options() const {
97 #ifdef DEBUG_PRINT
98         std::cout << "OPT:\n" << options() << std::endl;
99 #endif
100     }
101 
102     template <typename T>
get_scalar(const std::string & s) const103     T get_scalar(const std::string &s) const {
104         UNUSED(s);
105         static_assert(!std::is_same<T, T>::value, "not expected");
106         return {};
107     }
108 
data_type() const109     std::string data_type() const {
110         if (int_var_map_.count("DT_F16") != 0) return "f16";
111 
112         if (int_var_map_.count("DT_F32") != 0) return "f32";
113 
114         if (int_var_map_.count("DT_S8") != 0) return "s8";
115 
116         return "";
117     }
118 
119 private:
set_default_options()120     void set_default_options() {
121         // By default fp32 division and sqrt are not IEEE-compliant
122         add_option("-cl-fp32-correctly-rounded-divide-sqrt");
123     }
124 
125     std::map<std::string, int64_t> int_var_map_;
126     std::map<std::string, float> float_var_map_;
127     std::set<std::string> option_set_;
128 };
129 
130 template <>
get_scalar(const std::string & name) const131 inline int64_t kernel_ctx_t::get_scalar(const std::string &name) const {
132     assert(int_var_map_.count(name) != 0 && "not expected");
133     return int_var_map_.at(name);
134 }
135 
136 } // namespace compute
137 } // namespace gpu
138 } // namespace impl
139 } // namespace dnnl
140 
141 #endif // GPU_COMPUTE_KERNEL_CTX_HPP
142