1// Copyright ©2013 The Gonum Authors. All rights reserved.
2// Use of this code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package floats
6
7import (
8	"errors"
9	"math"
10	"sort"
11
12	"gonum.org/v1/gonum/floats/scalar"
13	"gonum.org/v1/gonum/internal/asm/f64"
14)
15
16const (
17	zeroLength   = "floats: zero length slice"
18	shortSpan    = "floats: slice length less than 2"
19	badLength    = "floats: slice lengths do not match"
20	badDstLength = "floats: destination slice length does not match input"
21)
22
23// Add adds, element-wise, the elements of s and dst, and stores the result in dst.
24// It panics if the argument lengths do not match.
25func Add(dst, s []float64) {
26	if len(dst) != len(s) {
27		panic(badDstLength)
28	}
29	f64.AxpyUnitaryTo(dst, 1, s, dst)
30}
31
32// AddTo adds, element-wise, the elements of s and t and
33// stores the result in dst.
34// It panics if the argument lengths do not match.
35func AddTo(dst, s, t []float64) []float64 {
36	if len(s) != len(t) {
37		panic(badLength)
38	}
39	if len(dst) != len(s) {
40		panic(badDstLength)
41	}
42	f64.AxpyUnitaryTo(dst, 1, s, t)
43	return dst
44}
45
46// AddConst adds the scalar c to all of the values in dst.
47func AddConst(c float64, dst []float64) {
48	f64.AddConst(c, dst)
49}
50
51// AddScaled performs dst = dst + alpha * s.
52// It panics if the slice argument lengths do not match.
53func AddScaled(dst []float64, alpha float64, s []float64) {
54	if len(dst) != len(s) {
55		panic(badLength)
56	}
57	f64.AxpyUnitaryTo(dst, alpha, s, dst)
58}
59
60// AddScaledTo performs dst = y + alpha * s, where alpha is a scalar,
61// and dst, y and s are all slices.
62// It panics if the slice argument lengths do not match.
63//
64// At the return of the function, dst[i] = y[i] + alpha * s[i]
65func AddScaledTo(dst, y []float64, alpha float64, s []float64) []float64 {
66	if len(s) != len(y) {
67		panic(badLength)
68	}
69	if len(dst) != len(y) {
70		panic(badDstLength)
71	}
72	f64.AxpyUnitaryTo(dst, alpha, s, y)
73	return dst
74}
75
76// argsort is a helper that implements sort.Interface, as used by
77// Argsort.
78type argsort struct {
79	s    []float64
80	inds []int
81}
82
83func (a argsort) Len() int {
84	return len(a.s)
85}
86
87func (a argsort) Less(i, j int) bool {
88	return a.s[i] < a.s[j]
89}
90
91func (a argsort) Swap(i, j int) {
92	a.s[i], a.s[j] = a.s[j], a.s[i]
93	a.inds[i], a.inds[j] = a.inds[j], a.inds[i]
94}
95
96// Argsort sorts the elements of dst while tracking their original order.
97// At the conclusion of Argsort, dst will contain the original elements of dst
98// but sorted in increasing order, and inds will contain the original position
99// of the elements in the slice such that dst[i] = origDst[inds[i]].
100// It panics if the argument lengths do not match.
101func Argsort(dst []float64, inds []int) {
102	if len(dst) != len(inds) {
103		panic(badDstLength)
104	}
105	for i := range dst {
106		inds[i] = i
107	}
108
109	a := argsort{s: dst, inds: inds}
110	sort.Sort(a)
111}
112
113// Count applies the function f to every element of s and returns the number
114// of times the function returned true.
115func Count(f func(float64) bool, s []float64) int {
116	var n int
117	for _, val := range s {
118		if f(val) {
119			n++
120		}
121	}
122	return n
123}
124
125// CumProd finds the cumulative product of the first i elements in
126// s and puts them in place into the ith element of the
127// destination dst.
128// It panics if the argument lengths do not match.
129//
130// At the return of the function, dst[i] = s[i] * s[i-1] * s[i-2] * ...
131func CumProd(dst, s []float64) []float64 {
132	if len(dst) != len(s) {
133		panic(badDstLength)
134	}
135	if len(dst) == 0 {
136		return dst
137	}
138	return f64.CumProd(dst, s)
139}
140
141// CumSum finds the cumulative sum of the first i elements in
142// s and puts them in place into the ith element of the
143// destination dst.
144// It panics if the argument lengths do not match.
145//
146// At the return of the function, dst[i] = s[i] + s[i-1] + s[i-2] + ...
147func CumSum(dst, s []float64) []float64 {
148	if len(dst) != len(s) {
149		panic(badDstLength)
150	}
151	if len(dst) == 0 {
152		return dst
153	}
154	return f64.CumSum(dst, s)
155}
156
157// Distance computes the L-norm of s - t. See Norm for special cases.
158// It panics if the slice argument lengths do not match.
159func Distance(s, t []float64, L float64) float64 {
160	if len(s) != len(t) {
161		panic(badLength)
162	}
163	if len(s) == 0 {
164		return 0
165	}
166	if L == 2 {
167		return f64.L2DistanceUnitary(s, t)
168	}
169	var norm float64
170	if L == 1 {
171		for i, v := range s {
172			norm += math.Abs(t[i] - v)
173		}
174		return norm
175	}
176	if math.IsInf(L, 1) {
177		for i, v := range s {
178			absDiff := math.Abs(t[i] - v)
179			if absDiff > norm {
180				norm = absDiff
181			}
182		}
183		return norm
184	}
185	for i, v := range s {
186		norm += math.Pow(math.Abs(t[i]-v), L)
187	}
188	return math.Pow(norm, 1/L)
189}
190
191// Div performs element-wise division dst / s
192// and stores the value in dst.
193// It panics if the argument lengths do not match.
194func Div(dst, s []float64) {
195	if len(dst) != len(s) {
196		panic(badLength)
197	}
198	f64.Div(dst, s)
199}
200
201// DivTo performs element-wise division s / t
202// and stores the value in dst.
203// It panics if the argument lengths do not match.
204func DivTo(dst, s, t []float64) []float64 {
205	if len(s) != len(t) {
206		panic(badLength)
207	}
208	if len(dst) != len(s) {
209		panic(badDstLength)
210	}
211	return f64.DivTo(dst, s, t)
212}
213
214// Dot computes the dot product of s1 and s2, i.e.
215// sum_{i = 1}^N s1[i]*s2[i].
216// It panics if the argument lengths do not match.
217func Dot(s1, s2 []float64) float64 {
218	if len(s1) != len(s2) {
219		panic(badLength)
220	}
221	return f64.DotUnitary(s1, s2)
222}
223
224// Equal returns true when the slices have equal lengths and
225// all elements are numerically identical.
226func Equal(s1, s2 []float64) bool {
227	if len(s1) != len(s2) {
228		return false
229	}
230	for i, val := range s1 {
231		if s2[i] != val {
232			return false
233		}
234	}
235	return true
236}
237
238// EqualApprox returns true when the slices have equal lengths and
239// all element pairs have an absolute tolerance less than tol or a
240// relative tolerance less than tol.
241func EqualApprox(s1, s2 []float64, tol float64) bool {
242	if len(s1) != len(s2) {
243		return false
244	}
245	for i, a := range s1 {
246		if !scalar.EqualWithinAbsOrRel(a, s2[i], tol, tol) {
247			return false
248		}
249	}
250	return true
251}
252
253// EqualFunc returns true when the slices have the same lengths
254// and the function returns true for all element pairs.
255func EqualFunc(s1, s2 []float64, f func(float64, float64) bool) bool {
256	if len(s1) != len(s2) {
257		return false
258	}
259	for i, val := range s1 {
260		if !f(val, s2[i]) {
261			return false
262		}
263	}
264	return true
265}
266
267// EqualLengths returns true when all of the slices have equal length,
268// and false otherwise. It also returns true when there are no input slices.
269func EqualLengths(slices ...[]float64) bool {
270	// This length check is needed: http://play.golang.org/p/sdty6YiLhM
271	if len(slices) == 0 {
272		return true
273	}
274	l := len(slices[0])
275	for i := 1; i < len(slices); i++ {
276		if len(slices[i]) != l {
277			return false
278		}
279	}
280	return true
281}
282
283// Find applies f to every element of s and returns the indices of the first
284// k elements for which the f returns true, or all such elements
285// if k < 0.
286// Find will reslice inds to have 0 length, and will append
287// found indices to inds.
288// If k > 0 and there are fewer than k elements in s satisfying f,
289// all of the found elements will be returned along with an error.
290// At the return of the function, the input inds will be in an undetermined state.
291func Find(inds []int, f func(float64) bool, s []float64, k int) ([]int, error) {
292	// inds is also returned to allow for calling with nil.
293
294	// Reslice inds to have zero length.
295	inds = inds[:0]
296
297	// If zero elements requested, can just return.
298	if k == 0 {
299		return inds, nil
300	}
301
302	// If k < 0, return all of the found indices.
303	if k < 0 {
304		for i, val := range s {
305			if f(val) {
306				inds = append(inds, i)
307			}
308		}
309		return inds, nil
310	}
311
312	// Otherwise, find the first k elements.
313	nFound := 0
314	for i, val := range s {
315		if f(val) {
316			inds = append(inds, i)
317			nFound++
318			if nFound == k {
319				return inds, nil
320			}
321		}
322	}
323	// Finished iterating over the loop, which means k elements were not found.
324	return inds, errors.New("floats: insufficient elements found")
325}
326
327// HasNaN returns true when the slice s has any values that are NaN and false
328// otherwise.
329func HasNaN(s []float64) bool {
330	for _, v := range s {
331		if math.IsNaN(v) {
332			return true
333		}
334	}
335	return false
336}
337
338// LogSpan returns a set of n equally spaced points in log space between,
339// l and u where N is equal to len(dst). The first element of the
340// resulting dst will be l and the final element of dst will be u.
341// It panics if the length of dst is less than 2.
342// Note that this call will return NaNs if either l or u are negative, and
343// will return all zeros if l or u is zero.
344// Also returns the mutated slice dst, so that it can be used in range, like:
345//
346//     for i, x := range LogSpan(dst, l, u) { ... }
347func LogSpan(dst []float64, l, u float64) []float64 {
348	Span(dst, math.Log(l), math.Log(u))
349	for i := range dst {
350		dst[i] = math.Exp(dst[i])
351	}
352	return dst
353}
354
355// LogSumExp returns the log of the sum of the exponentials of the values in s.
356// Panics if s is an empty slice.
357func LogSumExp(s []float64) float64 {
358	// Want to do this in a numerically stable way which avoids
359	// overflow and underflow
360	// First, find the maximum value in the slice.
361	maxval := Max(s)
362	if math.IsInf(maxval, 0) {
363		// If it's infinity either way, the logsumexp will be infinity as well
364		// returning now avoids NaNs
365		return maxval
366	}
367	var lse float64
368	// Compute the sumexp part
369	for _, val := range s {
370		lse += math.Exp(val - maxval)
371	}
372	// Take the log and add back on the constant taken out
373	return math.Log(lse) + maxval
374}
375
376// Max returns the maximum value in the input slice. If the slice is empty, Max will panic.
377func Max(s []float64) float64 {
378	return s[MaxIdx(s)]
379}
380
381// MaxIdx returns the index of the maximum value in the input slice. If several
382// entries have the maximum value, the first such index is returned.
383// It panics if s is zero length.
384func MaxIdx(s []float64) int {
385	if len(s) == 0 {
386		panic(zeroLength)
387	}
388	max := math.NaN()
389	var ind int
390	for i, v := range s {
391		if math.IsNaN(v) {
392			continue
393		}
394		if v > max || math.IsNaN(max) {
395			max = v
396			ind = i
397		}
398	}
399	return ind
400}
401
402// Min returns the minimum value in the input slice.
403// It panics if s is zero length.
404func Min(s []float64) float64 {
405	return s[MinIdx(s)]
406}
407
408// MinIdx returns the index of the minimum value in the input slice. If several
409// entries have the minimum value, the first such index is returned.
410// It panics if s is zero length.
411func MinIdx(s []float64) int {
412	if len(s) == 0 {
413		panic(zeroLength)
414	}
415	min := math.NaN()
416	var ind int
417	for i, v := range s {
418		if math.IsNaN(v) {
419			continue
420		}
421		if v < min || math.IsNaN(min) {
422			min = v
423			ind = i
424		}
425	}
426	return ind
427}
428
429// Mul performs element-wise multiplication between dst
430// and s and stores the value in dst.
431// It panics if the argument lengths do not match.
432func Mul(dst, s []float64) {
433	if len(dst) != len(s) {
434		panic(badLength)
435	}
436	for i, val := range s {
437		dst[i] *= val
438	}
439}
440
441// MulTo performs element-wise multiplication between s
442// and t and stores the value in dst.
443// It panics if the argument lengths do not match.
444func MulTo(dst, s, t []float64) []float64 {
445	if len(s) != len(t) {
446		panic(badLength)
447	}
448	if len(dst) != len(s) {
449		panic(badDstLength)
450	}
451	for i, val := range t {
452		dst[i] = val * s[i]
453	}
454	return dst
455}
456
457// NearestIdx returns the index of the element in s
458// whose value is nearest to v. If several such
459// elements exist, the lowest index is returned.
460// It panics if s is zero length.
461func NearestIdx(s []float64, v float64) int {
462	if len(s) == 0 {
463		panic(zeroLength)
464	}
465	switch {
466	case math.IsNaN(v):
467		return 0
468	case math.IsInf(v, 1):
469		return MaxIdx(s)
470	case math.IsInf(v, -1):
471		return MinIdx(s)
472	}
473	var ind int
474	dist := math.NaN()
475	for i, val := range s {
476		newDist := math.Abs(v - val)
477		// A NaN distance will not be closer.
478		if math.IsNaN(newDist) {
479			continue
480		}
481		if newDist < dist || math.IsNaN(dist) {
482			dist = newDist
483			ind = i
484		}
485	}
486	return ind
487}
488
489// NearestIdxForSpan return the index of a hypothetical vector created
490// by Span with length n and bounds l and u whose value is closest
491// to v. That is, NearestIdxForSpan(n, l, u, v) is equivalent to
492// Nearest(Span(make([]float64, n),l,u),v) without an allocation.
493// It panics if n is less than two.
494func NearestIdxForSpan(n int, l, u float64, v float64) int {
495	if n < 2 {
496		panic(shortSpan)
497	}
498	if math.IsNaN(v) {
499		return 0
500	}
501
502	// Special cases for Inf and NaN.
503	switch {
504	case math.IsNaN(l) && !math.IsNaN(u):
505		return n - 1
506	case math.IsNaN(u):
507		return 0
508	case math.IsInf(l, 0) && math.IsInf(u, 0):
509		if l == u {
510			return 0
511		}
512		if n%2 == 1 {
513			if !math.IsInf(v, 0) {
514				return n / 2
515			}
516			if math.Copysign(1, v) == math.Copysign(1, l) {
517				return 0
518			}
519			return n/2 + 1
520		}
521		if math.Copysign(1, v) == math.Copysign(1, l) {
522			return 0
523		}
524		return n / 2
525	case math.IsInf(l, 0):
526		if v == l {
527			return 0
528		}
529		return n - 1
530	case math.IsInf(u, 0):
531		if v == u {
532			return n - 1
533		}
534		return 0
535	case math.IsInf(v, -1):
536		if l <= u {
537			return 0
538		}
539		return n - 1
540	case math.IsInf(v, 1):
541		if u <= l {
542			return 0
543		}
544		return n - 1
545	}
546
547	// Special cases for v outside (l, u) and (u, l).
548	switch {
549	case l < u:
550		if v <= l {
551			return 0
552		}
553		if v >= u {
554			return n - 1
555		}
556	case l > u:
557		if v >= l {
558			return 0
559		}
560		if v <= u {
561			return n - 1
562		}
563	default:
564		return 0
565	}
566
567	// Can't guarantee anything about exactly halfway between
568	// because of floating point weirdness.
569	return int((float64(n)-1)/(u-l)*(v-l) + 0.5)
570}
571
572// Norm returns the L norm of the slice S, defined as
573// (sum_{i=1}^N s[i]^L)^{1/L}
574// Special cases:
575// L = math.Inf(1) gives the maximum absolute value.
576// Does not correctly compute the zero norm (use Count).
577func Norm(s []float64, L float64) float64 {
578	// Should this complain if L is not positive?
579	// Should this be done in log space for better numerical stability?
580	//	would be more cost
581	//	maybe only if L is high?
582	if len(s) == 0 {
583		return 0
584	}
585	if L == 2 {
586		return f64.L2NormUnitary(s)
587	}
588	var norm float64
589	if L == 1 {
590		for _, val := range s {
591			norm += math.Abs(val)
592		}
593		return norm
594	}
595	if math.IsInf(L, 1) {
596		for _, val := range s {
597			norm = math.Max(norm, math.Abs(val))
598		}
599		return norm
600	}
601	for _, val := range s {
602		norm += math.Pow(math.Abs(val), L)
603	}
604	return math.Pow(norm, 1/L)
605}
606
607// Prod returns the product of the elements of the slice.
608// Returns 1 if len(s) = 0.
609func Prod(s []float64) float64 {
610	prod := 1.0
611	for _, val := range s {
612		prod *= val
613	}
614	return prod
615}
616
617// Reverse reverses the order of elements in the slice.
618func Reverse(s []float64) {
619	for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
620		s[i], s[j] = s[j], s[i]
621	}
622}
623
624// Same returns true when the input slices have the same length and all
625// elements have the same value with NaN treated as the same.
626func Same(s, t []float64) bool {
627	if len(s) != len(t) {
628		return false
629	}
630	for i, v := range s {
631		w := t[i]
632		if v != w && !(math.IsNaN(v) && math.IsNaN(w)) {
633			return false
634		}
635	}
636	return true
637}
638
639// Scale multiplies every element in dst by the scalar c.
640func Scale(c float64, dst []float64) {
641	if len(dst) > 0 {
642		f64.ScalUnitary(c, dst)
643	}
644}
645
646// ScaleTo multiplies the elements in s by c and stores the result in dst.
647// It panics if the slice argument lengths do not match.
648func ScaleTo(dst []float64, c float64, s []float64) []float64 {
649	if len(dst) != len(s) {
650		panic(badDstLength)
651	}
652	if len(dst) > 0 {
653		f64.ScalUnitaryTo(dst, c, s)
654	}
655	return dst
656}
657
658// Span returns a set of N equally spaced points between l and u, where N
659// is equal to the length of the destination. The first element of the destination
660// is l, the final element of the destination is u.
661// It panics if the length of dst is less than 2.
662//
663// Span also returns the mutated slice dst, so that it can be used in range expressions,
664// like:
665//
666//     for i, x := range Span(dst, l, u) { ... }
667func Span(dst []float64, l, u float64) []float64 {
668	n := len(dst)
669	if n < 2 {
670		panic(shortSpan)
671	}
672
673	// Special cases for Inf and NaN.
674	switch {
675	case math.IsNaN(l):
676		for i := range dst[:len(dst)-1] {
677			dst[i] = math.NaN()
678		}
679		dst[len(dst)-1] = u
680		return dst
681	case math.IsNaN(u):
682		for i := range dst[1:] {
683			dst[i+1] = math.NaN()
684		}
685		dst[0] = l
686		return dst
687	case math.IsInf(l, 0) && math.IsInf(u, 0):
688		for i := range dst[:len(dst)/2] {
689			dst[i] = l
690			dst[len(dst)-i-1] = u
691		}
692		if len(dst)%2 == 1 {
693			if l != u {
694				dst[len(dst)/2] = 0
695			} else {
696				dst[len(dst)/2] = l
697			}
698		}
699		return dst
700	case math.IsInf(l, 0):
701		for i := range dst[:len(dst)-1] {
702			dst[i] = l
703		}
704		dst[len(dst)-1] = u
705		return dst
706	case math.IsInf(u, 0):
707		for i := range dst[1:] {
708			dst[i+1] = u
709		}
710		dst[0] = l
711		return dst
712	}
713
714	step := (u - l) / float64(n-1)
715	for i := range dst {
716		dst[i] = l + step*float64(i)
717	}
718	return dst
719}
720
721// Sub subtracts, element-wise, the elements of s from dst.
722// It panics if the argument lengths do not match.
723func Sub(dst, s []float64) {
724	if len(dst) != len(s) {
725		panic(badLength)
726	}
727	f64.AxpyUnitaryTo(dst, -1, s, dst)
728}
729
730// SubTo subtracts, element-wise, the elements of t from s and
731// stores the result in dst.
732// It panics if the argument lengths do not match.
733func SubTo(dst, s, t []float64) []float64 {
734	if len(s) != len(t) {
735		panic(badLength)
736	}
737	if len(dst) != len(s) {
738		panic(badDstLength)
739	}
740	f64.AxpyUnitaryTo(dst, -1, t, s)
741	return dst
742}
743
744// Sum returns the sum of the elements of the slice.
745func Sum(s []float64) float64 {
746	return f64.Sum(s)
747}
748
749// Within returns the first index i where s[i] <= v < s[i+1]. Within panics if:
750//  - len(s) < 2
751//  - s is not sorted
752func Within(s []float64, v float64) int {
753	if len(s) < 2 {
754		panic(shortSpan)
755	}
756	if !sort.Float64sAreSorted(s) {
757		panic("floats: input slice not sorted")
758	}
759	if v < s[0] || v >= s[len(s)-1] || math.IsNaN(v) {
760		return -1
761	}
762	for i, f := range s[1:] {
763		if v < f {
764			return i
765		}
766	}
767	return -1
768}
769
770// SumCompensated returns the sum of the elements of the slice calculated with greater
771// accuracy than Sum at the expense of additional computation.
772func SumCompensated(s []float64) float64 {
773	// SumCompensated uses an improved version of Kahan's compensated
774	// summation algorithm proposed by Neumaier.
775	// See https://en.wikipedia.org/wiki/Kahan_summation_algorithm for details.
776	var sum, c float64
777	for _, x := range s {
778		// This type conversion is here to prevent a sufficiently smart compiler
779		// from optimising away these operations.
780		t := float64(sum + x)
781		if math.Abs(sum) >= math.Abs(x) {
782			c += (sum - t) + x
783		} else {
784			c += (x - t) + sum
785		}
786		sum = t
787	}
788	return sum + c
789}
790