1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 #include <stdlib.h> // srand()
19 #include <string.h> // memcpy()
20 #include <gtest/gtest.h>
21 #include <clBLAS.h>
22
23 #include <common.h>
24 #include <blas-internal.h>
25 #include <blas-wrapper.h>
26 #include <clBLAS-wrapper.h>
27 #include <BlasBase.h>
28 #include <blas-random.h>
29 #include <gemm.h>
30
31 #include "tcase-filter.h"
32
33 static void
releaseMemObjects(cl_mem objA,cl_mem objB,cl_mem objC)34 releaseMemObjects(cl_mem objA, cl_mem objB, cl_mem objC)
35 {
36 clReleaseMemObject(objA);
37 clReleaseMemObject(objB);
38 clReleaseMemObject(objC);
39 }
40
41 template <typename T> static void
deleteBuffers(T * A,T * B,T * blasC,T * clblasC)42 deleteBuffers(T *A, T *B, T *blasC, T *clblasC)
43 {
44 delete[] A;
45 delete[] B;
46 delete[] blasC;
47 delete[] clblasC;
48 }
49
50 template <typename T>
51 void
gemmCorrectnessTest(TestParams * params)52 gemmCorrectnessTest(TestParams *params)
53 {
54 cl_int err;
55 T *A, *B, *blasC, *clblasC;
56 T alpha, beta;
57 cl_mem bufA, bufB, bufC;
58 clMath::BlasBase *base;
59 bool useAlpha;
60 bool useBeta;
61 cl_event *events;
62 bool isComplex;
63
64 base = clMath::BlasBase::getInstance();
65 if ((typeid(T) == typeid(cl_double) ||
66 typeid(T) == typeid(DoubleComplex)) &&
67 !base->isDevSupportDoublePrecision()) {
68
69 std::cerr << ">> WARNING: The target device doesn't support native "
70 "double precision floating point arithmetic" <<
71 std::endl << ">> Test skipped" << std::endl;
72 SUCCEED();
73 return;
74 }
75
76 isComplex = ((typeid(T) == typeid(FloatComplex)) ||
77 (typeid(T) == typeid(DoubleComplex)));
78
79 if (canCaseBeSkipped(params, isComplex)) {
80 std::cerr << ">> Test is skipped because it has no importance for this "
81 "level of coverage" << std::endl;
82 SUCCEED();
83 return;
84 }
85
86 useAlpha = base->useAlpha();
87 useBeta = base->useBeta();
88 alpha = ZERO<T>();
89 beta = ZERO<T>();
90
91 events = new cl_event[params->numCommandQueues];
92 memset(events, 0, params->numCommandQueues * sizeof(cl_event));
93
94 A = new T[params->rowsA * params->columnsA];
95 B = new T[params->rowsB * params->columnsB];
96 blasC = new T[params->rowsC * params->columnsC];
97 clblasC = new T[params->rowsC * params->columnsC];
98
99 srand(params->seed);
100 if (useAlpha) {
101 alpha = convertMultiplier<T>(params->alpha);
102 }
103 if (useBeta) {
104 beta = convertMultiplier<T>(params->beta);
105 }
106
107 //::std::cerr << "Generating input data... ";
108 randomGemmMatrices<T>(params->order, params->transA, params->transB,
109 params->M, params->N, params->K, useAlpha, &alpha, A, params->lda,
110 B, params->ldb, useBeta, &beta, blasC, params->ldc);
111 memcpy(clblasC, blasC, params->rowsC * params->columnsC * sizeof(*blasC));
112 //::std::cerr << "Done" << ::std::endl;
113
114 //::std::cerr << "Calling reference xGEMM routine... ";
115 if (params->order == clblasColumnMajor) {
116 ::clMath::blas::gemm(clblasColumnMajor, params->transA, params->transB,
117 params->M, params->N, params->K, alpha, A,
118 params->lda, B, params->ldb, beta, blasC, params->ldc);
119 }
120 else {
121 T *reorderedA = new T[params->rowsA * params->columnsA];
122 T *reorderedB = new T[params->rowsB * params->columnsB];
123 T *reorderedC = new T[params->rowsC * params->columnsC];
124
125 reorderMatrix<T>(clblasRowMajor, params->rowsA, params->columnsA,
126 A, reorderedA);
127 reorderMatrix<T>(clblasRowMajor, params->rowsB, params->columnsB,
128 B, reorderedB);
129 reorderMatrix<T>(clblasRowMajor, params->rowsC, params->columnsC,
130 blasC, reorderedC);
131 ::clMath::blas::gemm(clblasColumnMajor, params->transA, params->transB,
132 params->M, params->N, params->K, alpha, reorderedA,
133 params->rowsA, reorderedB, params->rowsB,
134 beta, reorderedC, params->rowsC);
135 reorderMatrix<T>(clblasColumnMajor, params->rowsC, params->columnsC,
136 reorderedC, blasC);
137
138 delete[] reorderedC;
139 delete[] reorderedB;
140 delete[] reorderedA;
141 }
142 //::std::cerr << "Done" << ::std::endl;
143
144 bufA = base->createEnqueueBuffer(A, params->rowsA * params->columnsA *
145 sizeof(*A), params->offA * sizeof(*A),
146 CL_MEM_READ_ONLY);
147 bufB = base->createEnqueueBuffer(B, params->rowsB * params->columnsB *
148 sizeof(*B), params->offBX * sizeof(*B),
149 CL_MEM_READ_ONLY);
150 bufC = base->createEnqueueBuffer(clblasC, params->rowsC * params->columnsC *
151 sizeof(*clblasC),
152 params->offCY * sizeof(*clblasC),
153 CL_MEM_READ_WRITE);
154 if ((bufA == NULL) || (bufB == NULL) || (bufC == NULL)) {
155 /* Skip the test, the most probable reason is
156 * matrix too big for a device.
157 */
158 releaseMemObjects(bufA, bufB, bufC);
159 deleteBuffers<T>(A, B, blasC, clblasC);
160 delete[] events;
161 ::std::cerr << ">> Failed to create/enqueue buffer for a matrix."
162 << ::std::endl
163 << ">> Can't execute the test, because data is not transfered to GPU."
164 << ::std::endl
165 << ">> Test skipped." << ::std::endl;
166 SUCCEED();
167 return;
168 }
169
170 //::std::cerr << "Calling clblas xGEMM routine... ";
171 err = (cl_int)::clMath::clblas::gemm(params->order, params->transA,
172 params->transB, params->M, params->N, params->K, alpha, bufA,
173 params->offA, params->lda, bufB, params->offBX, params->ldb, beta,
174 bufC, params->offCY, params->ldc, params->numCommandQueues,
175 base->commandQueues(), 0, NULL, events);
176 if (err != CL_SUCCESS) {
177 releaseMemObjects(bufA, bufB, bufC);
178 deleteBuffers<T>(A, B, blasC, clblasC);
179 delete[] events;
180 ASSERT_EQ(CL_SUCCESS, err) << "::clMath::clblas::GEMM() failed";
181 }
182
183 err = waitForSuccessfulFinish(params->numCommandQueues,
184 base->commandQueues(), events);
185 if (err != CL_SUCCESS) {
186 releaseMemObjects(bufA, bufB, bufC);
187 deleteBuffers<T>(A, B, blasC, clblasC);
188 delete[] events;
189 ASSERT_EQ(CL_SUCCESS, err) << "waitForSuccessfulFinish()";
190 }
191 //::std::cerr << "Done" << ::std::endl;
192
193 clEnqueueReadBuffer(base->commandQueues()[0], bufC, CL_TRUE,
194 params->offCY * sizeof(*clblasC),
195 params->rowsC * params->columnsC * sizeof(*clblasC),
196 clblasC, 0, NULL, NULL);
197
198 releaseMemObjects(bufA, bufB, bufC);
199 compareMatrices<T>(params->order, params->M, params->N, blasC, clblasC,
200 params->ldc);
201 deleteBuffers<T>(A, B, blasC, clblasC);
202 delete[] events;
203 }
204
205 // Instantiate the test
206
TEST_P(GEMM,sgemm)207 TEST_P(GEMM, sgemm) {
208 TestParams params;
209
210 getParams(¶ms);
211 gemmCorrectnessTest<cl_float>(¶ms);
212 }
213
TEST_P(GEMM,dgemm)214 TEST_P(GEMM, dgemm) {
215 TestParams params;
216
217 getParams(¶ms);
218 gemmCorrectnessTest<cl_double>(¶ms);
219 }
220
TEST_P(GEMM,cgemm)221 TEST_P(GEMM, cgemm) {
222 TestParams params;
223
224 getParams(¶ms);
225 gemmCorrectnessTest<FloatComplex>(¶ms);
226 }
227
TEST_P(GEMM,zgemm)228 TEST_P(GEMM, zgemm) {
229 TestParams params;
230
231 getParams(¶ms);
232 gemmCorrectnessTest<DoubleComplex>(¶ms);
233 }
234