1// Copyright ©2013 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import (
8	"math"
9
10	"gonum.org/v1/gonum/blas"
11	"gonum.org/v1/gonum/blas/blas64"
12	"gonum.org/v1/gonum/floats"
13	"gonum.org/v1/gonum/lapack"
14	"gonum.org/v1/gonum/lapack/lapack64"
15)
16
17// Matrix is the basic matrix interface type.
18type Matrix interface {
19	// Dims returns the dimensions of a Matrix.
20	Dims() (r, c int)
21
22	// At returns the value of a matrix element at row i, column j.
23	// It will panic if i or j are out of bounds for the matrix.
24	At(i, j int) float64
25
26	// T returns the transpose of the Matrix. Whether T returns a copy of the
27	// underlying data is implementation dependent.
28	// This method may be implemented using the Transpose type, which
29	// provides an implicit matrix transpose.
30	T() Matrix
31}
32
33var (
34	_ Matrix       = Transpose{}
35	_ Untransposer = Transpose{}
36)
37
38// Transpose is a type for performing an implicit matrix transpose. It implements
39// the Matrix interface, returning values from the transpose of the matrix within.
40type Transpose struct {
41	Matrix Matrix
42}
43
44// At returns the value of the element at row i and column j of the transposed
45// matrix, that is, row j and column i of the Matrix field.
46func (t Transpose) At(i, j int) float64 {
47	return t.Matrix.At(j, i)
48}
49
50// Dims returns the dimensions of the transposed matrix. The number of rows returned
51// is the number of columns in the Matrix field, and the number of columns is
52// the number of rows in the Matrix field.
53func (t Transpose) Dims() (r, c int) {
54	c, r = t.Matrix.Dims()
55	return r, c
56}
57
58// T performs an implicit transpose by returning the Matrix field.
59func (t Transpose) T() Matrix {
60	return t.Matrix
61}
62
63// Untranspose returns the Matrix field.
64func (t Transpose) Untranspose() Matrix {
65	return t.Matrix
66}
67
68// Untransposer is a type that can undo an implicit transpose.
69type Untransposer interface {
70	// Note: This interface is needed to unify all of the Transpose types. In
71	// the mat methods, we need to test if the Matrix has been implicitly
72	// transposed. If this is checked by testing for the specific Transpose type
73	// then the behavior will be different if the user uses T() or TTri() for a
74	// triangular matrix.
75
76	// Untranspose returns the underlying Matrix stored for the implicit transpose.
77	Untranspose() Matrix
78}
79
80// UntransposeBander is a type that can undo an implicit band transpose.
81type UntransposeBander interface {
82	// Untranspose returns the underlying Banded stored for the implicit transpose.
83	UntransposeBand() Banded
84}
85
86// UntransposeTrier is a type that can undo an implicit triangular transpose.
87type UntransposeTrier interface {
88	// Untranspose returns the underlying Triangular stored for the implicit transpose.
89	UntransposeTri() Triangular
90}
91
92// UntransposeTriBander is a type that can undo an implicit triangular banded
93// transpose.
94type UntransposeTriBander interface {
95	// Untranspose returns the underlying Triangular stored for the implicit transpose.
96	UntransposeTriBand() TriBanded
97}
98
99// Mutable is a matrix interface type that allows elements to be altered.
100type Mutable interface {
101	// Set alters the matrix element at row i, column j to v.
102	// It will panic if i or j are out of bounds for the matrix.
103	Set(i, j int, v float64)
104
105	Matrix
106}
107
108// A RowViewer can return a Vector reflecting a row that is backed by the matrix
109// data. The Vector returned will have length equal to the number of columns.
110type RowViewer interface {
111	RowView(i int) Vector
112}
113
114// A RawRowViewer can return a slice of float64 reflecting a row that is backed by the matrix
115// data.
116type RawRowViewer interface {
117	RawRowView(i int) []float64
118}
119
120// A ColViewer can return a Vector reflecting a column that is backed by the matrix
121// data. The Vector returned will have length equal to the number of rows.
122type ColViewer interface {
123	ColView(j int) Vector
124}
125
126// A RawColViewer can return a slice of float64 reflecting a column that is backed by the matrix
127// data.
128type RawColViewer interface {
129	RawColView(j int) []float64
130}
131
132// A Cloner can make a copy of a into the receiver, overwriting the previous value of the
133// receiver. The clone operation does not make any restriction on shape and will not cause
134// shadowing.
135type Cloner interface {
136	Clone(a Matrix)
137}
138
139// A Reseter can reset the matrix so that it can be reused as the receiver of a dimensionally
140// restricted operation. This is commonly used when the matrix is being used as a workspace
141// or temporary matrix.
142//
143// If the matrix is a view, using the reset matrix may result in data corruption in elements
144// outside the view.
145type Reseter interface {
146	Reset()
147}
148
149// A Copier can make a copy of elements of a into the receiver. The submatrix copied
150// starts at row and column 0 and has dimensions equal to the minimum dimensions of
151// the two matrices. The number of row and columns copied is returned.
152// Copy will copy from a source that aliases the receiver unless the source is transposed;
153// an aliasing transpose copy will panic with the exception for a special case when
154// the source data has a unitary increment or stride.
155type Copier interface {
156	Copy(a Matrix) (r, c int)
157}
158
159// A Grower can grow the size of the represented matrix by the given number of rows and columns.
160// Growing beyond the size given by the Caps method will result in the allocation of a new
161// matrix and copying of the elements. If Grow is called with negative increments it will
162// panic with ErrIndexOutOfRange.
163type Grower interface {
164	Caps() (r, c int)
165	Grow(r, c int) Matrix
166}
167
168// A BandWidther represents a banded matrix and can return the left and right half-bandwidths, k1 and
169// k2.
170type BandWidther interface {
171	BandWidth() (k1, k2 int)
172}
173
174// A RawMatrixSetter can set the underlying blas64.General used by the receiver. There is no restriction
175// on the shape of the receiver. Changes to the receiver's elements will be reflected in the blas64.General.Data.
176type RawMatrixSetter interface {
177	SetRawMatrix(a blas64.General)
178}
179
180// A RawMatrixer can return a blas64.General representation of the receiver. Changes to the blas64.General.Data
181// slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
182type RawMatrixer interface {
183	RawMatrix() blas64.General
184}
185
186// A RawVectorer can return a blas64.Vector representation of the receiver. Changes to the blas64.Vector.Data
187// slice will be reflected in the original matrix, changes to the Inc field will not.
188type RawVectorer interface {
189	RawVector() blas64.Vector
190}
191
192// A NonZeroDoer can call a function for each non-zero element of the receiver.
193// The parameters of the function are the element indices and its value.
194type NonZeroDoer interface {
195	DoNonZero(func(i, j int, v float64))
196}
197
198// A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver.
199// The parameters of the function are the element indices and its value.
200type RowNonZeroDoer interface {
201	DoRowNonZero(i int, fn func(i, j int, v float64))
202}
203
204// A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver.
205// The parameters of the function are the element indices and its value.
206type ColNonZeroDoer interface {
207	DoColNonZero(j int, fn func(i, j int, v float64))
208}
209
210// untranspose untransposes a matrix if applicable. If a is an Untransposer, then
211// untranspose returns the underlying matrix and true. If it is not, then it returns
212// the input matrix and false.
213func untranspose(a Matrix) (Matrix, bool) {
214	if ut, ok := a.(Untransposer); ok {
215		return ut.Untranspose(), true
216	}
217	return a, false
218}
219
220// untransposeExtract returns an untransposed matrix in a built-in matrix type.
221//
222// The untransposed matrix is returned unaltered if it is a built-in matrix type.
223// Otherwise, if it implements a Raw method, an appropriate built-in type value
224// is returned holding the raw matrix value of the input. If neither of these
225// is possible, the untransposed matrix is returned.
226func untransposeExtract(a Matrix) (Matrix, bool) {
227	ut, trans := untranspose(a)
228	switch m := ut.(type) {
229	case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense:
230		return m, trans
231	// TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
232	case RawSymBander:
233		rsb := m.RawSymBand()
234		if rsb.Uplo != blas.Upper {
235			return ut, trans
236		}
237		var sb SymBandDense
238		sb.SetRawSymBand(rsb)
239		return &sb, trans
240	case RawTriBander:
241		rtb := m.RawTriBand()
242		if rtb.Diag == blas.Unit {
243			return ut, trans
244		}
245		var tb TriBandDense
246		tb.SetRawTriBand(rtb)
247		return &tb, trans
248	case RawBander:
249		var b BandDense
250		b.SetRawBand(m.RawBand())
251		return &b, trans
252	case RawTriangular:
253		rt := m.RawTriangular()
254		if rt.Diag == blas.Unit {
255			return ut, trans
256		}
257		var t TriDense
258		t.SetRawTriangular(rt)
259		return &t, trans
260	case RawSymmetricer:
261		rs := m.RawSymmetric()
262		if rs.Uplo != blas.Upper {
263			return ut, trans
264		}
265		var s SymDense
266		s.SetRawSymmetric(rs)
267		return &s, trans
268	case RawMatrixer:
269		var d Dense
270		d.SetRawMatrix(m.RawMatrix())
271		return &d, trans
272	default:
273		return ut, trans
274	}
275}
276
277// TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
278// TODO(btracey): Add in fast paths to Row/Col for the other concrete types
279// (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
280
281// Col copies the elements in the jth column of the matrix into the slice dst.
282// The length of the provided slice must equal the number of rows, unless the
283// slice is nil in which case a new slice is first allocated.
284func Col(dst []float64, j int, a Matrix) []float64 {
285	r, c := a.Dims()
286	if j < 0 || j >= c {
287		panic(ErrColAccess)
288	}
289	if dst == nil {
290		dst = make([]float64, r)
291	} else {
292		if len(dst) != r {
293			panic(ErrColLength)
294		}
295	}
296	aU, aTrans := untranspose(a)
297	if rm, ok := aU.(RawMatrixer); ok {
298		m := rm.RawMatrix()
299		if aTrans {
300			copy(dst, m.Data[j*m.Stride:j*m.Stride+m.Cols])
301			return dst
302		}
303		blas64.Copy(blas64.Vector{N: r, Inc: m.Stride, Data: m.Data[j:]},
304			blas64.Vector{N: r, Inc: 1, Data: dst},
305		)
306		return dst
307	}
308	for i := 0; i < r; i++ {
309		dst[i] = a.At(i, j)
310	}
311	return dst
312}
313
314// Row copies the elements in the ith row of the matrix into the slice dst.
315// The length of the provided slice must equal the number of columns, unless the
316// slice is nil in which case a new slice is first allocated.
317func Row(dst []float64, i int, a Matrix) []float64 {
318	r, c := a.Dims()
319	if i < 0 || i >= r {
320		panic(ErrColAccess)
321	}
322	if dst == nil {
323		dst = make([]float64, c)
324	} else {
325		if len(dst) != c {
326			panic(ErrRowLength)
327		}
328	}
329	aU, aTrans := untranspose(a)
330	if rm, ok := aU.(RawMatrixer); ok {
331		m := rm.RawMatrix()
332		if aTrans {
333			blas64.Copy(blas64.Vector{N: c, Inc: m.Stride, Data: m.Data[i:]},
334				blas64.Vector{N: c, Inc: 1, Data: dst},
335			)
336			return dst
337		}
338		copy(dst, m.Data[i*m.Stride:i*m.Stride+m.Cols])
339		return dst
340	}
341	for j := 0; j < c; j++ {
342		dst[j] = a.At(i, j)
343	}
344	return dst
345}
346
347// Cond returns the condition number of the given matrix under the given norm.
348// The condition number must be based on the 1-norm, 2-norm or ∞-norm.
349// Cond will panic with matrix.ErrShape if the matrix has zero size.
350//
351// BUG(btracey): The computation of the 1-norm and ∞-norm for non-square matrices
352// is innacurate, although is typically the right order of magnitude. See
353// https://github.com/xianyi/OpenBLAS/issues/636. While the value returned will
354// change with the resolution of this bug, the result from Cond will match the
355// condition number used internally.
356func Cond(a Matrix, norm float64) float64 {
357	m, n := a.Dims()
358	if m == 0 || n == 0 {
359		panic(ErrShape)
360	}
361	var lnorm lapack.MatrixNorm
362	switch norm {
363	default:
364		panic("mat: bad norm value")
365	case 1:
366		lnorm = lapack.MaxColumnSum
367	case 2:
368		var svd SVD
369		ok := svd.Factorize(a, SVDNone)
370		if !ok {
371			return math.Inf(1)
372		}
373		return svd.Cond()
374	case math.Inf(1):
375		lnorm = lapack.MaxRowSum
376	}
377
378	if m == n {
379		// Use the LU decomposition to compute the condition number.
380		var lu LU
381		lu.factorize(a, lnorm)
382		return lu.Cond()
383	}
384	if m > n {
385		// Use the QR factorization to compute the condition number.
386		var qr QR
387		qr.factorize(a, lnorm)
388		return qr.Cond()
389	}
390	// Use the LQ factorization to compute the condition number.
391	var lq LQ
392	lq.factorize(a, lnorm)
393	return lq.Cond()
394}
395
396// Det returns the determinant of the matrix a. In many expressions using LogDet
397// will be more numerically stable.
398func Det(a Matrix) float64 {
399	det, sign := LogDet(a)
400	return math.Exp(det) * sign
401}
402
403// Dot returns the sum of the element-wise product of a and b.
404// Dot panics if the matrix sizes are unequal.
405func Dot(a, b Vector) float64 {
406	la := a.Len()
407	lb := b.Len()
408	if la != lb {
409		panic(ErrShape)
410	}
411	if arv, ok := a.(RawVectorer); ok {
412		if brv, ok := b.(RawVectorer); ok {
413			return blas64.Dot(arv.RawVector(), brv.RawVector())
414		}
415	}
416	var sum float64
417	for i := 0; i < la; i++ {
418		sum += a.At(i, 0) * b.At(i, 0)
419	}
420	return sum
421}
422
423// Equal returns whether the matrices a and b have the same size
424// and are element-wise equal.
425func Equal(a, b Matrix) bool {
426	ar, ac := a.Dims()
427	br, bc := b.Dims()
428	if ar != br || ac != bc {
429		return false
430	}
431	aU, aTrans := untranspose(a)
432	bU, bTrans := untranspose(b)
433	if rma, ok := aU.(RawMatrixer); ok {
434		if rmb, ok := bU.(RawMatrixer); ok {
435			ra := rma.RawMatrix()
436			rb := rmb.RawMatrix()
437			if aTrans == bTrans {
438				for i := 0; i < ra.Rows; i++ {
439					for j := 0; j < ra.Cols; j++ {
440						if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
441							return false
442						}
443					}
444				}
445				return true
446			}
447			for i := 0; i < ra.Rows; i++ {
448				for j := 0; j < ra.Cols; j++ {
449					if ra.Data[i*ra.Stride+j] != rb.Data[j*rb.Stride+i] {
450						return false
451					}
452				}
453			}
454			return true
455		}
456	}
457	if rma, ok := aU.(RawSymmetricer); ok {
458		if rmb, ok := bU.(RawSymmetricer); ok {
459			ra := rma.RawSymmetric()
460			rb := rmb.RawSymmetric()
461			// Symmetric matrices are always upper and equal to their transpose.
462			for i := 0; i < ra.N; i++ {
463				for j := i; j < ra.N; j++ {
464					if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
465						return false
466					}
467				}
468			}
469			return true
470		}
471	}
472	if ra, ok := aU.(*VecDense); ok {
473		if rb, ok := bU.(*VecDense); ok {
474			// If the raw vectors are the same length they must either both be
475			// transposed or both not transposed (or have length 1).
476			for i := 0; i < ra.mat.N; i++ {
477				if ra.mat.Data[i*ra.mat.Inc] != rb.mat.Data[i*rb.mat.Inc] {
478					return false
479				}
480			}
481			return true
482		}
483	}
484	for i := 0; i < ar; i++ {
485		for j := 0; j < ac; j++ {
486			if a.At(i, j) != b.At(i, j) {
487				return false
488			}
489		}
490	}
491	return true
492}
493
494// EqualApprox returns whether the matrices a and b have the same size and contain all equal
495// elements with tolerance for element-wise equality specified by epsilon. Matrices
496// with non-equal shapes are not equal.
497func EqualApprox(a, b Matrix, epsilon float64) bool {
498	ar, ac := a.Dims()
499	br, bc := b.Dims()
500	if ar != br || ac != bc {
501		return false
502	}
503	aU, aTrans := untranspose(a)
504	bU, bTrans := untranspose(b)
505	if rma, ok := aU.(RawMatrixer); ok {
506		if rmb, ok := bU.(RawMatrixer); ok {
507			ra := rma.RawMatrix()
508			rb := rmb.RawMatrix()
509			if aTrans == bTrans {
510				for i := 0; i < ra.Rows; i++ {
511					for j := 0; j < ra.Cols; j++ {
512						if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
513							return false
514						}
515					}
516				}
517				return true
518			}
519			for i := 0; i < ra.Rows; i++ {
520				for j := 0; j < ra.Cols; j++ {
521					if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[j*rb.Stride+i], epsilon, epsilon) {
522						return false
523					}
524				}
525			}
526			return true
527		}
528	}
529	if rma, ok := aU.(RawSymmetricer); ok {
530		if rmb, ok := bU.(RawSymmetricer); ok {
531			ra := rma.RawSymmetric()
532			rb := rmb.RawSymmetric()
533			// Symmetric matrices are always upper and equal to their transpose.
534			for i := 0; i < ra.N; i++ {
535				for j := i; j < ra.N; j++ {
536					if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
537						return false
538					}
539				}
540			}
541			return true
542		}
543	}
544	if ra, ok := aU.(*VecDense); ok {
545		if rb, ok := bU.(*VecDense); ok {
546			// If the raw vectors are the same length they must either both be
547			// transposed or both not transposed (or have length 1).
548			for i := 0; i < ra.mat.N; i++ {
549				if !floats.EqualWithinAbsOrRel(ra.mat.Data[i*ra.mat.Inc], rb.mat.Data[i*rb.mat.Inc], epsilon, epsilon) {
550					return false
551				}
552			}
553			return true
554		}
555	}
556	for i := 0; i < ar; i++ {
557		for j := 0; j < ac; j++ {
558			if !floats.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) {
559				return false
560			}
561		}
562	}
563	return true
564}
565
566// LogDet returns the log of the determinant and the sign of the determinant
567// for the matrix that has been factorized. Numerical stability in product and
568// division expressions is generally improved by working in log space.
569func LogDet(a Matrix) (det float64, sign float64) {
570	// TODO(btracey): Add specialized routines for TriDense, etc.
571	var lu LU
572	lu.Factorize(a)
573	return lu.LogDet()
574}
575
576// Max returns the largest element value of the matrix A.
577// Max will panic with matrix.ErrShape if the matrix has zero size.
578func Max(a Matrix) float64 {
579	r, c := a.Dims()
580	if r == 0 || c == 0 {
581		panic(ErrShape)
582	}
583	// Max(A) = Max(A^T)
584	aU, _ := untranspose(a)
585	switch m := aU.(type) {
586	case RawMatrixer:
587		rm := m.RawMatrix()
588		max := math.Inf(-1)
589		for i := 0; i < rm.Rows; i++ {
590			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
591				if v > max {
592					max = v
593				}
594			}
595		}
596		return max
597	case RawTriangular:
598		rm := m.RawTriangular()
599		// The max of a triangular is at least 0 unless the size is 1.
600		if rm.N == 1 {
601			return rm.Data[0]
602		}
603		max := 0.0
604		if rm.Uplo == blas.Upper {
605			for i := 0; i < rm.N; i++ {
606				for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
607					if v > max {
608						max = v
609					}
610				}
611			}
612			return max
613		}
614		for i := 0; i < rm.N; i++ {
615			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
616				if v > max {
617					max = v
618				}
619			}
620		}
621		return max
622	case RawSymmetricer:
623		rm := m.RawSymmetric()
624		if rm.Uplo != blas.Upper {
625			panic(badSymTriangle)
626		}
627		max := math.Inf(-1)
628		for i := 0; i < rm.N; i++ {
629			for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
630				if v > max {
631					max = v
632				}
633			}
634		}
635		return max
636	default:
637		r, c := aU.Dims()
638		max := math.Inf(-1)
639		for i := 0; i < r; i++ {
640			for j := 0; j < c; j++ {
641				v := aU.At(i, j)
642				if v > max {
643					max = v
644				}
645			}
646		}
647		return max
648	}
649}
650
651// Min returns the smallest element value of the matrix A.
652// Min will panic with matrix.ErrShape if the matrix has zero size.
653func Min(a Matrix) float64 {
654	r, c := a.Dims()
655	if r == 0 || c == 0 {
656		panic(ErrShape)
657	}
658	// Min(A) = Min(A^T)
659	aU, _ := untranspose(a)
660	switch m := aU.(type) {
661	case RawMatrixer:
662		rm := m.RawMatrix()
663		min := math.Inf(1)
664		for i := 0; i < rm.Rows; i++ {
665			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
666				if v < min {
667					min = v
668				}
669			}
670		}
671		return min
672	case RawTriangular:
673		rm := m.RawTriangular()
674		// The min of a triangular is at most 0 unless the size is 1.
675		if rm.N == 1 {
676			return rm.Data[0]
677		}
678		min := 0.0
679		if rm.Uplo == blas.Upper {
680			for i := 0; i < rm.N; i++ {
681				for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
682					if v < min {
683						min = v
684					}
685				}
686			}
687			return min
688		}
689		for i := 0; i < rm.N; i++ {
690			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
691				if v < min {
692					min = v
693				}
694			}
695		}
696		return min
697	case RawSymmetricer:
698		rm := m.RawSymmetric()
699		if rm.Uplo != blas.Upper {
700			panic(badSymTriangle)
701		}
702		min := math.Inf(1)
703		for i := 0; i < rm.N; i++ {
704			for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
705				if v < min {
706					min = v
707				}
708			}
709		}
710		return min
711	default:
712		r, c := aU.Dims()
713		min := math.Inf(1)
714		for i := 0; i < r; i++ {
715			for j := 0; j < c; j++ {
716				v := aU.At(i, j)
717				if v < min {
718					min = v
719				}
720			}
721		}
722		return min
723	}
724}
725
726// Norm returns the specified (induced) norm of the matrix a. See
727// https://en.wikipedia.org/wiki/Matrix_norm for the definition of an induced norm.
728//
729// Valid norms are:
730//    1 - The maximum absolute column sum
731//    2 - Frobenius norm, the square root of the sum of the squares of the elements.
732//  Inf - The maximum absolute row sum.
733// Norm will panic with ErrNormOrder if an illegal norm order is specified and
734// with matrix.ErrShape if the matrix has zero size.
735func Norm(a Matrix, norm float64) float64 {
736	r, c := a.Dims()
737	if r == 0 || c == 0 {
738		panic(ErrShape)
739	}
740	aU, aTrans := untranspose(a)
741	var work []float64
742	switch rma := aU.(type) {
743	case RawMatrixer:
744		rm := rma.RawMatrix()
745		n := normLapack(norm, aTrans)
746		if n == lapack.MaxColumnSum {
747			work = getFloats(rm.Cols, false)
748			defer putFloats(work)
749		}
750		return lapack64.Lange(n, rm, work)
751	case RawTriangular:
752		rm := rma.RawTriangular()
753		n := normLapack(norm, aTrans)
754		if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
755			work = getFloats(rm.N, false)
756			defer putFloats(work)
757		}
758		return lapack64.Lantr(n, rm, work)
759	case RawSymmetricer:
760		rm := rma.RawSymmetric()
761		n := normLapack(norm, aTrans)
762		if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
763			work = getFloats(rm.N, false)
764			defer putFloats(work)
765		}
766		return lapack64.Lansy(n, rm, work)
767	case *VecDense:
768		rv := rma.RawVector()
769		switch norm {
770		default:
771			panic("unreachable")
772		case 1:
773			if aTrans {
774				imax := blas64.Iamax(rv)
775				return math.Abs(rma.At(imax, 0))
776			}
777			return blas64.Asum(rv)
778		case 2:
779			return blas64.Nrm2(rv)
780		case math.Inf(1):
781			if aTrans {
782				return blas64.Asum(rv)
783			}
784			imax := blas64.Iamax(rv)
785			return math.Abs(rma.At(imax, 0))
786		}
787	}
788	switch norm {
789	default:
790		panic("unreachable")
791	case 1:
792		var max float64
793		for j := 0; j < c; j++ {
794			var sum float64
795			for i := 0; i < r; i++ {
796				sum += math.Abs(a.At(i, j))
797			}
798			if sum > max {
799				max = sum
800			}
801		}
802		return max
803	case 2:
804		var sum float64
805		for i := 0; i < r; i++ {
806			for j := 0; j < c; j++ {
807				v := a.At(i, j)
808				sum += v * v
809			}
810		}
811		return math.Sqrt(sum)
812	case math.Inf(1):
813		var max float64
814		for i := 0; i < r; i++ {
815			var sum float64
816			for j := 0; j < c; j++ {
817				sum += math.Abs(a.At(i, j))
818			}
819			if sum > max {
820				max = sum
821			}
822		}
823		return max
824	}
825}
826
827// normLapack converts the float64 norm input in Norm to a lapack.MatrixNorm.
828func normLapack(norm float64, aTrans bool) lapack.MatrixNorm {
829	switch norm {
830	case 1:
831		n := lapack.MaxColumnSum
832		if aTrans {
833			n = lapack.MaxRowSum
834		}
835		return n
836	case 2:
837		return lapack.Frobenius
838	case math.Inf(1):
839		n := lapack.MaxRowSum
840		if aTrans {
841			n = lapack.MaxColumnSum
842		}
843		return n
844	default:
845		panic(ErrNormOrder)
846	}
847}
848
849// Sum returns the sum of the elements of the matrix.
850func Sum(a Matrix) float64 {
851	// TODO(btracey): Add a fast path for the other supported matrix types.
852
853	r, c := a.Dims()
854	var sum float64
855	aU, _ := untranspose(a)
856	if rma, ok := aU.(RawMatrixer); ok {
857		rm := rma.RawMatrix()
858		for i := 0; i < rm.Rows; i++ {
859			for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
860				sum += v
861			}
862		}
863		return sum
864	}
865	for i := 0; i < r; i++ {
866		for j := 0; j < c; j++ {
867			sum += a.At(i, j)
868		}
869	}
870	return sum
871}
872
873// A Tracer can compute the trace of the matrix. Trace must panic if the
874// matrix is not square.
875type Tracer interface {
876	Trace() float64
877}
878
879// Trace returns the trace of the matrix. Trace will panic if the
880// matrix is not square.
881func Trace(a Matrix) float64 {
882	m, _ := untransposeExtract(a)
883	if t, ok := m.(Tracer); ok {
884		return t.Trace()
885	}
886	r, c := a.Dims()
887	if r != c {
888		panic(ErrSquare)
889	}
890	var v float64
891	for i := 0; i < r; i++ {
892		v += a.At(i, i)
893	}
894	return v
895}
896
897func min(a, b int) int {
898	if a < b {
899		return a
900	}
901	return b
902}
903
904func max(a, b int) int {
905	if a > b {
906		return a
907	}
908	return b
909}
910
911// use returns a float64 slice with l elements, using f if it
912// has the necessary capacity, otherwise creating a new slice.
913func use(f []float64, l int) []float64 {
914	if l <= cap(f) {
915		return f[:l]
916	}
917	return make([]float64, l)
918}
919
920// useZeroed returns a float64 slice with l elements, using f if it
921// has the necessary capacity, otherwise creating a new slice. The
922// elements of the returned slice are guaranteed to be zero.
923func useZeroed(f []float64, l int) []float64 {
924	if l <= cap(f) {
925		f = f[:l]
926		zero(f)
927		return f
928	}
929	return make([]float64, l)
930}
931
932// zero zeros the given slice's elements.
933func zero(f []float64) {
934	for i := range f {
935		f[i] = 0
936	}
937}
938
939// useInt returns an int slice with l elements, using i if it
940// has the necessary capacity, otherwise creating a new slice.
941func useInt(i []int, l int) []int {
942	if l <= cap(i) {
943		return i[:l]
944	}
945	return make([]int, l)
946}
947