1 #include "BSprivate.h"
2
3 /*@ BSfor_solve - Forward triangular matrix solution on a
4 single vector
5
6 Input Parameters:
7 . A - The sparse matrix
8 . x - The rhs
9 . comm - The communication structure for A
10 . procinfo - the usual processor information
11
12 Output Parameters:
13 . x - on exit contains the solution vector
14
15 Returns:
16 void
17
18 @*/
BSfor_solve(BSpar_mat * A,FLOAT * x,BScomm * comm,BSprocinfo * procinfo)19 void BSfor_solve(BSpar_mat *A, FLOAT *x, BScomm *comm, BSprocinfo *procinfo)
20 {
21 BMphase *to_phase, *from_phase;
22 BMmsg *msg;
23 int i, j, k;
24 int cl_ind, in_ind, symmetric;
25 int count, size, length, ind, num_cols;
26 int *row;
27 FLOAT *nz;
28 BScl_2_inode *clique2inode;
29 BSnumbering *color2clique;
30 BSinode *inodes;
31 int *data_ptr, msg_len;
32 FLOAT *msg_buf, *matrix;
33 FLOAT *work;
34 char UP = 'U';
35 char TR = 'T';
36 char NTR = 'N';
37 char ND = 'N';
38 int ione = 1;
39 FLOAT one = 1.0;
40 FLOAT zero = 0.0;
41 int *gnum, *iperm;
42
43 /* Is the symmetric data structure used? */
44 symmetric = A->icc_storage;
45
46 color2clique = A->color2clique;
47 clique2inode = A->clique2inode;
48 inodes = A->inodes->list;
49 gnum = A->global_row_num->numbers;
50 iperm = A->inv_perm->perm;
51
52 /* get some work space */
53 MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows,1);
54
55 /* post for all messages */
56 BMinit_comp_msg(comm->from_msg,procinfo); CHKERR(0);
57
58 /* now do this phase by phase */
59 for (i=0;i<color2clique->length-1;i++) {
60 if(symmetric) {
61 /* find my portion of the solution using the cliques on the diagonal */
62 for (cl_ind=color2clique->numbers[i];
63 cl_ind<color2clique->numbers[i+1];cl_ind++) {
64 if (procinfo->my_id == clique2inode->proc[cl_ind]) {
65 /* first, multiply the clique */
66 /* the clique is stored, inverted, in the upper triangle */
67 size = clique2inode->d_mats[cl_ind].size;
68 ind = clique2inode->d_mats[cl_ind].local_ind;
69 matrix = clique2inode->d_mats[cl_ind].matrix;
70 #ifdef MY_BLAS_DTRMV_ON
71 MY_DTRMV_T_U(size,matrix,size,&(x[ind]));
72 #else
73 DTRMV(&UP,&TR,&ND,&size,matrix,&size,&(x[ind]),&ione);
74 #endif
75 }
76 }
77 }
78
79 /* now send my messages */
80 to_phase = BMget_phase(comm->to_msg,i); CHKERR(0);
81 msg = NULL;
82 while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
83 CHKERR(0);
84 msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
85 data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
86 for (j=0;j<msg_len;j++) {
87 msg_buf[j] = x[data_ptr[j]];
88 }
89 BMsendf_msg(msg,procinfo); CHKERR(0);
90 }
91 CHKERR(0);
92
93 /* do some local work */
94 for (cl_ind=color2clique->numbers[i];
95 cl_ind<color2clique->numbers[i+1];cl_ind++) {
96 if (procinfo->my_id == clique2inode->proc[cl_ind]) {
97 ind = clique2inode->d_mats[cl_ind].local_ind;
98 /* multiply the inodes */
99 for (in_ind=clique2inode->inode_index[cl_ind];
100 in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
101 row = inodes[in_ind].row_num;
102 nz = inodes[in_ind].nz;
103 size = inodes[in_ind].length;
104 num_cols = inodes[in_ind].num_cols;
105 if(symmetric) {
106 if (size > 0) {
107 #ifdef MY_BLAS_DGEMV_ON
108 if (num_cols > DGEMV_UNROLL_LVL) {
109 DGEMV(&NTR,&size,&num_cols,&one,nz,&size,&(x[ind]),
110 &ione,&zero,work,&ione);
111 for (k=0;k<size;k++) x[row[k]] -= work[k];
112 } else {
113 MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]),x,
114 row);
115 }
116 #else
117 DGEMV(&NTR,&size,&num_cols,&one,nz,&size,&(x[ind]),
118 &ione,&zero,work,&ione);
119 for (k=0;k<size;k++) x[row[k]] -= work[k];
120 #endif
121 }
122 } else {
123 length = inodes[in_ind].length;
124 /* The following part is added to make sure the */
125 /* nz are below pivot. (ILU) */
126 /*
127 for (j=0; j<length; j++) {
128 if (gnum[iperm[row[j]]] < inodes[in_ind].gcol_num) {
129 nz++; size--;
130 } else {
131 break;
132 }
133 }
134 if(size!=length-inodes[in_ind].below_diag) {
135 printf("FS, L: size = %d, size2 = %d\n",size,
136 length-inodes[in_ind].below_diag);
137 }
138 */
139 size -= inodes[in_ind].below_diag;
140 nz += inodes[in_ind].below_diag;
141 if (size > 0) {
142 #ifdef MY_BLAS_DGEMV_ON
143 if (num_cols > DGEMV_UNROLL_LVL) {
144 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
145 &ione,&zero,work,&ione);
146 for (k=0;k<size;k++) x[row[k+j]] -= work[k];
147 } else {
148 MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]),
149 x,row);
150 }
151 #else
152 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
153 &ione,&zero,work,&ione);
154 for (k=0;k<size;k++) x[row[k+j]] -= work[k];
155 #endif
156 }
157 }
158 ind += num_cols;
159 }
160 }
161 }
162
163 /* receive my messages and do non-local work */
164 from_phase = BMget_phase(comm->from_msg,i); CHKERR(0);
165 while ((msg = BMrecv_msg(from_phase)) != NULL) {
166 CHKERR(0);
167 msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
168 data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
169 count = 0;
170 for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
171 for (in_ind=clique2inode->inode_index[cl_ind];
172 in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
173 row = inodes[in_ind].row_num;
174 nz = inodes[in_ind].nz;
175 size = inodes[in_ind].length;
176 num_cols = inodes[in_ind].num_cols;
177 if(symmetric) {
178 if (size > 0) {
179 #ifdef MY_BLAS_DGEMV_ON
180 if (num_cols > DGEMV_UNROLL_LVL) {
181 DGEMV(&NTR,&size,&num_cols,&one,nz,&size,
182 &(msg_buf[count]),&ione,&zero,work,&ione);
183 for (k=0;k<size;k++) x[row[k]] -= work[k];
184 } else {
185 MY_DGEMVM1_N_1111(size,num_cols,nz,size,
186 &(msg_buf[count]),x,row);
187 }
188 #else
189 DGEMV(&NTR,&size,&num_cols,&one,nz,&size,
190 &(msg_buf[count]),&ione,&zero,work,&ione);
191 for (k=0;k<size;k++) x[row[k]] -= work[k];
192 #endif
193 }
194 } else {
195 length = inodes[in_ind].length;
196 /* The following part is added to make sure the */
197 /* nz are below pivot. (ILU) */
198 /*
199 for (j=0; j<length; j++) {
200 if (gnum[iperm[row[j]]] < inodes[in_ind].gcol_num) {
201 nz++; size--;
202 } else {
203 break;
204 }
205 }
206 if(size!=length-inodes[in_ind].below_diag) {
207 printf("FS, NL: size = %d, size2 = %d\n",size,
208 length-inodes[in_ind].below_diag);
209 }
210 */
211 size -= inodes[in_ind].below_diag;
212 nz += inodes[in_ind].below_diag;
213 if (size > 0) {
214 #ifdef MY_BLAS_DGEMV_ON
215 if (num_cols > DGEMV_UNROLL_LVL) {
216 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,
217 &(msg_buf[count]),
218 &ione,&zero,work,&ione);
219 for (k=0;k<size;k++) x[row[k+j]] -= work[k];
220 } else {
221 MY_DGEMVM1_N_1111(size,num_cols,nz,size,
222 &(msg_buf[count]),x,row);
223 }
224 #else
225 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,
226 &(msg_buf[count]),&ione,&zero,work,&ione);
227 for (k=0;k<size;k++) x[row[k+j]] -= work[k];
228 #endif
229 }
230 }
231 count += num_cols;
232 }
233 }
234 BMfree_msg(msg); CHKERR(0);
235 }
236 CHKERR(0);
237 }
238 MY_FREE(work);
239 /* wait for all of the sent messages to finish */
240 BMfinish_comp_msg(comm->to_msg,procinfo); CHKERR(0);
241 MLOG_flop((2*A->local_nnz));
242 }
243