1// Copyright ©2013 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import (
8	"math"
9	"math/cmplx"
10
11	"gonum.org/v1/gonum/blas/cblas128"
12	"gonum.org/v1/gonum/floats/scalar"
13)
14
15// CMatrix is the basic matrix interface type for complex matrices.
16type CMatrix interface {
17	// Dims returns the dimensions of a CMatrix.
18	Dims() (r, c int)
19
20	// At returns the value of a matrix element at row i, column j.
21	// It will panic if i or j are out of bounds for the matrix.
22	At(i, j int) complex128
23
24	// H returns the conjugate transpose of the CMatrix. Whether H
25	// returns a copy of the underlying data is implementation dependent.
26	// This method may be implemented using the ConjTranspose type, which
27	// provides an implicit matrix conjugate transpose.
28	H() CMatrix
29
30	// T returns the transpose of the CMatrix. Whether T returns a copy of the
31	// underlying data is implementation dependent.
32	// This method may be implemented using the CTranspose type, which
33	// provides an implicit matrix transpose.
34	T() CMatrix
35}
36
37// A RawCMatrixer can return a cblas128.General representation of the receiver. Changes to the cblas128.General.Data
38// slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
39type RawCMatrixer interface {
40	RawCMatrix() cblas128.General
41}
42
43var (
44	_ CMatrix          = ConjTranspose{}
45	_ UnConjTransposer = ConjTranspose{}
46)
47
48// ConjTranspose is a type for performing an implicit matrix conjugate transpose.
49// It implements the CMatrix interface, returning values from the conjugate
50// transpose of the matrix within.
51type ConjTranspose struct {
52	CMatrix CMatrix
53}
54
55// At returns the value of the element at row i and column j of the conjugate
56// transposed matrix, that is, row j and column i of the CMatrix field.
57func (t ConjTranspose) At(i, j int) complex128 {
58	z := t.CMatrix.At(j, i)
59	return cmplx.Conj(z)
60}
61
62// Dims returns the dimensions of the transposed matrix. The number of rows returned
63// is the number of columns in the CMatrix field, and the number of columns is
64// the number of rows in the CMatrix field.
65func (t ConjTranspose) Dims() (r, c int) {
66	c, r = t.CMatrix.Dims()
67	return r, c
68}
69
70// H performs an implicit conjugate transpose by returning the CMatrix field.
71func (t ConjTranspose) H() CMatrix {
72	return t.CMatrix
73}
74
75// T performs an implicit transpose by returning the receiver inside a
76// CTranspose.
77func (t ConjTranspose) T() CMatrix {
78	return CTranspose{t}
79}
80
81// UnConjTranspose returns the CMatrix field.
82func (t ConjTranspose) UnConjTranspose() CMatrix {
83	return t.CMatrix
84}
85
86// CTranspose is a type for performing an implicit matrix conjugate transpose.
87// It implements the CMatrix interface, returning values from the conjugate
88// transpose of the matrix within.
89type CTranspose struct {
90	CMatrix CMatrix
91}
92
93// At returns the value of the element at row i and column j of the conjugate
94// transposed matrix, that is, row j and column i of the CMatrix field.
95func (t CTranspose) At(i, j int) complex128 {
96	return t.CMatrix.At(j, i)
97}
98
99// Dims returns the dimensions of the transposed matrix. The number of rows returned
100// is the number of columns in the CMatrix field, and the number of columns is
101// the number of rows in the CMatrix field.
102func (t CTranspose) Dims() (r, c int) {
103	c, r = t.CMatrix.Dims()
104	return r, c
105}
106
107// H performs an implicit transpose by returning the receiver inside a
108// ConjTranspose.
109func (t CTranspose) H() CMatrix {
110	return ConjTranspose{t}
111}
112
113// T performs an implicit conjugate transpose by returning the CMatrix field.
114func (t CTranspose) T() CMatrix {
115	return t.CMatrix
116}
117
118// Untranspose returns the CMatrix field.
119func (t CTranspose) Untranspose() CMatrix {
120	return t.CMatrix
121}
122
123// UnConjTransposer is a type that can undo an implicit conjugate transpose.
124type UnConjTransposer interface {
125	// UnConjTranspose returns the underlying CMatrix stored for the implicit
126	// conjugate transpose.
127	UnConjTranspose() CMatrix
128
129	// Note: This interface is needed to unify all of the Conjugate types. In
130	// the cmat128 methods, we need to test if the CMatrix has been implicitly
131	// transposed. If this is checked by testing for the specific Conjugate type
132	// then the behavior will be different if the user uses H() or HTri() for a
133	// triangular matrix.
134}
135
136// CUntransposer is a type that can undo an implicit transpose.
137type CUntransposer interface {
138	// Untranspose returns the underlying CMatrix stored for the implicit
139	// transpose.
140	Untranspose() CMatrix
141
142	// Note: This interface is needed to unify all of the CTranspose types. In
143	// the cmat128 methods, we need to test if the CMatrix has been implicitly
144	// transposed. If this is checked by testing for the specific CTranspose type
145	// then the behavior will be different if the user uses T() or TTri() for a
146	// triangular matrix.
147}
148
149// useC returns a complex128 slice with l elements, using c if it
150// has the necessary capacity, otherwise creating a new slice.
151func useC(c []complex128, l int) []complex128 {
152	if l <= cap(c) {
153		return c[:l]
154	}
155	return make([]complex128, l)
156}
157
158// useZeroedC returns a complex128 slice with l elements, using c if it
159// has the necessary capacity, otherwise creating a new slice. The
160// elements of the returned slice are guaranteed to be zero.
161func useZeroedC(c []complex128, l int) []complex128 {
162	if l <= cap(c) {
163		c = c[:l]
164		zeroC(c)
165		return c
166	}
167	return make([]complex128, l)
168}
169
170// zeroC zeros the given slice's elements.
171func zeroC(c []complex128) {
172	for i := range c {
173		c[i] = 0
174	}
175}
176
177// untransposeCmplx untransposes a matrix if applicable. If a is an CUntransposer
178// or an UnConjTransposer, then untranspose returns the underlying matrix and true for
179// the kind of transpose (potentially both).
180// If it is not, then it returns the input matrix and false for trans and conj.
181func untransposeCmplx(a CMatrix) (u CMatrix, trans, conj bool) {
182	switch ut := a.(type) {
183	case CUntransposer:
184		trans = true
185		u := ut.Untranspose()
186		if uc, ok := u.(UnConjTransposer); ok {
187			return uc.UnConjTranspose(), trans, true
188		}
189		return u, trans, false
190	case UnConjTransposer:
191		conj = true
192		u := ut.UnConjTranspose()
193		if ut, ok := u.(CUntransposer); ok {
194			return ut.Untranspose(), true, conj
195		}
196		return u, false, conj
197	default:
198		return a, false, false
199	}
200}
201
202// untransposeExtractCmplx returns an untransposed matrix in a built-in matrix type.
203//
204// The untransposed matrix is returned unaltered if it is a built-in matrix type.
205// Otherwise, if it implements a Raw method, an appropriate built-in type value
206// is returned holding the raw matrix value of the input. If neither of these
207// is possible, the untransposed matrix is returned.
208func untransposeExtractCmplx(a CMatrix) (u CMatrix, trans, conj bool) {
209	ut, trans, conj := untransposeCmplx(a)
210	switch m := ut.(type) {
211	case *CDense:
212		return m, trans, conj
213	case RawCMatrixer:
214		var d CDense
215		d.SetRawCMatrix(m.RawCMatrix())
216		return &d, trans, conj
217	default:
218		return ut, trans, conj
219	}
220}
221
222// CEqual returns whether the matrices a and b have the same size
223// and are element-wise equal.
224func CEqual(a, b CMatrix) bool {
225	ar, ac := a.Dims()
226	br, bc := b.Dims()
227	if ar != br || ac != bc {
228		return false
229	}
230	// TODO(btracey): Add in fast-paths.
231	for i := 0; i < ar; i++ {
232		for j := 0; j < ac; j++ {
233			if a.At(i, j) != b.At(i, j) {
234				return false
235			}
236		}
237	}
238	return true
239}
240
241// CEqualApprox returns whether the matrices a and b have the same size and contain all equal
242// elements with tolerance for element-wise equality specified by epsilon. Matrices
243// with non-equal shapes are not equal.
244func CEqualApprox(a, b CMatrix, epsilon float64) bool {
245	// TODO(btracey):
246	ar, ac := a.Dims()
247	br, bc := b.Dims()
248	if ar != br || ac != bc {
249		return false
250	}
251	for i := 0; i < ar; i++ {
252		for j := 0; j < ac; j++ {
253			if !cEqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) {
254				return false
255			}
256		}
257	}
258	return true
259}
260
261// TODO(btracey): Move these into a cmplxs if/when we have one.
262
263func cEqualWithinAbsOrRel(a, b complex128, absTol, relTol float64) bool {
264	if cEqualWithinAbs(a, b, absTol) {
265		return true
266	}
267	return cEqualWithinRel(a, b, relTol)
268}
269
270// cEqualWithinAbs returns true if a and b have an absolute
271// difference of less than tol.
272func cEqualWithinAbs(a, b complex128, tol float64) bool {
273	return a == b || cmplx.Abs(a-b) <= tol
274}
275
276const minNormalFloat64 = 2.2250738585072014e-308
277
278// cEqualWithinRel returns true if the difference between a and b
279// is not greater than tol times the greater value.
280func cEqualWithinRel(a, b complex128, tol float64) bool {
281	if a == b {
282		return true
283	}
284	if cmplx.IsNaN(a) || cmplx.IsNaN(b) {
285		return false
286	}
287	// Cannot play the same trick as in floats/scalar because there are multiple
288	// possible infinities.
289	if cmplx.IsInf(a) {
290		if !cmplx.IsInf(b) {
291			return false
292		}
293		ra := real(a)
294		if math.IsInf(ra, 0) {
295			if ra == real(b) {
296				return scalar.EqualWithinRel(imag(a), imag(b), tol)
297			}
298			return false
299		}
300		if imag(a) == imag(b) {
301			return scalar.EqualWithinRel(ra, real(b), tol)
302		}
303		return false
304	}
305	if cmplx.IsInf(b) {
306		return false
307	}
308
309	delta := cmplx.Abs(a - b)
310	if delta <= minNormalFloat64 {
311		return delta <= tol*minNormalFloat64
312	}
313	return delta/math.Max(cmplx.Abs(a), cmplx.Abs(b)) <= tol
314}
315