1 /* ========================================================================== */
2 /* === UMF_utsolve ========================================================== */
3 /* ========================================================================== */
4 
5 /* -------------------------------------------------------------------------- */
6 /* Copyright (c) 2005-2012 by Timothy A. Davis, http://www.suitesparse.com.   */
7 /* All Rights Reserved.  See ../Doc/License.txt for License.                  */
8 /* -------------------------------------------------------------------------- */
9 
10 /*  solves U'x = b or U.'x=b, where U is the upper triangular factor of a */
11 /*  matrix.  B is overwritten with the solution X. */
12 /*  Returns the floating point operation count */
13 
14 #include "umf_internal.h"
15 #include "umf_utsolve.h"
16 
17 GLOBAL double
18 #ifdef CONJUGATE_SOLVE
UMF_uhsolve(NumericType * Numeric,Entry X[],Int Pattern[])19 UMF_uhsolve			/* solve U'x=b  (complex conjugate transpose) */
20 #else
21 UMF_utsolve			/* solve U.'x=b (array transpose) */
22 #endif
23 (
24     NumericType *Numeric,
25     Entry X [ ],		/* b on input, solution x on output */
26     Int Pattern [ ]		/* a work array of size n */
27 )
28 {
29     /* ---------------------------------------------------------------------- */
30     /* local variables */
31     /* ---------------------------------------------------------------------- */
32 
33     Entry xk ;
34     Entry *xp, *D, *Uval ;
35     Int k, deg, j, *ip, col, *Upos, *Uilen, kstart, kend, up,
36 	*Uip, n, uhead, ulen, pos, npiv, n1, *Ui ;
37 
38     /* ---------------------------------------------------------------------- */
39     /* get parameters */
40     /* ---------------------------------------------------------------------- */
41 
42     if (Numeric->n_row != Numeric->n_col) return (0.) ;
43     n = Numeric->n_row ;
44     npiv = Numeric->npiv ;
45     Upos = Numeric->Upos ;
46     Uilen = Numeric->Uilen ;
47     Uip = Numeric->Uip ;
48     D = Numeric->D ;
49     kend = 0 ;
50     n1 = Numeric->n1 ;
51 
52 #ifndef NDEBUG
53     DEBUG4 (("Utsolve start: npiv "ID" n "ID"\n", npiv, n)) ;
54     for (j = 0 ; j < n ; j++)
55     {
56 	DEBUG4 (("Utsolve start "ID": ", j)) ;
57 	EDEBUG4 (X [j]) ;
58 	DEBUG4 (("\n")) ;
59     }
60 #endif
61 
62     /* ---------------------------------------------------------------------- */
63     /* singletons */
64     /* ---------------------------------------------------------------------- */
65 
66     for (k = 0 ; k < n1 ; k++)
67     {
68 	DEBUG4 (("Singleton k "ID"\n", k)) ;
69 
70 #ifndef NO_DIVIDE_BY_ZERO
71 	/* Go ahead and divide by zero if D [k] is zero. */
72 #ifdef CONJUGATE_SOLVE
73 	/* xk = X [k] / conjugate (D [k]) ; */
74 	DIV_CONJ (xk, X [k], D [k]) ;
75 #else
76 	/* xk = X [k] / D [k] ; */
77 	DIV (xk, X [k], D [k]) ;
78 #endif
79 #else
80 	/* Do not divide by zero */
81 	if (IS_NONZERO (D [k]))
82 	{
83 #ifdef CONJUGATE_SOLVE
84 	    /* xk = X [k] / conjugate (D [k]) ; */
85 	    DIV_CONJ (xk, X [k], D [k]) ;
86 #else
87 	    /* xk = X [k] / D [k] ; */
88 	    DIV (xk, X [k], D [k]) ;
89 #endif
90 	}
91 #endif
92 
93 	X [k] = xk ;
94 	deg = Uilen [k] ;
95 	if (deg > 0 && IS_NONZERO (xk))
96 	{
97 	    up = Uip [k] ;
98 	    Ui = (Int *) (Numeric->Memory + up) ;
99 	    up += UNITS (Int, deg) ;
100 	    Uval = (Entry *) (Numeric->Memory + up) ;
101 	    for (j = 0 ; j < deg ; j++)
102 	    {
103 		DEBUG4 (("  k "ID" col "ID" value", k, Ui [j])) ;
104 		EDEBUG4 (Uval [j]) ;
105 		DEBUG4 (("\n")) ;
106 #ifdef CONJUGATE_SOLVE
107 		/* X [Ui [j]] -= xk * conjugate (Uval [j]) ; */
108 		MULT_SUB_CONJ (X [Ui [j]], xk, Uval [j]) ;
109 #else
110 		/* X [Ui [j]] -= xk * Uval [j] ; */
111 		MULT_SUB (X [Ui [j]], xk, Uval [j]) ;
112 #endif
113 	    }
114 	}
115     }
116 
117     /* ---------------------------------------------------------------------- */
118     /* nonsingletons */
119     /* ---------------------------------------------------------------------- */
120 
121     for (kstart = n1 ; kstart < npiv ; kstart = kend + 1)
122     {
123 
124 	/* ------------------------------------------------------------------ */
125 	/* find the end of this Uchain */
126 	/* ------------------------------------------------------------------ */
127 
128 	DEBUG4 (("kstart "ID" kend "ID"\n", kstart, kend)) ;
129 	/* for (kend = kstart ; kend < npiv && Uip [kend+1] > 0 ; kend++) ; */
130 	kend = kstart ;
131 	while (kend < npiv && Uip [kend+1] > 0)
132 	{
133 	    kend++ ;
134 	}
135 
136 	/* ------------------------------------------------------------------ */
137 	/* scan the whole Uchain to find the pattern of the first row of U */
138 	/* ------------------------------------------------------------------ */
139 
140 	k = kend+1 ;
141 	DEBUG4 (("\nKend "ID" K "ID"\n", kend, k)) ;
142 
143 	/* ------------------------------------------------------------------ */
144 	/* start with last row in Uchain of U in Pattern [0..deg-1] */
145 	/* ------------------------------------------------------------------ */
146 
147 	if (k == npiv)
148 	{
149 	    deg = Numeric->ulen ;
150 	    if (deg > 0)
151 	    {
152 		/* :: make last pivot row of U (singular matrices only) :: */
153 		for (j = 0 ; j < deg ; j++)
154 		{
155 		    Pattern [j] = Numeric->Upattern [j] ;
156 		}
157 	    }
158 	}
159 	else
160 	{
161 	    ASSERT (k >= 0 && k < npiv) ;
162 	    up = -Uip [k] ;
163 	    ASSERT (up > 0) ;
164 	    deg = Uilen [k] ;
165 	    DEBUG4 (("end of chain for row of U "ID" deg "ID"\n", k-1, deg)) ;
166 	    ip = (Int *) (Numeric->Memory + up) ;
167 	    for (j = 0 ; j < deg ; j++)
168 	    {
169 		col = *ip++ ;
170 		DEBUG4 (("  k "ID" col "ID"\n", k-1, col)) ;
171 		ASSERT (k <= col) ;
172 		Pattern [j] = col ;
173 	    }
174 	}
175 
176 	/* empty the stack at the bottom of Pattern */
177 	uhead = n ;
178 
179 	for (k = kend ; k > kstart ; k--)
180 	{
181 	    /* Pattern [0..deg-1] is the pattern of row k of U */
182 
183 	    /* -------------------------------------------------------------- */
184 	    /* make row k-1 of U in Pattern [0..deg-1] */
185 	    /* -------------------------------------------------------------- */
186 
187 	    ASSERT (k >= 0 && k < npiv) ;
188 	    ulen = Uilen [k] ;
189 	    /* delete, and push on the stack */
190 	    for (j = 0 ; j < ulen ; j++)
191 	    {
192 		ASSERT (uhead >= deg) ;
193 		Pattern [--uhead] = Pattern [--deg] ;
194 	    }
195 	    DEBUG4 (("middle of chain for row of U "ID" deg "ID"\n", k, deg)) ;
196 	    ASSERT (deg >= 0) ;
197 
198 	    pos = Upos [k] ;
199 	    if (pos != EMPTY)
200 	    {
201 		/* add the pivot column */
202 		DEBUG4 (("k "ID" add pivot entry at position "ID"\n", k, pos)) ;
203 		ASSERT (pos >= 0 && pos <= deg) ;
204 		Pattern [deg++] = Pattern [pos] ;
205 		Pattern [pos] = k ;
206 	    }
207 	}
208 
209 	/* Pattern [0..deg-1] is now the pattern of the first row in Uchain */
210 
211 	/* ------------------------------------------------------------------ */
212 	/* solve using this Uchain, in reverse order */
213 	/* ------------------------------------------------------------------ */
214 
215 	DEBUG4 (("Unwinding Uchain\n")) ;
216 	for (k = kstart ; k <= kend ; k++)
217 	{
218 
219 	    /* -------------------------------------------------------------- */
220 	    /* construct row k */
221 	    /* -------------------------------------------------------------- */
222 
223 	    ASSERT (k >= 0 && k < npiv) ;
224 	    pos = Upos [k] ;
225 	    if (pos != EMPTY)
226 	    {
227 		/* remove the pivot column */
228 		DEBUG4 (("k "ID" add pivot entry at position "ID"\n", k, pos)) ;
229 		ASSERT (k > kstart) ;
230 		ASSERT (pos >= 0 && pos < deg) ;
231 		ASSERT (Pattern [pos] == k) ;
232 		Pattern [pos] = Pattern [--deg] ;
233 	    }
234 
235 	    up = Uip [k] ;
236 	    ulen = Uilen [k] ;
237 	    if (k > kstart)
238 	    {
239 		/* concatenate the deleted pattern; pop from the stack */
240 		for (j = 0 ; j < ulen ; j++)
241 		{
242 		    ASSERT (deg <= uhead && uhead < n) ;
243 		    Pattern [deg++] = Pattern [uhead++] ;
244 		}
245 		DEBUG4 (("middle of chain, row of U "ID" deg "ID"\n", k, deg)) ;
246 		ASSERT (deg >= 0) ;
247 	    }
248 
249 	    /* -------------------------------------------------------------- */
250 	    /* use row k of U */
251 	    /* -------------------------------------------------------------- */
252 
253 #ifndef NO_DIVIDE_BY_ZERO
254 	    /* Go ahead and divide by zero if D [k] is zero. */
255 #ifdef CONJUGATE_SOLVE
256 	    /* xk = X [k] / conjugate (D [k]) ; */
257 	    DIV_CONJ (xk, X [k], D [k]) ;
258 #else
259 	    /* xk = X [k] / D [k] ; */
260 	    DIV (xk, X [k], D [k]) ;
261 #endif
262 #else
263 	    /* Do not divide by zero */
264 	    if (IS_NONZERO (D [k]))
265 	    {
266 #ifdef CONJUGATE_SOLVE
267 		/* xk = X [k] / conjugate (D [k]) ; */
268 		DIV_CONJ (xk, X [k], D [k]) ;
269 #else
270 		/* xk = X [k] / D [k] ; */
271 		DIV (xk, X [k], D [k]) ;
272 #endif
273 	    }
274 #endif
275 
276 	    X [k] = xk ;
277 	    if (IS_NONZERO (xk))
278 	    {
279 		if (k == kstart)
280 		{
281 		    up = -up ;
282 		    xp = (Entry *) (Numeric->Memory + up + UNITS (Int, ulen)) ;
283 		}
284 		else
285 		{
286 		    xp = (Entry *) (Numeric->Memory + up) ;
287 		}
288 		for (j = 0 ; j < deg ; j++)
289 		{
290 		    DEBUG4 (("  k "ID" col "ID" value", k, Pattern [j])) ;
291 		    EDEBUG4 (*xp) ;
292 		    DEBUG4 (("\n")) ;
293 #ifdef CONJUGATE_SOLVE
294 		    /* X [Pattern [j]] -= xk * conjugate (*xp) ; */
295 		    MULT_SUB_CONJ (X [Pattern [j]], xk, *xp) ;
296 #else
297 		    /* X [Pattern [j]] -= xk * (*xp) ; */
298 		    MULT_SUB (X [Pattern [j]], xk, *xp) ;
299 #endif
300 		    xp++ ;
301 		}
302 	    }
303 	}
304 	ASSERT (uhead == n) ;
305     }
306 
307 #ifndef NO_DIVIDE_BY_ZERO
308     for (k = npiv ; k < n ; k++)
309     {
310 	/* This is an *** intentional *** divide-by-zero, to get Inf or Nan,
311 	 * as appropriate.  It is not a bug. */
312 	ASSERT (IS_ZERO (D [k])) ;
313 	/* For conjugate solve, D [k] == conjugate (D [k]), in this case */
314 	/* xk = X [k] / D [k] ; */
315 	DIV (xk, X [k], D [k]) ;
316 	X [k] = xk ;
317     }
318 #endif
319 
320 #ifndef NDEBUG
321     for (j = 0 ; j < n ; j++)
322     {
323 	DEBUG4 (("Utsolve done "ID": ", j)) ;
324 	EDEBUG4 (X [j]) ;
325 	DEBUG4 (("\n")) ;
326     }
327     DEBUG4 (("Utsolve done.\n")) ;
328 #endif
329 
330     return (DIV_FLOPS * ((double) n) + MULTSUB_FLOPS * ((double) Numeric->unz));
331 }
332