1 // Copyright (C) 2018 Intel Corporation
2 //
3 //
4 // SPDX-License-Identifier: Apache-2.0
5 //
6 
7 #include <vector>
8 #include <functional>
9 
10 #include <gtest/gtest.h>
11 
12 #include <ade/graph.hpp>
13 #include <ade/typed_graph.hpp>
14 
15 #include <ade/communication/comm_interface.hpp>
16 #include <ade/communication/comm_buffer.hpp>
17 
18 #include <ade/memory/memory_descriptor.hpp>
19 #include <ade/memory/memory_descriptor_view.hpp>
20 #include <ade/memory/memory_descriptor_ref.hpp>
21 
22 #include <ade/metatypes/metatypes.hpp>
23 
24 #include <ade/passes/communications.hpp>
25 
26 #include <ade/util/iota_range.hpp>
27 
28 //=================== Comm channels tests=======================================
29 
30 namespace
31 {
32 
33 class TestCommChannel : public ade::ICommChannel
34 {
35 public:
36     // ICommChannel interface
37     virtual BufferPrefs getBufferPrefs(const BufferDesc& desc) override;
38     virtual std::unique_ptr<ade::IDataBuffer> getBuffer(const BufferDesc& desc, const BufferPrefs& prefs) override;
39     virtual void setBuffer(const ade::DataBufferView& buffer, const BufferDesc& desc) override;
40 
41     ade::IDataBuffer* buff = nullptr;
42 };
43 
getBufferPrefs(const BufferDesc & desc)44 ade::ICommChannel::BufferPrefs TestCommChannel::getBufferPrefs(const BufferDesc& desc)
45 {
46     BufferPrefs ret;
47     ret.preferredAlignment.redim(desc.memoryRef.size().dims_count());
48     ade::util::fill(ret.preferredAlignment, 1);
49     return ret;
50 }
51 
getBuffer(const BufferDesc &,const BufferPrefs &)52 std::unique_ptr<ade::IDataBuffer> TestCommChannel::getBuffer(const BufferDesc& /*desc*/,
53                                                              const BufferPrefs& /*prefs*/)
54 {
55     return nullptr;
56 }
57 
setBuffer(const ade::DataBufferView & buffer,const BufferDesc & desc)58 void TestCommChannel::setBuffer(const ade::DataBufferView& buffer, const BufferDesc& desc)
59 {
60     EXPECT_EQ(nullptr, buff);
61     EXPECT_NE(nullptr, buffer.getBuffer());
62     buff = buffer.getBuffer();
63     EXPECT_EQ(desc.memoryRef.span(), buffer.getSpan());
64 }
65 
66 }
67 
68 
TEST(CommTest,CommChannelSimple)69 TEST(CommTest, CommChannelSimple)
70 {
71     // (src)->[img]->(comm)->[img]->(dst)
72 
73     ade::Graph srcGr;
74     using GraphT = ade::TypedGraph<ade::meta::NodeInfo,
75                                    ade::meta::DataObject,
76                                    ade::meta::CommNode,
77                                    ade::meta::CommChannel,
78                                    ade::meta::CommConsumerCallback,
79                                    ade::meta::CommProducerCallback,
80                                    ade::meta::Finalizers>;
81     GraphT gr(srcGr);
82 
83     auto srcNode  = gr.createNode();
84     auto srcImg   = gr.createNode();
85     auto dstNode  = gr.createNode();
86     auto dstImg   = gr.createNode();
87     auto commNode = gr.createNode();
88 
89     gr.link(srcNode, srcImg);
90     gr.link(srcImg, commNode);
91     gr.link(commNode, dstImg);
92     gr.link(dstImg, dstNode);
93 
94     gr.metadata(srcNode).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
95     gr.metadata(dstNode).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
96 
97     ade::MemoryDescriptor desc(1, {10,10});
98     ade::MemoryDescriptorView view(desc, {ade::util::Span(0,10),
99                                           ade::util::Span(0, 10)});
100 
101     gr.metadata(srcImg).set<ade::meta::DataObject>(ade::meta::DataObject());
102     gr.metadata(dstImg).set<ade::meta::DataObject>(ade::meta::DataObject());
103 
104     gr.metadata(srcImg).get<ade::meta::DataObject>().dataRef = view;
105     gr.metadata(dstImg).get<ade::meta::DataObject>().dataRef = view;
106 
107     gr.metadata(commNode).set<ade::meta::CommNode>(ade::meta::CommNode(1));
108 
109     ade::passes::PassContext ctx{srcGr};
110     ASSERT_THROW(ade::passes::ConnectCommChannels()(ctx), std::runtime_error);
111 
112     gr.metadata(srcImg).set<ade::meta::CommChannel>(ade::meta::CommChannel());
113     gr.metadata(dstImg).set<ade::meta::CommChannel>(ade::meta::CommChannel());
114 
115     ASSERT_THROW(ade::passes::ConnectCommChannels()(ctx), std::runtime_error);
116 
117     auto chan1 = std::make_shared<TestCommChannel>();
118     auto chan2 = std::make_shared<TestCommChannel>();
119 
120     gr.metadata(srcImg).get<ade::meta::CommChannel>().channel = chan1;
121     gr.metadata(dstImg).get<ade::meta::CommChannel>().channel = chan2;
122 
123     ASSERT_THROW(ade::passes::ConnectCommChannels()(ctx), std::runtime_error);
124 
125     gr.metadata(dstNode).set(ade::meta::CommConsumerCallback{});
126 
127     ASSERT_THROW(ade::passes::ConnectCommChannels()(ctx), std::runtime_error);
128 
129     int callbackCallCount = 0;
130 
131     gr.metadata(dstNode).get<ade::meta::CommConsumerCallback>().callback = [&]()
132     {
133         ++callbackCallCount;
134     };
135 
136     ade::passes::ConnectCommChannels()(ctx);
137 
138     ASSERT_TRUE(gr.metadata(srcNode).contains<ade::meta::CommProducerCallback>());
139     auto producerCallback = gr.metadata(srcNode).get<ade::meta::CommProducerCallback>().callback;
140     ASSERT_NE(nullptr, producerCallback);
141     ASSERT_EQ(0, callbackCallCount);
142 
143     auto finalizers = gr.metadata().get(ade::meta::Finalizers{}).finalizers;
144 
145     for (auto i: ade::util::iota(10))
146     {
147         (void)i;
148         callbackCallCount = 0;
149         producerCallback();
150         ASSERT_EQ(1, callbackCallCount);
151 
152         for (auto& fin: finalizers)
153         {
154             fin();
155         }
156     }
157 }
158 
TEST(CommTest,CommChannelComplexDeps)159 TEST(CommTest, CommChannelComplexDeps)
160 {
161     //
162     // (src1)->[img]
163     //              \
164     //                ->[img]->(comm)->[img]->(dst1)
165     //              /
166     // (src2)->[img]
167     //              \
168     //                ->[img]->(comm)->[img]->(dst2)
169     //              /
170     // (src3)->[img]
171     //
172     // (src4)->[img]->(comm)->[img]->(dst3)
173 
174     ade::Graph srcGr;
175     using GraphT = ade::TypedGraph<ade::meta::NodeInfo,
176                                    ade::meta::DataObject,
177                                    ade::meta::CommNode,
178                                    ade::meta::CommChannel,
179                                    ade::meta::CommConsumerCallback,
180                                    ade::meta::CommProducerCallback,
181                                    ade::meta::Finalizers>;
182     GraphT gr(srcGr);
183 
184     auto srcNode1 = gr.createNode();
185     auto srcNode2 = gr.createNode();
186     auto srcNode3 = gr.createNode();
187     auto srcNode4 = gr.createNode();
188 
189     auto srcImg1 = gr.createNode();
190     auto srcImg2 = gr.createNode();
191     auto srcImg3 = gr.createNode();
192     auto srcImg4 = gr.createNode();
193 
194     auto tempImg1 = gr.createNode();
195     auto tempImg2 = gr.createNode();
196 
197     auto commNode1 = gr.createNode();
198     auto commNode2 = gr.createNode();
199     auto commNode3 = gr.createNode();
200 
201     auto dstImg1 = gr.createNode();
202     auto dstImg2 = gr.createNode();
203     auto dstImg3 = gr.createNode();
204 
205     auto dstNode1 = gr.createNode();
206     auto dstNode2 = gr.createNode();
207     auto dstNode3 = gr.createNode();
208 
209     gr.link(srcNode1, srcImg1);
210     gr.link(srcNode2, srcImg2);
211     gr.link(srcNode3, srcImg3);
212     gr.link(srcNode4, srcImg4);
213 
214     gr.link(srcImg1, tempImg1);
215     gr.link(srcImg2, tempImg1);
216     gr.link(srcImg2, tempImg2);
217     gr.link(srcImg3, tempImg2);
218 
219     gr.link(tempImg1, commNode1);
220     gr.link(tempImg2, commNode2);
221     gr.link(srcImg4,  commNode3);
222 
223     gr.link(commNode1, dstImg1);
224     gr.link(commNode2, dstImg2);
225     gr.link(commNode3, dstImg3);
226 
227     gr.link(dstImg1, dstNode1);
228     gr.link(dstImg2, dstNode2);
229     gr.link(dstImg3, dstNode3);
230 
231     ade::MemoryDescriptor desc1(1, {10,30});
232     ade::MemoryDescriptorView srcView1(desc1, {ade::util::Span(0,10),
233                                                ade::util::Span(0,30)});
234     ade::MemoryDescriptor desc2(1, {20,20});
235     ade::MemoryDescriptorView srcView2(desc2, {ade::util::Span(0,20),
236                                                ade::util::Span(0,20)});
237 
238     ade::MemoryDescriptorView view1(srcView1, {ade::util::Span(0,10),
239                                                ade::util::Span(0, 10)});
240     ade::MemoryDescriptorView view2(srcView1, {ade::util::Span(0,10),
241                                                ade::util::Span(10,20)});
242     ade::MemoryDescriptorView view3(srcView1, {ade::util::Span(0,10),
243                                                ade::util::Span(20,30)});
244 
245     ade::MemoryDescriptorView view5(srcView1, {ade::util::Span(0,10),
246                                                ade::util::Span(5, 15)});
247     ade::MemoryDescriptorView view6(srcView1, {ade::util::Span(0,10),
248                                                ade::util::Span(15,25)});
249 
250     ade::MemoryDescriptorView view7(srcView2, {ade::util::Span(0,20),
251                                                ade::util::Span(0,20)});
252 
253     gr.metadata(srcNode1).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
254     gr.metadata(srcNode2).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
255     gr.metadata(srcNode3).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
256     gr.metadata(srcNode4).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
257 
258     gr.metadata(dstNode1).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
259     gr.metadata(dstNode2).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
260     gr.metadata(dstNode3).set<ade::meta::NodeInfo>(ade::meta::NodeInfo());
261 
262     gr.metadata(commNode1).set<ade::meta::CommNode>(ade::meta::CommNode(2));
263     gr.metadata(commNode2).set<ade::meta::CommNode>(ade::meta::CommNode(2));
264     gr.metadata(commNode3).set<ade::meta::CommNode>(ade::meta::CommNode(1));
265 
266     gr.metadata(srcImg1).set<ade::meta::DataObject>(ade::meta::DataObject());
267     gr.metadata(srcImg2).set<ade::meta::DataObject>(ade::meta::DataObject());
268     gr.metadata(srcImg3).set<ade::meta::DataObject>(ade::meta::DataObject());
269     gr.metadata(srcImg4).set<ade::meta::DataObject>(ade::meta::DataObject());
270     gr.metadata(tempImg1).set<ade::meta::DataObject>(ade::meta::DataObject());
271     gr.metadata(tempImg2).set<ade::meta::DataObject>(ade::meta::DataObject());
272     gr.metadata(dstImg1).set<ade::meta::DataObject>(ade::meta::DataObject());
273     gr.metadata(dstImg2).set<ade::meta::DataObject>(ade::meta::DataObject());
274     gr.metadata(dstImg3).set<ade::meta::DataObject>(ade::meta::DataObject());
275 
276     gr.metadata(srcImg1).get<ade::meta::DataObject>().dataRef = view1;
277     gr.metadata(srcImg2).get<ade::meta::DataObject>().dataRef = view2;
278     gr.metadata(srcImg3).get<ade::meta::DataObject>().dataRef = view3;
279 
280     gr.metadata(tempImg1).get<ade::meta::DataObject>().dataRef = view5;
281     gr.metadata(tempImg2).get<ade::meta::DataObject>().dataRef = view6;
282 
283     gr.metadata(dstImg1).get<ade::meta::DataObject>().dataRef = view5;
284     gr.metadata(dstImg2).get<ade::meta::DataObject>().dataRef = view6;
285 
286     gr.metadata(srcImg4).get<ade::meta::DataObject>().dataRef = view7;
287     gr.metadata(dstImg3).get<ade::meta::DataObject>().dataRef = view7;
288 
289     ade::passes::PassContext ctx{srcGr};
290     ASSERT_THROW(ade::passes::ConnectCommChannels()(ctx), std::runtime_error);
291 
292     auto chan1 = std::make_shared<TestCommChannel>();
293     auto chan2 = std::make_shared<TestCommChannel>();
294     auto chan3 = std::make_shared<TestCommChannel>();
295     auto chan4 = std::make_shared<TestCommChannel>();
296     auto chan5 = std::make_shared<TestCommChannel>();
297     auto chan6 = std::make_shared<TestCommChannel>();
298     auto chan7 = std::make_shared<TestCommChannel>();
299 
300     gr.metadata(srcImg1).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan1});
301     gr.metadata(srcImg2).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan2});
302     gr.metadata(srcImg3).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan3});
303     gr.metadata(srcImg4).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan4});
304     gr.metadata(dstImg1).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan5});
305     gr.metadata(dstImg2).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan6});
306     gr.metadata(dstImg3).set<ade::meta::CommChannel>(ade::meta::CommChannel{chan7});
307 
308     int consumerCallbackCalled1 = 0;
309     int consumerCallbackCalled2 = 0;
310     int consumerCallbackCalled3 = 0;
311 
312     gr.metadata(dstNode1).set(ade::meta::CommConsumerCallback{[&]()
313     {
314         ++consumerCallbackCalled1;
315     }});
316 
317     gr.metadata(dstNode2).set(ade::meta::CommConsumerCallback{[&]()
318     {
319         ++consumerCallbackCalled2;
320     }});
321 
322     gr.metadata(dstNode3).set(ade::meta::CommConsumerCallback{[&]()
323     {
324         ++consumerCallbackCalled3;
325     }});
326 
327     ade::passes::ConnectCommChannels()(ctx);
328 
329     auto buff1 = chan1->buff;
330     auto buff2 = chan2->buff;
331     auto buff3 = chan3->buff;
332     auto buff4 = chan4->buff;
333 
334     auto buff5 = chan5->buff;
335     auto buff6 = chan6->buff;
336     auto buff7 = chan7->buff;
337 
338     // First group
339     EXPECT_NE(nullptr, buff1);
340     EXPECT_EQ(buff1, buff2);
341     EXPECT_EQ(buff1, buff3);
342     EXPECT_EQ(buff1, buff5);
343     EXPECT_EQ(buff1, buff6);
344 
345     // Second group
346     EXPECT_NE(nullptr, buff4);
347     EXPECT_EQ(buff4, buff7);
348 
349     EXPECT_NE(buff1, buff4);
350 
351     ASSERT_TRUE(gr.metadata(srcNode1).contains<ade::meta::CommProducerCallback>());
352     ASSERT_TRUE(gr.metadata(srcNode2).contains<ade::meta::CommProducerCallback>());
353     ASSERT_TRUE(gr.metadata(srcNode3).contains<ade::meta::CommProducerCallback>());
354     ASSERT_TRUE(gr.metadata(srcNode4).contains<ade::meta::CommProducerCallback>());
355 
356     auto producerCallback1 = gr.metadata(srcNode1).get<ade::meta::CommProducerCallback>().callback;
357     auto producerCallback2 = gr.metadata(srcNode2).get<ade::meta::CommProducerCallback>().callback;
358     auto producerCallback3 = gr.metadata(srcNode3).get<ade::meta::CommProducerCallback>().callback;
359     auto producerCallback4 = gr.metadata(srcNode4).get<ade::meta::CommProducerCallback>().callback;
360 
361     ASSERT_NE(nullptr, producerCallback1);
362     ASSERT_NE(nullptr, producerCallback2);
363     ASSERT_NE(nullptr, producerCallback3);
364     ASSERT_NE(nullptr, producerCallback4);
365 
366     ASSERT_EQ(0, consumerCallbackCalled1);
367     ASSERT_EQ(0, consumerCallbackCalled2);
368     ASSERT_EQ(0, consumerCallbackCalled3);
369 
370     auto finalizers = gr.metadata().get(ade::meta::Finalizers{}).finalizers;
371 
372     ASSERT_TRUE(!finalizers.empty());
373 
374     for (auto i: ade::util::iota(10))
375     {
376         (void)i;
377         consumerCallbackCalled1 = 0;
378         consumerCallbackCalled2 = 0;
379         consumerCallbackCalled3 = 0;
380 
381         producerCallback1();
382 
383         ASSERT_EQ(0, consumerCallbackCalled1);
384         ASSERT_EQ(0, consumerCallbackCalled2);
385         ASSERT_EQ(0, consumerCallbackCalled3);
386 
387         producerCallback2();
388 
389         ASSERT_EQ(1, consumerCallbackCalled1);
390         ASSERT_EQ(0, consumerCallbackCalled2);
391         ASSERT_EQ(0, consumerCallbackCalled3);
392 
393         producerCallback3();
394 
395         ASSERT_EQ(1, consumerCallbackCalled1);
396         ASSERT_EQ(1, consumerCallbackCalled2);
397         ASSERT_EQ(0, consumerCallbackCalled3);
398 
399         producerCallback4();
400 
401         ASSERT_EQ(1, consumerCallbackCalled1);
402         ASSERT_EQ(1, consumerCallbackCalled2);
403         ASSERT_EQ(1, consumerCallbackCalled3);
404 
405         for (auto& fin: finalizers)
406         {
407             fin();
408         }
409     }
410 }
411