1// Copyright ©2013 The Gonum Authors. All rights reserved.
2// Use of this code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package cmplxs
6
7import (
8	"errors"
9	"math"
10	"math/cmplx"
11
12	"gonum.org/v1/gonum/cmplxs/cscalar"
13	"gonum.org/v1/gonum/internal/asm/c128"
14)
15
16const (
17	zeroLength   = "cmplxs: zero length slice"
18	shortSpan    = "cmplxs: slice length less than 2"
19	badLength    = "cmplxs: slice lengths do not match"
20	badDstLength = "cmplxs: destination slice length does not match input"
21)
22
23// Abs calculates the absolute values of the elements of s, and stores them in dst.
24// It panics if the argument lengths do not match.
25func Abs(dst []float64, s []complex128) {
26	if len(dst) != len(s) {
27		panic(badDstLength)
28	}
29	for i, v := range s {
30		dst[i] = cmplx.Abs(v)
31	}
32}
33
34// Add adds, element-wise, the elements of s and dst, and stores the result in dst.
35// It panics if the argument lengths do not match.
36func Add(dst, s []complex128) {
37	if len(dst) != len(s) {
38		panic(badLength)
39	}
40	c128.AxpyUnitaryTo(dst, 1, s, dst)
41}
42
43// AddTo adds, element-wise, the elements of s and t and
44// stores the result in dst.
45// It panics if the argument lengths do not match.
46func AddTo(dst, s, t []complex128) []complex128 {
47	if len(s) != len(t) {
48		panic(badLength)
49	}
50	if len(dst) != len(s) {
51		panic(badDstLength)
52	}
53	c128.AxpyUnitaryTo(dst, 1, s, t)
54	return dst
55}
56
57// AddConst adds the scalar c to all of the values in dst.
58func AddConst(c complex128, dst []complex128) {
59	c128.AddConst(c, dst)
60}
61
62// AddScaled performs dst = dst + alpha * s.
63// It panics if the slice argument lengths do not match.
64func AddScaled(dst []complex128, alpha complex128, s []complex128) {
65	if len(dst) != len(s) {
66		panic(badLength)
67	}
68	c128.AxpyUnitaryTo(dst, alpha, s, dst)
69}
70
71// AddScaledTo performs dst = y + alpha * s, where alpha is a scalar,
72// and dst, y and s are all slices.
73// It panics if the slice argument lengths do not match.
74//
75// At the return of the function, dst[i] = y[i] + alpha * s[i]
76func AddScaledTo(dst, y []complex128, alpha complex128, s []complex128) []complex128 {
77	if len(s) != len(y) {
78		panic(badLength)
79	}
80	if len(dst) != len(y) {
81		panic(badDstLength)
82	}
83	c128.AxpyUnitaryTo(dst, alpha, s, y)
84	return dst
85}
86
87// Count applies the function f to every element of s and returns the number
88// of times the function returned true.
89func Count(f func(complex128) bool, s []complex128) int {
90	var n int
91	for _, val := range s {
92		if f(val) {
93			n++
94		}
95	}
96	return n
97}
98
99// Complex fills each of the elements of dst with the complex number
100// constructed from the corresponding elements of real and imag.
101// It panics if the argument lengths do not match.
102func Complex(dst []complex128, real, imag []float64) []complex128 {
103	if len(real) != len(imag) {
104		panic(badLength)
105	}
106	if len(dst) != len(real) {
107		panic(badDstLength)
108	}
109	if len(dst) == 0 {
110		return dst
111	}
112	for i, r := range real {
113		dst[i] = complex(r, imag[i])
114	}
115	return dst
116}
117
118// CumProd finds the cumulative product of elements of s and store it in
119// place into dst so that
120//  dst[i] = s[i] * s[i-1] * s[i-2] * ... * s[0]
121// It panics if the argument lengths do not match.
122func CumProd(dst, s []complex128) []complex128 {
123	if len(dst) != len(s) {
124		panic(badDstLength)
125	}
126	if len(dst) == 0 {
127		return dst
128	}
129	return c128.CumProd(dst, s)
130}
131
132// CumSum finds the cumulative sum of elements of s and stores it in place
133// into dst so that
134//  dst[i] = s[i] + s[i-1] + s[i-2] + ... + s[0]
135// It panics if the argument lengths do not match.
136func CumSum(dst, s []complex128) []complex128 {
137	if len(dst) != len(s) {
138		panic(badDstLength)
139	}
140	if len(dst) == 0 {
141		return dst
142	}
143	return c128.CumSum(dst, s)
144}
145
146// Distance computes the L-norm of s - t. See Norm for special cases.
147// It panics if the slice argument lengths do not match.
148func Distance(s, t []complex128, L float64) float64 {
149	if len(s) != len(t) {
150		panic(badLength)
151	}
152	if len(s) == 0 {
153		return 0
154	}
155
156	var norm float64
157	switch {
158	case L == 2:
159		return c128.L2DistanceUnitary(s, t)
160	case L == 1:
161		for i, v := range s {
162			norm += cmplx.Abs(t[i] - v)
163		}
164		return norm
165	case math.IsInf(L, 1):
166		for i, v := range s {
167			absDiff := cmplx.Abs(t[i] - v)
168			if absDiff > norm {
169				norm = absDiff
170			}
171		}
172		return norm
173	default:
174		for i, v := range s {
175			norm += math.Pow(cmplx.Abs(t[i]-v), L)
176		}
177		return math.Pow(norm, 1/L)
178	}
179}
180
181// Div performs element-wise division dst / s
182// and stores the result in dst.
183// It panics if the argument lengths do not match.
184func Div(dst, s []complex128) {
185	if len(dst) != len(s) {
186		panic(badLength)
187	}
188	c128.Div(dst, s)
189}
190
191// DivTo performs element-wise division s / t
192// and stores the result in dst.
193// It panics if the argument lengths do not match.
194func DivTo(dst, s, t []complex128) []complex128 {
195	if len(s) != len(t) {
196		panic(badLength)
197	}
198	if len(dst) != len(s) {
199		panic(badDstLength)
200	}
201	return c128.DivTo(dst, s, t)
202}
203
204// Dot computes the dot product of s1 and s2, i.e.
205// sum_{i = 1}^N conj(s1[i])*s2[i].
206// It panics if the argument lengths do not match.
207func Dot(s1, s2 []complex128) complex128 {
208	if len(s1) != len(s2) {
209		panic(badLength)
210	}
211	return c128.DotUnitary(s1, s2)
212}
213
214// Equal returns true when the slices have equal lengths and
215// all elements are numerically identical.
216func Equal(s1, s2 []complex128) bool {
217	if len(s1) != len(s2) {
218		return false
219	}
220	for i, val := range s1 {
221		if s2[i] != val {
222			return false
223		}
224	}
225	return true
226}
227
228// EqualApprox returns true when the slices have equal lengths and
229// all element pairs have an absolute tolerance less than tol or a
230// relative tolerance less than tol.
231func EqualApprox(s1, s2 []complex128, tol float64) bool {
232	if len(s1) != len(s2) {
233		return false
234	}
235	for i, a := range s1 {
236		if !cscalar.EqualWithinAbsOrRel(a, s2[i], tol, tol) {
237			return false
238		}
239	}
240	return true
241}
242
243// EqualFunc returns true when the slices have the same lengths
244// and the function returns true for all element pairs.
245func EqualFunc(s1, s2 []complex128, f func(complex128, complex128) bool) bool {
246	if len(s1) != len(s2) {
247		return false
248	}
249	for i, val := range s1 {
250		if !f(val, s2[i]) {
251			return false
252		}
253	}
254	return true
255}
256
257// EqualLengths returns true when all of the slices have equal length,
258// and false otherwise. It also eturns true when there are no input slices.
259func EqualLengths(slices ...[]complex128) bool {
260	// This length check is needed: http://play.golang.org/p/sdty6YiLhM
261	if len(slices) == 0 {
262		return true
263	}
264	l := len(slices[0])
265	for i := 1; i < len(slices); i++ {
266		if len(slices[i]) != l {
267			return false
268		}
269	}
270	return true
271}
272
273// Find applies f to every element of s and returns the indices of the first
274// k elements for which the f returns true, or all such elements
275// if k < 0.
276// Find will reslice inds to have 0 length, and will append
277// found indices to inds.
278// If k > 0 and there are fewer than k elements in s satisfying f,
279// all of the found elements will be returned along with an error.
280// At the return of the function, the input inds will be in an undetermined state.
281func Find(inds []int, f func(complex128) bool, s []complex128, k int) ([]int, error) {
282	// inds is also returned to allow for calling with nil.
283
284	// Reslice inds to have zero length.
285	inds = inds[:0]
286
287	// If zero elements requested, can just return.
288	if k == 0 {
289		return inds, nil
290	}
291
292	// If k < 0, return all of the found indices.
293	if k < 0 {
294		for i, val := range s {
295			if f(val) {
296				inds = append(inds, i)
297			}
298		}
299		return inds, nil
300	}
301
302	// Otherwise, find the first k elements.
303	nFound := 0
304	for i, val := range s {
305		if f(val) {
306			inds = append(inds, i)
307			nFound++
308			if nFound == k {
309				return inds, nil
310			}
311		}
312	}
313	// Finished iterating over the loop, which means k elements were not found.
314	return inds, errors.New("cmplxs: insufficient elements found")
315}
316
317// HasNaN returns true when the slice s has any values that are NaN and false
318// otherwise.
319func HasNaN(s []complex128) bool {
320	for _, v := range s {
321		if cmplx.IsNaN(v) {
322			return true
323		}
324	}
325	return false
326}
327
328// Imag places the imaginary components of src into dst.
329// It panics if the argument lengths do not match.
330func Imag(dst []float64, src []complex128) []float64 {
331	if len(dst) != len(src) {
332		panic(badDstLength)
333	}
334	if len(dst) == 0 {
335		return dst
336	}
337	for i, z := range src {
338		dst[i] = imag(z)
339	}
340	return dst
341}
342
343// LogSpan returns a set of n equally spaced points in log space between,
344// l and u where N is equal to len(dst). The first element of the
345// resulting dst will be l and the final element of dst will be u.
346// Panics if len(dst) < 2
347// Note that this call will return NaNs if either l or u are negative, and
348// will return all zeros if l or u is zero.
349// Also returns the mutated slice dst, so that it can be used in range, like:
350//
351//     for i, x := range LogSpan(dst, l, u) { ... }
352func LogSpan(dst []complex128, l, u complex128) []complex128 {
353	Span(dst, cmplx.Log(l), cmplx.Log(u))
354	for i := range dst {
355		dst[i] = cmplx.Exp(dst[i])
356	}
357	return dst
358}
359
360// MaxAbs returns the maximum absolute value in the input slice.
361// It panics if s is zero length.
362func MaxAbs(s []complex128) complex128 {
363	return s[MaxAbsIdx(s)]
364}
365
366// MaxAbsIdx returns the index of the maximum absolute value in the input slice.
367// If several entries have the maximum absolute value, the first such index is
368// returned.
369// It panics if s is zero length.
370func MaxAbsIdx(s []complex128) int {
371	if len(s) == 0 {
372		panic(zeroLength)
373	}
374	max := math.NaN()
375	var ind int
376	for i, v := range s {
377		if cmplx.IsNaN(v) {
378			continue
379		}
380		if a := cmplx.Abs(v); a > max || math.IsNaN(max) {
381			max = a
382			ind = i
383		}
384	}
385	return ind
386}
387
388// MinAbs returns the minimum absolute value in the input slice.
389// It panics if s is zero length.
390func MinAbs(s []complex128) complex128 {
391	return s[MinAbsIdx(s)]
392}
393
394// MinAbsIdx returns the index of the minimum absolute value in the input slice. If several
395// entries have the minimum absolute value, the first such index is returned.
396// It panics if s is zero length.
397func MinAbsIdx(s []complex128) int {
398	if len(s) == 0 {
399		panic(zeroLength)
400	}
401	min := math.NaN()
402	var ind int
403	for i, v := range s {
404		if cmplx.IsNaN(v) {
405			continue
406		}
407		if a := cmplx.Abs(v); a < min || math.IsNaN(min) {
408			min = a
409			ind = i
410		}
411	}
412	return ind
413}
414
415// Mul performs element-wise multiplication between dst
416// and s and stores the result in dst.
417// It panics if the argument lengths do not match.
418func Mul(dst, s []complex128) {
419	if len(dst) != len(s) {
420		panic(badLength)
421	}
422	for i, val := range s {
423		dst[i] *= val
424	}
425}
426
427// MulTo performs element-wise multiplication between s
428// and t and stores the result in dst.
429// It panics if the argument lengths do not match.
430func MulTo(dst, s, t []complex128) []complex128 {
431	if len(s) != len(t) {
432		panic(badLength)
433	}
434	if len(dst) != len(s) {
435		panic(badDstLength)
436	}
437	for i, val := range t {
438		dst[i] = val * s[i]
439	}
440	return dst
441}
442
443// NearestIdx returns the index of the element in s
444// whose value is nearest to v. If several such
445// elements exist, the lowest index is returned.
446// It panics if s is zero length.
447func NearestIdx(s []complex128, v complex128) int {
448	if len(s) == 0 {
449		panic(zeroLength)
450	}
451	switch {
452	case cmplx.IsNaN(v):
453		return 0
454	case cmplx.IsInf(v):
455		return MaxAbsIdx(s)
456	}
457	var ind int
458	dist := math.NaN()
459	for i, val := range s {
460		newDist := cmplx.Abs(v - val)
461		// A NaN distance will not be closer.
462		if math.IsNaN(newDist) {
463			continue
464		}
465		if newDist < dist || math.IsNaN(dist) {
466			dist = newDist
467			ind = i
468		}
469	}
470	return ind
471}
472
473// Norm returns the L-norm of the slice S, defined as
474// (sum_{i=1}^N abs(s[i])^L)^{1/L}
475// Special cases:
476// L = math.Inf(1) gives the maximum absolute value.
477// Does not correctly compute the zero norm (use Count).
478func Norm(s []complex128, L float64) float64 {
479	// Should this complain if L is not positive?
480	// Should this be done in log space for better numerical stability?
481	//	would be more cost
482	//	maybe only if L is high?
483	if len(s) == 0 {
484		return 0
485	}
486	var norm float64
487	switch {
488	case L == 2:
489		return c128.L2NormUnitary(s)
490	case L == 1:
491		for _, v := range s {
492			norm += cmplx.Abs(v)
493		}
494		return norm
495	case math.IsInf(L, 1):
496		for _, v := range s {
497			norm = math.Max(norm, cmplx.Abs(v))
498		}
499		return norm
500	default:
501		for _, v := range s {
502			norm += math.Pow(cmplx.Abs(v), L)
503		}
504		return math.Pow(norm, 1/L)
505	}
506}
507
508// Prod returns the product of the elements of the slice.
509// Returns 1 if len(s) = 0.
510func Prod(s []complex128) complex128 {
511	prod := 1 + 0i
512	for _, val := range s {
513		prod *= val
514	}
515	return prod
516}
517
518// Real places the real components of src into dst.
519// It panics if the argument lengths do not match.
520func Real(dst []float64, src []complex128) []float64 {
521	if len(dst) != len(src) {
522		panic(badDstLength)
523	}
524	if len(dst) == 0 {
525		return dst
526	}
527	for i, z := range src {
528		dst[i] = real(z)
529	}
530	return dst
531}
532
533// Reverse reverses the order of elements in the slice.
534func Reverse(s []complex128) {
535	for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
536		s[i], s[j] = s[j], s[i]
537	}
538}
539
540// Same returns true when the input slices have the same length and all
541// elements have the same value with NaN treated as the same.
542func Same(s, t []complex128) bool {
543	if len(s) != len(t) {
544		return false
545	}
546	for i, v := range s {
547		w := t[i]
548		if v != w && !(cmplx.IsNaN(v) && cmplx.IsNaN(w)) {
549			return false
550		}
551	}
552	return true
553}
554
555// Scale multiplies every element in dst by the scalar c.
556func Scale(c complex128, dst []complex128) {
557	if len(dst) > 0 {
558		c128.ScalUnitary(c, dst)
559	}
560}
561
562// ScaleTo multiplies the elements in s by c and stores the result in dst.
563// It panics if the slice argument lengths do not match.
564func ScaleTo(dst []complex128, c complex128, s []complex128) []complex128 {
565	if len(dst) != len(s) {
566		panic(badDstLength)
567	}
568	if len(dst) > 0 {
569		c128.ScalUnitaryTo(dst, c, s)
570	}
571	return dst
572}
573
574// Span returns a set of N equally spaced points between l and u, where N
575// is equal to the length of the destination. The first element of the destination
576// is l, the final element of the destination is u.
577// It panics if the length of dst is less than 2.
578//
579// Span also returns the mutated slice dst, so that it can be used in range expressions,
580// like:
581//
582//     for i, x := range Span(dst, l, u) { ... }
583func Span(dst []complex128, l, u complex128) []complex128 {
584	n := len(dst)
585	if n < 2 {
586		panic(shortSpan)
587	}
588
589	// Special cases for Inf and NaN.
590	switch {
591	case cmplx.IsNaN(l):
592		for i := range dst[:len(dst)-1] {
593			dst[i] = cmplx.NaN()
594		}
595		dst[len(dst)-1] = u
596		return dst
597	case cmplx.IsNaN(u):
598		for i := range dst[1:] {
599			dst[i+1] = cmplx.NaN()
600		}
601		dst[0] = l
602		return dst
603	case cmplx.IsInf(l) && cmplx.IsInf(u):
604		for i := range dst {
605			dst[i] = cmplx.Inf()
606		}
607		return dst
608	case cmplx.IsInf(l):
609		for i := range dst[:len(dst)-1] {
610			dst[i] = l
611		}
612		dst[len(dst)-1] = u
613		return dst
614	case cmplx.IsInf(u):
615		for i := range dst[1:] {
616			dst[i+1] = u
617		}
618		dst[0] = l
619		return dst
620	}
621
622	step := (u - l) / complex(float64(n-1), 0)
623	for i := range dst {
624		dst[i] = l + step*complex(float64(i), 0)
625	}
626	return dst
627}
628
629// Sub subtracts, element-wise, the elements of s from dst.
630// It panics if the argument lengths do not match.
631func Sub(dst, s []complex128) {
632	if len(dst) != len(s) {
633		panic(badLength)
634	}
635	c128.AxpyUnitaryTo(dst, -1, s, dst)
636}
637
638// SubTo subtracts, element-wise, the elements of t from s and
639// stores the result in dst.
640// It panics if the argument lengths do not match.
641func SubTo(dst, s, t []complex128) []complex128 {
642	if len(s) != len(t) {
643		panic(badLength)
644	}
645	if len(dst) != len(s) {
646		panic(badDstLength)
647	}
648	c128.AxpyUnitaryTo(dst, -1, t, s)
649	return dst
650}
651
652// Sum returns the sum of the elements of the slice.
653func Sum(s []complex128) complex128 {
654	return c128.Sum(s)
655}
656