1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import (
8	"math"
9
10	"gonum.org/v1/gonum/blas"
11	"gonum.org/v1/gonum/blas/blas64"
12)
13
14var (
15	symDense *SymDense
16
17	_ Matrix           = symDense
18	_ allMatrix        = symDense
19	_ denseMatrix      = symDense
20	_ Symmetric        = symDense
21	_ RawSymmetricer   = symDense
22	_ MutableSymmetric = symDense
23)
24
25const (
26	badSymTriangle = "mat: blas64.Symmetric not upper"
27	badSymCap      = "mat: bad capacity for SymDense"
28)
29
30// SymDense is a symmetric matrix that uses dense storage. SymDense
31// matrices are stored in the upper triangle.
32type SymDense struct {
33	mat blas64.Symmetric
34	cap int
35}
36
37// Symmetric represents a symmetric matrix (where the element at {i, j} equals
38// the element at {j, i}). Symmetric matrices are always square.
39type Symmetric interface {
40	Matrix
41	// Symmetric returns the number of rows/columns in the matrix.
42	Symmetric() int
43}
44
45// A RawSymmetricer can return a view of itself as a BLAS Symmetric matrix.
46type RawSymmetricer interface {
47	RawSymmetric() blas64.Symmetric
48}
49
50// A MutableSymmetric can set elements of a symmetric matrix.
51type MutableSymmetric interface {
52	Symmetric
53	SetSym(i, j int, v float64)
54}
55
56// NewSymDense creates a new Symmetric matrix with n rows and columns. If data == nil,
57// a new slice is allocated for the backing slice. If len(data) == n*n, data is
58// used as the backing slice, and changes to the elements of the returned SymDense
59// will be reflected in data. If neither of these is true, NewSymDense will panic.
60// NewSymDense will panic if n is zero.
61//
62// The data must be arranged in row-major order, i.e. the (i*c + j)-th
63// element in the data slice is the {i, j}-th element in the matrix.
64// Only the values in the upper triangular portion of the matrix are used.
65func NewSymDense(n int, data []float64) *SymDense {
66	if n <= 0 {
67		if n == 0 {
68			panic(ErrZeroLength)
69		}
70		panic("mat: negative dimension")
71	}
72	if data != nil && n*n != len(data) {
73		panic(ErrShape)
74	}
75	if data == nil {
76		data = make([]float64, n*n)
77	}
78	return &SymDense{
79		mat: blas64.Symmetric{
80			N:      n,
81			Stride: n,
82			Data:   data,
83			Uplo:   blas.Upper,
84		},
85		cap: n,
86	}
87}
88
89// Dims returns the number of rows and columns in the matrix.
90func (s *SymDense) Dims() (r, c int) {
91	return s.mat.N, s.mat.N
92}
93
94// Caps returns the number of rows and columns in the backing matrix.
95func (s *SymDense) Caps() (r, c int) {
96	return s.cap, s.cap
97}
98
99// T returns the receiver, the transpose of a symmetric matrix.
100func (s *SymDense) T() Matrix {
101	return s
102}
103
104// Symmetric implements the Symmetric interface and returns the number of rows
105// and columns in the matrix.
106func (s *SymDense) Symmetric() int {
107	return s.mat.N
108}
109
110// RawSymmetric returns the matrix as a blas64.Symmetric. The returned
111// value must be stored in upper triangular format.
112func (s *SymDense) RawSymmetric() blas64.Symmetric {
113	return s.mat
114}
115
116// SetRawSymmetric sets the underlying blas64.Symmetric used by the receiver.
117// Changes to elements in the receiver following the call will be reflected
118// in the input.
119//
120// The supplied Symmetric must use blas.Upper storage format.
121func (s *SymDense) SetRawSymmetric(mat blas64.Symmetric) {
122	if mat.Uplo != blas.Upper {
123		panic(badSymTriangle)
124	}
125	s.cap = mat.N
126	s.mat = mat
127}
128
129// Reset empties the matrix so that it can be reused as the
130// receiver of a dimensionally restricted operation.
131//
132// Reset should not be used when the matrix shares backing data.
133// See the Reseter interface for more information.
134func (s *SymDense) Reset() {
135	// N and Stride must be zeroed in unison.
136	s.mat.N, s.mat.Stride = 0, 0
137	s.mat.Data = s.mat.Data[:0]
138}
139
140// ReuseAsSym changes the receiver if it IsEmpty() to be of size n×n.
141//
142// ReuseAsSym re-uses the backing data slice if it has sufficient capacity,
143// otherwise a new slice is allocated. The backing data is zero on return.
144//
145// ReuseAsSym panics if the receiver is not empty, and panics if
146// the input size is less than one. To empty the receiver for re-use,
147// Reset should be used.
148func (s *SymDense) ReuseAsSym(n int) {
149	if n <= 0 {
150		if n == 0 {
151			panic(ErrZeroLength)
152		}
153		panic(ErrNegativeDimension)
154	}
155	if !s.IsEmpty() {
156		panic(ErrReuseNonEmpty)
157	}
158	s.reuseAsZeroed(n)
159}
160
161// Zero sets all of the matrix elements to zero.
162func (s *SymDense) Zero() {
163	for i := 0; i < s.mat.N; i++ {
164		zero(s.mat.Data[i*s.mat.Stride+i : i*s.mat.Stride+s.mat.N])
165	}
166}
167
168// IsEmpty returns whether the receiver is empty. Empty matrices can be the
169// receiver for size-restricted operations. The receiver can be emptied using
170// Reset.
171func (s *SymDense) IsEmpty() bool {
172	// It must be the case that m.Dims() returns
173	// zeros in this case. See comment in Reset().
174	return s.mat.N == 0
175}
176
177// reuseAsNonZeroed resizes an empty matrix to a n×n matrix,
178// or checks that a non-empty matrix is n×n.
179func (s *SymDense) reuseAsNonZeroed(n int) {
180	// reuseAsNonZeroed must be kept in sync with reuseAsZeroed.
181	if n == 0 {
182		panic(ErrZeroLength)
183	}
184	if s.mat.N > s.cap {
185		panic(badSymCap)
186	}
187	if s.IsEmpty() {
188		s.mat = blas64.Symmetric{
189			N:      n,
190			Stride: n,
191			Data:   use(s.mat.Data, n*n),
192			Uplo:   blas.Upper,
193		}
194		s.cap = n
195		return
196	}
197	if s.mat.Uplo != blas.Upper {
198		panic(badSymTriangle)
199	}
200	if s.mat.N != n {
201		panic(ErrShape)
202	}
203}
204
205// reuseAsNonZeroed resizes an empty matrix to a n×n matrix,
206// or checks that a non-empty matrix is n×n. It then zeros the
207// elements of the matrix.
208func (s *SymDense) reuseAsZeroed(n int) {
209	// reuseAsZeroed must be kept in sync with reuseAsNonZeroed.
210	if n == 0 {
211		panic(ErrZeroLength)
212	}
213	if s.mat.N > s.cap {
214		panic(badSymCap)
215	}
216	if s.IsEmpty() {
217		s.mat = blas64.Symmetric{
218			N:      n,
219			Stride: n,
220			Data:   useZeroed(s.mat.Data, n*n),
221			Uplo:   blas.Upper,
222		}
223		s.cap = n
224		return
225	}
226	if s.mat.Uplo != blas.Upper {
227		panic(badSymTriangle)
228	}
229	if s.mat.N != n {
230		panic(ErrShape)
231	}
232	s.Zero()
233}
234
235func (s *SymDense) isolatedWorkspace(a Symmetric) (w *SymDense, restore func()) {
236	n := a.Symmetric()
237	if n == 0 {
238		panic(ErrZeroLength)
239	}
240	w = getWorkspaceSym(n, false)
241	return w, func() {
242		s.CopySym(w)
243		putWorkspaceSym(w)
244	}
245}
246
247// DiagView returns the diagonal as a matrix backed by the original data.
248func (s *SymDense) DiagView() Diagonal {
249	n := s.mat.N
250	return &DiagDense{
251		mat: blas64.Vector{
252			N:    n,
253			Inc:  s.mat.Stride + 1,
254			Data: s.mat.Data[:(n-1)*s.mat.Stride+n],
255		},
256	}
257}
258
259func (s *SymDense) AddSym(a, b Symmetric) {
260	n := a.Symmetric()
261	if n != b.Symmetric() {
262		panic(ErrShape)
263	}
264	s.reuseAsNonZeroed(n)
265
266	if a, ok := a.(RawSymmetricer); ok {
267		if b, ok := b.(RawSymmetricer); ok {
268			amat, bmat := a.RawSymmetric(), b.RawSymmetric()
269			if s != a {
270				s.checkOverlap(generalFromSymmetric(amat))
271			}
272			if s != b {
273				s.checkOverlap(generalFromSymmetric(bmat))
274			}
275			for i := 0; i < n; i++ {
276				btmp := bmat.Data[i*bmat.Stride+i : i*bmat.Stride+n]
277				stmp := s.mat.Data[i*s.mat.Stride+i : i*s.mat.Stride+n]
278				for j, v := range amat.Data[i*amat.Stride+i : i*amat.Stride+n] {
279					stmp[j] = v + btmp[j]
280				}
281			}
282			return
283		}
284	}
285
286	s.checkOverlapMatrix(a)
287	s.checkOverlapMatrix(b)
288	for i := 0; i < n; i++ {
289		stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
290		for j := i; j < n; j++ {
291			stmp[j] = a.At(i, j) + b.At(i, j)
292		}
293	}
294}
295
296func (s *SymDense) CopySym(a Symmetric) int {
297	n := a.Symmetric()
298	n = min(n, s.mat.N)
299	if n == 0 {
300		return 0
301	}
302	switch a := a.(type) {
303	case RawSymmetricer:
304		amat := a.RawSymmetric()
305		if amat.Uplo != blas.Upper {
306			panic(badSymTriangle)
307		}
308		for i := 0; i < n; i++ {
309			copy(s.mat.Data[i*s.mat.Stride+i:i*s.mat.Stride+n], amat.Data[i*amat.Stride+i:i*amat.Stride+n])
310		}
311	default:
312		for i := 0; i < n; i++ {
313			stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
314			for j := i; j < n; j++ {
315				stmp[j] = a.At(i, j)
316			}
317		}
318	}
319	return n
320}
321
322// SymRankOne performs a symmetric rank-one update to the matrix a with x,
323// which is treated as a column vector, and stores the result in the receiver
324//  s = a + alpha * x * xᵀ
325func (s *SymDense) SymRankOne(a Symmetric, alpha float64, x Vector) {
326	n := x.Len()
327	if a.Symmetric() != n {
328		panic(ErrShape)
329	}
330	s.reuseAsNonZeroed(n)
331
332	if s != a {
333		if rs, ok := a.(RawSymmetricer); ok {
334			s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
335		}
336		s.CopySym(a)
337	}
338
339	xU, _ := untransposeExtract(x)
340	if rv, ok := xU.(*VecDense); ok {
341		r, c := xU.Dims()
342		xmat := rv.mat
343		s.checkOverlap(generalFromVector(xmat, r, c))
344		blas64.Syr(alpha, xmat, s.mat)
345		return
346	}
347
348	for i := 0; i < n; i++ {
349		for j := i; j < n; j++ {
350			s.set(i, j, s.at(i, j)+alpha*x.AtVec(i)*x.AtVec(j))
351		}
352	}
353}
354
355// SymRankK performs a symmetric rank-k update to the matrix a and stores the
356// result into the receiver. If a is zero, see SymOuterK.
357//  s = a + alpha * x * x'
358func (s *SymDense) SymRankK(a Symmetric, alpha float64, x Matrix) {
359	n := a.Symmetric()
360	r, _ := x.Dims()
361	if r != n {
362		panic(ErrShape)
363	}
364	xMat, aTrans := untransposeExtract(x)
365	var g blas64.General
366	if rm, ok := xMat.(*Dense); ok {
367		g = rm.mat
368	} else {
369		g = DenseCopyOf(x).mat
370		aTrans = false
371	}
372	if a != s {
373		if rs, ok := a.(RawSymmetricer); ok {
374			s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
375		}
376		s.reuseAsNonZeroed(n)
377		s.CopySym(a)
378	}
379	t := blas.NoTrans
380	if aTrans {
381		t = blas.Trans
382	}
383	blas64.Syrk(t, alpha, g, 1, s.mat)
384}
385
386// SymOuterK calculates the outer product of x with itself and stores
387// the result into the receiver. It is equivalent to the matrix
388// multiplication
389//  s = alpha * x * x'.
390// In order to update an existing matrix, see SymRankOne.
391func (s *SymDense) SymOuterK(alpha float64, x Matrix) {
392	n, _ := x.Dims()
393	switch {
394	case s.IsEmpty():
395		s.mat = blas64.Symmetric{
396			N:      n,
397			Stride: n,
398			Data:   useZeroed(s.mat.Data, n*n),
399			Uplo:   blas.Upper,
400		}
401		s.cap = n
402		s.SymRankK(s, alpha, x)
403	case s.mat.Uplo != blas.Upper:
404		panic(badSymTriangle)
405	case s.mat.N == n:
406		if s == x {
407			w := getWorkspaceSym(n, true)
408			w.SymRankK(w, alpha, x)
409			s.CopySym(w)
410			putWorkspaceSym(w)
411		} else {
412			switch r := x.(type) {
413			case RawMatrixer:
414				s.checkOverlap(r.RawMatrix())
415			case RawSymmetricer:
416				s.checkOverlap(generalFromSymmetric(r.RawSymmetric()))
417			case RawTriangular:
418				s.checkOverlap(generalFromTriangular(r.RawTriangular()))
419			}
420			// Only zero the upper triangle.
421			for i := 0; i < n; i++ {
422				ri := i * s.mat.Stride
423				zero(s.mat.Data[ri+i : ri+n])
424			}
425			s.SymRankK(s, alpha, x)
426		}
427	default:
428		panic(ErrShape)
429	}
430}
431
432// RankTwo performs a symmetric rank-two update to the matrix a with the
433// vectors x and y, which are treated as column vectors, and stores the
434// result in the receiver
435//  m = a + alpha * (x * yᵀ + y * xᵀ)
436func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y Vector) {
437	n := s.mat.N
438	if x.Len() != n {
439		panic(ErrShape)
440	}
441	if y.Len() != n {
442		panic(ErrShape)
443	}
444
445	if s != a {
446		if rs, ok := a.(RawSymmetricer); ok {
447			s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
448		}
449	}
450
451	var xmat, ymat blas64.Vector
452	fast := true
453	xU, _ := untransposeExtract(x)
454	if rv, ok := xU.(*VecDense); ok {
455		r, c := xU.Dims()
456		xmat = rv.mat
457		s.checkOverlap(generalFromVector(xmat, r, c))
458	} else {
459		fast = false
460	}
461	yU, _ := untransposeExtract(y)
462	if rv, ok := yU.(*VecDense); ok {
463		r, c := yU.Dims()
464		ymat = rv.mat
465		s.checkOverlap(generalFromVector(ymat, r, c))
466	} else {
467		fast = false
468	}
469
470	if s != a {
471		if rs, ok := a.(RawSymmetricer); ok {
472			s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
473		}
474		s.reuseAsNonZeroed(n)
475		s.CopySym(a)
476	}
477
478	if fast {
479		if s != a {
480			s.reuseAsNonZeroed(n)
481			s.CopySym(a)
482		}
483		blas64.Syr2(alpha, xmat, ymat, s.mat)
484		return
485	}
486
487	for i := 0; i < n; i++ {
488		s.reuseAsNonZeroed(n)
489		for j := i; j < n; j++ {
490			s.set(i, j, a.At(i, j)+alpha*(x.AtVec(i)*y.AtVec(j)+y.AtVec(i)*x.AtVec(j)))
491		}
492	}
493}
494
495// ScaleSym multiplies the elements of a by f, placing the result in the receiver.
496func (s *SymDense) ScaleSym(f float64, a Symmetric) {
497	n := a.Symmetric()
498	s.reuseAsNonZeroed(n)
499	if a, ok := a.(RawSymmetricer); ok {
500		amat := a.RawSymmetric()
501		if s != a {
502			s.checkOverlap(generalFromSymmetric(amat))
503		}
504		for i := 0; i < n; i++ {
505			for j := i; j < n; j++ {
506				s.mat.Data[i*s.mat.Stride+j] = f * amat.Data[i*amat.Stride+j]
507			}
508		}
509		return
510	}
511	for i := 0; i < n; i++ {
512		for j := i; j < n; j++ {
513			s.mat.Data[i*s.mat.Stride+j] = f * a.At(i, j)
514		}
515	}
516}
517
518// SubsetSym extracts a subset of the rows and columns of the matrix a and stores
519// the result in-place into the receiver. The resulting matrix size is
520// len(set)×len(set). Specifically, at the conclusion of SubsetSym,
521// s.At(i, j) equals a.At(set[i], set[j]). Note that the supplied set does not
522// have to be a strict subset, dimension repeats are allowed.
523func (s *SymDense) SubsetSym(a Symmetric, set []int) {
524	n := len(set)
525	na := a.Symmetric()
526	s.reuseAsNonZeroed(n)
527	var restore func()
528	if a == s {
529		s, restore = s.isolatedWorkspace(a)
530		defer restore()
531	}
532
533	if a, ok := a.(RawSymmetricer); ok {
534		raw := a.RawSymmetric()
535		if s != a {
536			s.checkOverlap(generalFromSymmetric(raw))
537		}
538		for i := 0; i < n; i++ {
539			ssub := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
540			r := set[i]
541			rsub := raw.Data[r*raw.Stride : r*raw.Stride+na]
542			for j := i; j < n; j++ {
543				c := set[j]
544				if r <= c {
545					ssub[j] = rsub[c]
546				} else {
547					ssub[j] = raw.Data[c*raw.Stride+r]
548				}
549			}
550		}
551		return
552	}
553	for i := 0; i < n; i++ {
554		for j := i; j < n; j++ {
555			s.mat.Data[i*s.mat.Stride+j] = a.At(set[i], set[j])
556		}
557	}
558}
559
560// SliceSym returns a new Matrix that shares backing data with the receiver.
561// The returned matrix starts at {i,i} of the receiver and extends k-i rows
562// and columns. The final row and column in the resulting matrix is k-1.
563// SliceSym panics with ErrIndexOutOfRange if the slice is outside the
564// capacity of the receiver.
565func (s *SymDense) SliceSym(i, k int) Symmetric {
566	return s.sliceSym(i, k)
567}
568
569func (s *SymDense) sliceSym(i, k int) *SymDense {
570	sz := s.cap
571	if i < 0 || sz < i || k < i || sz < k {
572		panic(ErrIndexOutOfRange)
573	}
574	v := *s
575	v.mat.Data = s.mat.Data[i*s.mat.Stride+i : (k-1)*s.mat.Stride+k]
576	v.mat.N = k - i
577	v.cap = s.cap - i
578	return &v
579}
580
581// Trace returns the trace of the matrix.
582func (s *SymDense) Trace() float64 {
583	// TODO(btracey): could use internal asm sum routine.
584	var v float64
585	for i := 0; i < s.mat.N; i++ {
586		v += s.mat.Data[i*s.mat.Stride+i]
587	}
588	return v
589}
590
591// GrowSym returns the receiver expanded by n rows and n columns. If the
592// dimensions of the expanded matrix are outside the capacity of the receiver
593// a new allocation is made, otherwise not. Note that the receiver itself is
594// not modified during the call to GrowSquare.
595func (s *SymDense) GrowSym(n int) Symmetric {
596	if n < 0 {
597		panic(ErrIndexOutOfRange)
598	}
599	if n == 0 {
600		return s
601	}
602	var v SymDense
603	n += s.mat.N
604	if n > s.cap {
605		v.mat = blas64.Symmetric{
606			N:      n,
607			Stride: n,
608			Uplo:   blas.Upper,
609			Data:   make([]float64, n*n),
610		}
611		v.cap = n
612		// Copy elements, including those not currently visible. Use a temporary
613		// structure to avoid modifying the receiver.
614		var tmp SymDense
615		tmp.mat = blas64.Symmetric{
616			N:      s.cap,
617			Stride: s.mat.Stride,
618			Data:   s.mat.Data,
619			Uplo:   s.mat.Uplo,
620		}
621		tmp.cap = s.cap
622		v.CopySym(&tmp)
623		return &v
624	}
625	v.mat = blas64.Symmetric{
626		N:      n,
627		Stride: s.mat.Stride,
628		Uplo:   blas.Upper,
629		Data:   s.mat.Data[:(n-1)*s.mat.Stride+n],
630	}
631	v.cap = s.cap
632	return &v
633}
634
635// PowPSD computes a^pow where a is a positive symmetric definite matrix.
636//
637// PowPSD returns an error if the matrix is not not positive symmetric definite
638// or the Eigen decomposition is not successful.
639func (s *SymDense) PowPSD(a Symmetric, pow float64) error {
640	dim := a.Symmetric()
641	s.reuseAsNonZeroed(dim)
642
643	var eigen EigenSym
644	ok := eigen.Factorize(a, true)
645	if !ok {
646		return ErrFailedEigen
647	}
648	values := eigen.Values(nil)
649	for i, v := range values {
650		if v <= 0 {
651			return ErrNotPSD
652		}
653		values[i] = math.Pow(v, pow)
654	}
655	var u Dense
656	eigen.VectorsTo(&u)
657
658	s.SymOuterK(values[0], u.ColView(0))
659
660	var v VecDense
661	for i := 1; i < dim; i++ {
662		v.ColViewOf(&u, i)
663		s.SymRankOne(s, values[i], &v)
664	}
665	return nil
666}
667