1// Copyright ©2013 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import (
8	"math"
9
10	"gonum.org/v1/gonum/blas"
11	"gonum.org/v1/gonum/blas/blas64"
12	"gonum.org/v1/gonum/floats/scalar"
13	"gonum.org/v1/gonum/lapack"
14	"gonum.org/v1/gonum/lapack/lapack64"
15)
16
17// Matrix is the basic matrix interface type.
18type Matrix interface {
19	// Dims returns the dimensions of a Matrix.
20	Dims() (r, c int)
21
22	// At returns the value of a matrix element at row i, column j.
23	// It will panic if i or j are out of bounds for the matrix.
24	At(i, j int) float64
25
26	// T returns the transpose of the Matrix. Whether T returns a copy of the
27	// underlying data is implementation dependent.
28	// This method may be implemented using the Transpose type, which
29	// provides an implicit matrix transpose.
30	T() Matrix
31}
32
33// allMatrix represents the extra set of methods that all mat Matrix types
34// should satisfy. This is used to enforce compile-time consistency between the
35// Dense types, especially helpful when adding new features.
36type allMatrix interface {
37	Reseter
38	IsEmpty() bool
39	Zero()
40}
41
42// denseMatrix represents the extra set of methods that all Dense Matrix types
43// should satisfy. This is used to enforce compile-time consistency between the
44// Dense types, especially helpful when adding new features.
45type denseMatrix interface {
46	DiagView() Diagonal
47	Tracer
48}
49
50var (
51	_ Matrix       = Transpose{}
52	_ Untransposer = Transpose{}
53)
54
55// Transpose is a type for performing an implicit matrix transpose. It implements
56// the Matrix interface, returning values from the transpose of the matrix within.
57type Transpose struct {
58	Matrix Matrix
59}
60
61// At returns the value of the element at row i and column j of the transposed
62// matrix, that is, row j and column i of the Matrix field.
63func (t Transpose) At(i, j int) float64 {
64	return t.Matrix.At(j, i)
65}
66
67// Dims returns the dimensions of the transposed matrix. The number of rows returned
68// is the number of columns in the Matrix field, and the number of columns is
69// the number of rows in the Matrix field.
70func (t Transpose) Dims() (r, c int) {
71	c, r = t.Matrix.Dims()
72	return r, c
73}
74
75// T performs an implicit transpose by returning the Matrix field.
76func (t Transpose) T() Matrix {
77	return t.Matrix
78}
79
80// Untranspose returns the Matrix field.
81func (t Transpose) Untranspose() Matrix {
82	return t.Matrix
83}
84
85// Untransposer is a type that can undo an implicit transpose.
86type Untransposer interface {
87	// Note: This interface is needed to unify all of the Transpose types. In
88	// the mat methods, we need to test if the Matrix has been implicitly
89	// transposed. If this is checked by testing for the specific Transpose type
90	// then the behavior will be different if the user uses T() or TTri() for a
91	// triangular matrix.
92
93	// Untranspose returns the underlying Matrix stored for the implicit transpose.
94	Untranspose() Matrix
95}
96
97// UntransposeBander is a type that can undo an implicit band transpose.
98type UntransposeBander interface {
99	// Untranspose returns the underlying Banded stored for the implicit transpose.
100	UntransposeBand() Banded
101}
102
103// UntransposeTrier is a type that can undo an implicit triangular transpose.
104type UntransposeTrier interface {
105	// Untranspose returns the underlying Triangular stored for the implicit transpose.
106	UntransposeTri() Triangular
107}
108
109// UntransposeTriBander is a type that can undo an implicit triangular banded
110// transpose.
111type UntransposeTriBander interface {
112	// Untranspose returns the underlying Triangular stored for the implicit transpose.
113	UntransposeTriBand() TriBanded
114}
115
116// Mutable is a matrix interface type that allows elements to be altered.
117type Mutable interface {
118	// Set alters the matrix element at row i, column j to v.
119	// It will panic if i or j are out of bounds for the matrix.
120	Set(i, j int, v float64)
121
122	Matrix
123}
124
125// A RowViewer can return a Vector reflecting a row that is backed by the matrix
126// data. The Vector returned will have length equal to the number of columns.
127type RowViewer interface {
128	RowView(i int) Vector
129}
130
131// A RawRowViewer can return a slice of float64 reflecting a row that is backed by the matrix
132// data.
133type RawRowViewer interface {
134	RawRowView(i int) []float64
135}
136
137// A ColViewer can return a Vector reflecting a column that is backed by the matrix
138// data. The Vector returned will have length equal to the number of rows.
139type ColViewer interface {
140	ColView(j int) Vector
141}
142
143// A RawColViewer can return a slice of float64 reflecting a column that is backed by the matrix
144// data.
145type RawColViewer interface {
146	RawColView(j int) []float64
147}
148
149// A ClonerFrom can make a copy of a into the receiver, overwriting the previous value of the
150// receiver. The clone operation does not make any restriction on shape and will not cause
151// shadowing.
152type ClonerFrom interface {
153	CloneFrom(a Matrix)
154}
155
156// A Reseter can reset the matrix so that it can be reused as the receiver of a dimensionally
157// restricted operation. This is commonly used when the matrix is being used as a workspace
158// or temporary matrix.
159//
160// If the matrix is a view, using Reset may result in data corruption in elements outside
161// the view. Similarly, if the matrix shares backing data with another variable, using
162// Reset may lead to unexpected changes in data values.
163type Reseter interface {
164	Reset()
165}
166
167// A Copier can make a copy of elements of a into the receiver. The submatrix copied
168// starts at row and column 0 and has dimensions equal to the minimum dimensions of
169// the two matrices. The number of row and columns copied is returned.
170// Copy will copy from a source that aliases the receiver unless the source is transposed;
171// an aliasing transpose copy will panic with the exception for a special case when
172// the source data has a unitary increment or stride.
173type Copier interface {
174	Copy(a Matrix) (r, c int)
175}
176
177// A Grower can grow the size of the represented matrix by the given number of rows and columns.
178// Growing beyond the size given by the Caps method will result in the allocation of a new
179// matrix and copying of the elements. If Grow is called with negative increments it will
180// panic with ErrIndexOutOfRange.
181type Grower interface {
182	Caps() (r, c int)
183	Grow(r, c int) Matrix
184}
185
186// A BandWidther represents a banded matrix and can return the left and right half-bandwidths, k1 and
187// k2.
188type BandWidther interface {
189	BandWidth() (k1, k2 int)
190}
191
192// A RawMatrixSetter can set the underlying blas64.General used by the receiver. There is no restriction
193// on the shape of the receiver. Changes to the receiver's elements will be reflected in the blas64.General.Data.
194type RawMatrixSetter interface {
195	SetRawMatrix(a blas64.General)
196}
197
198// A RawMatrixer can return a blas64.General representation of the receiver. Changes to the blas64.General.Data
199// slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
200type RawMatrixer interface {
201	RawMatrix() blas64.General
202}
203
204// A RawVectorer can return a blas64.Vector representation of the receiver. Changes to the blas64.Vector.Data
205// slice will be reflected in the original matrix, changes to the Inc field will not.
206type RawVectorer interface {
207	RawVector() blas64.Vector
208}
209
210// A NonZeroDoer can call a function for each non-zero element of the receiver.
211// The parameters of the function are the element indices and its value.
212type NonZeroDoer interface {
213	DoNonZero(func(i, j int, v float64))
214}
215
216// A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver.
217// The parameters of the function are the element indices and its value.
218type RowNonZeroDoer interface {
219	DoRowNonZero(i int, fn func(i, j int, v float64))
220}
221
222// A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver.
223// The parameters of the function are the element indices and its value.
224type ColNonZeroDoer interface {
225	DoColNonZero(j int, fn func(i, j int, v float64))
226}
227
228// untranspose untransposes a matrix if applicable. If a is an Untransposer, then
229// untranspose returns the underlying matrix and true. If it is not, then it returns
230// the input matrix and false.
231func untranspose(a Matrix) (Matrix, bool) {
232	if ut, ok := a.(Untransposer); ok {
233		return ut.Untranspose(), true
234	}
235	return a, false
236}
237
238// untransposeExtract returns an untransposed matrix in a built-in matrix type.
239//
240// The untransposed matrix is returned unaltered if it is a built-in matrix type.
241// Otherwise, if it implements a Raw method, an appropriate built-in type value
242// is returned holding the raw matrix value of the input. If neither of these
243// is possible, the untransposed matrix is returned.
244func untransposeExtract(a Matrix) (Matrix, bool) {
245	ut, trans := untranspose(a)
246	switch m := ut.(type) {
247	case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense, *VecDense:
248		return m, trans
249	// TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
250	case RawSymBander:
251		rsb := m.RawSymBand()
252		if rsb.Uplo != blas.Upper {
253			return ut, trans
254		}
255		var sb SymBandDense
256		sb.SetRawSymBand(rsb)
257		return &sb, trans
258	case RawTriBander:
259		rtb := m.RawTriBand()
260		if rtb.Diag == blas.Unit {
261			return ut, trans
262		}
263		var tb TriBandDense
264		tb.SetRawTriBand(rtb)
265		return &tb, trans
266	case RawBander:
267		var b BandDense
268		b.SetRawBand(m.RawBand())
269		return &b, trans
270	case RawTriangular:
271		rt := m.RawTriangular()
272		if rt.Diag == blas.Unit {
273			return ut, trans
274		}
275		var t TriDense
276		t.SetRawTriangular(rt)
277		return &t, trans
278	case RawSymmetricer:
279		rs := m.RawSymmetric()
280		if rs.Uplo != blas.Upper {
281			return ut, trans
282		}
283		var s SymDense
284		s.SetRawSymmetric(rs)
285		return &s, trans
286	case RawMatrixer:
287		var d Dense
288		d.SetRawMatrix(m.RawMatrix())
289		return &d, trans
290	case RawVectorer:
291		var v VecDense
292		v.SetRawVector(m.RawVector())
293		return &v, trans
294	default:
295		return ut, trans
296	}
297}
298
299// TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
300// TODO(btracey): Add in fast paths to Row/Col for the other concrete types
301// (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
302
303// Col copies the elements in the jth column of the matrix into the slice dst.
304// The length of the provided slice must equal the number of rows, unless the
305// slice is nil in which case a new slice is first allocated.
306func Col(dst []float64, j int, a Matrix) []float64 {
307	r, c := a.Dims()
308	if j < 0 || j >= c {
309		panic(ErrColAccess)
310	}
311	if dst == nil {
312		dst = make([]float64, r)
313	} else {
314		if len(dst) != r {
315			panic(ErrColLength)
316		}
317	}
318	aU, aTrans := untranspose(a)
319	if rm, ok := aU.(RawMatrixer); ok {
320		m := rm.RawMatrix()
321		if aTrans {
322			copy(dst, m.Data[j*m.Stride:j*m.Stride+m.Cols])
323			return dst
324		}
325		blas64.Copy(blas64.Vector{N: r, Inc: m.Stride, Data: m.Data[j:]},
326			blas64.Vector{N: r, Inc: 1, Data: dst},
327		)
328		return dst
329	}
330	for i := 0; i < r; i++ {
331		dst[i] = a.At(i, j)
332	}
333	return dst
334}
335
336// Row copies the elements in the ith row of the matrix into the slice dst.
337// The length of the provided slice must equal the number of columns, unless the
338// slice is nil in which case a new slice is first allocated.
339func Row(dst []float64, i int, a Matrix) []float64 {
340	r, c := a.Dims()
341	if i < 0 || i >= r {
342		panic(ErrColAccess)
343	}
344	if dst == nil {
345		dst = make([]float64, c)
346	} else {
347		if len(dst) != c {
348			panic(ErrRowLength)
349		}
350	}
351	aU, aTrans := untranspose(a)
352	if rm, ok := aU.(RawMatrixer); ok {
353		m := rm.RawMatrix()
354		if aTrans {
355			blas64.Copy(blas64.Vector{N: c, Inc: m.Stride, Data: m.Data[i:]},
356				blas64.Vector{N: c, Inc: 1, Data: dst},
357			)
358			return dst
359		}
360		copy(dst, m.Data[i*m.Stride:i*m.Stride+m.Cols])
361		return dst
362	}
363	for j := 0; j < c; j++ {
364		dst[j] = a.At(i, j)
365	}
366	return dst
367}
368
369// Cond returns the condition number of the given matrix under the given norm.
370// The condition number must be based on the 1-norm, 2-norm or ∞-norm.
371// Cond will panic with matrix.ErrShape if the matrix has zero size.
372//
373// BUG(btracey): The computation of the 1-norm and ∞-norm for non-square matrices
374// is inaccurate, although is typically the right order of magnitude. See
375// https://github.com/xianyi/OpenBLAS/issues/636. While the value returned will
376// change with the resolution of this bug, the result from Cond will match the
377// condition number used internally.
378func Cond(a Matrix, norm float64) float64 {
379	m, n := a.Dims()
380	if m == 0 || n == 0 {
381		panic(ErrShape)
382	}
383	var lnorm lapack.MatrixNorm
384	switch norm {
385	default:
386		panic("mat: bad norm value")
387	case 1:
388		lnorm = lapack.MaxColumnSum
389	case 2:
390		var svd SVD
391		ok := svd.Factorize(a, SVDNone)
392		if !ok {
393			return math.Inf(1)
394		}
395		return svd.Cond()
396	case math.Inf(1):
397		lnorm = lapack.MaxRowSum
398	}
399
400	if m == n {
401		// Use the LU decomposition to compute the condition number.
402		var lu LU
403		lu.factorize(a, lnorm)
404		return lu.Cond()
405	}
406	if m > n {
407		// Use the QR factorization to compute the condition number.
408		var qr QR
409		qr.factorize(a, lnorm)
410		return qr.Cond()
411	}
412	// Use the LQ factorization to compute the condition number.
413	var lq LQ
414	lq.factorize(a, lnorm)
415	return lq.Cond()
416}
417
418// Det returns the determinant of the matrix a. In many expressions using LogDet
419// will be more numerically stable.
420func Det(a Matrix) float64 {
421	det, sign := LogDet(a)
422	return math.Exp(det) * sign
423}
424
425// Dot returns the sum of the element-wise product of a and b.
426// Dot panics if the matrix sizes are unequal.
427func Dot(a, b Vector) float64 {
428	la := a.Len()
429	lb := b.Len()
430	if la != lb {
431		panic(ErrShape)
432	}
433	if arv, ok := a.(RawVectorer); ok {
434		if brv, ok := b.(RawVectorer); ok {
435			return blas64.Dot(arv.RawVector(), brv.RawVector())
436		}
437	}
438	var sum float64
439	for i := 0; i < la; i++ {
440		sum += a.At(i, 0) * b.At(i, 0)
441	}
442	return sum
443}
444
445// Equal returns whether the matrices a and b have the same size
446// and are element-wise equal.
447func Equal(a, b Matrix) bool {
448	ar, ac := a.Dims()
449	br, bc := b.Dims()
450	if ar != br || ac != bc {
451		return false
452	}
453	aU, aTrans := untranspose(a)
454	bU, bTrans := untranspose(b)
455	if rma, ok := aU.(RawMatrixer); ok {
456		if rmb, ok := bU.(RawMatrixer); ok {
457			ra := rma.RawMatrix()
458			rb := rmb.RawMatrix()
459			if aTrans == bTrans {
460				for i := 0; i < ra.Rows; i++ {
461					for j := 0; j < ra.Cols; j++ {
462						if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
463							return false
464						}
465					}
466				}
467				return true
468			}
469			for i := 0; i < ra.Rows; i++ {
470				for j := 0; j < ra.Cols; j++ {
471					if ra.Data[i*ra.Stride+j] != rb.Data[j*rb.Stride+i] {
472						return false
473					}
474				}
475			}
476			return true
477		}
478	}
479	if rma, ok := aU.(RawSymmetricer); ok {
480		if rmb, ok := bU.(RawSymmetricer); ok {
481			ra := rma.RawSymmetric()
482			rb := rmb.RawSymmetric()
483			// Symmetric matrices are always upper and equal to their transpose.
484			for i := 0; i < ra.N; i++ {
485				for j := i; j < ra.N; j++ {
486					if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
487						return false
488					}
489				}
490			}
491			return true
492		}
493	}
494	if ra, ok := aU.(*VecDense); ok {
495		if rb, ok := bU.(*VecDense); ok {
496			// If the raw vectors are the same length they must either both be
497			// transposed or both not transposed (or have length 1).
498			for i := 0; i < ra.mat.N; i++ {
499				if ra.mat.Data[i*ra.mat.Inc] != rb.mat.Data[i*rb.mat.Inc] {
500					return false
501				}
502			}
503			return true
504		}
505	}
506	for i := 0; i < ar; i++ {
507		for j := 0; j < ac; j++ {
508			if a.At(i, j) != b.At(i, j) {
509				return false
510			}
511		}
512	}
513	return true
514}
515
516// EqualApprox returns whether the matrices a and b have the same size and contain all equal
517// elements with tolerance for element-wise equality specified by epsilon. Matrices
518// with non-equal shapes are not equal.
519func EqualApprox(a, b Matrix, epsilon float64) bool {
520	ar, ac := a.Dims()
521	br, bc := b.Dims()
522	if ar != br || ac != bc {
523		return false
524	}
525	aU, aTrans := untranspose(a)
526	bU, bTrans := untranspose(b)
527	if rma, ok := aU.(RawMatrixer); ok {
528		if rmb, ok := bU.(RawMatrixer); ok {
529			ra := rma.RawMatrix()
530			rb := rmb.RawMatrix()
531			if aTrans == bTrans {
532				for i := 0; i < ra.Rows; i++ {
533					for j := 0; j < ra.Cols; j++ {
534						if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
535							return false
536						}
537					}
538				}
539				return true
540			}
541			for i := 0; i < ra.Rows; i++ {
542				for j := 0; j < ra.Cols; j++ {
543					if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[j*rb.Stride+i], epsilon, epsilon) {
544						return false
545					}
546				}
547			}
548			return true
549		}
550	}
551	if rma, ok := aU.(RawSymmetricer); ok {
552		if rmb, ok := bU.(RawSymmetricer); ok {
553			ra := rma.RawSymmetric()
554			rb := rmb.RawSymmetric()
555			// Symmetric matrices are always upper and equal to their transpose.
556			for i := 0; i < ra.N; i++ {
557				for j := i; j < ra.N; j++ {
558					if !scalar.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
559						return false
560					}
561				}
562			}
563			return true
564		}
565	}
566	if ra, ok := aU.(*VecDense); ok {
567		if rb, ok := bU.(*VecDense); ok {
568			// If the raw vectors are the same length they must either both be
569			// transposed or both not transposed (or have length 1).
570			for i := 0; i < ra.mat.N; i++ {
571				if !scalar.EqualWithinAbsOrRel(ra.mat.Data[i*ra.mat.Inc], rb.mat.Data[i*rb.mat.Inc], epsilon, epsilon) {
572					return false
573				}
574			}
575			return true
576		}
577	}
578	for i := 0; i < ar; i++ {
579		for j := 0; j < ac; j++ {
580			if !scalar.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) {
581				return false
582			}
583		}
584	}
585	return true
586}
587
588// LogDet returns the log of the determinant and the sign of the determinant
589// for the matrix that has been factorized. Numerical stability in product and
590// division expressions is generally improved by working in log space.
591func LogDet(a Matrix) (det float64, sign float64) {
592	// TODO(btracey): Add specialized routines for TriDense, etc.
593	var lu LU
594	lu.Factorize(a)
595	return lu.LogDet()
596}
597
598// Max returns the largest element value of the matrix A.
599// Max will panic with matrix.ErrShape if the matrix has zero size.
600func Max(a Matrix) float64 {
601	r, c := a.Dims()
602	if r == 0 || c == 0 {
603		panic(ErrShape)
604	}
605	// Max(A) = Max(Aᵀ)
606	aU, _ := untranspose(a)
607	switch m := aU.(type) {
608	case RawMatrixer:
609		rm := m.RawMatrix()
610		max := math.Inf(-1)
611		for i := 0; i < rm.Rows; i++ {
612			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
613				if v > max {
614					max = v
615				}
616			}
617		}
618		return max
619	case RawTriangular:
620		rm := m.RawTriangular()
621		// The max of a triangular is at least 0 unless the size is 1.
622		if rm.N == 1 {
623			return rm.Data[0]
624		}
625		max := 0.0
626		if rm.Uplo == blas.Upper {
627			for i := 0; i < rm.N; i++ {
628				for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
629					if v > max {
630						max = v
631					}
632				}
633			}
634			return max
635		}
636		for i := 0; i < rm.N; i++ {
637			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
638				if v > max {
639					max = v
640				}
641			}
642		}
643		return max
644	case RawSymmetricer:
645		rm := m.RawSymmetric()
646		if rm.Uplo != blas.Upper {
647			panic(badSymTriangle)
648		}
649		max := math.Inf(-1)
650		for i := 0; i < rm.N; i++ {
651			for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
652				if v > max {
653					max = v
654				}
655			}
656		}
657		return max
658	default:
659		r, c := aU.Dims()
660		max := math.Inf(-1)
661		for i := 0; i < r; i++ {
662			for j := 0; j < c; j++ {
663				v := aU.At(i, j)
664				if v > max {
665					max = v
666				}
667			}
668		}
669		return max
670	}
671}
672
673// Min returns the smallest element value of the matrix A.
674// Min will panic with matrix.ErrShape if the matrix has zero size.
675func Min(a Matrix) float64 {
676	r, c := a.Dims()
677	if r == 0 || c == 0 {
678		panic(ErrShape)
679	}
680	// Min(A) = Min(Aᵀ)
681	aU, _ := untranspose(a)
682	switch m := aU.(type) {
683	case RawMatrixer:
684		rm := m.RawMatrix()
685		min := math.Inf(1)
686		for i := 0; i < rm.Rows; i++ {
687			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
688				if v < min {
689					min = v
690				}
691			}
692		}
693		return min
694	case RawTriangular:
695		rm := m.RawTriangular()
696		// The min of a triangular is at most 0 unless the size is 1.
697		if rm.N == 1 {
698			return rm.Data[0]
699		}
700		min := 0.0
701		if rm.Uplo == blas.Upper {
702			for i := 0; i < rm.N; i++ {
703				for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
704					if v < min {
705						min = v
706					}
707				}
708			}
709			return min
710		}
711		for i := 0; i < rm.N; i++ {
712			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
713				if v < min {
714					min = v
715				}
716			}
717		}
718		return min
719	case RawSymmetricer:
720		rm := m.RawSymmetric()
721		if rm.Uplo != blas.Upper {
722			panic(badSymTriangle)
723		}
724		min := math.Inf(1)
725		for i := 0; i < rm.N; i++ {
726			for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
727				if v < min {
728					min = v
729				}
730			}
731		}
732		return min
733	default:
734		r, c := aU.Dims()
735		min := math.Inf(1)
736		for i := 0; i < r; i++ {
737			for j := 0; j < c; j++ {
738				v := aU.At(i, j)
739				if v < min {
740					min = v
741				}
742			}
743		}
744		return min
745	}
746}
747
748// Norm returns the specified norm of the matrix A. Valid norms are:
749//  1 - The maximum absolute column sum
750//  2 - The Frobenius norm, the square root of the sum of the squares of the elements
751//  Inf - The maximum absolute row sum
752//
753// Norm will panic with ErrNormOrder if an illegal norm order is specified and
754// with ErrShape if the matrix has zero size.
755func Norm(a Matrix, norm float64) float64 {
756	r, c := a.Dims()
757	if r == 0 || c == 0 {
758		panic(ErrShape)
759	}
760	aU, aTrans := untranspose(a)
761	var work []float64
762	switch rma := aU.(type) {
763	case RawMatrixer:
764		rm := rma.RawMatrix()
765		n := normLapack(norm, aTrans)
766		if n == lapack.MaxColumnSum {
767			work = getFloats(rm.Cols, false)
768			defer putFloats(work)
769		}
770		return lapack64.Lange(n, rm, work)
771	case RawTriangular:
772		rm := rma.RawTriangular()
773		n := normLapack(norm, aTrans)
774		if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
775			work = getFloats(rm.N, false)
776			defer putFloats(work)
777		}
778		return lapack64.Lantr(n, rm, work)
779	case RawSymmetricer:
780		rm := rma.RawSymmetric()
781		n := normLapack(norm, aTrans)
782		if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
783			work = getFloats(rm.N, false)
784			defer putFloats(work)
785		}
786		return lapack64.Lansy(n, rm, work)
787	case *VecDense:
788		rv := rma.RawVector()
789		switch norm {
790		default:
791			panic(ErrNormOrder)
792		case 1:
793			if aTrans {
794				imax := blas64.Iamax(rv)
795				return math.Abs(rma.At(imax, 0))
796			}
797			return blas64.Asum(rv)
798		case 2:
799			return blas64.Nrm2(rv)
800		case math.Inf(1):
801			if aTrans {
802				return blas64.Asum(rv)
803			}
804			imax := blas64.Iamax(rv)
805			return math.Abs(rma.At(imax, 0))
806		}
807	}
808	switch norm {
809	default:
810		panic(ErrNormOrder)
811	case 1:
812		var max float64
813		for j := 0; j < c; j++ {
814			var sum float64
815			for i := 0; i < r; i++ {
816				sum += math.Abs(a.At(i, j))
817			}
818			if sum > max {
819				max = sum
820			}
821		}
822		return max
823	case 2:
824		var sum float64
825		for i := 0; i < r; i++ {
826			for j := 0; j < c; j++ {
827				v := a.At(i, j)
828				sum += v * v
829			}
830		}
831		return math.Sqrt(sum)
832	case math.Inf(1):
833		var max float64
834		for i := 0; i < r; i++ {
835			var sum float64
836			for j := 0; j < c; j++ {
837				sum += math.Abs(a.At(i, j))
838			}
839			if sum > max {
840				max = sum
841			}
842		}
843		return max
844	}
845}
846
847// normLapack converts the float64 norm input in Norm to a lapack.MatrixNorm.
848func normLapack(norm float64, aTrans bool) lapack.MatrixNorm {
849	switch norm {
850	case 1:
851		n := lapack.MaxColumnSum
852		if aTrans {
853			n = lapack.MaxRowSum
854		}
855		return n
856	case 2:
857		return lapack.Frobenius
858	case math.Inf(1):
859		n := lapack.MaxRowSum
860		if aTrans {
861			n = lapack.MaxColumnSum
862		}
863		return n
864	default:
865		panic(ErrNormOrder)
866	}
867}
868
869// Sum returns the sum of the elements of the matrix.
870func Sum(a Matrix) float64 {
871
872	var sum float64
873	aU, _ := untranspose(a)
874	switch rma := aU.(type) {
875	case RawSymmetricer:
876		rm := rma.RawSymmetric()
877		for i := 0; i < rm.N; i++ {
878			// Diagonals count once while off-diagonals count twice.
879			sum += rm.Data[i*rm.Stride+i]
880			var s float64
881			for _, v := range rm.Data[i*rm.Stride+i+1 : i*rm.Stride+rm.N] {
882				s += v
883			}
884			sum += 2 * s
885		}
886		return sum
887	case RawTriangular:
888		rm := rma.RawTriangular()
889		var startIdx, endIdx int
890		for i := 0; i < rm.N; i++ {
891			// Start and end index for this triangle-row.
892			switch rm.Uplo {
893			case blas.Upper:
894				startIdx = i
895				endIdx = rm.N
896			case blas.Lower:
897				startIdx = 0
898				endIdx = i + 1
899			default:
900				panic(badTriangle)
901			}
902			for _, v := range rm.Data[i*rm.Stride+startIdx : i*rm.Stride+endIdx] {
903				sum += v
904			}
905		}
906		return sum
907	case RawMatrixer:
908		rm := rma.RawMatrix()
909		for i := 0; i < rm.Rows; i++ {
910			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
911				sum += v
912			}
913		}
914		return sum
915	case *VecDense:
916		rm := rma.RawVector()
917		for i := 0; i < rm.N; i++ {
918			sum += rm.Data[i*rm.Inc]
919		}
920		return sum
921	default:
922		r, c := a.Dims()
923		for i := 0; i < r; i++ {
924			for j := 0; j < c; j++ {
925				sum += a.At(i, j)
926			}
927		}
928		return sum
929	}
930}
931
932// A Tracer can compute the trace of the matrix. Trace must panic if the
933// matrix is not square.
934type Tracer interface {
935	Trace() float64
936}
937
938// Trace returns the trace of the matrix. Trace will panic if the
939// matrix is not square. If a is a Tracer, its Trace method will be
940// used to calculate the matrix trace.
941func Trace(a Matrix) float64 {
942	m, _ := untransposeExtract(a)
943	if t, ok := m.(Tracer); ok {
944		return t.Trace()
945	}
946	r, c := a.Dims()
947	if r != c {
948		panic(ErrSquare)
949	}
950	var v float64
951	for i := 0; i < r; i++ {
952		v += a.At(i, i)
953	}
954	return v
955}
956
957func min(a, b int) int {
958	if a < b {
959		return a
960	}
961	return b
962}
963
964func max(a, b int) int {
965	if a > b {
966		return a
967	}
968	return b
969}
970
971// use returns a float64 slice with l elements, using f if it
972// has the necessary capacity, otherwise creating a new slice.
973func use(f []float64, l int) []float64 {
974	if l <= cap(f) {
975		return f[:l]
976	}
977	return make([]float64, l)
978}
979
980// useZeroed returns a float64 slice with l elements, using f if it
981// has the necessary capacity, otherwise creating a new slice. The
982// elements of the returned slice are guaranteed to be zero.
983func useZeroed(f []float64, l int) []float64 {
984	if l <= cap(f) {
985		f = f[:l]
986		zero(f)
987		return f
988	}
989	return make([]float64, l)
990}
991
992// zero zeros the given slice's elements.
993func zero(f []float64) {
994	for i := range f {
995		f[i] = 0
996	}
997}
998
999// useInt returns an int slice with l elements, using i if it
1000// has the necessary capacity, otherwise creating a new slice.
1001func useInt(i []int, l int) []int {
1002	if l <= cap(i) {
1003		return i[:l]
1004	}
1005	return make([]int, l)
1006}
1007