1#include "BSprivate.h" 2 3/*@ BSback_solve - Backward triangular matrix solution on a 4 single vector 5 6 Input Parameters: 7. A - The sparse matrix 8. x - The rhs 9. comm - The communication structure for A 10. procinfo - the usual processor information 11 12 Output Parameters: 13. x - on exit contains the solution vector 14 15 Returns: 16 void 17 18 @*/ 19void BSback_solve(BSpar_mat *A, FLOAT *x, BScomm *comm, BSprocinfo *procinfo) 20{ 21 BMcomp_msg *from_msg, *to_msg; 22 BMphase *to_phase, *from_phase; 23 BMmsg *msg; 24 int i, j, k; 25 int cl_ind, in_ind; 26 int count, size, ind, num_cols; 27 int *row; 28 FLOAT *nz; 29 BScl_2_inode *clique2inode = A->clique2inode; 30 BSnumbering *color2clique = A->color2clique; 31 BSinode *inodes = A->inodes->list; 32 int *in_index = clique2inode->inode_index; 33 int *proc = clique2inode->proc; 34 BSdense *d_mats = clique2inode->d_mats; 35 int *data_ptr, msg_len; 36 FLOAT *msg_buf, *matrix; 37 int my_id = procinfo->my_id; 38 FLOAT *work; 39 char UP = 'U'; 40 char TR = 'T'; 41 char NTR = 'N'; 42 char ND = 'N'; 43 int *col2cl = color2clique->numbers; 44 int length = color2clique->length; 45 int start, finish, symmetric; 46 int ione = 1; 47 FLOAT one = 1.0; 48 FLOAT zero = 0.0; 49 FLOAT minus_one = -1.0; 50 FLOAT DDOT(); 51 int *gnum = A->global_row_num->numbers; 52 int *iperm = A->inv_perm->perm; 53 FLOAT time0, gmev_time, trmv_time, other_time; 54 55 gmev_time = 0.0; 56 trmv_time = 0.0; 57 other_time = 0.0; 58 time0 = MPI_Wtime(); 59 60 /* Is the symmetric data structure used? */ 61 symmetric = A->icc_storage; 62 63 if(symmetric) { 64 from_msg = comm->to_msg; /* we do mean to switch these */ 65 to_msg = comm->from_msg; 66 } else { 67 from_msg = comm->from_msg; /* do not switch for ILU case */ 68 to_msg = comm->to_msg; 69 } 70 71 /* get some work space */ 72 MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows,1); 73 74 /* post for all messages */ 75 BMinit_comp_msg(from_msg,procinfo); CHKERR(0); 76 77 /* now do this phase by phase */ 78 for (i=length-2;i>=0;i--) { 79 start = col2cl[i]; 80 finish = col2cl[i+1]; 81 82 if(!symmetric) { 83 /* invert the diagonals and find the answers */ 84 for (cl_ind=start;cl_ind<finish;cl_ind++) { 85 if (my_id == proc[cl_ind]) { 86 size = clique2inode->d_mats[cl_ind].size; 87 ind = clique2inode->d_mats[cl_ind].local_ind; 88 matrix = clique2inode->d_mats[cl_ind].matrix; 89 /* can't do much better (likely) on this DGEMV */ 90 DGEMV(&NTR,&size,&size,&one,matrix,&size,&(x[ind]),&ione,&zero, 91 work,&ione); 92 for (k=0; k<size; k++) x[ind+k] = work[k]; 93 } 94 } 95 } 96 97 /* first send my messages */ 98 /* this will involve computing partial sums */ 99 to_phase = BMget_phase(to_msg,i); CHKERR(0); 100 msg = NULL; 101 while ((msg = BMnext_msg(to_phase,msg)) != NULL) { 102 CHKERR(0); 103 msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0); 104 data_ptr = BMget_user(msg,&msg_len); CHKERR(0); 105 if(symmetric) { 106 count = 0; 107 for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) { 108 for (in_ind=in_index[cl_ind]; 109 in_ind<in_index[cl_ind+1];in_ind++) { 110 row = inodes[in_ind].row_num; 111 nz = inodes[in_ind].nz; 112 size = inodes[in_ind].length; 113 num_cols = inodes[in_ind].num_cols; 114 if (size > 0) { 115other_time += MPI_Wtime() - time0; 116time0 = MPI_Wtime(); 117#ifdef MY_BLAS_DGEMV_ON 118 if (num_cols > DGEMV_UNROLL_LVL) { 119 for (k=0;k<size;k++) work[k] = x[row[k]]; 120 DGEMV(&TR,&size,&num_cols,&one,nz,&size, 121 work,&ione,&zero,&(msg_buf[count]),&ione); 122 } else { 123 MY_DGEMV_Y_1101(size,num_cols,nz,size,x,row, 124 &(msg_buf[count])); 125 } 126#else 127 for (k=0;k<size;k++) work[k] = x[row[k]]; 128 DGEMV(&TR,&size,&num_cols,&one,nz,&size, 129 work,&ione,&zero,&(msg_buf[count]),&ione); 130#endif 131gmev_time += MPI_Wtime() - time0; 132time0 = MPI_Wtime(); 133 } 134 count += num_cols; 135 } 136 } 137 } else { 138 for (j=0; j<msg_len; j++) 139 msg_buf[j] = x[data_ptr[j]]; 140 } 141 BMsendf_msg(msg,procinfo); CHKERR(0); 142 } 143 CHKERR(0); 144 145 /* do some local work, multiply by the i-nodes */ 146 for (cl_ind=start;cl_ind<finish;cl_ind++) { 147 if (my_id == proc[cl_ind]) { 148 ind = d_mats[cl_ind].local_ind; 149 for (in_ind=in_index[cl_ind]; 150 in_ind<in_index[cl_ind+1];in_ind++) { 151 size = inodes[in_ind].length; 152 num_cols = inodes[in_ind].num_cols; 153 row = inodes[in_ind].row_num; 154 nz = inodes[in_ind].nz; 155 if(symmetric) { 156 if (size > 0) { 157other_time += MPI_Wtime() - time0; 158time0 = MPI_Wtime(); 159#ifdef MY_BLAS_DGEMV_ON 160 if (num_cols > DGEMV_UNROLL_LVL) { 161 for (k=0;k<size;k++) work[k] = x[row[k]]; 162 DGEMV(&TR,&size,&num_cols,&minus_one,nz,&size, 163 work,&ione,&one,&(x[ind]),&ione); 164 } else { 165 MY_DGEMVM1_Y_1111(size,num_cols,nz,size,x,row, 166 &(x[ind])); 167 } 168#else 169 for (k=0;k<size;k++) work[k] = x[row[k]]; 170 DGEMV(&TR,&size,&num_cols,&minus_one,nz,&size, 171 work,&ione,&one,&(x[ind]),&ione); 172#endif 173gmev_time += MPI_Wtime() - time0; 174time0 = MPI_Wtime(); 175 } 176 } else { 177 /* The following part is added to make sure the nz are */ 178 /* above pivot. (ILU) */ 179 length = size; 180 for (j=length-1; j>=0; j--) { 181 if (gnum[iperm[row[j]]] > inodes[in_ind].gcol_num) 182 size--; 183 else 184 break; 185 } 186 if (size > 0) { 187other_time += MPI_Wtime() - time0; 188time0 = MPI_Wtime(); 189#ifdef MY_BLAS_DGEMV_ON 190 if (num_cols > DGEMV_UNROLL_LVL) { 191 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]), 192 &ione,&zero,work,&ione); 193 for (k=0;k<size;k++) x[row[k]] -= work[k]; 194 } else { 195 MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]), 196 x,row); 197 } 198#else 199 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]), 200 &ione,&zero,work,&ione); 201 for (k=0;k<size;k++) x[row[k]] -= work[k]; 202#endif 203gmev_time += MPI_Wtime() - time0; 204time0 = MPI_Wtime(); 205 } 206 } 207 ind += num_cols; 208 } 209 } 210 } 211 212 /* receive my messages and update my rhs */ 213 from_phase = BMget_phase(from_msg,i); CHKERR(0); 214 while ((msg = BMrecv_msg(from_phase)) != NULL) { 215 CHKERR(0); 216 msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0); 217 data_ptr = BMget_user(msg,&msg_len); CHKERR(0); 218 if(symmetric) { 219 msg_len = BMget_msg_size(msg); CHKERR(0); 220 for (j=0;j<msg_len;j++) x[data_ptr[j]] -= msg_buf[j]; 221 } else { 222 count = 0; 223 for (cl_ind=data_ptr[0]; cl_ind<=data_ptr[1]; cl_ind++) { 224 for (in_ind=clique2inode->inode_index[cl_ind]; 225 in_ind<clique2inode->inode_index[cl_ind+1]; in_ind++) { 226 row = inodes[in_ind].row_num; 227 nz = inodes[in_ind].nz; 228 size = inodes[in_ind].length; 229 length = inodes[in_ind].length; 230 num_cols = inodes[in_ind].num_cols; 231 232 for (j=length-1; j>=0; j--) { 233 if (gnum[iperm[row[j]]] > inodes[in_ind].gcol_num) 234 size--; 235 else 236 break; 237 } 238 if (size > 0) { 239other_time += MPI_Wtime() - time0; 240time0 = MPI_Wtime(); 241#ifdef MY_BLAS_DGEMV_ON 242 if (num_cols > DGEMV_UNROLL_LVL) { 243 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(msg_buf[count]), 244 &ione,&zero,work,&ione); 245 for (k=0;k<size;k++) x[row[k]] -= work[k]; 246 } else { 247 MY_DGEMVM1_N_1111(size,num_cols,nz,length, 248 &(msg_buf[count]),x,row); 249 } 250#else 251 DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(msg_buf[count]), 252 &ione,&zero,work,&ione); 253 for (k=0;k<size;k++) x[row[k]] -= work[k]; 254#endif 255gmev_time += MPI_Wtime() - time0; 256time0 = MPI_Wtime(); 257 } 258 count += num_cols; 259 } 260 } 261 } 262 BMfree_msg(msg); CHKERR(0); 263 } 264 CHKERR(0); 265 266 if(symmetric) { 267 /* invert the diagonals and find the answers */ 268 for (cl_ind=start;cl_ind<finish;cl_ind++) { 269 if (my_id == proc[cl_ind]) { 270 /* first, multiply the clique */ 271 /* only do the strictly upper triangular part */ 272 /* we ASSUME the diagonal is all 1's */ 273 size = clique2inode->d_mats[cl_ind].size; 274other_time += MPI_Wtime() - time0; 275time0 = MPI_Wtime(); 276#ifdef MY_BLAS_DTRMV_ON 277 MY_DTRMV_N_U(size,d_mats[cl_ind].matrix,size, 278 &(x[d_mats[cl_ind].local_ind]),work); 279#else 280 DTRMV(&UP,&NTR,&ND,&size,d_mats[cl_ind].matrix,&size, 281 &(x[d_mats[cl_ind].local_ind]),&ione); 282#endif 283trmv_time += MPI_Wtime() - time0; 284time0 = MPI_Wtime(); 285 } 286 } 287 } 288 289 } 290 MY_FREE(work); 291 /* wait for all of the sent messages to finish */ 292 BMfinish_comp_msg(to_msg,procinfo); CHKERR(0); 293 MLOG_flop((2*A->local_nnz)); 294other_time += MPI_Wtime() - time0; 295time0 = MPI_Wtime(); 296printf("BSb_solve: other = %e, gmev = %e, trmv = %e\n", 297 other_time,gmev_time,trmv_time); 298} 299