1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import (
8	"math"
9
10	"gonum.org/v1/gonum/blas"
11	"gonum.org/v1/gonum/blas/blas64"
12	"gonum.org/v1/gonum/lapack/lapack64"
13)
14
15var (
16	triDense *TriDense
17	_        Matrix            = triDense
18	_        allMatrix         = triDense
19	_        denseMatrix       = triDense
20	_        Triangular        = triDense
21	_        RawTriangular     = triDense
22	_        MutableTriangular = triDense
23
24	_ NonZeroDoer    = triDense
25	_ RowNonZeroDoer = triDense
26	_ ColNonZeroDoer = triDense
27)
28
29const badTriCap = "mat: bad capacity for TriDense"
30
31// TriDense represents an upper or lower triangular matrix in dense storage
32// format.
33type TriDense struct {
34	mat blas64.Triangular
35	cap int
36}
37
38// Triangular represents a triangular matrix. Triangular matrices are always square.
39type Triangular interface {
40	Matrix
41	// Triangle returns the number of rows/columns in the matrix and its
42	// orientation.
43	Triangle() (n int, kind TriKind)
44
45	// TTri is the equivalent of the T() method in the Matrix interface but
46	// guarantees the transpose is of triangular type.
47	TTri() Triangular
48}
49
50// A RawTriangular can return a blas64.Triangular representation of the receiver.
51// Changes to the blas64.Triangular.Data slice will be reflected in the original
52// matrix, changes to the N, Stride, Uplo and Diag fields will not.
53type RawTriangular interface {
54	RawTriangular() blas64.Triangular
55}
56
57// A MutableTriangular can set elements of a triangular matrix.
58type MutableTriangular interface {
59	Triangular
60	SetTri(i, j int, v float64)
61}
62
63var (
64	_ Matrix           = TransposeTri{}
65	_ Triangular       = TransposeTri{}
66	_ UntransposeTrier = TransposeTri{}
67)
68
69// TransposeTri is a type for performing an implicit transpose of a Triangular
70// matrix. It implements the Triangular interface, returning values from the
71// transpose of the matrix within.
72type TransposeTri struct {
73	Triangular Triangular
74}
75
76// At returns the value of the element at row i and column j of the transposed
77// matrix, that is, row j and column i of the Triangular field.
78func (t TransposeTri) At(i, j int) float64 {
79	return t.Triangular.At(j, i)
80}
81
82// Dims returns the dimensions of the transposed matrix. Triangular matrices are
83// square and thus this is the same size as the original Triangular.
84func (t TransposeTri) Dims() (r, c int) {
85	c, r = t.Triangular.Dims()
86	return r, c
87}
88
89// T performs an implicit transpose by returning the Triangular field.
90func (t TransposeTri) T() Matrix {
91	return t.Triangular
92}
93
94// Triangle returns the number of rows/columns in the matrix and its orientation.
95func (t TransposeTri) Triangle() (int, TriKind) {
96	n, upper := t.Triangular.Triangle()
97	return n, !upper
98}
99
100// TTri performs an implicit transpose by returning the Triangular field.
101func (t TransposeTri) TTri() Triangular {
102	return t.Triangular
103}
104
105// Untranspose returns the Triangular field.
106func (t TransposeTri) Untranspose() Matrix {
107	return t.Triangular
108}
109
110func (t TransposeTri) UntransposeTri() Triangular {
111	return t.Triangular
112}
113
114// NewTriDense creates a new Triangular matrix with n rows and columns. If data == nil,
115// a new slice is allocated for the backing slice. If len(data) == n*n, data is
116// used as the backing slice, and changes to the elements of the returned TriDense
117// will be reflected in data. If neither of these is true, NewTriDense will panic.
118// NewTriDense will panic if n is zero.
119//
120// The data must be arranged in row-major order, i.e. the (i*c + j)-th
121// element in the data slice is the {i, j}-th element in the matrix.
122// Only the values in the triangular portion corresponding to kind are used.
123func NewTriDense(n int, kind TriKind, data []float64) *TriDense {
124	if n <= 0 {
125		if n == 0 {
126			panic(ErrZeroLength)
127		}
128		panic("mat: negative dimension")
129	}
130	if data != nil && len(data) != n*n {
131		panic(ErrShape)
132	}
133	if data == nil {
134		data = make([]float64, n*n)
135	}
136	uplo := blas.Lower
137	if kind == Upper {
138		uplo = blas.Upper
139	}
140	return &TriDense{
141		mat: blas64.Triangular{
142			N:      n,
143			Stride: n,
144			Data:   data,
145			Uplo:   uplo,
146			Diag:   blas.NonUnit,
147		},
148		cap: n,
149	}
150}
151
152func (t *TriDense) Dims() (r, c int) {
153	return t.mat.N, t.mat.N
154}
155
156// Triangle returns the dimension of t and its orientation. The returned
157// orientation is only valid when n is not empty.
158func (t *TriDense) Triangle() (n int, kind TriKind) {
159	return t.mat.N, t.triKind()
160}
161
162func (t *TriDense) isUpper() bool {
163	return isUpperUplo(t.mat.Uplo)
164}
165
166func (t *TriDense) triKind() TriKind {
167	return TriKind(isUpperUplo(t.mat.Uplo))
168}
169
170func isUpperUplo(u blas.Uplo) bool {
171	switch u {
172	case blas.Upper:
173		return true
174	case blas.Lower:
175		return false
176	default:
177		panic(badTriangle)
178	}
179}
180
181// asSymBlas returns the receiver restructured as a blas64.Symmetric with the
182// same backing memory. Panics if the receiver is unit.
183// This returns a blas64.Symmetric and not a *SymDense because SymDense can only
184// be upper triangular.
185func (t *TriDense) asSymBlas() blas64.Symmetric {
186	if t.mat.Diag == blas.Unit {
187		panic("mat: cannot convert unit TriDense into blas64.Symmetric")
188	}
189	return blas64.Symmetric{
190		N:      t.mat.N,
191		Stride: t.mat.Stride,
192		Data:   t.mat.Data,
193		Uplo:   t.mat.Uplo,
194	}
195}
196
197// T performs an implicit transpose by returning the receiver inside a Transpose.
198func (t *TriDense) T() Matrix {
199	return Transpose{t}
200}
201
202// TTri performs an implicit transpose by returning the receiver inside a TransposeTri.
203func (t *TriDense) TTri() Triangular {
204	return TransposeTri{t}
205}
206
207func (t *TriDense) RawTriangular() blas64.Triangular {
208	return t.mat
209}
210
211// SetRawTriangular sets the underlying blas64.Triangular used by the receiver.
212// Changes to elements in the receiver following the call will be reflected
213// in the input.
214//
215// The supplied Triangular must not use blas.Unit storage format.
216func (t *TriDense) SetRawTriangular(mat blas64.Triangular) {
217	if mat.Diag == blas.Unit {
218		panic("mat: cannot set TriDense with Unit storage format")
219	}
220	t.cap = mat.N
221	t.mat = mat
222}
223
224// Reset empties the matrix so that it can be reused as the
225// receiver of a dimensionally restricted operation.
226//
227// Reset should not be used when the matrix shares backing data.
228// See the Reseter interface for more information.
229func (t *TriDense) Reset() {
230	// N and Stride must be zeroed in unison.
231	t.mat.N, t.mat.Stride = 0, 0
232	// Defensively zero Uplo to ensure
233	// it is set correctly later.
234	t.mat.Uplo = 0
235	t.mat.Data = t.mat.Data[:0]
236}
237
238// Zero sets all of the matrix elements to zero.
239func (t *TriDense) Zero() {
240	if t.isUpper() {
241		for i := 0; i < t.mat.N; i++ {
242			zero(t.mat.Data[i*t.mat.Stride+i : i*t.mat.Stride+t.mat.N])
243		}
244		return
245	}
246	for i := 0; i < t.mat.N; i++ {
247		zero(t.mat.Data[i*t.mat.Stride : i*t.mat.Stride+i+1])
248	}
249}
250
251// IsEmpty returns whether the receiver is empty. Empty matrices can be the
252// receiver for size-restricted operations. The receiver can be emptied using
253// Reset.
254func (t *TriDense) IsEmpty() bool {
255	// It must be the case that t.Dims() returns
256	// zeros in this case. See comment in Reset().
257	return t.mat.Stride == 0
258}
259
260// untranspose untransposes a matrix if applicable. If a is an Untransposer, then
261// untranspose returns the underlying matrix and true. If it is not, then it returns
262// the input matrix and false.
263func untransposeTri(a Triangular) (Triangular, bool) {
264	if ut, ok := a.(UntransposeTrier); ok {
265		return ut.UntransposeTri(), true
266	}
267	return a, false
268}
269
270// ReuseAsTri changes the receiver if it IsEmpty() to be of size n×n.
271//
272// ReuseAsTri re-uses the backing data slice if it has sufficient capacity,
273// otherwise a new slice is allocated. The backing data is zero on return.
274//
275// ReuseAsTri panics if the receiver is not empty, and panics if
276// the input size is less than one. To empty the receiver for re-use,
277// Reset should be used.
278func (t *TriDense) ReuseAsTri(n int, kind TriKind) {
279	if n <= 0 {
280		if n == 0 {
281			panic(ErrZeroLength)
282		}
283		panic(ErrNegativeDimension)
284	}
285	if !t.IsEmpty() {
286		panic(ErrReuseNonEmpty)
287	}
288	t.reuseAsZeroed(n, kind)
289}
290
291// reuseAsNonZeroed resizes a zero receiver to an n×n triangular matrix with the given
292// orientation. If the receiver is non-zero, reuseAsNonZeroed checks that the receiver
293// is the correct size and orientation.
294func (t *TriDense) reuseAsNonZeroed(n int, kind TriKind) {
295	// reuseAsNonZeroed must be kept in sync with reuseAsZeroed.
296	if n == 0 {
297		panic(ErrZeroLength)
298	}
299	ul := blas.Lower
300	if kind == Upper {
301		ul = blas.Upper
302	}
303	if t.mat.N > t.cap {
304		panic(badTriCap)
305	}
306	if t.IsEmpty() {
307		t.mat = blas64.Triangular{
308			N:      n,
309			Stride: n,
310			Diag:   blas.NonUnit,
311			Data:   use(t.mat.Data, n*n),
312			Uplo:   ul,
313		}
314		t.cap = n
315		return
316	}
317	if t.mat.N != n {
318		panic(ErrShape)
319	}
320	if t.mat.Uplo != ul {
321		panic(ErrTriangle)
322	}
323}
324
325// reuseAsZeroed resizes a zero receiver to an n×n triangular matrix with the given
326// orientation. If the receiver is non-zero, reuseAsZeroed checks that the receiver
327// is the correct size and orientation. It then zeros out the matrix data.
328func (t *TriDense) reuseAsZeroed(n int, kind TriKind) {
329	// reuseAsZeroed must be kept in sync with reuseAsNonZeroed.
330	if n == 0 {
331		panic(ErrZeroLength)
332	}
333	ul := blas.Lower
334	if kind == Upper {
335		ul = blas.Upper
336	}
337	if t.mat.N > t.cap {
338		panic(badTriCap)
339	}
340	if t.IsEmpty() {
341		t.mat = blas64.Triangular{
342			N:      n,
343			Stride: n,
344			Diag:   blas.NonUnit,
345			Data:   useZeroed(t.mat.Data, n*n),
346			Uplo:   ul,
347		}
348		t.cap = n
349		return
350	}
351	if t.mat.N != n {
352		panic(ErrShape)
353	}
354	if t.mat.Uplo != ul {
355		panic(ErrTriangle)
356	}
357	t.Zero()
358}
359
360// isolatedWorkspace returns a new TriDense matrix w with the size of a and
361// returns a callback to defer which performs cleanup at the return of the call.
362// This should be used when a method receiver is the same pointer as an input argument.
363func (t *TriDense) isolatedWorkspace(a Triangular) (w *TriDense, restore func()) {
364	n, kind := a.Triangle()
365	if n == 0 {
366		panic(ErrZeroLength)
367	}
368	w = getWorkspaceTri(n, kind, false)
369	return w, func() {
370		t.Copy(w)
371		putWorkspaceTri(w)
372	}
373}
374
375// DiagView returns the diagonal as a matrix backed by the original data.
376func (t *TriDense) DiagView() Diagonal {
377	if t.mat.Diag == blas.Unit {
378		panic("mat: cannot take view of Unit diagonal")
379	}
380	n := t.mat.N
381	return &DiagDense{
382		mat: blas64.Vector{
383			N:    n,
384			Inc:  t.mat.Stride + 1,
385			Data: t.mat.Data[:(n-1)*t.mat.Stride+n],
386		},
387	}
388}
389
390// Copy makes a copy of elements of a into the receiver. It is similar to the
391// built-in copy; it copies as much as the overlap between the two matrices and
392// returns the number of rows and columns it copied. Only elements within the
393// receiver's non-zero triangle are set.
394//
395// See the Copier interface for more information.
396func (t *TriDense) Copy(a Matrix) (r, c int) {
397	r, c = a.Dims()
398	r = min(r, t.mat.N)
399	c = min(c, t.mat.N)
400	if r == 0 || c == 0 {
401		return 0, 0
402	}
403
404	switch a := a.(type) {
405	case RawMatrixer:
406		amat := a.RawMatrix()
407		if t.isUpper() {
408			for i := 0; i < r; i++ {
409				copy(t.mat.Data[i*t.mat.Stride+i:i*t.mat.Stride+c], amat.Data[i*amat.Stride+i:i*amat.Stride+c])
410			}
411		} else {
412			for i := 0; i < r; i++ {
413				copy(t.mat.Data[i*t.mat.Stride:i*t.mat.Stride+i+1], amat.Data[i*amat.Stride:i*amat.Stride+i+1])
414			}
415		}
416	case RawTriangular:
417		amat := a.RawTriangular()
418		aIsUpper := isUpperUplo(amat.Uplo)
419		tIsUpper := t.isUpper()
420		switch {
421		case tIsUpper && aIsUpper:
422			for i := 0; i < r; i++ {
423				copy(t.mat.Data[i*t.mat.Stride+i:i*t.mat.Stride+c], amat.Data[i*amat.Stride+i:i*amat.Stride+c])
424			}
425		case !tIsUpper && !aIsUpper:
426			for i := 0; i < r; i++ {
427				copy(t.mat.Data[i*t.mat.Stride:i*t.mat.Stride+i+1], amat.Data[i*amat.Stride:i*amat.Stride+i+1])
428			}
429		default:
430			for i := 0; i < r; i++ {
431				t.set(i, i, amat.Data[i*amat.Stride+i])
432			}
433		}
434	default:
435		isUpper := t.isUpper()
436		for i := 0; i < r; i++ {
437			if isUpper {
438				for j := i; j < c; j++ {
439					t.set(i, j, a.At(i, j))
440				}
441			} else {
442				for j := 0; j <= i; j++ {
443					t.set(i, j, a.At(i, j))
444				}
445			}
446		}
447	}
448
449	return r, c
450}
451
452// InverseTri computes the inverse of the triangular matrix a, storing the result
453// into the receiver. If a is ill-conditioned, a Condition error will be returned.
454// Note that matrix inversion is numerically unstable, and should generally be
455// avoided where possible, for example by using the Solve routines.
456func (t *TriDense) InverseTri(a Triangular) error {
457	t.checkOverlapMatrix(a)
458	n, _ := a.Triangle()
459	t.reuseAsNonZeroed(a.Triangle())
460	t.Copy(a)
461	work := getFloats(3*n, false)
462	iwork := getInts(n, false)
463	cond := lapack64.Trcon(CondNorm, t.mat, work, iwork)
464	putFloats(work)
465	putInts(iwork)
466	if math.IsInf(cond, 1) {
467		return Condition(cond)
468	}
469	ok := lapack64.Trtri(t.mat)
470	if !ok {
471		return Condition(math.Inf(1))
472	}
473	if cond > ConditionTolerance {
474		return Condition(cond)
475	}
476	return nil
477}
478
479// MulTri takes the product of triangular matrices a and b and places the result
480// in the receiver. The size of a and b must match, and they both must have the
481// same TriKind, or Mul will panic.
482func (t *TriDense) MulTri(a, b Triangular) {
483	n, kind := a.Triangle()
484	nb, kindb := b.Triangle()
485	if n != nb {
486		panic(ErrShape)
487	}
488	if kind != kindb {
489		panic(ErrTriangle)
490	}
491
492	aU, _ := untransposeTri(a)
493	bU, _ := untransposeTri(b)
494	t.checkOverlapMatrix(bU)
495	t.checkOverlapMatrix(aU)
496	t.reuseAsNonZeroed(n, kind)
497	var restore func()
498	if t == aU {
499		t, restore = t.isolatedWorkspace(aU)
500		defer restore()
501	} else if t == bU {
502		t, restore = t.isolatedWorkspace(bU)
503		defer restore()
504	}
505
506	// Inspect types here, helps keep the loops later clean(er).
507	_, aDiag := aU.(Diagonal)
508	_, bDiag := bU.(Diagonal)
509	// If they are both diagonal only need 1 loop.
510	// All diagonal matrices are Upper.
511	// TODO: Add fast paths for DiagDense.
512	if aDiag && bDiag {
513		t.Zero()
514		for i := 0; i < n; i++ {
515			t.SetTri(i, i, a.At(i, i)*b.At(i, i))
516		}
517		return
518	}
519
520	// Now we know at least one matrix is non-diagonal.
521	// And all diagonal matrices are all Upper.
522	// The both-diagonal case is handled above.
523	// TODO: Add fast paths for Dense variants.
524	if kind == Upper {
525		for i := 0; i < n; i++ {
526			for j := i; j < n; j++ {
527				switch {
528				case aDiag:
529					t.SetTri(i, j, a.At(i, i)*b.At(i, j))
530				case bDiag:
531					t.SetTri(i, j, a.At(i, j)*b.At(j, j))
532				default:
533					var v float64
534					for k := i; k <= j; k++ {
535						v += a.At(i, k) * b.At(k, j)
536					}
537					t.SetTri(i, j, v)
538				}
539			}
540		}
541		return
542	}
543	for i := 0; i < n; i++ {
544		for j := 0; j <= i; j++ {
545			var v float64
546			for k := j; k <= i; k++ {
547				v += a.At(i, k) * b.At(k, j)
548			}
549			t.SetTri(i, j, v)
550		}
551	}
552}
553
554// ScaleTri multiplies the elements of a by f, placing the result in the receiver.
555// If the receiver is non-zero, the size and kind of the receiver must match
556// the input, or ScaleTri will panic.
557func (t *TriDense) ScaleTri(f float64, a Triangular) {
558	n, kind := a.Triangle()
559	t.reuseAsNonZeroed(n, kind)
560
561	// TODO(btracey): Improve the set of fast-paths.
562	switch a := a.(type) {
563	case RawTriangular:
564		amat := a.RawTriangular()
565		if t != a {
566			t.checkOverlap(generalFromTriangular(amat))
567		}
568		if kind == Upper {
569			for i := 0; i < n; i++ {
570				ts := t.mat.Data[i*t.mat.Stride+i : i*t.mat.Stride+n]
571				as := amat.Data[i*amat.Stride+i : i*amat.Stride+n]
572				for i, v := range as {
573					ts[i] = v * f
574				}
575			}
576			return
577		}
578		for i := 0; i < n; i++ {
579			ts := t.mat.Data[i*t.mat.Stride : i*t.mat.Stride+i+1]
580			as := amat.Data[i*amat.Stride : i*amat.Stride+i+1]
581			for i, v := range as {
582				ts[i] = v * f
583			}
584		}
585		return
586	default:
587		t.checkOverlapMatrix(a)
588		isUpper := kind == Upper
589		for i := 0; i < n; i++ {
590			if isUpper {
591				for j := i; j < n; j++ {
592					t.set(i, j, f*a.At(i, j))
593				}
594			} else {
595				for j := 0; j <= i; j++ {
596					t.set(i, j, f*a.At(i, j))
597				}
598			}
599		}
600	}
601}
602
603// Trace returns the trace of the matrix.
604func (t *TriDense) Trace() float64 {
605	// TODO(btracey): could use internal asm sum routine.
606	var v float64
607	for i := 0; i < t.mat.N; i++ {
608		v += t.mat.Data[i*t.mat.Stride+i]
609	}
610	return v
611}
612
613// copySymIntoTriangle copies a symmetric matrix into a TriDense
614func copySymIntoTriangle(t *TriDense, s Symmetric) {
615	n, upper := t.Triangle()
616	ns := s.Symmetric()
617	if n != ns {
618		panic("mat: triangle size mismatch")
619	}
620	ts := t.mat.Stride
621	if rs, ok := s.(RawSymmetricer); ok {
622		sd := rs.RawSymmetric()
623		ss := sd.Stride
624		if upper {
625			if sd.Uplo == blas.Upper {
626				for i := 0; i < n; i++ {
627					copy(t.mat.Data[i*ts+i:i*ts+n], sd.Data[i*ss+i:i*ss+n])
628				}
629				return
630			}
631			for i := 0; i < n; i++ {
632				for j := i; j < n; j++ {
633					t.mat.Data[i*ts+j] = sd.Data[j*ss+i]
634				}
635			}
636			return
637		}
638		if sd.Uplo == blas.Upper {
639			for i := 0; i < n; i++ {
640				for j := 0; j <= i; j++ {
641					t.mat.Data[i*ts+j] = sd.Data[j*ss+i]
642				}
643			}
644			return
645		}
646		for i := 0; i < n; i++ {
647			copy(t.mat.Data[i*ts:i*ts+i+1], sd.Data[i*ss:i*ss+i+1])
648		}
649		return
650	}
651	if upper {
652		for i := 0; i < n; i++ {
653			for j := i; j < n; j++ {
654				t.mat.Data[i*ts+j] = s.At(i, j)
655			}
656		}
657		return
658	}
659	for i := 0; i < n; i++ {
660		for j := 0; j <= i; j++ {
661			t.mat.Data[i*ts+j] = s.At(i, j)
662		}
663	}
664}
665
666// DoNonZero calls the function fn for each of the non-zero elements of t. The function fn
667// takes a row/column index and the element value of t at (i, j).
668func (t *TriDense) DoNonZero(fn func(i, j int, v float64)) {
669	if t.isUpper() {
670		for i := 0; i < t.mat.N; i++ {
671			for j := i; j < t.mat.N; j++ {
672				v := t.at(i, j)
673				if v != 0 {
674					fn(i, j, v)
675				}
676			}
677		}
678		return
679	}
680	for i := 0; i < t.mat.N; i++ {
681		for j := 0; j <= i; j++ {
682			v := t.at(i, j)
683			if v != 0 {
684				fn(i, j, v)
685			}
686		}
687	}
688}
689
690// DoRowNonZero calls the function fn for each of the non-zero elements of row i of t. The function fn
691// takes a row/column index and the element value of t at (i, j).
692func (t *TriDense) DoRowNonZero(i int, fn func(i, j int, v float64)) {
693	if i < 0 || t.mat.N <= i {
694		panic(ErrRowAccess)
695	}
696	if t.isUpper() {
697		for j := i; j < t.mat.N; j++ {
698			v := t.at(i, j)
699			if v != 0 {
700				fn(i, j, v)
701			}
702		}
703		return
704	}
705	for j := 0; j <= i; j++ {
706		v := t.at(i, j)
707		if v != 0 {
708			fn(i, j, v)
709		}
710	}
711}
712
713// DoColNonZero calls the function fn for each of the non-zero elements of column j of t. The function fn
714// takes a row/column index and the element value of t at (i, j).
715func (t *TriDense) DoColNonZero(j int, fn func(i, j int, v float64)) {
716	if j < 0 || t.mat.N <= j {
717		panic(ErrColAccess)
718	}
719	if t.isUpper() {
720		for i := 0; i <= j; i++ {
721			v := t.at(i, j)
722			if v != 0 {
723				fn(i, j, v)
724			}
725		}
726		return
727	}
728	for i := j; i < t.mat.N; i++ {
729		v := t.at(i, j)
730		if v != 0 {
731			fn(i, j, v)
732		}
733	}
734}
735