1 /* ========================================================================== */
2 /* === UMF_utsolve ========================================================== */
3 /* ========================================================================== */
4 
5 /* -------------------------------------------------------------------------- */
6 /* UMFPACK Copyright (c) Timothy A. Davis, CISE,                              */
7 /* Univ. of Florida.  All Rights Reserved.  See ../Doc/License for License.   */
8 /* web: http://www.cise.ufl.edu/research/sparse/umfpack                       */
9 /* -------------------------------------------------------------------------- */
10 
11 /*  solves U'x = b or U.'x=b, where U is the upper triangular factor of a */
12 /*  matrix.  B is overwritten with the solution X. */
13 /*  Returns the floating point operation count */
14 
15 #include "umf_internal.h"
16 #include "umf_utsolve.h"
17 
18 GLOBAL double
19 #ifdef CONJUGATE_SOLVE
UMF_uhsolve(NumericType * Numeric,Entry X[],Int Pattern[])20 UMF_uhsolve			/* solve U'x=b  (complex conjugate transpose) */
21 #else
22 UMF_utsolve			/* solve U.'x=b (array transpose) */
23 #endif
24 (
25     NumericType *Numeric,
26     Entry X [ ],		/* b on input, solution x on output */
27     Int Pattern [ ]		/* a work array of size n */
28 )
29 {
30     /* ---------------------------------------------------------------------- */
31     /* local variables */
32     /* ---------------------------------------------------------------------- */
33 
34     Entry xk ;
35     Entry *xp, *D, *Uval ;
36     Int k, deg, j, *ip, col, *Upos, *Uilen, kstart, kend, up,
37 	*Uip, n, uhead, ulen, pos, npiv, n1, *Ui ;
38 
39     /* ---------------------------------------------------------------------- */
40     /* get parameters */
41     /* ---------------------------------------------------------------------- */
42 
43     if (Numeric->n_row != Numeric->n_col) return (0.) ;
44     n = Numeric->n_row ;
45     npiv = Numeric->npiv ;
46     Upos = Numeric->Upos ;
47     Uilen = Numeric->Uilen ;
48     Uip = Numeric->Uip ;
49     D = Numeric->D ;
50     kend = 0 ;
51     n1 = Numeric->n1 ;
52 
53 #ifndef NDEBUG
54     DEBUG4 (("Utsolve start: npiv "ID" n "ID"\n", npiv, n)) ;
55     for (j = 0 ; j < n ; j++)
56     {
57 	DEBUG4 (("Utsolve start "ID": ", j)) ;
58 	EDEBUG4 (X [j]) ;
59 	DEBUG4 (("\n")) ;
60     }
61 #endif
62 
63     /* ---------------------------------------------------------------------- */
64     /* singletons */
65     /* ---------------------------------------------------------------------- */
66 
67     for (k = 0 ; k < n1 ; k++)
68     {
69 	DEBUG4 (("Singleton k "ID"\n", k)) ;
70 
71 #ifndef NO_DIVIDE_BY_ZERO
72 	/* Go ahead and divide by zero if D [k] is zero. */
73 #ifdef CONJUGATE_SOLVE
74 	/* xk = X [k] / conjugate (D [k]) ; */
75 	DIV_CONJ (xk, X [k], D [k]) ;
76 #else
77 	/* xk = X [k] / D [k] ; */
78 	DIV (xk, X [k], D [k]) ;
79 #endif
80 #else
81 	/* Do not divide by zero */
82 	if (IS_NONZERO (D [k]))
83 	{
84 #ifdef CONJUGATE_SOLVE
85 	    /* xk = X [k] / conjugate (D [k]) ; */
86 	    DIV_CONJ (xk, X [k], D [k]) ;
87 #else
88 	    /* xk = X [k] / D [k] ; */
89 	    DIV (xk, X [k], D [k]) ;
90 #endif
91 	}
92 #endif
93 
94 	X [k] = xk ;
95 	deg = Uilen [k] ;
96 	if (deg > 0 && IS_NONZERO (xk))
97 	{
98 	    up = Uip [k] ;
99 	    Ui = (Int *) (Numeric->Memory + up) ;
100 	    up += UNITS (Int, deg) ;
101 	    Uval = (Entry *) (Numeric->Memory + up) ;
102 	    for (j = 0 ; j < deg ; j++)
103 	    {
104 		DEBUG4 (("  k "ID" col "ID" value", k, Ui [j])) ;
105 		EDEBUG4 (Uval [j]) ;
106 		DEBUG4 (("\n")) ;
107 #ifdef CONJUGATE_SOLVE
108 		/* X [Ui [j]] -= xk * conjugate (Uval [j]) ; */
109 		MULT_SUB_CONJ (X [Ui [j]], xk, Uval [j]) ;
110 #else
111 		/* X [Ui [j]] -= xk * Uval [j] ; */
112 		MULT_SUB (X [Ui [j]], xk, Uval [j]) ;
113 #endif
114 	    }
115 	}
116     }
117 
118     /* ---------------------------------------------------------------------- */
119     /* nonsingletons */
120     /* ---------------------------------------------------------------------- */
121 
122     for (kstart = n1 ; kstart < npiv ; kstart = kend + 1)
123     {
124 
125 	/* ------------------------------------------------------------------ */
126 	/* find the end of this Uchain */
127 	/* ------------------------------------------------------------------ */
128 
129 	DEBUG4 (("kstart "ID" kend "ID"\n", kstart, kend)) ;
130 	/* for (kend = kstart ; kend < npiv && Uip [kend+1] > 0 ; kend++) ; */
131 	kend = kstart ;
132 	while (kend < npiv && Uip [kend+1] > 0)
133 	{
134 	    kend++ ;
135 	}
136 
137 	/* ------------------------------------------------------------------ */
138 	/* scan the whole Uchain to find the pattern of the first row of U */
139 	/* ------------------------------------------------------------------ */
140 
141 	k = kend+1 ;
142 	DEBUG4 (("\nKend "ID" K "ID"\n", kend, k)) ;
143 
144 	/* ------------------------------------------------------------------ */
145 	/* start with last row in Uchain of U in Pattern [0..deg-1] */
146 	/* ------------------------------------------------------------------ */
147 
148 	if (k == npiv)
149 	{
150 	    deg = Numeric->ulen ;
151 	    if (deg > 0)
152 	    {
153 		/* :: make last pivot row of U (singular matrices only) :: */
154 		for (j = 0 ; j < deg ; j++)
155 		{
156 		    Pattern [j] = Numeric->Upattern [j] ;
157 		}
158 	    }
159 	}
160 	else
161 	{
162 	    ASSERT (k >= 0 && k < npiv) ;
163 	    up = -Uip [k] ;
164 	    ASSERT (up > 0) ;
165 	    deg = Uilen [k] ;
166 	    DEBUG4 (("end of chain for row of U "ID" deg "ID"\n", k-1, deg)) ;
167 	    ip = (Int *) (Numeric->Memory + up) ;
168 	    for (j = 0 ; j < deg ; j++)
169 	    {
170 		col = *ip++ ;
171 		DEBUG4 (("  k "ID" col "ID"\n", k-1, col)) ;
172 		ASSERT (k <= col) ;
173 		Pattern [j] = col ;
174 	    }
175 	}
176 
177 	/* empty the stack at the bottom of Pattern */
178 	uhead = n ;
179 
180 	for (k = kend ; k > kstart ; k--)
181 	{
182 	    /* Pattern [0..deg-1] is the pattern of row k of U */
183 
184 	    /* -------------------------------------------------------------- */
185 	    /* make row k-1 of U in Pattern [0..deg-1] */
186 	    /* -------------------------------------------------------------- */
187 
188 	    ASSERT (k >= 0 && k < npiv) ;
189 	    ulen = Uilen [k] ;
190 	    /* delete, and push on the stack */
191 	    for (j = 0 ; j < ulen ; j++)
192 	    {
193 		ASSERT (uhead >= deg) ;
194 		Pattern [--uhead] = Pattern [--deg] ;
195 	    }
196 	    DEBUG4 (("middle of chain for row of U "ID" deg "ID"\n", k, deg)) ;
197 	    ASSERT (deg >= 0) ;
198 
199 	    pos = Upos [k] ;
200 	    if (pos != EMPTY)
201 	    {
202 		/* add the pivot column */
203 		DEBUG4 (("k "ID" add pivot entry at position "ID"\n", k, pos)) ;
204 		ASSERT (pos >= 0 && pos <= deg) ;
205 		Pattern [deg++] = Pattern [pos] ;
206 		Pattern [pos] = k ;
207 	    }
208 	}
209 
210 	/* Pattern [0..deg-1] is now the pattern of the first row in Uchain */
211 
212 	/* ------------------------------------------------------------------ */
213 	/* solve using this Uchain, in reverse order */
214 	/* ------------------------------------------------------------------ */
215 
216 	DEBUG4 (("Unwinding Uchain\n")) ;
217 	for (k = kstart ; k <= kend ; k++)
218 	{
219 
220 	    /* -------------------------------------------------------------- */
221 	    /* construct row k */
222 	    /* -------------------------------------------------------------- */
223 
224 	    ASSERT (k >= 0 && k < npiv) ;
225 	    pos = Upos [k] ;
226 	    if (pos != EMPTY)
227 	    {
228 		/* remove the pivot column */
229 		DEBUG4 (("k "ID" add pivot entry at position "ID"\n", k, pos)) ;
230 		ASSERT (k > kstart) ;
231 		ASSERT (pos >= 0 && pos < deg) ;
232 		ASSERT (Pattern [pos] == k) ;
233 		Pattern [pos] = Pattern [--deg] ;
234 	    }
235 
236 	    up = Uip [k] ;
237 	    ulen = Uilen [k] ;
238 	    if (k > kstart)
239 	    {
240 		/* concatenate the deleted pattern; pop from the stack */
241 		for (j = 0 ; j < ulen ; j++)
242 		{
243 		    ASSERT (deg <= uhead && uhead < n) ;
244 		    Pattern [deg++] = Pattern [uhead++] ;
245 		}
246 		DEBUG4 (("middle of chain, row of U "ID" deg "ID"\n", k, deg)) ;
247 		ASSERT (deg >= 0) ;
248 	    }
249 
250 	    /* -------------------------------------------------------------- */
251 	    /* use row k of U */
252 	    /* -------------------------------------------------------------- */
253 
254 #ifndef NO_DIVIDE_BY_ZERO
255 	    /* Go ahead and divide by zero if D [k] is zero. */
256 #ifdef CONJUGATE_SOLVE
257 	    /* xk = X [k] / conjugate (D [k]) ; */
258 	    DIV_CONJ (xk, X [k], D [k]) ;
259 #else
260 	    /* xk = X [k] / D [k] ; */
261 	    DIV (xk, X [k], D [k]) ;
262 #endif
263 #else
264 	    /* Do not divide by zero */
265 	    if (IS_NONZERO (D [k]))
266 	    {
267 #ifdef CONJUGATE_SOLVE
268 		/* xk = X [k] / conjugate (D [k]) ; */
269 		DIV_CONJ (xk, X [k], D [k]) ;
270 #else
271 		/* xk = X [k] / D [k] ; */
272 		DIV (xk, X [k], D [k]) ;
273 #endif
274 	    }
275 #endif
276 
277 	    X [k] = xk ;
278 	    if (IS_NONZERO (xk))
279 	    {
280 		if (k == kstart)
281 		{
282 		    up = -up ;
283 		    xp = (Entry *) (Numeric->Memory + up + UNITS (Int, ulen)) ;
284 		}
285 		else
286 		{
287 		    xp = (Entry *) (Numeric->Memory + up) ;
288 		}
289 		for (j = 0 ; j < deg ; j++)
290 		{
291 		    DEBUG4 (("  k "ID" col "ID" value", k, Pattern [j])) ;
292 		    EDEBUG4 (*xp) ;
293 		    DEBUG4 (("\n")) ;
294 #ifdef CONJUGATE_SOLVE
295 		    /* X [Pattern [j]] -= xk * conjugate (*xp) ; */
296 		    MULT_SUB_CONJ (X [Pattern [j]], xk, *xp) ;
297 #else
298 		    /* X [Pattern [j]] -= xk * (*xp) ; */
299 		    MULT_SUB (X [Pattern [j]], xk, *xp) ;
300 #endif
301 		    xp++ ;
302 		}
303 	    }
304 	}
305 	ASSERT (uhead == n) ;
306     }
307 
308 #ifndef NO_DIVIDE_BY_ZERO
309     for (k = npiv ; k < n ; k++)
310     {
311 	/* This is an *** intentional *** divide-by-zero, to get Inf or Nan,
312 	 * as appropriate.  It is not a bug. */
313 	ASSERT (IS_ZERO (D [k])) ;
314 	/* For conjugate solve, D [k] == conjugate (D [k]), in this case */
315 	/* xk = X [k] / D [k] ; */
316 	DIV (xk, X [k], D [k]) ;
317 	X [k] = xk ;
318     }
319 #endif
320 
321 #ifndef NDEBUG
322     for (j = 0 ; j < n ; j++)
323     {
324 	DEBUG4 (("Utsolve done "ID": ", j)) ;
325 	EDEBUG4 (X [j]) ;
326 	DEBUG4 (("\n")) ;
327     }
328     DEBUG4 (("Utsolve done.\n")) ;
329 #endif
330 
331     return (DIV_FLOPS * ((double) n) + MULTSUB_FLOPS * ((double) Numeric->unz));
332 }
333