1 // Copyright (C) 2018 Intel Corporation
2 //
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6
7 #include <vector>
8 #include <functional>
9
10 #include <gtest/gtest.h>
11
12 #include <ade/graph.hpp>
13 #include <ade/typed_graph.hpp>
14
15 #include <ade/communication/comm_interface.hpp>
16 #include <ade/communication/comm_buffer.hpp>
17
18 #include <ade/memory/memory_descriptor.hpp>
19 #include <ade/memory/memory_descriptor_view.hpp>
20 #include <ade/memory/memory_descriptor_ref.hpp>
21
22 #include <ade/metatypes/metatypes.hpp>
23
24 #include <ade/passes/communications.hpp>
25
26 #include <ade/util/iota_range.hpp>
27
28 //=================== Comm channels tests=======================================
29
30 namespace
31 {
32
33 class TestCommChannel : public ade::ICommChannel
34 {
35 public:
36 // ICommChannel interface
37 virtual BufferPrefs getBufferPrefs(const BufferDesc& desc) override;
38 virtual std::unique_ptr<ade::IDataBuffer> getBuffer(const BufferDesc& desc, const BufferPrefs& prefs) override;
39 virtual void setBuffer(const ade::DataBufferView& buffer, const BufferDesc& desc) override;
40
41 ade::IDataBuffer* buff = nullptr;
42 };
43
getBufferPrefs(const BufferDesc & desc)44 ade::ICommChannel::BufferPrefs TestCommChannel::getBufferPrefs(const BufferDesc& desc)
45 {
46 BufferPrefs ret;
47 ret.preferredAlignment.redim(desc.memoryRef.size().dims_count());
48 ade::util::fill(ret.preferredAlignment, 1);
49 return ret;
50 }
51
getBuffer(const BufferDesc &,const BufferPrefs &)52 std::unique_ptr<ade::IDataBuffer> TestCommChannel::getBuffer(const BufferDesc& /*desc*/,
53 const BufferPrefs& /*prefs*/)
54 {
55 return nullptr;
56 }
57
setBuffer(const ade::DataBufferView & buffer,const BufferDesc & desc)58 void TestCommChannel::setBuffer(const ade::DataBufferView& buffer, const BufferDesc& desc)
59 {
60 EXPECT_EQ(nullptr, buff);
61 EXPECT_NE(nullptr, buffer.getBuffer());
62 buff = buffer.getBuffer();
63 EXPECT_EQ(desc.memoryRef.span(), buffer.getSpan());
64 }
65
66 }
67
68
TEST(CommTest,CommChannelSimple)69 TEST(CommTest, CommChannelSimple)
70 {
71 // (src)->[img]->(comm)->[img]->(dst)
72
73 ade::Graph srcGr;
74 using GraphT = ade::TypedGraph<ade::meta::NodeInfo,
75 ade::meta::DataObject,
76 ade::meta::CommNode,
77 ade::meta::CommChannel,
78 ade::meta::CommConsumerCallback,
79 ade::meta::CommProducerCallback,
80 ade::meta::Finalizers>;
81 GraphT gr(srcGr);
82
83 auto srcNode = gr.createNode();
84 auto srcImg = gr.createNode();
85 auto dstNode = gr.createNode();
86 auto dstImg = gr.createNode();
87 auto commNode = gr.createNode();
88
89 gr.link(srcNode, srcImg);
90 gr.link(srcImg, commNode);
91 gr.link(commNode, dstImg);
92 gr.link(dstImg, dstNode);
93
94 gr.metadata(srcNode).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
95 gr.metadata(dstNode).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
96
97 ade::MemoryDescriptor desc(1, {10,10});
98 ade::MemoryDescriptorView view(desc, {ade::util::Span(0,10),
99 ade::util::Span(0, 10)});
100
101 gr.metadata(srcImg).set<ade::meta::DataObject>(ade::meta::DataObject());
102 gr.metadata(dstImg).set<ade::meta::DataObject>(ade::meta::DataObject());
103
104 gr.metadata(srcImg).get<ade::meta::DataObject>().dataRef = view;
105 gr.metadata(dstImg).get<ade::meta::DataObject>().dataRef = view;
106
107 gr.metadata(commNode).set<ade::meta::CommNode>(ade::meta::CommNode(1));
108
109 ade::passes::PassContext ctx{srcGr};
110 ASSERT_THROW(ade::passes::ConnectCommChannels()(ctx), std::runtime_error);
111
112 gr.metadata(srcImg).set<ade::meta::CommChannel>(ade::meta::CommChannel());
113 gr.metadata(dstImg).set<ade::meta::CommChannel>(ade::meta::CommChannel());
114
115 ASSERT_THROW(ade::passes::ConnectCommChannels()(ctx), std::runtime_error);
116
117 auto chan1 = std::make_shared<TestCommChannel>();
118 auto chan2 = std::make_shared<TestCommChannel>();
119
120 gr.metadata(srcImg).get<ade::meta::CommChannel>().channel = chan1;
121 gr.metadata(dstImg).get<ade::meta::CommChannel>().channel = chan2;
122
123 ASSERT_THROW(ade::passes::ConnectCommChannels()(ctx), std::runtime_error);
124
125 gr.metadata(dstNode).set(ade::meta::CommConsumerCallback{});
126
127 ASSERT_THROW(ade::passes::ConnectCommChannels()(ctx), std::runtime_error);
128
129 int callbackCallCount = 0;
130
131 gr.metadata(dstNode).get<ade::meta::CommConsumerCallback>().callback = [&]()
132 {
133 ++callbackCallCount;
134 };
135
136 ade::passes::ConnectCommChannels()(ctx);
137
138 ASSERT_TRUE(gr.metadata(srcNode).contains<ade::meta::CommProducerCallback>());
139 auto producerCallback = gr.metadata(srcNode).get<ade::meta::CommProducerCallback>().callback;
140 ASSERT_NE(nullptr, producerCallback);
141 ASSERT_EQ(0, callbackCallCount);
142
143 auto finalizers = gr.metadata().get(ade::meta::Finalizers{}).finalizers;
144
145 for (auto i: ade::util::iota(10))
146 {
147 (void)i;
148 callbackCallCount = 0;
149 producerCallback();
150 ASSERT_EQ(1, callbackCallCount);
151
152 for (auto& fin: finalizers)
153 {
154 fin();
155 }
156 }
157 }
158
TEST(CommTest,CommChannelComplexDeps)159 TEST(CommTest, CommChannelComplexDeps)
160 {
161 //
162 // (src1)->[img]
163 // \
164 // ->[img]->(comm)->[img]->(dst1)
165 // /
166 // (src2)->[img]
167 // \
168 // ->[img]->(comm)->[img]->(dst2)
169 // /
170 // (src3)->[img]
171 //
172 // (src4)->[img]->(comm)->[img]->(dst3)
173
174 ade::Graph srcGr;
175 using GraphT = ade::TypedGraph<ade::meta::NodeInfo,
176 ade::meta::DataObject,
177 ade::meta::CommNode,
178 ade::meta::CommChannel,
179 ade::meta::CommConsumerCallback,
180 ade::meta::CommProducerCallback,
181 ade::meta::Finalizers>;
182 GraphT gr(srcGr);
183
184 auto srcNode1 = gr.createNode();
185 auto srcNode2 = gr.createNode();
186 auto srcNode3 = gr.createNode();
187 auto srcNode4 = gr.createNode();
188
189 auto srcImg1 = gr.createNode();
190 auto srcImg2 = gr.createNode();
191 auto srcImg3 = gr.createNode();
192 auto srcImg4 = gr.createNode();
193
194 auto tempImg1 = gr.createNode();
195 auto tempImg2 = gr.createNode();
196
197 auto commNode1 = gr.createNode();
198 auto commNode2 = gr.createNode();
199 auto commNode3 = gr.createNode();
200
201 auto dstImg1 = gr.createNode();
202 auto dstImg2 = gr.createNode();
203 auto dstImg3 = gr.createNode();
204
205 auto dstNode1 = gr.createNode();
206 auto dstNode2 = gr.createNode();
207 auto dstNode3 = gr.createNode();
208
209 gr.link(srcNode1, srcImg1);
210 gr.link(srcNode2, srcImg2);
211 gr.link(srcNode3, srcImg3);
212 gr.link(srcNode4, srcImg4);
213
214 gr.link(srcImg1, tempImg1);
215 gr.link(srcImg2, tempImg1);
216 gr.link(srcImg2, tempImg2);
217 gr.link(srcImg3, tempImg2);
218
219 gr.link(tempImg1, commNode1);
220 gr.link(tempImg2, commNode2);
221 gr.link(srcImg4, commNode3);
222
223 gr.link(commNode1, dstImg1);
224 gr.link(commNode2, dstImg2);
225 gr.link(commNode3, dstImg3);
226
227 gr.link(dstImg1, dstNode1);
228 gr.link(dstImg2, dstNode2);
229 gr.link(dstImg3, dstNode3);
230
231 ade::MemoryDescriptor desc1(1, {10,30});
232 ade::MemoryDescriptorView srcView1(desc1, {ade::util::Span(0,10),
233 ade::util::Span(0,30)});
234 ade::MemoryDescriptor desc2(1, {20,20});
235 ade::MemoryDescriptorView srcView2(desc2, {ade::util::Span(0,20),
236 ade::util::Span(0,20)});
237
238 ade::MemoryDescriptorView view1(srcView1, {ade::util::Span(0,10),
239 ade::util::Span(0, 10)});
240 ade::MemoryDescriptorView view2(srcView1, {ade::util::Span(0,10),
241 ade::util::Span(10,20)});
242 ade::MemoryDescriptorView view3(srcView1, {ade::util::Span(0,10),
243 ade::util::Span(20,30)});
244
245 ade::MemoryDescriptorView view5(srcView1, {ade::util::Span(0,10),
246 ade::util::Span(5, 15)});
247 ade::MemoryDescriptorView view6(srcView1, {ade::util::Span(0,10),
248 ade::util::Span(15,25)});
249
250 ade::MemoryDescriptorView view7(srcView2, {ade::util::Span(0,20),
251 ade::util::Span(0,20)});
252
253 gr.metadata(srcNode1).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
254 gr.metadata(srcNode2).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
255 gr.metadata(srcNode3).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
256 gr.metadata(srcNode4).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
257
258 gr.metadata(dstNode1).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
259 gr.metadata(dstNode2).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
260 gr.metadata(dstNode3).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
261
262 gr.metadata(commNode1).set<ade::meta::CommNode>(ade::meta::CommNode(2));
263 gr.metadata(commNode2).set<ade::meta::CommNode>(ade::meta::CommNode(2));
264 gr.metadata(commNode3).set<ade::meta::CommNode>(ade::meta::CommNode(1));
265
266 gr.metadata(srcImg1).set<ade::meta::DataObject>(ade::meta::DataObject());
267 gr.metadata(srcImg2).set<ade::meta::DataObject>(ade::meta::DataObject());
268 gr.metadata(srcImg3).set<ade::meta::DataObject>(ade::meta::DataObject());
269 gr.metadata(srcImg4).set<ade::meta::DataObject>(ade::meta::DataObject());
270 gr.metadata(tempImg1).set<ade::meta::DataObject>(ade::meta::DataObject());
271 gr.metadata(tempImg2).set<ade::meta::DataObject>(ade::meta::DataObject());
272 gr.metadata(dstImg1).set<ade::meta::DataObject>(ade::meta::DataObject());
273 gr.metadata(dstImg2).set<ade::meta::DataObject>(ade::meta::DataObject());
274 gr.metadata(dstImg3).set<ade::meta::DataObject>(ade::meta::DataObject());
275
276 gr.metadata(srcImg1).get<ade::meta::DataObject>().dataRef = view1;
277 gr.metadata(srcImg2).get<ade::meta::DataObject>().dataRef = view2;
278 gr.metadata(srcImg3).get<ade::meta::DataObject>().dataRef = view3;
279
280 gr.metadata(tempImg1).get<ade::meta::DataObject>().dataRef = view5;
281 gr.metadata(tempImg2).get<ade::meta::DataObject>().dataRef = view6;
282
283 gr.metadata(dstImg1).get<ade::meta::DataObject>().dataRef = view5;
284 gr.metadata(dstImg2).get<ade::meta::DataObject>().dataRef = view6;
285
286 gr.metadata(srcImg4).get<ade::meta::DataObject>().dataRef = view7;
287 gr.metadata(dstImg3).get<ade::meta::DataObject>().dataRef = view7;
288
289 ade::passes::PassContext ctx{srcGr};
290 ASSERT_THROW(ade::passes::ConnectCommChannels()(ctx), std::runtime_error);
291
292 auto chan1 = std::make_shared<TestCommChannel>();
293 auto chan2 = std::make_shared<TestCommChannel>();
294 auto chan3 = std::make_shared<TestCommChannel>();
295 auto chan4 = std::make_shared<TestCommChannel>();
296 auto chan5 = std::make_shared<TestCommChannel>();
297 auto chan6 = std::make_shared<TestCommChannel>();
298 auto chan7 = std::make_shared<TestCommChannel>();
299
300 gr.metadata(srcImg1).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan1});
301 gr.metadata(srcImg2).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan2});
302 gr.metadata(srcImg3).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan3});
303 gr.metadata(srcImg4).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan4});
304 gr.metadata(dstImg1).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan5});
305 gr.metadata(dstImg2).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan6});
306 gr.metadata(dstImg3).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan7});
307
308 int consumerCallbackCalled1 = 0;
309 int consumerCallbackCalled2 = 0;
310 int consumerCallbackCalled3 = 0;
311
312 gr.metadata(dstNode1).set(ade::meta::CommConsumerCallback{[&]()
313 {
314 ++consumerCallbackCalled1;
315 }});
316
317 gr.metadata(dstNode2).set(ade::meta::CommConsumerCallback{[&]()
318 {
319 ++consumerCallbackCalled2;
320 }});
321
322 gr.metadata(dstNode3).set(ade::meta::CommConsumerCallback{[&]()
323 {
324 ++consumerCallbackCalled3;
325 }});
326
327 ade::passes::ConnectCommChannels()(ctx);
328
329 auto buff1 = chan1->buff;
330 auto buff2 = chan2->buff;
331 auto buff3 = chan3->buff;
332 auto buff4 = chan4->buff;
333
334 auto buff5 = chan5->buff;
335 auto buff6 = chan6->buff;
336 auto buff7 = chan7->buff;
337
338 // First group
339 EXPECT_NE(nullptr, buff1);
340 EXPECT_EQ(buff1, buff2);
341 EXPECT_EQ(buff1, buff3);
342 EXPECT_EQ(buff1, buff5);
343 EXPECT_EQ(buff1, buff6);
344
345 // Second group
346 EXPECT_NE(nullptr, buff4);
347 EXPECT_EQ(buff4, buff7);
348
349 EXPECT_NE(buff1, buff4);
350
351 ASSERT_TRUE(gr.metadata(srcNode1).contains<ade::meta::CommProducerCallback>());
352 ASSERT_TRUE(gr.metadata(srcNode2).contains<ade::meta::CommProducerCallback>());
353 ASSERT_TRUE(gr.metadata(srcNode3).contains<ade::meta::CommProducerCallback>());
354 ASSERT_TRUE(gr.metadata(srcNode4).contains<ade::meta::CommProducerCallback>());
355
356 auto producerCallback1 = gr.metadata(srcNode1).get<ade::meta::CommProducerCallback>().callback;
357 auto producerCallback2 = gr.metadata(srcNode2).get<ade::meta::CommProducerCallback>().callback;
358 auto producerCallback3 = gr.metadata(srcNode3).get<ade::meta::CommProducerCallback>().callback;
359 auto producerCallback4 = gr.metadata(srcNode4).get<ade::meta::CommProducerCallback>().callback;
360
361 ASSERT_NE(nullptr, producerCallback1);
362 ASSERT_NE(nullptr, producerCallback2);
363 ASSERT_NE(nullptr, producerCallback3);
364 ASSERT_NE(nullptr, producerCallback4);
365
366 ASSERT_EQ(0, consumerCallbackCalled1);
367 ASSERT_EQ(0, consumerCallbackCalled2);
368 ASSERT_EQ(0, consumerCallbackCalled3);
369
370 auto finalizers = gr.metadata().get(ade::meta::Finalizers{}).finalizers;
371
372 ASSERT_TRUE(!finalizers.empty());
373
374 for (auto i: ade::util::iota(10))
375 {
376 (void)i;
377 consumerCallbackCalled1 = 0;
378 consumerCallbackCalled2 = 0;
379 consumerCallbackCalled3 = 0;
380
381 producerCallback1();
382
383 ASSERT_EQ(0, consumerCallbackCalled1);
384 ASSERT_EQ(0, consumerCallbackCalled2);
385 ASSERT_EQ(0, consumerCallbackCalled3);
386
387 producerCallback2();
388
389 ASSERT_EQ(1, consumerCallbackCalled1);
390 ASSERT_EQ(0, consumerCallbackCalled2);
391 ASSERT_EQ(0, consumerCallbackCalled3);
392
393 producerCallback3();
394
395 ASSERT_EQ(1, consumerCallbackCalled1);
396 ASSERT_EQ(1, consumerCallbackCalled2);
397 ASSERT_EQ(0, consumerCallbackCalled3);
398
399 producerCallback4();
400
401 ASSERT_EQ(1, consumerCallbackCalled1);
402 ASSERT_EQ(1, consumerCallbackCalled2);
403 ASSERT_EQ(1, consumerCallbackCalled3);
404
405 for (auto& fin: finalizers)
406 {
407 fin();
408 }
409 }
410 }
411