1 #include "BSprivate.h"
2 
3 /*@ BSforward - Forward triangular matrix multiplication on a single vector
4 
5     Input Parameters:
6 .   A - The sparse matrix
7 .   x - The rhs
8 .   comm - The communication structure for A
9 .   procinfo - the usual processor information
10 
11     Output Parameters:
12 .   b - on exit contains A*x
13 
14     Returns:
15     void
16 
17  @*/
BSforward(BSpar_mat * A,FLOAT * x,FLOAT * b,BScomm * comm,BSprocinfo * procinfo)18 void BSforward(BSpar_mat *A, FLOAT *x, FLOAT *b,
19 		BScomm *comm, BSprocinfo *procinfo)
20 {
21 	BMphase *to_phase, *from_phase;
22 	BMmsg *msg;
23 	int	i, j, k;
24 	int	cl_ind, in_ind, symmetric;
25 	int	count, size, ind, num_cols;
26 	int *row;
27 	FLOAT *nz;
28 	BScl_2_inode *clique2inode;
29 	BSnumbering *color2clique;
30 	BSinode *inodes;
31 	int	*data_ptr, msg_len;
32 	FLOAT *msg_buf, *matrix;
33 	FLOAT *work;
34 	char UP = 'L';
35 	char TR = 'N';
36 	char ND = 'N';
37 	int	ione = 1;
38 	FLOAT one = 1.0;
39 	FLOAT zero = 0.0;
40 
41 	/* Is the symmetric data structure used? */
42 	symmetric = A->icc_storage;
43 
44 	color2clique = A->color2clique;
45 	clique2inode = A->clique2inode;
46 	inodes = A->inodes->list;
47 
48 	/* get some work space */
49 	MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows,1);
50 
51 	/* post for all messages */
52 	BMinit_comp_msg(comm->from_msg,procinfo); CHKERR(0);
53 
54 	if(symmetric) {
55 		if (A->save_diag == NULL) {
56 			/* because we know the diagonal is ones, initialize b to x */
57 			for (i=0;i<A->num_rows;i++) b[i] = x[i];
58 		} else {
59 			for (i=0;i<A->num_rows;i++) b[i] = A->save_diag[i]*x[i];
60 		}
61 	} else {
62 		/* The following part is modified (ILU) for nonsymmetric version */
63 		for (i=0;i<color2clique->length-1;i++) {
64 			for (cl_ind=color2clique->numbers[i];
65 				cl_ind<color2clique->numbers[i+1];cl_ind++) {
66 				if (procinfo->my_id == clique2inode->proc[cl_ind]) {
67 					size = clique2inode->d_mats[cl_ind].size;
68 					ind = clique2inode->d_mats[cl_ind].local_ind;
69 					matrix = clique2inode->d_mats[cl_ind].matrix;
70 
71 					/* Much simplied in the case of ILU */
72 					if (size > 1) {
73 						DGEMV(&TR,&size,&size,&one,matrix,&size,
74 						&(x[ind]),&ione,&zero,&(b[ind]),&ione);
75 					} else if (size == 1) {
76 						if (A->scale_diag == NULL)
77 							b[ind] = A->diag[ind]*x[ind];
78 						else
79 							if (A->diag[ind] > 0.0)
80 								b[ind] = x[ind];
81 							else if (A->diag[ind] < 0.0)
82 								b[ind] = -x[ind];
83 							else
84 								b[ind] = 0.0;
85 					}
86 				}
87 			}
88 		}
89 	}
90 
91 	/* first send my messages */
92 	for (i=0;i<color2clique->length-1;i++) {
93 		to_phase = BMget_phase(comm->to_msg,i); CHKERR(0);
94 		msg = NULL;
95 		while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
96 			CHKERR(0);
97 			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
98 			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
99 			for (j=0;j<msg_len;j++) {
100 				msg_buf[j] = x[data_ptr[j]];
101 			}
102 			BMsendf_msg(msg,procinfo); CHKERR(0);
103 		}
104 		CHKERR(0);
105 	}
106 
107 	/* do some local work */
108 	for (i=0;i<color2clique->length-1;i++) {
109 		for (cl_ind=color2clique->numbers[i];
110 				cl_ind<color2clique->numbers[i+1];cl_ind++) {
111 			ind = clique2inode->d_mats[cl_ind].local_ind;
112 			if (procinfo->my_id == clique2inode->proc[cl_ind]) {
113 				if(symmetric) {
114 					/* first, multiply the clique */
115 					/* only do the strictly lower triangular part */
116 					/* we ASSUME the diagonal is all 1's */
117 					size = clique2inode->d_mats[cl_ind].size;
118 					matrix = clique2inode->d_mats[cl_ind].matrix;
119 					if (size > 1) {
120 						j = size-1;
121 						matrix++;
122 #ifdef MY_BLAS_DTRMV_ON
123 						MY_DTRMV_N_L(j,matrix,size,&(x[ind]),&(b[ind+1]));
124 #else
125 						DCOPY(&j,&(x[ind]),&ione,work,&ione);
126 						DTRMV(&UP,&TR,&ND,&j,matrix,&size,work,&ione);
127 						DAXPY(&j,&one,work,&ione,&(b[ind+1]),&ione);
128 #endif
129 					}
130 				}
131 
132 				/* now, multiply the inodes */
133 				for (in_ind=clique2inode->inode_index[cl_ind];
134 					in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
135 					row = inodes[in_ind].row_num;
136 					nz = inodes[in_ind].nz;
137 					size = inodes[in_ind].length;
138 					num_cols = inodes[in_ind].num_cols;
139 					if (size > 0) {
140 #ifdef MY_BLAS_DGEMV_ON
141 						if (num_cols > DGEMV_UNROLL_LVL) {
142 							DGEMV(&TR,&size,&num_cols,&one,
143 								nz,&size,&(x[ind]),&ione,&zero,work,&ione);
144 							for (k=0;k<size;k++) b[row[k]] += work[k];
145 						} else {
146 							MY_DGEMV_N_1111(size,num_cols,nz,size,&(x[ind]),b,row);
147 						}
148 #else
149 						DGEMV(&TR,&size,&num_cols,&one,
150 							nz,&size,&(x[ind]),&ione,&zero,work,&ione);
151 						for (k=0;k<size;k++) b[row[k]] += work[k];
152 #endif
153 					}
154 					ind += num_cols;
155 				}
156 			}
157 		}
158 	}
159 
160 	/* receive my messages and do non-local work */
161 	for (i=0;i<color2clique->length-1;i++) {
162 		from_phase = BMget_phase(comm->from_msg,i); CHKERR(0);
163 		while ((msg = BMrecv_msg(from_phase)) != NULL) {
164 			CHKERR(0);
165 			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
166 			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
167 			count = 0;
168 			for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
169 				for (in_ind=clique2inode->inode_index[cl_ind];
170 					in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
171 					row = inodes[in_ind].row_num;
172 					nz = inodes[in_ind].nz;
173 					size = inodes[in_ind].length;
174 					num_cols = inodes[in_ind].num_cols;
175 					if (size > 0) {
176 #ifdef MY_BLAS_DGEMV_ON
177 						if (num_cols > DGEMV_UNROLL_LVL) {
178 							DGEMV(&TR,&size,&num_cols,&one,
179 								nz,&size,&(msg_buf[count]),&ione,&zero,work,&ione);
180 							for (k=0;k<size;k++) b[row[k]] += work[k];
181 						} else {
182 							MY_DGEMV_N_1111(size,num_cols,nz,size,&(msg_buf[count]),
183 								b,row);
184 						}
185 #else
186 						DGEMV(&TR,&size,&num_cols,&one,
187 							nz,&size,&(msg_buf[count]),&ione,&zero,work,&ione);
188 						for (k=0;k<size;k++) b[row[k]] += work[k];
189 #endif
190 					}
191 					count += num_cols;
192 				}
193 			}
194 			BMfree_msg(msg); CHKERR(0);
195 		}
196 		CHKERR(0);
197 	}
198 	MY_FREE(work);
199 	/* wait for all of the sent messages to finish */
200 	BMfinish_comp_msg(comm->to_msg,procinfo); CHKERR(0);
201 	if(symmetric) {
202 		MLOG_flop((2*A->local_nnz));
203 	} else {
204 		MLOG_flop((4*A->local_nnz-2*A->num_rows));
205 	}
206 }
207