1 #pragma once
2
3 #include <CL/sycl.hpp>
4
5 #include "sycl_meta.hpp"
6
7 namespace tf {
8
9 // ----------------------------------------------------------------------------
10 // syclGraph class
11 // ----------------------------------------------------------------------------
12
13 // class: syclGraph
14 class syclGraph : public CustomGraphBase {
15
16 friend class syclNode;
17 friend class syclTask;
18 friend class syclFlow;
19 friend class Taskflow;
20 friend class Executor;
21
22 constexpr static int OFFLOADED = 0x01;
23 constexpr static int TOPOLOGY_CHANGED = 0x02;
24
25 public:
26
27 syclGraph() = default;
28 ~syclGraph() = default;
29
30 syclGraph(const syclGraph&) = delete;
31 syclGraph(syclGraph&&);
32
33 syclGraph& operator = (const syclGraph&) = delete;
34 syclGraph& operator = (syclGraph&&);
35
36 template <typename... ArgsT>
37 syclNode* emplace_back(ArgsT&&...);
38
39 bool empty() const;
40
41 void clear();
42 void dump(std::ostream&, const void*, const std::string&) const override final;
43
44 private:
45
46 int _state {0};
47
48 std::vector<std::unique_ptr<syclNode>> _nodes;
49 };
50
51 // ----------------------------------------------------------------------------
52 // syclNode definitions
53 // ----------------------------------------------------------------------------
54
55 // class: syclNode
56 class syclNode {
57
58 friend class syclGraph;
59 friend class syclTask;
60 friend class syclFlow;
61 friend class Taskflow;
62 friend class Executor;
63
64 struct CommandGroupHandler {
65
66 std::function<void(sycl::handler&)> work;
67
68 template <typename F>
CommandGroupHandlertf::syclNode::CommandGroupHandler69 CommandGroupHandler(F&& func) : work {std::forward<F>(func)} {}
70 };
71
72 struct DependentSubmit {
73 std::function<sycl::event(sycl::queue&, std::vector<sycl::event>)> work;
74
75 template <typename F>
DependentSubmittf::syclNode::DependentSubmit76 DependentSubmit(F&& func) : work {std::forward<F>(func)} {}
77 };
78
79 using handle_t = std::variant<
80 CommandGroupHandler,
81 DependentSubmit
82 >;
83
84 public:
85
86 // variant index
87 constexpr static auto COMMAND_GROUP_HANDLER =
88 get_index_v<CommandGroupHandler, handle_t>;
89
90 constexpr static auto DEPENDENT_SUBMIT =
91 get_index_v<DependentSubmit, handle_t>;
92
93 syclNode() = delete;
94
95 template <typename... ArgsT>
96 syclNode(syclGraph&, ArgsT&&...);
97
98 private:
99
100 syclGraph& _graph;
101
102 std::string _name;
103
104 int _level;
105
106 sycl::event _event;
107
108 handle_t _handle;
109
110 SmallVector<syclNode*> _successors;
111 SmallVector<syclNode*> _dependents;
112
113 void _precede(syclNode*);
114 };
115
116 // ----------------------------------------------------------------------------
117 // syclNode definitions
118 // ----------------------------------------------------------------------------
119
120 // Constructor
121 template <typename... ArgsT>
syclNode(syclGraph & g,ArgsT &&...args)122 syclNode::syclNode(syclGraph& g, ArgsT&&... args) :
123 _graph {g},
124 _handle {std::forward<ArgsT>(args)...} {
125 }
126
127 // Procedure: _precede
_precede(syclNode * v)128 inline void syclNode::_precede(syclNode* v) {
129 _graph._state |= syclGraph::TOPOLOGY_CHANGED;
130 _successors.push_back(v);
131 v->_dependents.push_back(this);
132 }
133
134 // ----------------------------------------------------------------------------
135 // syclGraph definitions
136 // ----------------------------------------------------------------------------
137
138 // Move constructor
syclGraph(syclGraph && g)139 inline syclGraph::syclGraph(syclGraph&& g) :
140 _nodes {std::move(g._nodes)} {
141
142 assert(g._nodes.empty());
143 }
144
145 // Move assignment
operator =(syclGraph && rhs)146 inline syclGraph& syclGraph::operator = (syclGraph&& rhs) {
147
148 // lhs
149 _nodes = std::move(rhs._nodes);
150
151 assert(rhs._nodes.empty());
152
153 return *this;
154 }
155
156 // Function: empty
empty() const157 inline bool syclGraph::empty() const {
158 return _nodes.empty();
159 }
160
161 // Procedure: clear
clear()162 inline void syclGraph::clear() {
163 _state = syclGraph::TOPOLOGY_CHANGED;
164 _nodes.clear();
165 }
166
167 // Function: emplace_back
168 template <typename... ArgsT>
emplace_back(ArgsT &&...args)169 syclNode* syclGraph::emplace_back(ArgsT&&... args) {
170
171 _state |= syclGraph::TOPOLOGY_CHANGED;
172
173 auto node = std::make_unique<syclNode>(std::forward<ArgsT>(args)...);
174 _nodes.emplace_back(std::move(node));
175 return _nodes.back().get();
176
177 // TODO: object pool
178
179 //auto node = new syclNode(std::forward<ArgsT>(args)...);
180 //_nodes.push_back(node);
181 //return node;
182 }
183
184 // Procedure: dump the graph to a DOT format
dump(std::ostream & os,const void * root,const std::string & root_name) const185 inline void syclGraph::dump(
186 std::ostream& os, const void* root, const std::string& root_name
187 ) const {
188
189 // recursive dump with stack
190 std::stack<std::tuple<const syclGraph*, const syclNode*, int>> stack;
191 stack.push(std::make_tuple(this, nullptr, 1));
192
193 int pl = 0;
194
195 while(!stack.empty()) {
196
197 auto [graph, parent, l] = stack.top();
198 stack.pop();
199
200 for(int i=0; i<pl-l+1; i++) {
201 os << "}\n";
202 }
203
204 if(parent == nullptr) {
205 if(root) {
206 os << "subgraph cluster_p" << root << " {\nlabel=\"syclFlow: ";
207 if(root_name.empty()) os << 'p' << root;
208 else os << root_name;
209 os << "\";\n" << "color=\"red\"\n";
210 }
211 else {
212 os << "digraph syclFlow {\n";
213 }
214 }
215 else {
216 os << "subgraph cluster_p" << parent << " {\nlabel=\"syclSubflow: ";
217 if(parent->_name.empty()) os << 'p' << parent;
218 else os << parent->_name;
219 os << "\";\n" << "color=\"purple\"\n";
220 }
221
222 for(auto& v : graph->_nodes) {
223
224 os << 'p' << v.get() << "[label=\"";
225 if(v->_name.empty()) {
226 os << 'p' << v.get() << "\"";
227 }
228 else {
229 os << v->_name << "\"";
230 }
231 os << "];\n";
232
233 for(const auto s : v->_successors) {
234 os << 'p' << v.get() << " -> " << 'p' << s << ";\n";
235 }
236
237 if(v->_successors.size() == 0) {
238 if(parent == nullptr) {
239 if(root) {
240 os << 'p' << v.get() << " -> p" << root << ";\n";
241 }
242 }
243 else {
244 os << 'p' << v.get() << " -> p" << parent << ";\n";
245 }
246 }
247 }
248
249 // set the previous level
250 pl = l;
251 }
252
253 for(int i=0; i<pl; i++) {
254 os << "}\n";
255 }
256
257 }
258
259
260 } // end of namespace tf -----------------------------------------------------
261
262
263