1
2 // =================================================================================================
3 // This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
4 // project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
5 // width of 100 characters per line.
6 //
7 // Author(s):
8 // Cedric Nugteren <www.cedricnugteren.nl>
9 //
10 // This file implements the tests for the OpenCL buffers (matrices and vectors). These tests are
11 // templated and thus header-only.
12 //
13 // =================================================================================================
14
15 #ifndef CLBLAST_BUFFER_TEST_H_
16 #define CLBLAST_BUFFER_TEST_H_
17
18 #include "clblast.h"
19
20 namespace clblast {
21 // =================================================================================================
22
23 // Tests matrix 'A' for validity
24 template <typename T>
TestMatrixA(const size_t one,const size_t two,const Buffer<T> & buffer,const size_t offset,const size_t ld,const bool test_lead_dim=true)25 void TestMatrixA(const size_t one, const size_t two, const Buffer<T> &buffer,
26 const size_t offset, const size_t ld, const bool test_lead_dim = true) {
27 if (test_lead_dim && ld < one) { throw BLASError(StatusCode::kInvalidLeadDimA); }
28 try {
29 const auto required_size = (ld * (two - 1) + one + offset) * sizeof(T);
30 if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryA); }
31 } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidMatrixA, e.what()); }
32 }
33
34 // Tests matrix 'B' for validity
35 template <typename T>
TestMatrixB(const size_t one,const size_t two,const Buffer<T> & buffer,const size_t offset,const size_t ld,const bool test_lead_dim=true)36 void TestMatrixB(const size_t one, const size_t two, const Buffer<T> &buffer,
37 const size_t offset, const size_t ld, const bool test_lead_dim = true) {
38 if (test_lead_dim && ld < one) { throw BLASError(StatusCode::kInvalidLeadDimB); }
39 try {
40 const auto required_size = (ld * (two - 1) + one + offset) * sizeof(T);
41 if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryB); }
42 } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidMatrixB, e.what()); }
43 }
44
45 // Tests matrix 'C' for validity
46 template <typename T>
TestMatrixC(const size_t one,const size_t two,const Buffer<T> & buffer,const size_t offset,const size_t ld)47 void TestMatrixC(const size_t one, const size_t two, const Buffer<T> &buffer,
48 const size_t offset, const size_t ld) {
49 if (ld < one) { throw BLASError(StatusCode::kInvalidLeadDimC); }
50 try {
51 const auto required_size = (ld * (two - 1) + one + offset) * sizeof(T);
52 if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryC); }
53 } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidMatrixC, e.what()); }
54 }
55
56 // Tests matrix 'AP' for validity
57 template <typename T>
TestMatrixAP(const size_t n,const Buffer<T> & buffer,const size_t offset)58 void TestMatrixAP(const size_t n, const Buffer<T> &buffer, const size_t offset) {
59 try {
60 const auto required_size = (((n * (n + 1)) / 2) + offset) * sizeof(T);
61 if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryA); }
62 } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidMatrixA, e.what()); }
63 }
64
65 // =================================================================================================
66
67 // Tests vector 'X' for validity
68 template <typename T>
TestVectorX(const size_t n,const Buffer<T> & buffer,const size_t offset,const size_t inc)69 void TestVectorX(const size_t n, const Buffer<T> &buffer, const size_t offset, const size_t inc) {
70 if (inc == 0) { throw BLASError(StatusCode::kInvalidIncrementX); }
71 try {
72 const auto required_size = ((n - 1) * inc + 1 + offset) * sizeof(T);
73 if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryX); }
74 } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidVectorX, e.what()); }
75 }
76
77 // Tests vector 'Y' for validity
78 template <typename T>
TestVectorY(const size_t n,const Buffer<T> & buffer,const size_t offset,const size_t inc)79 void TestVectorY(const size_t n, const Buffer<T> &buffer, const size_t offset, const size_t inc) {
80 if (inc == 0) { throw BLASError(StatusCode::kInvalidIncrementY); }
81 try {
82 const auto required_size = ((n - 1) * inc + 1 + offset) * sizeof(T);
83 if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryY); }
84 } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidVectorY, e.what()); }
85 }
86
87 // =================================================================================================
88
89 // Tests vector 'scalar' for validity
90 template <typename T>
TestVectorScalar(const size_t n,const Buffer<T> & buffer,const size_t offset)91 void TestVectorScalar(const size_t n, const Buffer<T> &buffer, const size_t offset) {
92 try {
93 const auto required_size = (n + offset) * sizeof(T);
94 if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryScalar); }
95 } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidVectorScalar, e.what()); }
96 }
97
98 // Tests vector 'index' for validity
99 template <typename T>
TestVectorIndex(const size_t n,const Buffer<T> & buffer,const size_t offset)100 void TestVectorIndex(const size_t n, const Buffer<T> &buffer, const size_t offset) {
101 try {
102 const auto required_size = (n + offset) * sizeof(T);
103 if (buffer.GetSize() < required_size) { throw BLASError(StatusCode::kInsufficientMemoryScalar); }
104 } catch (const Error<std::runtime_error> &e) { throw BLASError(StatusCode::kInvalidVectorScalar, e.what()); }
105 }
106
107 // =================================================================================================
108 } // namespace clblast
109
110 // CLBLAST_BUFFER_TEST_H_
111 #endif
112