1#include "BSprivate.h"
2
3/*@ BSback_solve - Backward triangular matrix solution on a
4                   single vector
5
6    Input Parameters:
7.   A - The sparse matrix
8.   x - The rhs
9.   comm - The communication structure for A
10.   procinfo - the usual processor information
11
12    Output Parameters:
13.   x - on exit contains the solution vector
14
15    Returns:
16    void
17
18 @*/
19void BSback_solve(BSpar_mat *A, FLOAT *x, BScomm *comm, BSprocinfo *procinfo)
20{
21	BMcomp_msg *from_msg, *to_msg;
22	BMphase *to_phase, *from_phase;
23	BMmsg *msg;
24	int	i, j, k;
25	int	cl_ind, in_ind;
26	int	count, size, ind, num_cols;
27	int *row;
28	FLOAT *nz;
29	BScl_2_inode *clique2inode = A->clique2inode;
30	BSnumbering *color2clique = A->color2clique;
31	BSinode *inodes = A->inodes->list;
32	int	*in_index = clique2inode->inode_index;
33	int	*proc = clique2inode->proc;
34	BSdense	*d_mats = clique2inode->d_mats;
35	int	*data_ptr, msg_len;
36	FLOAT *msg_buf, *matrix;
37	int	my_id = procinfo->my_id;
38	FLOAT *work;
39	char UP = 'U';
40	char TR = 'T';
41	char NTR = 'N';
42	char ND = 'N';
43	int	*col2cl = color2clique->numbers;
44	int	length = color2clique->length;
45	int	start, finish, symmetric;
46	int	ione = 1;
47	FLOAT one = 1.0;
48	FLOAT zero = 0.0;
49	FLOAT minus_one = -1.0;
50	FLOAT DDOT();
51	int *gnum = A->global_row_num->numbers;
52	int *iperm = A->inv_perm->perm;
53	FLOAT time0, gmev_time, trmv_time, other_time;
54
55	gmev_time = 0.0;
56	trmv_time = 0.0;
57	other_time = 0.0;
58	time0 = MPI_Wtime();
59
60	/* Is the symmetric data structure used? */
61	symmetric = A->icc_storage;
62
63	if(symmetric) {
64		from_msg = comm->to_msg; /* we do mean to switch these */
65		to_msg = comm->from_msg;
66	} else {
67		from_msg = comm->from_msg; /* do not switch for ILU case */
68		to_msg = comm->to_msg;
69	}
70
71	/* get some work space */
72	MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows,1);
73
74	/* post for all messages */
75	BMinit_comp_msg(from_msg,procinfo); CHKERR(0);
76
77	/* now do this phase by phase */
78	for (i=length-2;i>=0;i--) {
79		start = col2cl[i];
80		finish = col2cl[i+1];
81
82		if(!symmetric) {
83			/* invert the diagonals and find the answers */
84			for (cl_ind=start;cl_ind<finish;cl_ind++) {
85				if (my_id == proc[cl_ind]) {
86					size = clique2inode->d_mats[cl_ind].size;
87					ind = clique2inode->d_mats[cl_ind].local_ind;
88					matrix = clique2inode->d_mats[cl_ind].matrix;
89					/* can't do much better (likely) on this DGEMV */
90					DGEMV(&NTR,&size,&size,&one,matrix,&size,&(x[ind]),&ione,&zero,
91						work,&ione);
92					for (k=0; k<size; k++) x[ind+k] = work[k];
93				}
94			}
95		}
96
97		/* first send my messages */
98		/* this will involve computing partial sums */
99		to_phase = BMget_phase(to_msg,i); CHKERR(0);
100		msg = NULL;
101		while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
102			CHKERR(0);
103			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
104			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
105			if(symmetric) {
106				count = 0;
107				for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
108					for (in_ind=in_index[cl_ind];
109						in_ind<in_index[cl_ind+1];in_ind++) {
110						row = inodes[in_ind].row_num;
111						nz = inodes[in_ind].nz;
112						size = inodes[in_ind].length;
113						num_cols = inodes[in_ind].num_cols;
114						if (size > 0) {
115other_time += MPI_Wtime() - time0;
116time0 = MPI_Wtime();
117#ifdef MY_BLAS_DGEMV_ON
118							if (num_cols > DGEMV_UNROLL_LVL) {
119								for (k=0;k<size;k++) work[k] = x[row[k]];
120								DGEMV(&TR,&size,&num_cols,&one,nz,&size,
121									work,&ione,&zero,&(msg_buf[count]),&ione);
122							} else {
123								MY_DGEMV_Y_1101(size,num_cols,nz,size,x,row,
124									&(msg_buf[count]));
125							}
126#else
127							for (k=0;k<size;k++) work[k] = x[row[k]];
128							DGEMV(&TR,&size,&num_cols,&one,nz,&size,
129								work,&ione,&zero,&(msg_buf[count]),&ione);
130#endif
131gmev_time += MPI_Wtime() - time0;
132time0 = MPI_Wtime();
133						}
134						count += num_cols;
135					}
136				}
137			} else {
138				for (j=0; j<msg_len; j++)
139					msg_buf[j] = x[data_ptr[j]];
140			}
141			BMsendf_msg(msg,procinfo); CHKERR(0);
142		}
143		CHKERR(0);
144
145		/* do some local work, multiply by the i-nodes */
146		for (cl_ind=start;cl_ind<finish;cl_ind++) {
147			if (my_id == proc[cl_ind]) {
148				ind = d_mats[cl_ind].local_ind;
149				for (in_ind=in_index[cl_ind];
150					in_ind<in_index[cl_ind+1];in_ind++) {
151					size = inodes[in_ind].length;
152					num_cols = inodes[in_ind].num_cols;
153					row = inodes[in_ind].row_num;
154					nz = inodes[in_ind].nz;
155					if(symmetric) {
156						if (size > 0) {
157other_time += MPI_Wtime() - time0;
158time0 = MPI_Wtime();
159#ifdef MY_BLAS_DGEMV_ON
160							if (num_cols > DGEMV_UNROLL_LVL) {
161								for (k=0;k<size;k++) work[k] = x[row[k]];
162								DGEMV(&TR,&size,&num_cols,&minus_one,nz,&size,
163									work,&ione,&one,&(x[ind]),&ione);
164							} else {
165								MY_DGEMVM1_Y_1111(size,num_cols,nz,size,x,row,
166									&(x[ind]));
167							}
168#else
169							for (k=0;k<size;k++) work[k] = x[row[k]];
170							DGEMV(&TR,&size,&num_cols,&minus_one,nz,&size,
171								work,&ione,&one,&(x[ind]),&ione);
172#endif
173gmev_time += MPI_Wtime() - time0;
174time0 = MPI_Wtime();
175						}
176					} else {
177						/* The following part is added to make sure the nz are */
178						/* above pivot. (ILU)                                  */
179						length = size;
180						for (j=length-1; j>=0; j--) {
181						if (gnum[iperm[row[j]]] > inodes[in_ind].gcol_num)
182							size--;
183						else
184							break;
185						}
186						if (size > 0) {
187other_time += MPI_Wtime() - time0;
188time0 = MPI_Wtime();
189#ifdef MY_BLAS_DGEMV_ON
190							if (num_cols > DGEMV_UNROLL_LVL) {
191								DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
192									&ione,&zero,work,&ione);
193								for (k=0;k<size;k++) x[row[k]] -= work[k];
194							} else {
195								MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]),
196									x,row);
197							}
198#else
199							DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
200								&ione,&zero,work,&ione);
201							for (k=0;k<size;k++) x[row[k]] -= work[k];
202#endif
203gmev_time += MPI_Wtime() - time0;
204time0 = MPI_Wtime();
205						}
206					}
207					ind += num_cols;
208				}
209			}
210		}
211
212		/* receive my messages and update my rhs */
213		from_phase = BMget_phase(from_msg,i); CHKERR(0);
214		while ((msg = BMrecv_msg(from_phase)) != NULL) {
215			CHKERR(0);
216			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
217			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
218			if(symmetric) {
219				msg_len = BMget_msg_size(msg); CHKERR(0);
220				for (j=0;j<msg_len;j++) x[data_ptr[j]] -= msg_buf[j];
221			} else {
222				count = 0;
223				for (cl_ind=data_ptr[0]; cl_ind<=data_ptr[1]; cl_ind++) {
224					for (in_ind=clique2inode->inode_index[cl_ind];
225					in_ind<clique2inode->inode_index[cl_ind+1]; in_ind++) {
226						row = inodes[in_ind].row_num;
227						nz = inodes[in_ind].nz;
228						size = inodes[in_ind].length;
229						length = inodes[in_ind].length;
230						num_cols = inodes[in_ind].num_cols;
231
232						for (j=length-1; j>=0; j--) {
233							if (gnum[iperm[row[j]]] > inodes[in_ind].gcol_num)
234								size--;
235							else
236								break;
237						}
238						if (size > 0) {
239other_time += MPI_Wtime() - time0;
240time0 = MPI_Wtime();
241#ifdef MY_BLAS_DGEMV_ON
242							if (num_cols > DGEMV_UNROLL_LVL) {
243								DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(msg_buf[count]),
244									&ione,&zero,work,&ione);
245								for (k=0;k<size;k++) x[row[k]] -= work[k];
246							} else {
247								MY_DGEMVM1_N_1111(size,num_cols,nz,length,
248									&(msg_buf[count]),x,row);
249							}
250#else
251							DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(msg_buf[count]),
252								&ione,&zero,work,&ione);
253							for (k=0;k<size;k++) x[row[k]] -= work[k];
254#endif
255gmev_time += MPI_Wtime() - time0;
256time0 = MPI_Wtime();
257						}
258						count += num_cols;
259					}
260				}
261			}
262			BMfree_msg(msg); CHKERR(0);
263		}
264		CHKERR(0);
265
266		if(symmetric) {
267			/* invert the diagonals and find the answers */
268			for (cl_ind=start;cl_ind<finish;cl_ind++) {
269				if (my_id == proc[cl_ind]) {
270					/* first, multiply the clique */
271					/* only do the strictly upper triangular part */
272					/* we ASSUME the diagonal is all 1's */
273					size = clique2inode->d_mats[cl_ind].size;
274other_time += MPI_Wtime() - time0;
275time0 = MPI_Wtime();
276#ifdef MY_BLAS_DTRMV_ON
277					MY_DTRMV_N_U(size,d_mats[cl_ind].matrix,size,
278						&(x[d_mats[cl_ind].local_ind]),work);
279#else
280					DTRMV(&UP,&NTR,&ND,&size,d_mats[cl_ind].matrix,&size,
281						&(x[d_mats[cl_ind].local_ind]),&ione);
282#endif
283trmv_time += MPI_Wtime() - time0;
284time0 = MPI_Wtime();
285				}
286			}
287		}
288
289	}
290	MY_FREE(work);
291	/* wait for all of the sent messages to finish */
292	BMfinish_comp_msg(to_msg,procinfo); CHKERR(0);
293	MLOG_flop((2*A->local_nnz));
294other_time += MPI_Wtime() - time0;
295time0 = MPI_Wtime();
296printf("BSb_solve: other = %e, gmev = %e, trmv = %e\n",
297	other_time,gmev_time,trmv_time);
298}
299