1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 #ifndef TRSV_DELTA_H_
18 #define TRSV_DELTA_H_
19 
20 #include "delta.h"
21 
22 static size_t
trsvBlockSize(size_t elemSize)23 trsvBlockSize(size_t elemSize)
24 {
25     /* TODO: Right now TRSV generators use block size of 16 elements for the
26      *       double complex type, and of 32 elements for another types.
27      *       If this changes, we have to fetch block size from TRSV generator
28      *       somehow.
29      */
30     return (elemSize == sizeof(DoubleComplex)) ? 16 : 32;
31 }
32 
33 template <typename T>
34 void
trsvDelta(clblasOrder order,clblasUplo uplo,clblasTranspose transA,clblasDiag diag,size_t N,T * A,size_t lda,T * X,int incx,cl_double * delta)35 trsvDelta(
36     clblasOrder order,
37     clblasUplo uplo,
38     clblasTranspose transA,
39     clblasDiag diag,
40     size_t N,
41     T *A,
42     size_t lda,
43     T *X,
44 	int incx,
45     cl_double *delta)
46 {
47     cl_double *deltaCLBLAS, s;
48     int i, j, jStart, jEnd, idx;
49     int zinc;
50     size_t z = 0;
51     size_t bsize, lenX;
52     bool isUpper = false;
53 	size_t previncxi=0;
54     T v;
55 
56    	isUpper = ((uplo == clblasUpper) && (transA == clblasNoTrans)) ||
57              ((uplo == clblasLower) && (transA != clblasNoTrans));
58 	// incx = abs(incx);
59 	lenX = 1 + (N-1)*abs(incx);
60     deltaCLBLAS = new cl_double[lenX];
61     bsize = trsvBlockSize(sizeof(T));
62 
63         // Calculate delta of TRSV evaluated with the Gauss' method
64 
65             if (isUpper) {
66                 for (i = (int)N - 1; i >= 0; i--) {
67 					size_t incxi;
68 
69 					incxi = (incx > 0) ? (i*incx) : (N-1-i)*abs(incx);
70                     v = getElement<T>(clblasColumnMajor, clblasNoTrans, incxi, 0, X, lenX);
71                     if (diag == clblasNonUnit) {
72                         T tempA;
73                         if(lda > 0)
74                         {
75                             tempA = getElement<T>(order, transA, i, i, A, lda);
76                     }
77                         else
78                         {
79                             tempA = getElementPacked(order, clblasNoTrans, uplo, i, i, A, N);
80                         }
81                         v = v / tempA;
82                     }
83                     s = module(v) * DELTA_0<T>();
84                     if (i == (int)(N - 1)) {
85                         delta[ incxi ] = s;
86                     }
87                     else {
88                         delta[ incxi ] = s + delta[ previncxi ];
89                     }
90                     assert(delta[ incxi ] >= 0);
91 					previncxi = incxi;
92                 }
93             }
94             else {
95                 for (i = 0; i < (int)N; i++) {
96 					size_t incxi;
97 
98 					incxi = (incx > 0) ? (i*incx) : (N-1-i)*abs(incx);
99                     v = getElement<T>(clblasColumnMajor, clblasNoTrans, incxi, 0, X, lenX);
100                     if (diag == clblasNonUnit) {
101                         T tempA;
102                         if(lda > 0)
103                         {
104                             tempA = getElement<T>(order, transA, i, i, A, lda);
105                     }
106                         else
107                         {
108                             tempA = getElementPacked(order, clblasNoTrans, uplo, i, i, A, N);
109                         }
110                         v = v / tempA;
111                     }
112                     s = module(v) * DELTA_0<T>();
113                     if (i == 0) {
114                         delta[ incxi ] = s;
115                     }
116                     else {
117                         delta[ incxi ] = s + delta[ previncxi ];
118                     }
119                     assert(delta[ incxi ] >= 0);
120 					previncxi = incxi;
121                 }
122             }
123 
124         // Calculate clblas TRSV delta
125 
126             for (i = 0; i < (int)N; i++) {
127 				size_t incxi;
128                 s = 0.0;
129 
130                 /*
131                  *  For the upper triangular matrix the solving process proceeds
132                  *  from the bottom to the top, and the bottommost block's
133                  *  delta influents most of all. For the lower triangular matrix
134                  *  the situation is opposite.
135                  */
136                 if (isUpper) {
137                     jStart = i / (int)bsize;
138                     // index of the block just after the last matrix block
139                     jEnd = ((int)N + (int)bsize - 1) / (int)bsize;
140                     z = 1;
141                     zinc = 1;
142                 }
143                 else {
144                     jStart = 0;
145                     jEnd = i / (int)bsize + 1;
146                     z = jEnd - jStart;
147                     zinc = -1;
148                 }
149 
150                 for (j = jStart; j < jEnd; j++) {
151 					size_t incxi;
152 
153                     idx = j * (int)bsize + i % (int)bsize;
154                     if (idx >= (int)N) {
155                         continue;
156                     }
157 					incxi = (incx > 0) ? (idx*incx) : (N-1-idx)*abs(incx);
158                     s += z * delta[ incxi ];
159                     z += zinc;
160                 }
161 
162 				incxi = (incx > 0) ? (i*incx) : (N-1-i)*abs(incx);
163                 deltaCLBLAS[ incxi ] = s * bsize;
164                 assert(deltaCLBLAS[ incxi ] >= 0);
165             }
166 
167 			for (i = 0; i < (int)N; i++) {
168 				size_t incxi;
169 
170 				incxi = (incx > 0) ? (i*incx) : (N-1-i)*abs(incx);
171 				delta[ incxi ] += deltaCLBLAS[ incxi ];
172 			}
173 
174     delete[] deltaCLBLAS;
175 }
176 
177 template <typename T>
178 void
tbsvDelta(clblasOrder order,clblasUplo uplo,clblasTranspose transA,clblasDiag diag,size_t N,size_t K,T * A,size_t lda,T * X,int incx,cl_double * delta)179 tbsvDelta(
180     clblasOrder order,
181     clblasUplo uplo,
182     clblasTranspose transA,
183     clblasDiag diag,
184     size_t N,
185     size_t K,
186     T *A,
187     size_t lda,
188     T *X,
189     int incx,
190     cl_double *delta)
191 {
192     cl_double *deltaCLBLAS, s;
193     int i, j, jStart, jEnd, idx;
194     int zinc;
195     size_t z = 0;
196     size_t bsize, lenX;
197     bool isUpper = false;
198     size_t previncxi=0;
199     T v;
200 
201     isUpper = ((uplo == clblasUpper) && (transA == clblasNoTrans)) ||
202              ((uplo == clblasLower) && (transA != clblasNoTrans));
203     lenX = 1 + (N-1)*abs(incx);
204     deltaCLBLAS = new cl_double[lenX];
205     bsize = trsvBlockSize(sizeof(T));
206 
207         // Calculate delta of TRSV evaluated with the Gauss' method
208 
209             if (isUpper) {
210                 for (i = (int)N - 1; i >= 0; i--) {
211                     size_t incxi;
212 
213                     incxi = (incx > 0) ? (i*incx) : (N-1-i)*abs(incx);
214                     v = getElement<T>(clblasColumnMajor, clblasNoTrans, incxi, 0, X, lenX);
215                     if (diag == clblasNonUnit) {
216                         v = v / getElementBanded<T>(order, uplo, i, i, K, A, lda);
217                     }
218                     s = module(v) * DELTA_0<T>();
219                     if (i == (int)(N - 1)) {
220                         delta[ incxi ] = s;
221                     }
222                     else {
223                         delta[ incxi ] = s + delta[ previncxi ];
224                     }
225                     assert(delta[ incxi ] >= 0);
226                     previncxi = incxi;
227                 }
228             }
229             else {
230                 for (i = 0; i < (int)N; i++) {
231                     size_t incxi;
232 
233                     incxi = (incx > 0) ? (i*incx) : (N-1-i)*abs(incx);
234                     v = getElement<T>(clblasColumnMajor, clblasNoTrans, incxi, 0, X, lenX);
235                     if (diag == clblasNonUnit) {
236                         v = v / getElementBanded<T>(order, uplo, i, i, K, A, lda);
237                     }
238                     s = module(v) * DELTA_0<T>();
239                     if (i == 0) {
240                         delta[ incxi ] = s;
241                     }
242                     else {
243                         delta[ incxi ] = s + delta[ previncxi ];
244                     }
245                     assert(delta[ incxi ] >= 0);
246                     previncxi = incxi;
247                 }
248             }
249 
250         // Calculate clblas TRSV delta
251 
252             for (i = 0; i < (int)N; i++) {
253                 size_t incxi;
254                 s = 0.0;
255                 if (isUpper) {
256                     jStart = i / (int)bsize;
257                     // index of the block just after the last matrix block
258                     jEnd = ((int)N + (int)bsize - 1) / (int)bsize;
259                     z = 1;
260                     zinc = 1;
261                 }
262                 else {
263                     jStart = 0;
264                     jEnd = i / (int)bsize + 1;
265                     z = jEnd - jStart;
266                     zinc = -1;
267                 }
268 
269                 for (j = jStart; j < jEnd; j++) {
270                     size_t incxi;
271 
272                     idx = j * (int)bsize + i % (int)bsize;
273                     if (idx >= (int)N) {
274                         continue;
275                     }
276                     incxi = (incx > 0) ? (idx*incx) : (N-1-idx)*abs(incx);
277                     s += z * delta[ incxi ];
278                     z += zinc;
279                 }
280 
281                 incxi = (incx > 0) ? (i*incx) : (N-1-i)*abs(incx);
282                 deltaCLBLAS[ incxi ] = s * bsize;
283                 assert(deltaCLBLAS[ incxi ] >= 0);
284             }
285 
286             for (i = 0; i < (int)N; i++) {
287                 size_t incxi;
288 
289                 incxi = (incx > 0) ? (i*incx) : (N-1-i)*abs(incx);
290                 delta[ incxi ] += deltaCLBLAS[ incxi ];
291             }
292 
293     delete[] deltaCLBLAS;
294 }
295 #endif
296 
297