1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4 //
5 // Copyright (C) 2017, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7
8 #include "precomp.hpp"
9 #include "halide_scheduler.hpp"
10 #include "op_halide.hpp"
11
12 namespace cv
13 {
14 namespace dnn
15 {
16
17 #ifdef HAVE_HALIDE
applySplit(const FileNode & directive,Halide::Func & func,const FileNode & params)18 static void applySplit(const FileNode& directive, Halide::Func& func,
19 const FileNode& params)
20 {
21 for (const auto& varNode : directive)
22 {
23 const std::string varName = varNode.name();
24 const std::string factorName = (std::string)varNode;
25 Halide::Var var(varName);
26 Halide::Var outerVar(varName + "o");
27 Halide::Var innerVar(varName + "i");
28 // If split factor is integer or parameters map has parameter value.
29 CV_Assert(varNode.isString() && !params[factorName].empty() ||
30 varNode.isInt());
31 int factor = (int)(varNode.isInt() ? varNode : params[factorName]);
32 func.split(var, outerVar, innerVar, factor);
33 }
34 }
35
applyReorder(const FileNode & directive,Halide::Func & func)36 static void applyReorder(const FileNode& directive, Halide::Func& func)
37 {
38 std::string varName;
39 const int numVars = directive.size();
40 std::vector<Halide::VarOrRVar> reorderedVars;
41 reorderedVars.reserve(numVars);
42 for (int i = 0; i < numVars; ++i)
43 {
44 directive[i] >> varName;
45 reorderedVars.push_back(Halide::Var(varName));
46 }
47 func.reorder(reorderedVars);
48 }
49
applyFuse(const FileNode & directive,Halide::Func & func)50 static void applyFuse(const FileNode& directive, Halide::Func& func)
51 {
52 CV_Assert(directive["src"].size() >= 2);
53 CV_Assert(directive["dst"].size() == 1);
54
55 std::string str;
56 directive["src"][0] >> str;
57 Halide::Var firstVar(str);
58 directive["src"][1] >> str;
59 Halide::Var secondVar(str);
60 directive["dst"] >> str;
61 Halide::Var dstVar(str);
62
63 func.fuse(firstVar, secondVar, dstVar);
64 for (int i = 2, n = directive["src"].size(); i < n; ++i)
65 {
66 directive["src"][i] >> str;
67 func.fuse(Halide::Var(str), dstVar, dstVar);
68 }
69 }
70
applyParallel(const FileNode & directive,Halide::Func & func)71 static void applyParallel(const FileNode& directive, Halide::Func& func)
72 {
73 std::string varName;
74 for (int i = 0, n = directive.size(); i < n; ++i)
75 {
76 directive[i] >> varName;
77 func.parallel(Halide::Var(varName));
78 }
79 }
80
applyUnroll(const FileNode & directive,Halide::Func & func)81 static void applyUnroll(const FileNode& directive, Halide::Func& func)
82 {
83 std::string varName;
84 for (int i = 0, n = directive.size(); i < n; ++i)
85 {
86 directive[i] >> varName;
87 func.unroll(Halide::Var(varName));
88 }
89 }
90
applyVectorize(const FileNode & directive,Halide::Func & func,const FileNode & params)91 static void applyVectorize(const FileNode& directive, Halide::Func& func,
92 const FileNode& params)
93 {
94 for (const auto& varNode : directive)
95 {
96 const std::string varName = varNode.name();
97 const std::string factorName = (std::string)varNode;
98 // If split factor is integer or parameters map has parameter value.
99 CV_Assert(varNode.isString() && !params[factorName].empty() ||
100 varNode.isInt());
101 int factor = (int)(varNode.isInt() ? varNode : params[factorName]);
102 Halide::Var var(varName);
103 Halide::Var inner(varName + "v");
104 func.split(var, var, inner, factor);
105 func.vectorize(inner);
106 }
107 }
108
applyStoreAt(const FileNode & directive,Halide::Func & func,std::map<std::string,Halide::Func> & funcsMap)109 static void applyStoreAt(const FileNode& directive, Halide::Func& func,
110 std::map<std::string, Halide::Func>& funcsMap)
111 {
112 for (const auto& funcNode : directive)
113 {
114 const std::string targetFuncName = funcNode.name();
115 if (funcsMap.find(targetFuncName) == funcsMap.end())
116 CV_Error(cv::Error::StsParseError, "Function " + targetFuncName +
117 " is not represented in Halide pipeline");
118 Halide::Func targetFunc = funcsMap[targetFuncName];
119 func.store_at(targetFunc, (std::string)funcNode);
120 break;
121 }
122 }
123
applyComputeAt(const FileNode & directive,Halide::Func & func,std::map<std::string,Halide::Func> & funcsMap)124 static void applyComputeAt(const FileNode& directive, Halide::Func& func,
125 std::map<std::string, Halide::Func>& funcsMap)
126 {
127 for (const auto& funcNode : directive)
128 {
129 const std::string targetFuncName = funcNode.name();
130 if (funcsMap.find(targetFuncName) == funcsMap.end())
131 CV_Error(cv::Error::StsParseError, "Function " + targetFuncName +
132 " is not represented in Halide pipeline");
133 Halide::Func targetFunc = funcsMap[targetFuncName];
134 func.compute_at(targetFunc, (std::string)funcNode);
135 break;
136 }
137 }
138
applyComputeRoot(const FileNode & directive,Halide::Func & func)139 static void applyComputeRoot(const FileNode& directive, Halide::Func& func)
140 {
141 bool compute_root;
142 directive >> compute_root;
143 if (compute_root)
144 func.compute_root();
145 }
146
applyGpuBlocks(const FileNode & directive,Halide::Func & func)147 static void applyGpuBlocks(const FileNode& directive, Halide::Func& func)
148 {
149 std::string varName;
150 for (int i = 0, n = directive.size(); i < n; ++i)
151 {
152 directive[i] >> varName;
153 func.gpu_blocks(Halide::Var(varName));
154 }
155 }
156
applyGpuThreads(const FileNode & directive,Halide::Func & func)157 static void applyGpuThreads(const FileNode& directive, Halide::Func& func)
158 {
159 std::string varName;
160 for (int i = 0, n = directive.size(); i < n; ++i)
161 {
162 directive[i] >> varName;
163 func.gpu_threads(Halide::Var(varName));
164 }
165 }
166
apply(const FileNode & directives,Halide::Func & func,std::map<std::string,Halide::Func> & funcsMap,const FileNode & params)167 static void apply(const FileNode& directives, Halide::Func& func,
168 std::map<std::string, Halide::Func>& funcsMap,
169 const FileNode& params)
170 {
171 for (const auto& directive : directives)
172 {
173 if (directive.name() == "split")
174 applySplit(directive, func, params);
175 else if (directive.name() == "reorder")
176 applyReorder(directive, func);
177 else if (directive.name() == "fuse")
178 applyFuse(directive, func);
179 else if (directive.name() == "parallel")
180 applyParallel(directive, func);
181 else if (directive.name() == "unroll")
182 applyUnroll(directive, func);
183 else if (directive.name() == "vectorize")
184 applyVectorize(directive, func, params);
185 else if (directive.name() == "store_at")
186 applyStoreAt(directive, func, funcsMap);
187 else if (directive.name() == "compute_at")
188 applyComputeAt(directive, func, funcsMap);
189 else if (directive.name() == "compute_root")
190 applyComputeRoot(directive, func);
191 else if (directive.name() == "gpu_blocks")
192 applyGpuBlocks(directive, func);
193 else if (directive.name() == "gpu_threads")
194 applyGpuThreads(directive, func);
195 else
196 CV_Error(Error::StsNotImplemented, "Scheduling directive " +
197 directive.name() + " is not implemented.");
198 }
199 }
200
201 // Remove any numeric symbols after '$' sign.
Deunique(std::string str)202 static std::string Deunique(std::string str)
203 {
204 int pos = -1;
205 do
206 {
207 pos = str.find('$');
208 if (pos != -1)
209 {
210 int len = str.find_first_not_of("0123456789", pos + 1) - pos;
211 str = str.replace(pos, len, "");
212 }
213 }
214 while (pos != -1);
215 return str;
216 }
217 #endif // HAVE_HALIDE
218
HalideScheduler(const std::string & configFile)219 HalideScheduler::HalideScheduler(const std::string& configFile)
220 {
221 if (!configFile.empty())
222 fs = FileStorage(configFile, FileStorage::READ);
223 }
224
~HalideScheduler()225 HalideScheduler::~HalideScheduler()
226 {
227 if (fs.isOpened())
228 fs.release();
229 }
230
process(Ptr<BackendNode> & node)231 bool HalideScheduler::process(Ptr<BackendNode>& node)
232 {
233 #ifdef HAVE_HALIDE
234 if (!fs.isOpened())
235 return false;
236
237 const FileNode& scheduleNode = fs["scheduling"];
238 if (scheduleNode.empty())
239 CV_Error(cv::Error::StsParseError, "Scheduling file should has scheduling node");
240
241 std::string str;
242 std::map<std::string, Halide::Func> funcsMap; // Scheduled functions.
243 // For every function, from top to bottom, we try to find a scheduling node.
244 // Scheduling is successful (return true) if for the first function (top)
245 // node is represented.
246 CV_Assert(!node.empty());
247 std::vector<Halide::Func>& funcs = node.dynamicCast<HalideBackendNode>()->funcs;
248 for (int i = funcs.size() - 1; i >= 0; --i)
249 {
250 Halide::Func& func = funcs[i];
251 // For functions with the same name Halide generates unique names
252 // for example func, func$1, func$2.
253 // They are always formed with '$' and number.
254 std::string funcName = Deunique(func.name());
255
256 const FileNode& funcNode = scheduleNode[funcName];
257 if (!funcNode.empty())
258 {
259 if (!funcNode["pattern"].empty())
260 {
261 funcNode["pattern"] >> str;
262 if (fs["patterns"][str].empty())
263 CV_Error(cv::Error::StsParseError, "Scheduling pattern " + str +
264 " is not defined");
265 apply(fs["patterns"][str], func, funcsMap, funcNode["params"]);
266 }
267 else
268 {
269 apply(funcNode, func, funcsMap, funcNode["params"]);
270 }
271 }
272 else
273 {
274 if (funcsMap.empty())
275 return false;
276 }
277 funcsMap[funcName] = func;
278 }
279 return true;
280 #endif // HAVE_HALIDE
281 return false;
282 }
283
284 } // namespace dnn
285 } // namespace cv
286