1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 #include <stdlib.h>             // srand()
19 #include <string.h>             // memcpy()
20 #include <gtest/gtest.h>
21 #include <clBLAS.h>
22 
23 #include <common.h>
24 #include <blas-internal.h>
25 #include <blas-wrapper.h>
26 #include <clBLAS-wrapper.h>
27 #include <BlasBase.h>
28 #include <blas-random.h>
29 #include <gemm.h>
30 
31 #include "tcase-filter.h"
32 
33 static void
releaseMemObjects(cl_mem objA,cl_mem objB,cl_mem objC)34 releaseMemObjects(cl_mem objA, cl_mem objB, cl_mem objC)
35 {
36     clReleaseMemObject(objA);
37     clReleaseMemObject(objB);
38     clReleaseMemObject(objC);
39 }
40 
41 template <typename T> static void
deleteBuffers(T * A,T * B,T * blasC,T * clblasC)42 deleteBuffers(T *A, T *B, T *blasC, T *clblasC)
43 {
44     delete[] A;
45     delete[] B;
46     delete[] blasC;
47     delete[] clblasC;
48 }
49 
50 template <typename T>
51 void
gemmCorrectnessTest(TestParams * params)52 gemmCorrectnessTest(TestParams *params)
53 {
54     cl_int err;
55     T *A, *B, *blasC, *clblasC;
56     T alpha, beta;
57     cl_mem bufA, bufB, bufC;
58     clMath::BlasBase *base;
59     bool useAlpha;
60     bool useBeta;
61     cl_event *events;
62     bool isComplex;
63 
64     base = clMath::BlasBase::getInstance();
65     if ((typeid(T) == typeid(cl_double) ||
66          typeid(T) == typeid(DoubleComplex)) &&
67         !base->isDevSupportDoublePrecision()) {
68 
69         std::cerr << ">> WARNING: The target device doesn't support native "
70                      "double precision floating point arithmetic" <<
71                      std::endl << ">> Test skipped" << std::endl;
72         SUCCEED();
73         return;
74     }
75 
76     isComplex = ((typeid(T) == typeid(FloatComplex)) ||
77                  (typeid(T) == typeid(DoubleComplex)));
78 
79     if (canCaseBeSkipped(params, isComplex)) {
80         std::cerr << ">> Test is skipped because it has no importance for this "
81                      "level of coverage" << std::endl;
82         SUCCEED();
83         return;
84     }
85 
86     useAlpha = base->useAlpha();
87     useBeta = base->useBeta();
88     alpha = ZERO<T>();
89     beta = ZERO<T>();
90 
91     events = new cl_event[params->numCommandQueues];
92     memset(events, 0, params->numCommandQueues * sizeof(cl_event));
93 
94     A = new T[params->rowsA * params->columnsA];
95     B = new T[params->rowsB * params->columnsB];
96     blasC = new T[params->rowsC * params->columnsC];
97     clblasC = new T[params->rowsC * params->columnsC];
98 
99     srand(params->seed);
100     if (useAlpha) {
101         alpha = convertMultiplier<T>(params->alpha);
102     }
103     if (useBeta) {
104         beta = convertMultiplier<T>(params->beta);
105     }
106 
107     //::std::cerr << "Generating input data... ";
108     randomGemmMatrices<T>(params->order, params->transA, params->transB,
109         params->M, params->N, params->K, useAlpha, &alpha, A, params->lda,
110         B, params->ldb, useBeta, &beta, blasC, params->ldc);
111     memcpy(clblasC, blasC, params->rowsC * params->columnsC * sizeof(*blasC));
112     //::std::cerr << "Done" << ::std::endl;
113 
114     //::std::cerr << "Calling reference xGEMM routine... ";
115     if (params->order == clblasColumnMajor) {
116         ::clMath::blas::gemm(clblasColumnMajor, params->transA, params->transB,
117                           params->M, params->N, params->K, alpha, A,
118                           params->lda, B, params->ldb, beta, blasC, params->ldc);
119     }
120     else {
121         T *reorderedA = new T[params->rowsA * params->columnsA];
122         T *reorderedB = new T[params->rowsB * params->columnsB];
123         T *reorderedC = new T[params->rowsC * params->columnsC];
124 
125         reorderMatrix<T>(clblasRowMajor, params->rowsA, params->columnsA,
126                          A, reorderedA);
127         reorderMatrix<T>(clblasRowMajor, params->rowsB, params->columnsB,
128                          B, reorderedB);
129         reorderMatrix<T>(clblasRowMajor, params->rowsC, params->columnsC,
130                          blasC, reorderedC);
131         ::clMath::blas::gemm(clblasColumnMajor, params->transA, params->transB,
132                           params->M, params->N, params->K, alpha, reorderedA,
133                           params->rowsA, reorderedB, params->rowsB,
134                           beta, reorderedC, params->rowsC);
135         reorderMatrix<T>(clblasColumnMajor, params->rowsC, params->columnsC,
136                          reorderedC, blasC);
137 
138         delete[] reorderedC;
139         delete[] reorderedB;
140         delete[] reorderedA;
141     }
142     //::std::cerr << "Done" << ::std::endl;
143 
144     bufA = base->createEnqueueBuffer(A, params->rowsA * params->columnsA *
145                                         sizeof(*A), params->offA * sizeof(*A),
146                                      CL_MEM_READ_ONLY);
147     bufB = base->createEnqueueBuffer(B, params->rowsB * params->columnsB *
148                                         sizeof(*B), params->offBX * sizeof(*B),
149                                      CL_MEM_READ_ONLY);
150     bufC = base->createEnqueueBuffer(clblasC, params->rowsC * params->columnsC *
151                                               sizeof(*clblasC),
152                                      params->offCY * sizeof(*clblasC),
153                                      CL_MEM_READ_WRITE);
154     if ((bufA == NULL) || (bufB == NULL) || (bufC == NULL)) {
155         /* Skip the test, the most probable reason is
156          *     matrix too big for a device.
157          */
158         releaseMemObjects(bufA, bufB, bufC);
159         deleteBuffers<T>(A, B, blasC, clblasC);
160         delete[] events;
161         ::std::cerr << ">> Failed to create/enqueue buffer for a matrix."
162             << ::std::endl
163             << ">> Can't execute the test, because data is not transfered to GPU."
164             << ::std::endl
165             << ">> Test skipped." << ::std::endl;
166         SUCCEED();
167         return;
168     }
169 
170     //::std::cerr << "Calling clblas xGEMM routine... ";
171     err = (cl_int)::clMath::clblas::gemm(params->order, params->transA,
172         params->transB, params->M, params->N, params->K, alpha, bufA,
173         params->offA, params->lda, bufB, params->offBX, params->ldb, beta,
174         bufC, params->offCY, params->ldc, params->numCommandQueues,
175         base->commandQueues(), 0, NULL, events);
176     if (err != CL_SUCCESS) {
177         releaseMemObjects(bufA, bufB, bufC);
178         deleteBuffers<T>(A, B, blasC, clblasC);
179         delete[] events;
180         ASSERT_EQ(CL_SUCCESS, err) << "::clMath::clblas::GEMM() failed";
181     }
182 
183     err = waitForSuccessfulFinish(params->numCommandQueues,
184         base->commandQueues(), events);
185     if (err != CL_SUCCESS) {
186         releaseMemObjects(bufA, bufB, bufC);
187         deleteBuffers<T>(A, B, blasC, clblasC);
188         delete[] events;
189         ASSERT_EQ(CL_SUCCESS, err) << "waitForSuccessfulFinish()";
190     }
191     //::std::cerr << "Done" << ::std::endl;
192 
193     clEnqueueReadBuffer(base->commandQueues()[0], bufC, CL_TRUE,
194                         params->offCY * sizeof(*clblasC),
195                         params->rowsC * params->columnsC * sizeof(*clblasC),
196                         clblasC, 0, NULL, NULL);
197 
198     releaseMemObjects(bufA, bufB, bufC);
199     compareMatrices<T>(params->order, params->M, params->N, blasC, clblasC,
200                        params->ldc);
201     deleteBuffers<T>(A, B, blasC, clblasC);
202     delete[] events;
203 }
204 
205 // Instantiate the test
206 
TEST_P(GEMM,sgemm)207 TEST_P(GEMM, sgemm) {
208     TestParams params;
209 
210     getParams(&params);
211     gemmCorrectnessTest<cl_float>(&params);
212 }
213 
TEST_P(GEMM,dgemm)214 TEST_P(GEMM, dgemm) {
215     TestParams params;
216 
217     getParams(&params);
218     gemmCorrectnessTest<cl_double>(&params);
219 }
220 
TEST_P(GEMM,cgemm)221 TEST_P(GEMM, cgemm) {
222     TestParams params;
223 
224     getParams(&params);
225     gemmCorrectnessTest<FloatComplex>(&params);
226 }
227 
TEST_P(GEMM,zgemm)228 TEST_P(GEMM, zgemm) {
229     TestParams params;
230 
231     getParams(&params);
232     gemmCorrectnessTest<DoubleComplex>(&params);
233 }
234