1 /* ************************************************************************
2 * Copyright 2013 Advanced Micro Devices, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 * ************************************************************************/
16
17
18 // Problem iterator to scatter solving, for passing over matrix A
19
20 #include <assert.h>
21 #include <sys/types.h>
22 #include <clblas_stddef.h>
23
24 #include "matrix_dims.h"
25 #include "problem_iter.h"
26
27 void VISIBILITY_HIDDEN
initProblemIterator(ProblemIterator * iter,BlasFunctionID funcID,MatrixRole mrole,CLBlasKargs * kargs,size_t maxPanels,size_t maxBlocks,SubproblemDim * topDim)28 initProblemIterator(
29 ProblemIterator *iter,
30 BlasFunctionID funcID,
31 MatrixRole mrole,
32 CLBlasKargs *kargs,
33 size_t maxPanels,
34 size_t maxBlocks,
35 SubproblemDim *topDim)
36 {
37 SubproblemDim tmp;
38
39 iter->mrole = mrole;
40 iter->funcID = funcID;
41 kargsToProbDims(&tmp, funcID, kargs, false);
42 iter->size = matrBlockHeight(&tmp, mrole, kargs->side);
43 iter->globPitch = matrBlockPitch(&tmp, mrole, kargs->dtype, kargs->side);
44 iter->maxPanels = maxPanels;
45 iter->maxBlocks = maxBlocks;
46 iter->uplo = kargs->uplo;
47 iter->side = kargs->side;
48 iter->dtype = kargs->dtype;
49 iter->bpitch = matrBlockPitch(topDim, mrole, kargs->dtype, kargs->side);
50 iter->bheight = matrBlockHeight(topDim, mrole, kargs->side);
51 iteratorReset(iter);
52 }
53
54 void VISIBILITY_HIDDEN
iteratorReset(ProblemIterator * iter)55 iteratorReset(ProblemIterator *iter)
56 {
57 if (isIterBackward(iter)) {
58 iter->pos = iter->size;
59 iter->prevPos = iter->size;
60 }
61 else {
62 iter->pos = 0;
63 iter->prevPos = 0;
64 }
65 }
66
67 bool VISIBILITY_HIDDEN
isIterBackward(ProblemIterator * iter)68 isIterBackward(ProblemIterator *iter)
69 {
70 bool ret = false;
71
72 if (iter->funcID != CLBLAS_GEMM) {
73 ret = (iter->side == clblasLeft && iter->uplo == clblasLower) ||
74 (iter->side == clblasRight && iter->uplo == clblasUpper);
75 if (iter->funcID == CLBLAS_TRSM) {
76 ret = !ret;
77 }
78 }
79
80 return ret;
81 }
82
83 int VISIBILITY_HIDDEN
iterateProblem(ProblemIterator * iter)84 iterateProblem(ProblemIterator *iter)
85 {
86 bool backward;
87 size_t dy = 0;
88
89 backward = isIterBackward(iter);
90
91 if (((iter->funcID != CLBLAS_TRSM) && (!iter->maxPanels)) ||
92 ((iter->funcID == CLBLAS_TRSM) && (!iter->maxBlocks))) {
93 iter->pos = (backward) ? 0 : iter->size;
94 return 1;
95 }
96
97 iter->prevPos = iter->pos;
98
99 if ((iter->funcID != CLBLAS_TRSM)) {
100 dy = iter->maxPanels * iter->bheight;
101 assert(dy != 0);
102 }
103 if (backward) {
104 dy = szmin(iter->pos, dy);
105 iter->pos -= dy;
106 }
107 else {
108 dy = szmin(dy, iter->size - iter->pos);
109 iter->pos += dy;
110 }
111
112 return (int)(backward && iter->pos == 0) ||
113 (!backward && iter->pos == iter->size);
114 }
115
116 size_t VISIBILITY_HIDDEN
iterLastOffset(ProblemIterator * iter)117 iterLastOffset(ProblemIterator *iter)
118 {
119 return (iter->pos > iter->prevPos) ? (iter->pos - iter->prevPos) :
120 (iter->prevPos - iter->pos);
121 }
122