1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 // Problem iterator to scatter solving, for passing over matrix A
19 
20 #include <assert.h>
21 #include <sys/types.h>
22 #include <clblas_stddef.h>
23 
24 #include "matrix_dims.h"
25 #include "problem_iter.h"
26 
27 void VISIBILITY_HIDDEN
initProblemIterator(ProblemIterator * iter,BlasFunctionID funcID,MatrixRole mrole,CLBlasKargs * kargs,size_t maxPanels,size_t maxBlocks,SubproblemDim * topDim)28 initProblemIterator(
29     ProblemIterator *iter,
30     BlasFunctionID funcID,
31     MatrixRole mrole,
32     CLBlasKargs *kargs,
33     size_t maxPanels,
34     size_t maxBlocks,
35     SubproblemDim *topDim)
36 {
37     SubproblemDim tmp;
38 
39     iter->mrole = mrole;
40     iter->funcID = funcID;
41     kargsToProbDims(&tmp, funcID, kargs, false);
42     iter->size = matrBlockHeight(&tmp, mrole, kargs->side);
43     iter->globPitch = matrBlockPitch(&tmp, mrole, kargs->dtype, kargs->side);
44     iter->maxPanels = maxPanels;
45     iter->maxBlocks = maxBlocks;
46     iter->uplo = kargs->uplo;
47     iter->side = kargs->side;
48     iter->dtype = kargs->dtype;
49     iter->bpitch = matrBlockPitch(topDim, mrole, kargs->dtype, kargs->side);
50     iter->bheight = matrBlockHeight(topDim, mrole, kargs->side);
51     iteratorReset(iter);
52 }
53 
54 void VISIBILITY_HIDDEN
iteratorReset(ProblemIterator * iter)55 iteratorReset(ProblemIterator *iter)
56 {
57     if (isIterBackward(iter)) {
58         iter->pos = iter->size;
59         iter->prevPos = iter->size;
60     }
61     else {
62         iter->pos = 0;
63         iter->prevPos = 0;
64     }
65 }
66 
67 bool VISIBILITY_HIDDEN
isIterBackward(ProblemIterator * iter)68 isIterBackward(ProblemIterator *iter)
69 {
70     bool ret = false;
71 
72     if (iter->funcID != CLBLAS_GEMM) {
73         ret = (iter->side == clblasLeft && iter->uplo == clblasLower) ||
74               (iter->side == clblasRight && iter->uplo == clblasUpper);
75         if (iter->funcID == CLBLAS_TRSM) {
76             ret = !ret;
77         }
78     }
79 
80     return ret;
81 }
82 
83 int VISIBILITY_HIDDEN
iterateProblem(ProblemIterator * iter)84 iterateProblem(ProblemIterator *iter)
85 {
86     bool backward;
87     size_t dy = 0;
88 
89     backward = isIterBackward(iter);
90 
91     if (((iter->funcID != CLBLAS_TRSM) && (!iter->maxPanels)) ||
92             ((iter->funcID == CLBLAS_TRSM) && (!iter->maxBlocks))) {
93         iter->pos = (backward) ? 0 : iter->size;
94         return 1;
95     }
96 
97     iter->prevPos = iter->pos;
98 
99     if ((iter->funcID != CLBLAS_TRSM)) {
100         dy = iter->maxPanels * iter->bheight;
101         assert(dy != 0);
102     }
103     if (backward) {
104         dy = szmin(iter->pos, dy);
105         iter->pos -= dy;
106     }
107     else {
108         dy = szmin(dy, iter->size - iter->pos);
109         iter->pos += dy;
110     }
111 
112     return (int)(backward && iter->pos == 0) ||
113                 (!backward && iter->pos == iter->size);
114 }
115 
116 size_t VISIBILITY_HIDDEN
iterLastOffset(ProblemIterator * iter)117 iterLastOffset(ProblemIterator *iter)
118 {
119     return (iter->pos > iter->prevPos) ? (iter->pos - iter->prevPos) :
120            (iter->prevPos - iter->pos);
121 }
122