1 /* ************************************************************************ 2 * Copyright 2013 Advanced Micro Devices, Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * ************************************************************************/ 16 17 18 #ifndef SYRK_H_ 19 #define SYRK_H_ 20 21 #include <gtest/gtest.h> 22 #include <clBLAS.h> 23 #include <common.h> 24 #include <BlasBase.h> 25 #include <ExtraTestSizes.h> 26 #include <common.h> 27 28 using namespace clMath; 29 using ::testing::TestWithParam; 30 31 class SYRK : public TestWithParam< 32 ::std::tr1::tuple< 33 clblasOrder, // order 34 clblasUplo, // uplo 35 clblasTranspose, // transA 36 int, // N 37 int, // K 38 ExtraTestSizes, 39 int // numCommandQueues 40 > > { 41 public: getParams(TestParams * params)42 void getParams(TestParams *params) 43 { 44 memset(params, 0, sizeof(TestParams)); 45 46 params->order = order; 47 params->uplo = uplo; 48 params->transA = transA; 49 params->seed = seed; 50 params->N = N; 51 params->K = K; 52 params->offA = offA; 53 params->offCY = offC; 54 params->lda = lda; 55 params->ldc = ldc; 56 params->rowsA = rowsA; 57 params->columnsA = columnsA; 58 params->rowsC = rowsC; 59 params->columnsC = columnsC; 60 params->numCommandQueues = numCommandQueues; 61 } 62 63 protected: SetUp()64 virtual void SetUp() 65 { 66 ExtraTestSizes extra; 67 68 order = ::std::tr1::get<0>(GetParam()); 69 uplo = ::std::tr1::get<1>(GetParam()); 70 transA = ::std::tr1::get<2>(GetParam()); 71 N = ::std::tr1::get<3>(GetParam()); 72 K = ::std::tr1::get<4>(GetParam()); 73 extra = ::std::tr1::get<5>(GetParam()); 74 offA = extra.offA; 75 offC = extra.offCY; 76 lda = extra.strideA.ld; 77 ldc = extra.strideCY.ld; 78 numCommandQueues = ::std::tr1::get<6>(GetParam()); 79 80 base = ::clMath::BlasBase::getInstance(); 81 seed = base->seed(); 82 83 useNumCommandQueues = base->useNumCommandQueues(); 84 if (useNumCommandQueues) { 85 numCommandQueues = base->numCommandQueues(); 86 } 87 88 useAlpha = base->useAlpha(); 89 if (useAlpha != 0) { 90 paramAlpha = base->alpha(); 91 } 92 useBeta = base->useBeta(); 93 if (useBeta != 0) { 94 paramBeta = base->beta(); 95 } 96 if (base->useN()) { 97 N = base->N(); 98 } 99 if (base->useK()) { 100 K = base->K(); 101 } 102 103 if (transA == clblasNoTrans) { 104 rowsA = N; 105 columnsA = K; 106 } 107 else { 108 rowsA = K; 109 columnsA = N; 110 } 111 rowsC = N; 112 columnsC = N; 113 114 switch (order) { 115 case clblasRowMajor: 116 lda = ::std::max(lda, columnsA); 117 columnsA = lda; 118 ldc = ::std::max(ldc, columnsC); 119 columnsC = ldc; 120 break; 121 case clblasColumnMajor: 122 lda = ::std::max(lda, rowsA); 123 rowsA = lda; 124 ldc = ::std::max(ldc, rowsC); 125 rowsC = ldc; 126 break; 127 } 128 129 printTestParams(order, uplo, transA, N, K, useAlpha, base->alpha(), 130 offA, lda, useBeta, base->beta(), offC, ldc); 131 ::std::cerr << "seed = " << seed << ::std::endl; 132 ::std::cerr << "queues = " << numCommandQueues << ::std::endl; 133 } 134 135 clblasOrder order; 136 clblasUplo uplo; 137 clblasTranspose transA; 138 size_t N, K; 139 size_t offA, offC; 140 size_t lda, ldc; 141 unsigned int seed; 142 143 bool useAlpha, useBeta; 144 ComplexLong paramAlpha, paramBeta; 145 146 size_t rowsA, columnsA; 147 size_t rowsC, columnsC; 148 149 ::clMath::BlasBase *base; 150 151 bool useNumCommandQueues; 152 cl_uint numCommandQueues; 153 }; 154 155 #endif // SYRK_H_ 156