1 #include "BSprivate.h"
2
3 /*@ BSback_solve - Backward triangular matrix solution on a
4 single vector
5
6 Input Parameters:
7 . A - The sparse matrix
8 . x - The rhs
9 . comm - The communication structure for A
10 . procinfo - the usual processor information
11
12 Output Parameters:
13 . x - on exit contains the solution vector
14
15 Returns:
16 void
17
18 @*/
BSback_solve(BSpar_mat * A,FLOAT * x,BScomm * comm,BSprocinfo * procinfo)19 void BSback_solve(BSpar_mat *A, FLOAT *x, BScomm *comm, BSprocinfo *procinfo)
20 {
21 BMcomp_msg *from_msg, *to_msg;
22 BMphase *to_phase, *from_phase;
23 BMmsg *msg;
24 int i, j, k;
25 int cl_ind, in_ind;
26 int count, size, ind, num_cols;
27 int *row;
28 FLOAT *nz;
29 BScl_2_inode *clique2inode = A->clique2inode;
30 BSnumbering *color2clique = A->color2clique;
31 BSinode *inodes = A->inodes->list;
32 int *in_index = clique2inode->inode_index;
33 int *proc = clique2inode->proc;
34 BSdense *d_mats = clique2inode->d_mats;
35 int *data_ptr, msg_len;
36 FLOAT *msg_buf, *matrix;
37 int my_id = procinfo->my_id;
38 FLOAT *work;
39 char UP = 'U';
40 char TR = 'T';
41 char NTR = 'N';
42 char ND = 'N';
43 int *col2cl = color2clique->numbers;
44 int length = color2clique->length;
45 int start, finish, symmetric;
46 int ione = 1;
47 FLOAT one = 1.0;
48 FLOAT zero = 0.0;
49 FLOAT minus_one = -1.0;
50 FLOAT DDOT();
51 int *gnum = A->global_row_num->numbers;
52 int *iperm = A->inv_perm->perm;
53
54 /* Is the symmetric data structure used? */
55 symmetric = A->icc_storage;
56
57 if(symmetric) {
58 from_msg = comm->to_msg; /* we do mean to switch these */
59 to_msg = comm->from_msg;
60 } else {
61 from_msg = comm->from_msg; /* do not switch for ILU case */
62 to_msg = comm->to_msg;
63 }
64
65 /* get some work space */
66 MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows,1);
67
68 /* post for all messages */
69 BMinit_comp_msg(from_msg,procinfo); CHKERR(0);
70
71 /* now do this phase by phase */
72 for (i=length-2;i>=0;i--) {
73 start = col2cl[i];
74 finish = col2cl[i+1];
75
76 if(!symmetric) {
77 /* invert the diagonals and find the answers */
78 for (cl_ind=start;cl_ind<finish;cl_ind++) {
79 if (my_id == proc[cl_ind]) {
80 size = clique2inode->d_mats[cl_ind].size;
81 ind = clique2inode->d_mats[cl_ind].local_ind;
82 matrix = clique2inode->d_mats[cl_ind].matrix;
83 /* can't do much better (likely) on this DGEMV */
84 DGEMV(&NTR,&size,&size,&one,matrix,&size,&(x[ind]),&ione,&zero,
85 work,&ione);
86 for (k=0; k<size; k++) x[ind+k] = work[k];
87 }
88 }
89 }
90
91 /* first send my messages */
92 /* this will involve computing partial sums */
93 to_phase = BMget_phase(to_msg,i); CHKERR(0);
94 msg = NULL;
95 while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
96 CHKERR(0);
97 msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
98 data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
99 if(symmetric) {
100 count = 0;
101 for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
102 for (in_ind=in_index[cl_ind];
103 in_ind<in_index[cl_ind+1];in_ind++) {
104 row = inodes[in_ind].row_num;
105 nz = inodes[in_ind].nz;
106 size = inodes[in_ind].length;
107 num_cols = inodes[in_ind].num_cols;
108 if (size > 0) {
109 #ifdef MY_BLAS_DGEMV_ON
110 if (num_cols > DGEMV_UNROLL_LVL) {
111 for (k=0;k<size;k++) work[k] = x[row[k]];
112 DGEMV(&TR,&size,&num_cols,&one,nz,&size,
113 work,&ione,&zero,&(msg_buf[count]),&ione);
114 } else {
115 MY_DGEMV_Y_1101(size,num_cols,nz,size,x,row,
116 &(msg_buf[count]));
117 }
118 #else
119 for (k=0;k<size;k++) work[k] = x[row[k]];
120 DGEMV(&TR,&size,&num_cols,&one,nz,&size,
121 work,&ione,&zero,&(msg_buf[count]),&ione);
122 #endif
123 }
124 count += num_cols;
125 }
126 }
127 } else {
128 for (j=0; j<msg_len; j++)
129 msg_buf[j] = x[data_ptr[j]];
130 }
131 BMsendf_msg(msg,procinfo); CHKERR(0);
132 }
133 CHKERR(0);
134
135 /* do some local work, multiply by the i-nodes */
136 for (cl_ind=start;cl_ind<finish;cl_ind++) {
137 if (my_id == proc[cl_ind]) {
138 ind = d_mats[cl_ind].local_ind;
139 for (in_ind=in_index[cl_ind];
140 in_ind<in_index[cl_ind+1];in_ind++) {
141 size = inodes[in_ind].length;
142 num_cols = inodes[in_ind].num_cols;
143 row = inodes[in_ind].row_num;
144 nz = inodes[in_ind].nz;
145 if(symmetric) {
146 if (size > 0) {
147 #ifdef MY_BLAS_DGEMV_ON
148 if (num_cols > DGEMV_UNROLL_LVL) {
149 for (k=0;k<size;k++) work[k] = x[row[k]];
150 DGEMV(&TR,&size,&num_cols,&minus_one,nz,&size,
151 work,&ione,&one,&(x[ind]),&ione);
152 } else {
153 MY_DGEMVM1_Y_1111(size,num_cols,nz,size,x,row,
154 &(x[ind]));
155 }
156 #else
157 for (k=0;k<size;k++) work[k] = x[row[k]];
158 DGEMV(&TR,&size,&num_cols,&minus_one,nz,&size,
159 work,&ione,&one,&(x[ind]),&ione);
160 #endif
161 }
162 } else {
163 /* The following part is added to make sure the nz are */
164 /* above pivot. (ILU) */
165 length = size;
166 size = inodes[in_ind].below_diag;
167 /*
168 for (j=length-1; j>=0; j--) {
169 if (gnum[iperm[row[j]]] > inodes[in_ind].gcol_num)
170 size--;
171 else
172 break;
173 }
174 if(size!=inodes[in_ind].below_diag) {
175 printf("BS, L: size = %d, size2 = %d\n",size,
176 inodes[in_ind].below_diag);
177 }
178 */
179 if (size > 0) {
180 #ifdef MY_BLAS_DGEMV_ON
181 if (num_cols > DGEMV_UNROLL_LVL) {
182 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
183 &ione,&zero,work,&ione);
184 for (k=0;k<size;k++) x[row[k]] -= work[k];
185 } else {
186 MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]),
187 x,row);
188 }
189 #else
190 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
191 &ione,&zero,work,&ione);
192 for (k=0;k<size;k++) x[row[k]] -= work[k];
193 #endif
194 }
195 }
196 ind += num_cols;
197 }
198 }
199 }
200
201 /* receive my messages and update my rhs */
202 from_phase = BMget_phase(from_msg,i); CHKERR(0);
203 while ((msg = BMrecv_msg(from_phase)) != NULL) {
204 CHKERR(0);
205 msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
206 data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
207 if(symmetric) {
208 msg_len = BMget_msg_size(msg); CHKERR(0);
209 for (j=0;j<msg_len;j++) x[data_ptr[j]] -= msg_buf[j];
210 } else {
211 count = 0;
212 for (cl_ind=data_ptr[0]; cl_ind<=data_ptr[1]; cl_ind++) {
213 for (in_ind=clique2inode->inode_index[cl_ind];
214 in_ind<clique2inode->inode_index[cl_ind+1]; in_ind++) {
215 row = inodes[in_ind].row_num;
216 nz = inodes[in_ind].nz;
217 /*size = inodes[in_ind].length;*/
218 length = inodes[in_ind].length;
219 num_cols = inodes[in_ind].num_cols;
220
221 size = inodes[in_ind].below_diag;
222 /*
223 for (j=length-1; j>=0; j--) {
224 if (gnum[iperm[row[j]]] > inodes[in_ind].gcol_num)
225 size--;
226 else
227 break;
228 }
229 if(size!=inodes[in_ind].below_diag) {
230 printf("NL: size = %d, size2 = %d\n",size,
231 inodes[in_ind].below_diag);
232 }
233 */
234 if (size > 0) {
235 #ifdef MY_BLAS_DGEMV_ON
236 if (num_cols > DGEMV_UNROLL_LVL) {
237 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(msg_buf[count]),
238 &ione,&zero,work,&ione);
239 for (k=0;k<size;k++) x[row[k]] -= work[k];
240 } else {
241 MY_DGEMVM1_N_1111(size,num_cols,nz,length,
242 &(msg_buf[count]),x,row);
243 }
244 #else
245 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(msg_buf[count]),
246 &ione,&zero,work,&ione);
247 for (k=0;k<size;k++) x[row[k]] -= work[k];
248 #endif
249 }
250 count += num_cols;
251 }
252 }
253 }
254 BMfree_msg(msg); CHKERR(0);
255 }
256 CHKERR(0);
257
258 if(symmetric) {
259 /* invert the diagonals and find the answers */
260 for (cl_ind=start;cl_ind<finish;cl_ind++) {
261 if (my_id == proc[cl_ind]) {
262 /* first, multiply the clique */
263 /* only do the strictly upper triangular part */
264 /* we ASSUME the diagonal is all 1's */
265 size = clique2inode->d_mats[cl_ind].size;
266 #ifdef MY_BLAS_DTRMV_ON
267 MY_DTRMV_N_U(size,d_mats[cl_ind].matrix,size,
268 &(x[d_mats[cl_ind].local_ind]),work);
269 #else
270 DTRMV(&UP,&NTR,&ND,&size,d_mats[cl_ind].matrix,&size,
271 &(x[d_mats[cl_ind].local_ind]),&ione);
272 #endif
273 }
274 }
275 }
276
277 }
278 MY_FREE(work);
279 /* wait for all of the sent messages to finish */
280 BMfinish_comp_msg(to_msg,procinfo); CHKERR(0);
281 MLOG_flop((2*A->local_nnz));
282 }
283