1 //===--- opencl_test.cpp - Tests for OpenCL and the Acxxel API ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "acxxel.h"
10 #include "gtest/gtest.h"
11 
12 #include <array>
13 #include <cstring>
14 
15 namespace {
16 
17 static const char *SaxpyKernelSource = R"(
18 __kernel void saxpyKernel(float A, __global float *X, __global float *Y, int N) {
19   int I = get_global_id(0);
20   if (I < N)
21     X[I] = A * X[I] + Y[I];
22 }
23 )";
24 
TEST(OpenCL,Saxpy)25 TEST(OpenCL, Saxpy) {
26   constexpr size_t Length = 3;
27 
28   float A = 2.f;
29   std::array<float, Length> X = {{0.f, 1.f, 2.f}};
30   std::array<float, Length> Y = {{3.f, 4.f, 5.f}};
31   std::array<float, Length> Expected = {{3.f, 6.f, 9.f}};
32 
33   acxxel::Platform *OpenCL = acxxel::getOpenCLPlatform().getValue();
34   acxxel::Stream Stream = OpenCL->createStream().takeValue();
35   auto DeviceX = OpenCL->mallocD<float>(Length).takeValue();
36   auto DeviceY = OpenCL->mallocD<float>(Length).takeValue();
37   Stream.syncCopyHToD(X, DeviceX);
38   Stream.syncCopyHToD(Y, DeviceY);
39   acxxel::Program Program =
40       OpenCL
41           ->createProgramFromSource(acxxel::Span<const char>(
42               SaxpyKernelSource, std::strlen(SaxpyKernelSource)))
43           .takeValue();
44   acxxel::Kernel Kernel = Program.createKernel("saxpyKernel").takeValue();
45   float *RawX = static_cast<float *>(DeviceX);
46   float *RawY = static_cast<float *>(DeviceY);
47   int IntLength = Length;
48   void *Arguments[] = {&A, &RawX, &RawY, &IntLength};
49   size_t ArgumentSizes[] = {sizeof(float), sizeof(float *), sizeof(float *),
50                             sizeof(int)};
51   EXPECT_FALSE(
52       Stream.asyncKernelLaunch(Kernel, Length, Arguments, ArgumentSizes)
53           .takeStatus()
54           .isError());
55   Stream.syncCopyDToH(DeviceX, X);
56   EXPECT_FALSE(Stream.sync().isError());
57 
58   EXPECT_EQ(X, Expected);
59 }
60 
61 } // namespace
62