1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 #ifndef SYRK_H_
19 #define SYRK_H_
20 
21 #include <gtest/gtest.h>
22 #include <clBLAS.h>
23 #include <common.h>
24 #include <BlasBase.h>
25 #include <ExtraTestSizes.h>
26 #include <common.h>
27 
28 using namespace clMath;
29 using ::testing::TestWithParam;
30 
31 class SYRK : public TestWithParam<
32     ::std::tr1::tuple<
33         clblasOrder,     // order
34         clblasUplo,      // uplo
35         clblasTranspose, // transA
36         int,                // N
37         int,                // K
38         ExtraTestSizes,
39         int                 // numCommandQueues
40         > > {
41 public:
getParams(TestParams * params)42     void getParams(TestParams *params)
43     {
44         memset(params, 0, sizeof(TestParams));
45 
46         params->order = order;
47         params->uplo = uplo;
48         params->transA = transA;
49         params->seed = seed;
50         params->N = N;
51         params->K = K;
52         params->offA = offA;
53         params->offCY = offC;
54         params->lda = lda;
55         params->ldc = ldc;
56         params->rowsA = rowsA;
57         params->columnsA = columnsA;
58         params->rowsC = rowsC;
59         params->columnsC = columnsC;
60         params->numCommandQueues = numCommandQueues;
61     }
62 
63 protected:
SetUp()64     virtual void SetUp()
65     {
66         ExtraTestSizes extra;
67 
68         order = ::std::tr1::get<0>(GetParam());
69         uplo = ::std::tr1::get<1>(GetParam());
70         transA = ::std::tr1::get<2>(GetParam());
71         N = ::std::tr1::get<3>(GetParam());
72         K = ::std::tr1::get<4>(GetParam());
73         extra = ::std::tr1::get<5>(GetParam());
74         offA = extra.offA;
75         offC = extra.offCY;
76         lda = extra.strideA.ld;
77         ldc = extra.strideCY.ld;
78         numCommandQueues = ::std::tr1::get<6>(GetParam());
79 
80         base = ::clMath::BlasBase::getInstance();
81         seed = base->seed();
82 
83         useNumCommandQueues = base->useNumCommandQueues();
84         if (useNumCommandQueues) {
85             numCommandQueues = base->numCommandQueues();
86         }
87 
88         useAlpha = base->useAlpha();
89         if (useAlpha != 0) {
90             paramAlpha = base->alpha();
91         }
92         useBeta = base->useBeta();
93         if (useBeta != 0) {
94             paramBeta = base->beta();
95         }
96         if (base->useN()) {
97             N = base->N();
98         }
99         if (base->useK()) {
100             K = base->K();
101         }
102 
103         if (transA == clblasNoTrans) {
104             rowsA = N;
105             columnsA = K;
106         }
107         else {
108             rowsA = K;
109             columnsA = N;
110         }
111         rowsC = N;
112         columnsC = N;
113 
114         switch (order) {
115         case clblasRowMajor:
116             lda = ::std::max(lda, columnsA);
117             columnsA = lda;
118             ldc = ::std::max(ldc, columnsC);
119             columnsC = ldc;
120             break;
121         case clblasColumnMajor:
122             lda = ::std::max(lda, rowsA);
123             rowsA = lda;
124             ldc = ::std::max(ldc, rowsC);
125             rowsC = ldc;
126             break;
127         }
128 
129         printTestParams(order, uplo, transA, N, K, useAlpha, base->alpha(),
130                         offA, lda, useBeta, base->beta(), offC, ldc);
131         ::std::cerr << "seed = " << seed << ::std::endl;
132         ::std::cerr << "queues = " << numCommandQueues << ::std::endl;
133     }
134 
135     clblasOrder order;
136     clblasUplo uplo;
137     clblasTranspose transA;
138     size_t N, K;
139     size_t offA, offC;
140     size_t lda, ldc;
141     unsigned int seed;
142 
143     bool useAlpha, useBeta;
144     ComplexLong paramAlpha, paramBeta;
145 
146     size_t rowsA, columnsA;
147     size_t rowsC, columnsC;
148 
149     ::clMath::BlasBase *base;
150 
151     bool useNumCommandQueues;
152     cl_uint numCommandQueues;
153 };
154 
155 #endif  // SYRK_H_
156