1 #include "BSprivate.h"
2 
3 /*@ BSb_forward - Forward triangular matrix multiplication on a
4                   block of vectors
5 
6     Input Parameters:
7 .   A - The sparse matrix
8 .   x - The contiguous block of input vectors
9 .   comm - The communication structure for A
10 .   block_size - the number of input vectors
11 .   procinfo - the usual processor information
12 
13     Output Parameters:
14 .   b - on exit these vectors contain A*x
15 
16     Returns:
17     void
18 
19  @*/
BSb_forward(BSpar_mat * A,FLOAT * x,FLOAT * b,BScomm * comm,int block_size,BSprocinfo * procinfo)20 void BSb_forward(BSpar_mat *A, FLOAT *x, FLOAT *b, BScomm *comm,
21 		int block_size, BSprocinfo *procinfo)
22 {
23 	BMphase *to_phase, *from_phase;
24 	BMmsg *msg;
25 	int	i, j, k, n;
26 	int	cl_ind, in_ind;
27 	int	count, size, ind, num_cols;
28 	int *row;
29 	FLOAT *nz;
30 	BScl_2_inode *clique2inode;
31 	BSnumbering *color2clique;
32 	BSinode *inodes;
33 	int	*data_ptr, msg_len;
34 	FLOAT *msg_buf, *matrix;
35 	FLOAT *work;
36 	FLOAT *bptr, *xptr, *wptr;
37 	FLOAT **boff, **xoff;
38 	char UP = 'L';
39 	char TR = 'N';
40 	char ND = 'N';
41 	char SIDE = 'L';
42 	int	ione = 1;
43 	FLOAT one = 1.0;
44 	FLOAT zero = 0.0;
45 
46 	if((!A->icc_storage)||(procinfo->single)) {
47 		/* No ILU version or single version so call BSforward BS times */
48 		n = A->num_rows;
49 		for (i=0;i<block_size;i++) {
50 			if(procinfo->single) {
51 				BSforward1(A,&(x[n*i]),&(b[n*i]),comm,procinfo); CHKERR(0);
52 			} else {
53 				BSforward(A,&(x[n*i]),&(b[n*i]),comm,procinfo); CHKERR(0);
54 			}
55 		}
56 		return;
57 	}
58 
59 	color2clique = A->color2clique;
60 	clique2inode = A->clique2inode;
61 	inodes = A->inodes->list;
62 
63 	/* get some work space */
64 	MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows*block_size,1);
65 
66 	/* calculate b and x offsets */
67 	MY_MALLOC(boff,(FLOAT **),sizeof(FLOAT *)*block_size,1);
68 	MY_MALLOC(xoff,(FLOAT **),sizeof(FLOAT *)*block_size,1);
69 	for (i=0;i<block_size;i++) {
70 		boff[i] = &(b[i*A->num_rows]);
71 		xoff[i] = &(x[i*A->num_rows]);
72 	}
73 
74 	/* post for all messages */
75 	BMinit_comp_msg(comm->from_msg,procinfo); CHKERR(0);
76 
77 	if (A->save_diag == NULL) {
78 		/* because we know the diagonal is ones, initialize b to x */
79 		for (i=0;i<block_size;i++) {
80 			bptr = boff[i];
81 			xptr = xoff[i];
82 			for (j=0;j<A->num_rows;j++) bptr[j] = xptr[j];
83 		}
84 	} else {
85 		for (i=0;i<block_size;i++) {
86 			bptr = boff[i];
87 			xptr = xoff[i];
88 			for (j=0;j<A->num_rows;j++) bptr[j] = A->save_diag[j]*xptr[j];
89 		}
90 	}
91 
92 	/* now do this phase by phase */
93 	for (i=0;i<color2clique->length-1;i++) {
94 		/* first send my messages */
95 		to_phase = BMget_phase(comm->to_msg,i); CHKERR(0);
96 		msg = NULL;
97 		while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
98 			CHKERR(0);
99 			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
100 			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
101 			count = 0;
102 			for (j=0;j<block_size;j++) {
103 				wptr = &(msg_buf[j*msg_len]);
104 				xptr = xoff[j];
105 				for (k=0;k<msg_len;k++) {
106 					wptr[k] = xptr[data_ptr[k]];
107 				}
108 			}
109 			BMsendf_msg(msg,procinfo); CHKERR(0);
110 		}
111 		CHKERR(0);
112 	}
113 
114 	/* do some local work */
115 	for (i=0;i<color2clique->length-1;i++) {
116 		for (cl_ind=color2clique->numbers[i];
117 			cl_ind<color2clique->numbers[i+1];cl_ind++) {
118 			if (procinfo->my_id == clique2inode->proc[cl_ind]) {
119 				/* first, multiply the clique */
120 				/* only do the strictly lower triangular part */
121 				/* we ASSUME the diagonal is all 1's */
122 				size = clique2inode->d_mats[cl_ind].size;
123 				ind = clique2inode->d_mats[cl_ind].local_ind;
124 				matrix = clique2inode->d_mats[cl_ind].matrix;
125 				j = size-1;
126 				matrix++;
127 				if (size > 1) {
128 					nz = work;
129 					for (k=0;k<block_size;k++) {
130 						DCOPY(&j,&(xoff[k][ind]),&ione,nz,&ione);
131 						nz += j;
132 					}
133 					DTRMM(&SIDE,&UP,&TR,&ND,&j,&block_size,&one,matrix,&size,
134 						work,&j);
135 					nz = work;
136 					for (k=0;k<block_size;k++) {
137 						DAXPY(&j,&one,nz,&ione,&(boff[k][ind+1]),&ione);
138 						nz += j;
139 					}
140 				}
141 
142 				/* now, multiply the inodes */
143 				for (in_ind=clique2inode->inode_index[cl_ind];
144 					in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
145 					row = inodes[in_ind].row_num;
146 					nz = inodes[in_ind].nz;
147 					size = inodes[in_ind].length;
148 					num_cols = inodes[in_ind].num_cols;
149 					if (size > 0) {
150 						DGEMM(&TR,&TR,&size,&block_size,&num_cols,&one,
151 							nz,&size,&(x[ind]),&(A->num_rows),&zero,work,&size);
152 						for (j=0;j<block_size;j++) {
153 							bptr = boff[j];
154 							wptr = &(work[j*size]);
155 							for (k=0;k<size;k++) {
156 								bptr[row[k]] += wptr[k];
157 							}
158 						}
159 					}
160 					ind += num_cols;
161 				}
162 			}
163 		}
164 	}
165 
166 	/* receive my messages and do non-local work */
167 	for (i=0;i<color2clique->length-1;i++) {
168 		from_phase = BMget_phase(comm->from_msg,i); CHKERR(0);
169 		while ((msg = BMrecv_msg(from_phase)) != NULL) {
170 			CHKERR(0);
171 			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
172 			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
173 			msg_len = BMget_msg_size(msg); CHKERR(0);
174 			msg_len /= block_size;
175 			count = 0;
176 			for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
177 				for (in_ind=clique2inode->inode_index[cl_ind];
178 					in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
179 					row = inodes[in_ind].row_num;
180 					nz = inodes[in_ind].nz;
181 					size = inodes[in_ind].length;
182 					num_cols = inodes[in_ind].num_cols;
183 					if (size > 0) {
184 						DGEMM(&TR,&TR,&size,&block_size,&num_cols,&one,
185 							nz,&size,&(msg_buf[count]),&msg_len,
186 							&zero,work,&size);
187 						for (j=0;j<block_size;j++) {
188 							bptr = boff[j];
189 							wptr = &(work[j*size]);
190 							for (k=0;k<size;k++) {
191 								bptr[row[k]] += wptr[k];
192 							}
193 						}
194 					}
195 					count += num_cols;
196 				}
197 			}
198 			BMfree_msg(msg); CHKERR(0);
199 		}
200 		CHKERR(0);
201 	}
202 
203 	MY_FREE(xoff);
204 	MY_FREE(boff);
205 	MY_FREE(work);
206 
207 	/* wait for all of the sent messages to finish */
208 	BMfinish_comp_msg(comm->to_msg,procinfo); CHKERR(0);
209 	MLOG_flop((2*block_size*A->local_nnz));
210 }
211