1 #include "BSprivate.h"
2 
3 /*@ BSfor_solve - Forward triangular matrix solution on a
4                   single vector
5 
6     Input Parameters:
7 .   A - The sparse matrix
8 .   x - The rhs
9 .   comm - The communication structure for A
10 .   procinfo - the usual processor information
11 
12     Output Parameters:
13 .   x - on exit contains the solution vector
14 
15     Returns:
16     void
17 
18  @*/
BSfor_solve(BSpar_mat * A,FLOAT * x,BScomm * comm,BSprocinfo * procinfo)19 void BSfor_solve(BSpar_mat *A, FLOAT *x, BScomm *comm, BSprocinfo *procinfo)
20 {
21 	BMphase *to_phase, *from_phase;
22 	BMmsg *msg;
23 	int	i, j, k;
24 	int	cl_ind, in_ind, symmetric;
25 	int	count, size, length, ind, num_cols;
26 	int *row;
27 	FLOAT *nz;
28 	BScl_2_inode *clique2inode;
29 	BSnumbering *color2clique;
30 	BSinode *inodes;
31 	int	*data_ptr, msg_len;
32 	FLOAT *msg_buf, *matrix;
33 	FLOAT *work;
34 	char UP = 'U';
35 	char TR = 'T';
36 	char NTR = 'N';
37 	char ND = 'N';
38 	int	ione = 1;
39 	FLOAT one = 1.0;
40 	FLOAT zero = 0.0;
41 	int *gnum, *iperm;
42 
43 	/* Is the symmetric data structure used? */
44 	symmetric = A->icc_storage;
45 
46 	color2clique = A->color2clique;
47 	clique2inode = A->clique2inode;
48 	inodes = A->inodes->list;
49 	gnum = A->global_row_num->numbers;
50 	iperm = A->inv_perm->perm;
51 
52 	/* get some work space */
53 	MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows,1);
54 
55 	/* post for all messages */
56 	BMinit_comp_msg(comm->from_msg,procinfo); CHKERR(0);
57 
58 	/* now do this phase by phase */
59 	for (i=0;i<color2clique->length-1;i++) {
60 		if(symmetric) {
61 			/* find my portion of the solution using the cliques on the diagonal */
62 			for (cl_ind=color2clique->numbers[i];
63 				cl_ind<color2clique->numbers[i+1];cl_ind++) {
64 				if (procinfo->my_id == clique2inode->proc[cl_ind]) {
65 					/* first, multiply the clique */
66 					/* the clique is stored, inverted, in the upper triangle */
67 					size = clique2inode->d_mats[cl_ind].size;
68 					ind = clique2inode->d_mats[cl_ind].local_ind;
69 					matrix = clique2inode->d_mats[cl_ind].matrix;
70 #ifdef MY_BLAS_DTRMV_ON
71 					MY_DTRMV_T_U(size,matrix,size,&(x[ind]));
72 #else
73 					DTRMV(&UP,&TR,&ND,&size,matrix,&size,&(x[ind]),&ione);
74 #endif
75 				}
76 			}
77 		}
78 
79 		/* now send my messages */
80 		to_phase = BMget_phase(comm->to_msg,i); CHKERR(0);
81 		msg = NULL;
82 		while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
83 			CHKERR(0);
84 			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
85 			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
86 			for (j=0;j<msg_len;j++) {
87 				msg_buf[j] = x[data_ptr[j]];
88 			}
89 			BMsendf_msg(msg,procinfo); CHKERR(0);
90 		}
91 		CHKERR(0);
92 
93 		/* do some local work */
94 		for (cl_ind=color2clique->numbers[i];
95 			cl_ind<color2clique->numbers[i+1];cl_ind++) {
96 			if (procinfo->my_id == clique2inode->proc[cl_ind]) {
97 				ind = clique2inode->d_mats[cl_ind].local_ind;
98 				/* multiply the inodes */
99 				for (in_ind=clique2inode->inode_index[cl_ind];
100 					in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
101 					row = inodes[in_ind].row_num;
102 					nz = inodes[in_ind].nz;
103 					size = inodes[in_ind].length;
104 					num_cols = inodes[in_ind].num_cols;
105 					if(symmetric) {
106 						if (size > 0) {
107 #ifdef MY_BLAS_DGEMV_ON
108 							if (num_cols > DGEMV_UNROLL_LVL) {
109 								DGEMV(&NTR,&size,&num_cols,&one,nz,&size,&(x[ind]),
110 									&ione,&zero,work,&ione);
111 								for (k=0;k<size;k++) x[row[k]] -= work[k];
112 							} else {
113 								MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]),x,
114 									row);
115 							}
116 #else
117 							DGEMV(&NTR,&size,&num_cols,&one,nz,&size,&(x[ind]),
118 								&ione,&zero,work,&ione);
119 							for (k=0;k<size;k++) x[row[k]] -= work[k];
120 #endif
121 						}
122 				 	} else {
123 						length = inodes[in_ind].length;
124 						/* The following part is added to make sure the */
125 						/* nz are below pivot. (ILU)  */
126 						/*
127 						for (j=0; j<length; j++) {
128 							if (gnum[iperm[row[j]]] < inodes[in_ind].gcol_num) {
129 								nz++; size--;
130 							} else {
131 								break;
132 							}
133 						}
134 						if(size!=length-inodes[in_ind].below_diag) {
135 							printf("FS, L: size = %d, size2 = %d\n",size,
136 								length-inodes[in_ind].below_diag);
137 						}
138 						*/
139 						size -= inodes[in_ind].below_diag;
140 						nz += inodes[in_ind].below_diag;
141 						if (size > 0) {
142 #ifdef MY_BLAS_DGEMV_ON
143 							if (num_cols > DGEMV_UNROLL_LVL) {
144 								DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
145 									&ione,&zero,work,&ione);
146 								for (k=0;k<size;k++) x[row[k+j]] -= work[k];
147 							} else {
148 								MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]),
149 									x,row);
150 							}
151 #else
152 							DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
153 								&ione,&zero,work,&ione);
154 							for (k=0;k<size;k++) x[row[k+j]] -= work[k];
155 #endif
156 						}
157 					}
158 					ind += num_cols;
159 				}
160 			}
161 		}
162 
163 		/* receive my messages and do non-local work */
164 		from_phase = BMget_phase(comm->from_msg,i); CHKERR(0);
165 		while ((msg = BMrecv_msg(from_phase)) != NULL) {
166 			CHKERR(0);
167 			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
168 			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
169 			count = 0;
170 			for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
171 				for (in_ind=clique2inode->inode_index[cl_ind];
172 					in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
173 					row = inodes[in_ind].row_num;
174 					nz = inodes[in_ind].nz;
175 					size = inodes[in_ind].length;
176 					num_cols = inodes[in_ind].num_cols;
177 					if(symmetric) {
178 						if (size > 0) {
179 #ifdef MY_BLAS_DGEMV_ON
180 							if (num_cols > DGEMV_UNROLL_LVL) {
181 								DGEMV(&NTR,&size,&num_cols,&one,nz,&size,
182 									&(msg_buf[count]),&ione,&zero,work,&ione);
183 								for (k=0;k<size;k++) x[row[k]] -= work[k];
184 							} else {
185 								MY_DGEMVM1_N_1111(size,num_cols,nz,size,
186 									&(msg_buf[count]),x,row);
187 							}
188 #else
189 							DGEMV(&NTR,&size,&num_cols,&one,nz,&size,
190 								&(msg_buf[count]),&ione,&zero,work,&ione);
191 							for (k=0;k<size;k++) x[row[k]] -= work[k];
192 #endif
193 						}
194 					} else {
195 						length = inodes[in_ind].length;
196 						/* The following part is added to make sure the */
197 						/* nz are below pivot. (ILU) */
198 						/*
199 						for (j=0; j<length; j++) {
200 							if (gnum[iperm[row[j]]] < inodes[in_ind].gcol_num) {
201 								nz++; size--;
202 							} else {
203 								break;
204 							}
205 						}
206 						if(size!=length-inodes[in_ind].below_diag) {
207 							printf("FS, NL: size = %d, size2 = %d\n",size,
208 								length-inodes[in_ind].below_diag);
209 						}
210 						*/
211 						size -= inodes[in_ind].below_diag;
212 						nz += inodes[in_ind].below_diag;
213 						if (size > 0) {
214 #ifdef MY_BLAS_DGEMV_ON
215 							if (num_cols > DGEMV_UNROLL_LVL) {
216 								DGEMV(&NTR,&size,&num_cols,&one,nz,&length,
217 									&(msg_buf[count]),
218 									&ione,&zero,work,&ione);
219 								for (k=0;k<size;k++) x[row[k+j]] -= work[k];
220 							} else {
221 								MY_DGEMVM1_N_1111(size,num_cols,nz,size,
222 									&(msg_buf[count]),x,row);
223 							}
224 #else
225 							DGEMV(&NTR,&size,&num_cols,&one,nz,&length,
226 								&(msg_buf[count]),&ione,&zero,work,&ione);
227 							for (k=0;k<size;k++) x[row[k+j]] -= work[k];
228 #endif
229 						}
230 					}
231 					count += num_cols;
232 				}
233 			}
234 			BMfree_msg(msg); CHKERR(0);
235 		}
236 		CHKERR(0);
237 	}
238 	MY_FREE(work);
239 	/* wait for all of the sent messages to finish */
240 	BMfinish_comp_msg(comm->to_msg,procinfo); CHKERR(0);
241 	MLOG_flop((2*A->local_nnz));
242 }
243