1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4 //
5 // Copyright (C) 2017, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7 
8 #include "precomp.hpp"
9 #include "halide_scheduler.hpp"
10 #include "op_halide.hpp"
11 
12 namespace cv
13 {
14 namespace dnn
15 {
16 
17 #ifdef HAVE_HALIDE
applySplit(const FileNode & directive,Halide::Func & func,const FileNode & params)18 static void applySplit(const FileNode& directive, Halide::Func& func,
19                        const FileNode& params)
20 {
21     for (const auto& varNode : directive)
22     {
23         const std::string varName = varNode.name();
24         const std::string factorName = (std::string)varNode;
25         Halide::Var var(varName);
26         Halide::Var outerVar(varName + "o");
27         Halide::Var innerVar(varName + "i");
28         // If split factor is integer or parameters map has parameter value.
29         CV_Assert(varNode.isString() && !params[factorName].empty() ||
30                   varNode.isInt());
31         int factor = (int)(varNode.isInt() ? varNode : params[factorName]);
32         func.split(var, outerVar, innerVar, factor);
33     }
34 }
35 
applyReorder(const FileNode & directive,Halide::Func & func)36 static void applyReorder(const FileNode& directive, Halide::Func& func)
37 {
38     std::string varName;
39     const int numVars = directive.size();
40     std::vector<Halide::VarOrRVar> reorderedVars;
41     reorderedVars.reserve(numVars);
42     for (int i = 0; i < numVars; ++i)
43     {
44         directive[i] >> varName;
45         reorderedVars.push_back(Halide::Var(varName));
46     }
47     func.reorder(reorderedVars);
48 }
49 
applyFuse(const FileNode & directive,Halide::Func & func)50 static void applyFuse(const FileNode& directive, Halide::Func& func)
51 {
52     CV_Assert(directive["src"].size() >= 2);
53     CV_Assert(directive["dst"].size() == 1);
54 
55     std::string str;
56     directive["src"][0] >> str;
57     Halide::Var firstVar(str);
58     directive["src"][1] >> str;
59     Halide::Var secondVar(str);
60     directive["dst"] >> str;
61     Halide::Var dstVar(str);
62 
63     func.fuse(firstVar, secondVar, dstVar);
64     for (int i = 2, n = directive["src"].size(); i < n; ++i)
65     {
66         directive["src"][i] >> str;
67         func.fuse(Halide::Var(str), dstVar, dstVar);
68     }
69 }
70 
applyParallel(const FileNode & directive,Halide::Func & func)71 static void applyParallel(const FileNode& directive, Halide::Func& func)
72 {
73     std::string varName;
74     for (int i = 0, n = directive.size(); i < n; ++i)
75     {
76         directive[i] >> varName;
77         func.parallel(Halide::Var(varName));
78     }
79 }
80 
applyUnroll(const FileNode & directive,Halide::Func & func)81 static void applyUnroll(const FileNode& directive, Halide::Func& func)
82 {
83     std::string varName;
84     for (int i = 0, n = directive.size(); i < n; ++i)
85     {
86         directive[i] >> varName;
87         func.unroll(Halide::Var(varName));
88     }
89 }
90 
applyVectorize(const FileNode & directive,Halide::Func & func,const FileNode & params)91 static void applyVectorize(const FileNode& directive, Halide::Func& func,
92                            const FileNode& params)
93 {
94     for (const auto& varNode : directive)
95     {
96         const std::string varName = varNode.name();
97         const std::string factorName = (std::string)varNode;
98         // If split factor is integer or parameters map has parameter value.
99         CV_Assert(varNode.isString() && !params[factorName].empty() ||
100                   varNode.isInt());
101         int factor = (int)(varNode.isInt() ? varNode : params[factorName]);
102         Halide::Var var(varName);
103         Halide::Var inner(varName + "v");
104         func.split(var, var, inner, factor);
105         func.vectorize(inner);
106     }
107 }
108 
applyStoreAt(const FileNode & directive,Halide::Func & func,std::map<std::string,Halide::Func> & funcsMap)109 static void applyStoreAt(const FileNode& directive, Halide::Func& func,
110                          std::map<std::string, Halide::Func>& funcsMap)
111 {
112     for (const auto& funcNode : directive)
113     {
114         const std::string targetFuncName = funcNode.name();
115         if (funcsMap.find(targetFuncName) == funcsMap.end())
116             CV_Error(cv::Error::StsParseError, "Function " + targetFuncName +
117                      " is not represented in Halide pipeline");
118         Halide::Func targetFunc = funcsMap[targetFuncName];
119         func.store_at(targetFunc, (std::string)funcNode);
120         break;
121     }
122 }
123 
applyComputeAt(const FileNode & directive,Halide::Func & func,std::map<std::string,Halide::Func> & funcsMap)124 static void applyComputeAt(const FileNode& directive, Halide::Func& func,
125                            std::map<std::string, Halide::Func>& funcsMap)
126 {
127     for (const auto& funcNode : directive)
128     {
129         const std::string targetFuncName = funcNode.name();
130         if (funcsMap.find(targetFuncName) == funcsMap.end())
131             CV_Error(cv::Error::StsParseError, "Function " + targetFuncName +
132                      " is not represented in Halide pipeline");
133         Halide::Func targetFunc = funcsMap[targetFuncName];
134         func.compute_at(targetFunc, (std::string)funcNode);
135         break;
136     }
137 }
138 
applyComputeRoot(const FileNode & directive,Halide::Func & func)139 static void applyComputeRoot(const FileNode& directive, Halide::Func& func)
140 {
141     bool compute_root;
142     directive >> compute_root;
143     if (compute_root)
144         func.compute_root();
145 }
146 
applyGpuBlocks(const FileNode & directive,Halide::Func & func)147 static void applyGpuBlocks(const FileNode& directive, Halide::Func& func)
148 {
149     std::string varName;
150     for (int i = 0, n = directive.size(); i < n; ++i)
151     {
152         directive[i] >> varName;
153         func.gpu_blocks(Halide::Var(varName));
154     }
155 }
156 
applyGpuThreads(const FileNode & directive,Halide::Func & func)157 static void applyGpuThreads(const FileNode& directive, Halide::Func& func)
158 {
159     std::string varName;
160     for (int i = 0, n = directive.size(); i < n; ++i)
161     {
162         directive[i] >> varName;
163         func.gpu_threads(Halide::Var(varName));
164     }
165 }
166 
apply(const FileNode & directives,Halide::Func & func,std::map<std::string,Halide::Func> & funcsMap,const FileNode & params)167 static void apply(const FileNode& directives, Halide::Func& func,
168                   std::map<std::string, Halide::Func>& funcsMap,
169                   const FileNode& params)
170 {
171     for (const auto& directive : directives)
172     {
173         if (directive.name() == "split")
174             applySplit(directive, func, params);
175         else if (directive.name() == "reorder")
176             applyReorder(directive, func);
177         else if (directive.name() == "fuse")
178             applyFuse(directive, func);
179         else if (directive.name() == "parallel")
180             applyParallel(directive, func);
181         else if (directive.name() == "unroll")
182             applyUnroll(directive, func);
183         else if (directive.name() == "vectorize")
184             applyVectorize(directive, func, params);
185         else if (directive.name() == "store_at")
186             applyStoreAt(directive, func, funcsMap);
187         else if (directive.name() == "compute_at")
188             applyComputeAt(directive, func, funcsMap);
189         else if (directive.name() == "compute_root")
190             applyComputeRoot(directive, func);
191         else if (directive.name() == "gpu_blocks")
192             applyGpuBlocks(directive, func);
193         else if (directive.name() == "gpu_threads")
194             applyGpuThreads(directive, func);
195         else
196             CV_Error(Error::StsNotImplemented, "Scheduling directive " +
197                      directive.name() + " is not implemented.");
198     }
199 }
200 
201 // Remove any numeric symbols after '$' sign.
Deunique(std::string str)202 static std::string Deunique(std::string str)
203 {
204     int pos = -1;
205     do
206     {
207         pos = str.find('$');
208         if (pos != -1)
209         {
210             int len = str.find_first_not_of("0123456789", pos + 1) - pos;
211             str = str.replace(pos, len, "");
212         }
213     }
214     while (pos != -1);
215     return str;
216 }
217 #endif  // HAVE_HALIDE
218 
HalideScheduler(const std::string & configFile)219 HalideScheduler::HalideScheduler(const std::string& configFile)
220 {
221     if (!configFile.empty())
222         fs = FileStorage(configFile, FileStorage::READ);
223 }
224 
~HalideScheduler()225 HalideScheduler::~HalideScheduler()
226 {
227     if (fs.isOpened())
228         fs.release();
229 }
230 
process(Ptr<BackendNode> & node)231 bool HalideScheduler::process(Ptr<BackendNode>& node)
232 {
233 #ifdef HAVE_HALIDE
234     if (!fs.isOpened())
235         return false;
236 
237     const FileNode& scheduleNode = fs["scheduling"];
238     if (scheduleNode.empty())
239         CV_Error(cv::Error::StsParseError, "Scheduling file should has scheduling node");
240 
241     std::string str;
242     std::map<std::string, Halide::Func> funcsMap;  // Scheduled functions.
243     // For every function, from top to bottom, we try to find a scheduling node.
244     // Scheduling is successful (return true) if for the first function (top)
245     // node is represented.
246     CV_Assert(!node.empty());
247     std::vector<Halide::Func>& funcs = node.dynamicCast<HalideBackendNode>()->funcs;
248     for (int i = funcs.size() - 1; i >= 0; --i)
249     {
250         Halide::Func& func = funcs[i];
251         // For functions with the same name Halide generates unique names
252         // for example func, func$1, func$2.
253         // They are always formed with '$' and number.
254         std::string funcName = Deunique(func.name());
255 
256         const FileNode& funcNode = scheduleNode[funcName];
257         if (!funcNode.empty())
258         {
259             if (!funcNode["pattern"].empty())
260             {
261                 funcNode["pattern"] >> str;
262                 if (fs["patterns"][str].empty())
263                     CV_Error(cv::Error::StsParseError, "Scheduling pattern " + str +
264                                                        " is not defined");
265                 apply(fs["patterns"][str], func, funcsMap, funcNode["params"]);
266             }
267             else
268             {
269                 apply(funcNode, func, funcsMap, funcNode["params"]);
270             }
271         }
272         else
273         {
274             if (funcsMap.empty())
275                 return false;
276         }
277         funcsMap[funcName] = func;
278     }
279     return true;
280 #endif  // HAVE_HALIDE
281     return false;
282 }
283 
284 }  // namespace dnn
285 }  // namespace cv
286