1 #include "BSprivate.h"
2 
3 /*@ BSback_solve - Backward triangular matrix solution on a
4                    single vector
5 
6     Input Parameters:
7 .   A - The sparse matrix
8 .   x - The rhs
9 .   comm - The communication structure for A
10 .   procinfo - the usual processor information
11 
12     Output Parameters:
13 .   x - on exit contains the solution vector
14 
15     Returns:
16     void
17 
18  @*/
BSback_solve(BSpar_mat * A,FLOAT * x,BScomm * comm,BSprocinfo * procinfo)19 void BSback_solve(BSpar_mat *A, FLOAT *x, BScomm *comm, BSprocinfo *procinfo)
20 {
21 	BMcomp_msg *from_msg, *to_msg;
22 	BMphase *to_phase, *from_phase;
23 	BMmsg *msg;
24 	int	i, j, k;
25 	int	cl_ind, in_ind;
26 	int	count, size, ind, num_cols;
27 	int *row;
28 	FLOAT *nz;
29 	BScl_2_inode *clique2inode = A->clique2inode;
30 	BSnumbering *color2clique = A->color2clique;
31 	BSinode *inodes = A->inodes->list;
32 	int	*in_index = clique2inode->inode_index;
33 	int	*proc = clique2inode->proc;
34 	BSdense	*d_mats = clique2inode->d_mats;
35 	int	*data_ptr, msg_len;
36 	FLOAT *msg_buf, *matrix;
37 	int	my_id = procinfo->my_id;
38 	FLOAT *work;
39 	char UP = 'U';
40 	char TR = 'T';
41 	char NTR = 'N';
42 	char ND = 'N';
43 	int	*col2cl = color2clique->numbers;
44 	int	length = color2clique->length;
45 	int	start, finish, symmetric;
46 	int	ione = 1;
47 	FLOAT one = 1.0;
48 	FLOAT zero = 0.0;
49 	FLOAT minus_one = -1.0;
50 	FLOAT DDOT();
51 	int *gnum = A->global_row_num->numbers;
52 	int *iperm = A->inv_perm->perm;
53 
54 	/* Is the symmetric data structure used? */
55 	symmetric = A->icc_storage;
56 
57 	if(symmetric) {
58 		from_msg = comm->to_msg; /* we do mean to switch these */
59 		to_msg = comm->from_msg;
60 	} else {
61 		from_msg = comm->from_msg; /* do not switch for ILU case */
62 		to_msg = comm->to_msg;
63 	}
64 
65 	/* get some work space */
66 	MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows,1);
67 
68 	/* post for all messages */
69 	BMinit_comp_msg(from_msg,procinfo); CHKERR(0);
70 
71 	/* now do this phase by phase */
72 	for (i=length-2;i>=0;i--) {
73 		start = col2cl[i];
74 		finish = col2cl[i+1];
75 
76 		if(!symmetric) {
77 			/* invert the diagonals and find the answers */
78 			for (cl_ind=start;cl_ind<finish;cl_ind++) {
79 				if (my_id == proc[cl_ind]) {
80 					size = clique2inode->d_mats[cl_ind].size;
81 					ind = clique2inode->d_mats[cl_ind].local_ind;
82 					matrix = clique2inode->d_mats[cl_ind].matrix;
83 					/* can't do much better (likely) on this DGEMV */
84 					DGEMV(&NTR,&size,&size,&one,matrix,&size,&(x[ind]),&ione,&zero,
85 						work,&ione);
86 					for (k=0; k<size; k++) x[ind+k] = work[k];
87 				}
88 			}
89 		}
90 
91 		/* first send my messages */
92 		/* this will involve computing partial sums */
93 		to_phase = BMget_phase(to_msg,i); CHKERR(0);
94 		msg = NULL;
95 		while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
96 			CHKERR(0);
97 			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
98 			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
99 			if(symmetric) {
100 				count = 0;
101 				for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
102 					for (in_ind=in_index[cl_ind];
103 						in_ind<in_index[cl_ind+1];in_ind++) {
104 						row = inodes[in_ind].row_num;
105 						nz = inodes[in_ind].nz;
106 						size = inodes[in_ind].length;
107 						num_cols = inodes[in_ind].num_cols;
108 						if (size > 0) {
109 #ifdef MY_BLAS_DGEMV_ON
110 							if (num_cols > DGEMV_UNROLL_LVL) {
111 								for (k=0;k<size;k++) work[k] = x[row[k]];
112 								DGEMV(&TR,&size,&num_cols,&one,nz,&size,
113 									work,&ione,&zero,&(msg_buf[count]),&ione);
114 							} else {
115 								MY_DGEMV_Y_1101(size,num_cols,nz,size,x,row,
116 									&(msg_buf[count]));
117 							}
118 #else
119 							for (k=0;k<size;k++) work[k] = x[row[k]];
120 							DGEMV(&TR,&size,&num_cols,&one,nz,&size,
121 								work,&ione,&zero,&(msg_buf[count]),&ione);
122 #endif
123 						}
124 						count += num_cols;
125 					}
126 				}
127 			} else {
128 				for (j=0; j<msg_len; j++)
129 					msg_buf[j] = x[data_ptr[j]];
130 			}
131 			BMsendf_msg(msg,procinfo); CHKERR(0);
132 		}
133 		CHKERR(0);
134 
135 		/* do some local work, multiply by the i-nodes */
136 		for (cl_ind=start;cl_ind<finish;cl_ind++) {
137 			if (my_id == proc[cl_ind]) {
138 				ind = d_mats[cl_ind].local_ind;
139 				for (in_ind=in_index[cl_ind];
140 					in_ind<in_index[cl_ind+1];in_ind++) {
141 					size = inodes[in_ind].length;
142 					num_cols = inodes[in_ind].num_cols;
143 					row = inodes[in_ind].row_num;
144 					nz = inodes[in_ind].nz;
145 					if(symmetric) {
146 						if (size > 0) {
147 #ifdef MY_BLAS_DGEMV_ON
148 							if (num_cols > DGEMV_UNROLL_LVL) {
149 								for (k=0;k<size;k++) work[k] = x[row[k]];
150 								DGEMV(&TR,&size,&num_cols,&minus_one,nz,&size,
151 									work,&ione,&one,&(x[ind]),&ione);
152 							} else {
153 								MY_DGEMVM1_Y_1111(size,num_cols,nz,size,x,row,
154 									&(x[ind]));
155 							}
156 #else
157 							for (k=0;k<size;k++) work[k] = x[row[k]];
158 							DGEMV(&TR,&size,&num_cols,&minus_one,nz,&size,
159 								work,&ione,&one,&(x[ind]),&ione);
160 #endif
161 						}
162 					} else {
163 						/* The following part is added to make sure the nz are */
164 						/* above pivot. (ILU)                                  */
165 						length = size;
166 						size = inodes[in_ind].below_diag;
167 						/*
168 						for (j=length-1; j>=0; j--) {
169 						if (gnum[iperm[row[j]]] > inodes[in_ind].gcol_num)
170 							size--;
171 						else
172 							break;
173 						}
174 						if(size!=inodes[in_ind].below_diag) {
175 							printf("BS, L: size = %d, size2 = %d\n",size,
176 								inodes[in_ind].below_diag);
177 						}
178 						*/
179 						if (size > 0) {
180 #ifdef MY_BLAS_DGEMV_ON
181 							if (num_cols > DGEMV_UNROLL_LVL) {
182 								DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
183 									&ione,&zero,work,&ione);
184 								for (k=0;k<size;k++) x[row[k]] -= work[k];
185 							} else {
186 								MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]),
187 									x,row);
188 							}
189 #else
190 							DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
191 								&ione,&zero,work,&ione);
192 							for (k=0;k<size;k++) x[row[k]] -= work[k];
193 #endif
194 						}
195 					}
196 					ind += num_cols;
197 				}
198 			}
199 		}
200 
201 		/* receive my messages and update my rhs */
202 		from_phase = BMget_phase(from_msg,i); CHKERR(0);
203 		while ((msg = BMrecv_msg(from_phase)) != NULL) {
204 			CHKERR(0);
205 			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
206 			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
207 			if(symmetric) {
208 				msg_len = BMget_msg_size(msg); CHKERR(0);
209 				for (j=0;j<msg_len;j++) x[data_ptr[j]] -= msg_buf[j];
210 			} else {
211 				count = 0;
212 				for (cl_ind=data_ptr[0]; cl_ind<=data_ptr[1]; cl_ind++) {
213 					for (in_ind=clique2inode->inode_index[cl_ind];
214 					in_ind<clique2inode->inode_index[cl_ind+1]; in_ind++) {
215 						row = inodes[in_ind].row_num;
216 						nz = inodes[in_ind].nz;
217 						/*size = inodes[in_ind].length;*/
218 						length = inodes[in_ind].length;
219 						num_cols = inodes[in_ind].num_cols;
220 
221 						size = inodes[in_ind].below_diag;
222 						/*
223 						for (j=length-1; j>=0; j--) {
224 							if (gnum[iperm[row[j]]] > inodes[in_ind].gcol_num)
225 								size--;
226 							else
227 								break;
228 						}
229 						if(size!=inodes[in_ind].below_diag) {
230 							printf("NL: size = %d, size2 = %d\n",size,
231 								inodes[in_ind].below_diag);
232 						}
233 						*/
234 						if (size > 0) {
235 #ifdef MY_BLAS_DGEMV_ON
236 							if (num_cols > DGEMV_UNROLL_LVL) {
237 								DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(msg_buf[count]),
238 									&ione,&zero,work,&ione);
239 								for (k=0;k<size;k++) x[row[k]] -= work[k];
240 							} else {
241 								MY_DGEMVM1_N_1111(size,num_cols,nz,length,
242 									&(msg_buf[count]),x,row);
243 							}
244 #else
245 							DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(msg_buf[count]),
246 								&ione,&zero,work,&ione);
247 							for (k=0;k<size;k++) x[row[k]] -= work[k];
248 #endif
249 						}
250 						count += num_cols;
251 					}
252 				}
253 			}
254 			BMfree_msg(msg); CHKERR(0);
255 		}
256 		CHKERR(0);
257 
258 		if(symmetric) {
259 			/* invert the diagonals and find the answers */
260 			for (cl_ind=start;cl_ind<finish;cl_ind++) {
261 				if (my_id == proc[cl_ind]) {
262 					/* first, multiply the clique */
263 					/* only do the strictly upper triangular part */
264 					/* we ASSUME the diagonal is all 1's */
265 					size = clique2inode->d_mats[cl_ind].size;
266 #ifdef MY_BLAS_DTRMV_ON
267 					MY_DTRMV_N_U(size,d_mats[cl_ind].matrix,size,
268 						&(x[d_mats[cl_ind].local_ind]),work);
269 #else
270 					DTRMV(&UP,&NTR,&ND,&size,d_mats[cl_ind].matrix,&size,
271 						&(x[d_mats[cl_ind].local_ind]),&ione);
272 #endif
273 				}
274 			}
275 		}
276 
277 	}
278 	MY_FREE(work);
279 	/* wait for all of the sent messages to finish */
280 	BMfinish_comp_msg(to_msg,procinfo); CHKERR(0);
281 	MLOG_flop((2*A->local_nnz));
282 }
283