1 
2 // =================================================================================================
3 // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
4 // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
5 // width of 100 characters per line.
6 //
7 // Author(s):
8 //   Cedric Nugteren <www.cedricnugteren.nl>
9 //
10 // This file implements the tests for the OpenCL buffers (matrices and vectors). These tests are
11 // templated and thus header-only.
12 //
13 // =================================================================================================
14 
15 #ifndef CLBLAST_BUFFER_TEST_H_
16 #define CLBLAST_BUFFER_TEST_H_
17 
18 #include "clblast.h"
19 
20 namespace clblast {
21 // =================================================================================================
22 
23 // Tests matrix 'A' for validity
24 template <typename T>
TestMatrixA(const size_t one,const size_t two,const Buffer<T> & buffer,const size_t offset,const size_t ld,const bool test_lead_dim=true)25 void TestMatrixA(const size_t one, const size_t two, const Buffer<T> &buffer,
26                  const size_t offset, const size_t ld, const bool test_lead_dim = true) {
27   if (test_lead_dim && ld < one) { throw BLASError(StatusCode::kInvalidLeadDimA); }
28   try {
29     const auto required_size = (ld * (two - 1) + one + offset) * sizeof(T);
30     if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryA); }
31   } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidMatrixA, e.what()); }
32 }
33 
34 // Tests matrix 'B' for validity
35 template <typename T>
TestMatrixB(const size_t one,const size_t two,const Buffer<T> & buffer,const size_t offset,const size_t ld,const bool test_lead_dim=true)36 void TestMatrixB(const size_t one, const size_t two, const Buffer<T> &buffer,
37                  const size_t offset, const size_t ld, const bool test_lead_dim = true) {
38   if (test_lead_dim && ld < one) { throw BLASError(StatusCode::kInvalidLeadDimB); }
39   try {
40     const auto required_size = (ld * (two - 1) + one + offset) * sizeof(T);
41     if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryB); }
42   } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidMatrixB, e.what()); }
43 }
44 
45 // Tests matrix 'C' for validity
46 template <typename T>
TestMatrixC(const size_t one,const size_t two,const Buffer<T> & buffer,const size_t offset,const size_t ld)47 void TestMatrixC(const size_t one, const size_t two, const Buffer<T> &buffer,
48                  const size_t offset, const size_t ld) {
49   if (ld < one) { throw BLASError(StatusCode::kInvalidLeadDimC); }
50   try {
51     const auto required_size = (ld * (two - 1) + one + offset) * sizeof(T);
52     if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryC); }
53   } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidMatrixC, e.what()); }
54 }
55 
56 // Tests matrix 'AP' for validity
57 template <typename T>
TestMatrixAP(const size_t n,const Buffer<T> & buffer,const size_t offset)58 void TestMatrixAP(const size_t n, const Buffer<T> &buffer, const size_t offset) {
59   try {
60     const auto required_size = (((n * (n + 1)) / 2) + offset) * sizeof(T);
61     if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryA); }
62   } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidMatrixA, e.what()); }
63 }
64 
65 // =================================================================================================
66 
67 // Tests vector 'X' for validity
68 template <typename T>
TestVectorX(const size_t n,const Buffer<T> & buffer,const size_t offset,const size_t inc)69 void TestVectorX(const size_t n, const Buffer<T> &buffer, const size_t offset, const size_t inc) {
70   if (inc == 0) { throw BLASError(StatusCode::kInvalidIncrementX); }
71   try {
72     const auto required_size = ((n - 1) * inc + 1 + offset) * sizeof(T);
73     if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryX); }
74   } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidVectorX, e.what()); }
75 }
76 
77 // Tests vector 'Y' for validity
78 template <typename T>
TestVectorY(const size_t n,const Buffer<T> & buffer,const size_t offset,const size_t inc)79 void TestVectorY(const size_t n, const Buffer<T> &buffer, const size_t offset, const size_t inc) {
80   if (inc == 0) { throw BLASError(StatusCode::kInvalidIncrementY); }
81   try {
82     const auto required_size = ((n - 1) * inc + 1 + offset) * sizeof(T);
83     if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryY); }
84   } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidVectorY, e.what()); }
85 }
86 
87 // =================================================================================================
88 
89 // Tests vector 'scalar' for validity
90 template <typename T>
TestVectorScalar(const size_t n,const Buffer<T> & buffer,const size_t offset)91 void TestVectorScalar(const size_t n, const Buffer<T> &buffer, const size_t offset) {
92   try {
93     const auto required_size = (n + offset) * sizeof(T);
94     if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryScalar); }
95   } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidVectorScalar, e.what()); }
96 }
97 
98 // Tests vector 'index' for validity
99 template <typename T>
TestVectorIndex(const size_t n,const Buffer<T> & buffer,const size_t offset)100 void TestVectorIndex(const size_t n, const Buffer<T> &buffer, const size_t offset) {
101   try {
102     const auto required_size = (n + offset) * sizeof(T);
103     if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryScalar); }
104   } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidVectorScalar, e.what()); }
105 }
106 
107 // =================================================================================================
108 } // namespace clblast
109 
110 // CLBLAST_BUFFER_TEST_H_
111 #endif
112