1 #include "acxxel.h"
2 #include "config.h"
3 #include "gtest/gtest.h"
4 
5 namespace {
6 
7 using PlatformGetter = acxxel::Expected<acxxel::Platform *> (*)();
8 class MultiDeviceTest : public ::testing::TestWithParam<PlatformGetter> {};
9 
TEST_P(MultiDeviceTest,AsyncCopy)10 TEST_P(MultiDeviceTest, AsyncCopy) {
11   acxxel::Platform *Platform = GetParam()().takeValue();
12   int DeviceCount = Platform->getDeviceCount().getValue();
13   EXPECT_GT(DeviceCount, 0);
14 
15   int Length = 3;
16   auto A = std::unique_ptr<int[]>(new int[Length]);
17   auto B0 = std::unique_ptr<int[]>(new int[Length]);
18   auto B1 = std::unique_ptr<int[]>(new int[Length]);
19 
20   auto ASpan = acxxel::Span<int>(A.get(), Length);
21   auto B0Span = acxxel::Span<int>(B0.get(), Length);
22   auto B1Span = acxxel::Span<int>(B1.get(), Length);
23 
24   for (int I = 0; I < Length; ++I)
25     A[I] = I;
26 
27   auto AsyncA = Platform->registerHostMem(ASpan).takeValue();
28   auto AsyncB0 = Platform->registerHostMem(B0Span).takeValue();
29   auto AsyncB1 = Platform->registerHostMem(B1Span).takeValue();
30 
31   acxxel::Stream Stream0 = Platform->createStream(0).takeValue();
32   acxxel::Stream Stream1 = Platform->createStream(1).takeValue();
33   auto Device0 = Platform->mallocD<int>(Length, 0).takeValue();
34   auto Device1 = Platform->mallocD<int>(Length, 1).takeValue();
35 
36   EXPECT_FALSE(Stream0.asyncCopyHToD(AsyncA, Device0, Length)
37                    .asyncCopyDToH(Device0, AsyncB0, Length)
38                    .sync()
39                    .isError());
40 
41   EXPECT_FALSE(Stream1.asyncCopyHToD(AsyncA, Device1, Length)
42                    .asyncCopyDToH(Device1, AsyncB1, Length)
43                    .sync()
44                    .isError());
45 
46   for (int I = 0; I < Length; ++I) {
47     EXPECT_EQ(B0[I], I);
48     EXPECT_EQ(B1[I], I);
49   }
50 }
51 
TEST_P(MultiDeviceTest,Events)52 TEST_P(MultiDeviceTest, Events) {
53   acxxel::Platform *Platform = GetParam()().takeValue();
54   int DeviceCount = Platform->getDeviceCount().getValue();
55   EXPECT_GT(DeviceCount, 0);
56 
57   acxxel::Stream Stream0 = Platform->createStream(0).takeValue();
58   acxxel::Stream Stream1 = Platform->createStream(1).takeValue();
59   acxxel::Event Event0 = Platform->createEvent(0).takeValue();
60   acxxel::Event Event1 = Platform->createEvent(1).takeValue();
61 
62   EXPECT_FALSE(Stream0.enqueueEvent(Event0).sync().isError());
63   EXPECT_FALSE(Stream1.enqueueEvent(Event1).sync().isError());
64 
65   EXPECT_TRUE(Event0.isDone());
66   EXPECT_TRUE(Event1.isDone());
67 
68   EXPECT_FALSE(Event0.sync().isError());
69   EXPECT_FALSE(Event1.sync().isError());
70 }
71 
72 #if defined(ACXXEL_ENABLE_CUDA) || defined(ACXXEL_ENABLE_OPENCL)
73 INSTANTIATE_TEST_CASE_P(BothPlatformTest, MultiDeviceTest,
74                         ::testing::Values(
75 #ifdef ACXXEL_ENABLE_CUDA
76                             acxxel::getCUDAPlatform
77 #ifdef ACXXEL_ENABLE_OPENCL
78                             ,
79 #endif
80 #endif
81 #ifdef ACXXEL_ENABLE_OPENCL
82                             acxxel::getOpenCLPlatform
83 #endif
84                             ));
85 #endif
86 
87 } // namespace
88