1 // Written in the D programming language.
2 
3 /**
4 This module is a port of a growing fragment of the $(D_PARAM numeric)
5 header in Alexander Stepanov's $(LINK2 http://sgi.com/tech/stl,
6 Standard Template Library), with a few additions.
7 
8 Macros:
9 Copyright: Copyright Andrei Alexandrescu 2008 - 2009.
10 License:   $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0).
11 Authors:   $(HTTP erdani.org, Andrei Alexandrescu),
12                    Don Clugston, Robert Jacques, Ilya Yaroshenko
13 Source:    $(PHOBOSSRC std/_numeric.d)
14 */
15 /*
16          Copyright Andrei Alexandrescu 2008 - 2009.
17 Distributed under the Boost Software License, Version 1.0.
18    (See accompanying file LICENSE_1_0.txt or copy at
19          http://www.boost.org/LICENSE_1_0.txt)
20 */
21 module std.numeric;
22 
23 import std.complex;
24 import std.math;
25 import std.range.primitives;
26 import std.traits;
27 import std.typecons;
28 
version(unittest)29 version (unittest)
30 {
31     import std.stdio;
32 }
33 /// Format flags for CustomFloat.
34 public enum CustomFloatFlags
35 {
36     /// Adds a sign bit to allow for signed numbers.
37     signed = 1,
38 
39     /**
40      * Store values in normalized form by default. The actual precision of the
41      * significand is extended by 1 bit by assuming an implicit leading bit of 1
42      * instead of 0. i.e. $(D 1.nnnn) instead of $(D 0.nnnn).
43      * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types
44      */
45     storeNormalized = 2,
46 
47     /**
48      * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers,
49      * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0.
50      */
51     allowDenorm = 4,
52 
53     /**
54       * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity,
55       * IEEE754 _infinity) values.
56       */
57     infinity = 8,
58 
59     /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values.
60     nan = 16,
61 
62     /**
63      * If set, select an exponent bias such that max_exp = 1.
64      * i.e. so that the maximum value is >= 1.0 and < 2.0.
65      * Ignored if the exponent bias is manually specified.
66      */
67     probability = 32,
68 
69     /// If set, unsigned custom floats are assumed to be negative.
70     negativeUnsigned = 64,
71 
72     /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers,
73      * IEEE754 denormalized) number.
74      * Requires allowDenorm and storeNormalized.
75      */
76     allowDenormZeroOnly = 128 | allowDenorm | storeNormalized,
77 
78     /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options.
79     ieee = signed | storeNormalized | allowDenorm | infinity | nan ,
80 
81     /// Include none of the above options.
82     none = 0
83 }
84 
CustomFloatParams(uint bits)85 private template CustomFloatParams(uint bits)
86 {
87     enum CustomFloatFlags flags = CustomFloatFlags.ieee
88                 ^ ((bits == 80) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none);
89     static if (bits ==  8) alias CustomFloatParams = CustomFloatParams!( 4,  3, flags);
90     static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10,  5, flags);
91     static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23,  8, flags);
92     static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags);
93     static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags);
94 }
95 
CustomFloatParams(uint precision,uint exponentWidth,CustomFloatFlags flags)96 private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags)
97 {
98     import std.meta : AliasSeq;
99     alias CustomFloatParams =
100         AliasSeq!(
101             precision,
102             exponentWidth,
103             flags,
104             (1 << (exponentWidth - ((flags & flags.probability) == 0)))
105              - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0)
106         ); // ((flags & CustomFloatFlags.probability) == 0)
107 }
108 
109 /**
110  * Allows user code to define custom floating-point formats. These formats are
111  * for storage only; all operations on them are performed by first implicitly
112  * extracting them to $(D real) first. After the operation is completed the
113  * result can be stored in a custom floating-point value via assignment.
114  */
115 template CustomFloat(uint bits)
116 if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80)
117 {
118     alias CustomFloat = CustomFloat!(CustomFloatParams!(bits));
119 }
120 
121 /// ditto
122 template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee)
123 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0)
124 {
125     alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags));
126 }
127 
128 ///
129 @safe unittest
130 {
131     import std.math : sin, cos;
132 
133     // Define a 16-bit floating point values
134     CustomFloat!16                                x;     // Using the number of bits
135     CustomFloat!(10, 5)                           y;     // Using the precision and exponent width
136     CustomFloat!(10, 5,CustomFloatFlags.ieee)     z;     // Using the precision, exponent width and format flags
137     CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w;     // Using the precision, exponent width, format flags and exponent offset bias
138 
139     // Use the 16-bit floats mostly like normal numbers
140     w = x*y - 1;
141 
142     // Functions calls require conversion
143     z = sin(+x)           + cos(+y);                     // Use unary plus to concisely convert to a real
144     z = sin(x.get!float)  + cos(y.get!float);            // Or use get!T
145     z = sin(cast(float) x) + cos(cast(float) y);           // Or use cast(T) to explicitly convert
146 
147     // Define a 8-bit custom float for storing probabilities
148     alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed );
149     auto p = Probability(0.5);
150 }
151 
152 /// ditto
153 struct CustomFloat(uint             precision,  // fraction bits (23 for float)
154                    uint             exponentWidth,  // exponent bits (8 for float)  Exponent width
155                    CustomFloatFlags flags,
156                    uint             bias)
157 if (((flags & flags.signed)  + precision + exponentWidth) % 8 == 0 &&
158     precision + exponentWidth > 0)
159 {
160     import std.bitmanip : bitfields;
161     import std.meta : staticIndexOf;
162 private:
163     // get the correct unsigned bitfield type to support > 32 bits
uType(uint bits)164     template uType(uint bits)
165     {
166         static if (bits <= size_t.sizeof*8)  alias uType = size_t;
167         else                                alias uType = ulong ;
168     }
169 
170     // get the correct signed   bitfield type to support > 32 bits
sType(uint bits)171     template sType(uint bits)
172     {
173         static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t;
174         else                                    alias sType = long;
175     }
176 
177     alias T_sig = uType!precision;
178     alias T_exp = uType!exponentWidth;
179     alias T_signed_exp = sType!exponentWidth;
180 
181     alias Flags = CustomFloatFlags;
182 
183     // Facilitate converting numeric types to custom float
184     union ToBinary(F)
185         if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real))
186     {
187         F set;
188 
189         // If on Linux or Mac, where 80-bit reals are padded, ignore the
190         // padding.
191         import std.algorithm.comparison : min;
192         CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get;
193 
194         // Convert F to the correct binary type.
opCall(F value)195         static typeof(get) opCall(F value)
196         {
197             ToBinary r;
198             r.set = value;
199             return r.get;
200         }
201         alias get this;
202     }
203 
204     // Perform IEEE rounding with round to nearest detection
roundedShift(T,U)205     void roundedShift(T,U)(ref T sig, U shift)
206     {
207         if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1))
208         {
209             // round to even
210             sig >>= shift;
211             sig  += sig & 1;
212         }
213         else
214         {
215             sig >>= shift - 1;
216             sig  += sig & 1;
217             // Perform standard rounding
218             sig >>= 1;
219         }
220     }
221 
222     // Convert the current value to signed exponent, normalized form
toNormalized(T,U)223     void toNormalized(T,U)(ref T sig, ref U exp)
224     {
225         sig = significand;
226         auto shift = (T.sizeof*8) - precision;
227         exp = exponent;
228         static if (flags&(Flags.infinity|Flags.nan))
229         {
230             // Handle inf or nan
231             if (exp == exponent_max)
232             {
233                 exp = exp.max;
234                 sig <<= shift;
235                 static if (flags&Flags.storeNormalized)
236                 {
237                     // Save inf/nan in denormalized format
238                     sig >>= 1;
239                     sig  += cast(T) 1uL << (T.sizeof*8 - 1);
240                 }
241                 return;
242             }
243         }
244         if ((~flags&Flags.storeNormalized) ||
245             // Convert denormalized form to normalized form
246             ((flags&Flags.allowDenorm) && exp == 0))
247         {
248             if (sig > 0)
249             {
250                 import core.bitop : bsr;
251                 auto shift2 = precision - bsr(sig);
252                 exp  -= shift2-1;
253                 shift += shift2;
254             }
255             else                                // value = 0.0
256             {
257                 exp = exp.min;
258                 return;
259             }
260         }
261         sig <<= shift;
262         exp -= bias;
263     }
264 
265     // Set the current value from signed exponent, normalized form
fromNormalized(T,U)266     void fromNormalized(T,U)(ref T sig, ref U exp)
267     {
268         auto shift = (T.sizeof*8) - precision;
269         if (exp == exp.max)
270         {
271             // infinity or nan
272             exp = exponent_max;
273             static if (flags & Flags.storeNormalized)
274                 sig <<= 1;
275 
276             // convert back to normalized form
277             static if (~flags & Flags.infinity)
278                 // No infinity support?
279                 assert(sig != 0, "Infinity floating point value assigned to a "
280                         ~ typeof(this).stringof ~ " (no infinity support).");
281 
282             static if (~flags & Flags.nan)  // No NaN support?
283                 assert(sig == 0, "NaN floating point value assigned to a " ~
284                         typeof(this).stringof ~ " (no nan support).");
285             sig >>= shift;
286             return;
287         }
288         if (exp == exp.min)     // 0.0
289         {
290              exp = 0;
291              sig = 0;
292              return;
293         }
294 
295         exp += bias;
296         if (exp <= 0)
297         {
298             static if ((flags&Flags.allowDenorm) ||
299                        // Convert from normalized form to denormalized
300                        (~flags&Flags.storeNormalized))
301             {
302                 shift += -exp;
303                 roundedShift(sig,1);
304                 sig   += cast(T) 1uL << (T.sizeof*8 - 1);
305                 // Add the leading 1
306                 exp    = 0;
307             }
308             else
309                 assert((flags&Flags.storeNormalized) && exp == 0,
310                     "Underflow occured assigning to a " ~
311                     typeof(this).stringof ~ " (no denormal support).");
312         }
313         else
314         {
315             static if (~flags&Flags.storeNormalized)
316             {
317                 // Convert from normalized form to denormalized
318                 roundedShift(sig,1);
319                 sig  += cast(T) 1uL << (T.sizeof*8 - 1);
320                 // Add the leading 1
321             }
322         }
323 
324         if (shift > 0)
325             roundedShift(sig,shift);
326         if (sig > significand_max)
327         {
328             // handle significand overflow (should only be 1 bit)
329             static if (~flags&Flags.storeNormalized)
330             {
331                 sig >>= 1;
332             }
333             else
334                 sig &= significand_max;
335             exp++;
336         }
337         static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly)
338         {
339             // disallow non-zero denormals
340             if (exp == 0)
341             {
342                 sig <<= 1;
343                 if (sig > significand_max && (sig&significand_max) > 0)
344                     // Check and round to even
345                     exp++;
346                 sig = 0;
347             }
348         }
349 
350         if (exp >= exponent_max)
351         {
352             static if (flags&(Flags.infinity|Flags.nan))
353             {
354                 sig         = 0;
355                 exp         = exponent_max;
356                 static if (~flags&(Flags.infinity))
357                     assert(0, "Overflow occured assigning to a " ~
358                         typeof(this).stringof ~ " (no infinity support).");
359             }
360             else
361                 assert(exp == exponent_max, "Overflow occured assigning to a "
362                     ~ typeof(this).stringof ~ " (no infinity support).");
363         }
364     }
365 
366 public:
367     static if (precision == 64) // CustomFloat!80 support hack
368     {
369         ulong significand;
370         enum ulong significand_max = ulong.max;
371         mixin(bitfields!(
372             T_exp , "exponent", exponentWidth,
373             bool  , "sign"    , flags & flags.signed ));
374     }
375     else
376     {
377         mixin(bitfields!(
378             T_sig, "significand", precision,
379             T_exp, "exponent"   , exponentWidth,
380             bool , "sign"       , flags & flags.signed ));
381     }
382 
383     /// Returns: infinity value
384     static if (flags & Flags.infinity)
infinity()385         static @property CustomFloat infinity()
386         {
387             CustomFloat value;
388             static if (flags & Flags.signed)
389             value.sign          = 0;
390             value.significand   = 0;
391             value.exponent      = exponent_max;
392             return value;
393         }
394 
395     /// Returns: NaN value
396     static if (flags & Flags.nan)
nan()397         static @property CustomFloat nan()
398         {
399             CustomFloat value;
400             static if (flags & Flags.signed)
401             value.sign          = 0;
402             value.significand   = cast(typeof(significand_max)) 1L << (precision-1);
403             value.exponent      = exponent_max;
404             return value;
405         }
406 
407     /// Returns: number of decimal digits of precision
dig()408     static @property size_t dig()
409     {
410         auto shiftcnt =  precision - ((flags&Flags.storeNormalized) != 0);
411         immutable x = (shiftcnt == 64) ? 0 : 1uL << shiftcnt;
412         return cast(size_t) log10(x);
413     }
414 
415     /// Returns: smallest increment to the value 1
epsilon()416     static @property CustomFloat epsilon()
417     {
418         CustomFloat value;
419         static if (flags & Flags.signed)
420         value.sign       = 0;
421         T_signed_exp exp = -precision;
422         T_sig        sig = 0;
423 
424         value.fromNormalized(sig,exp);
425         if (exp == 0 && sig == 0) // underflowed to zero
426         {
427             static if ((flags&Flags.allowDenorm) ||
428                        (~flags&Flags.storeNormalized))
429                 sig = 1;
430             else
431                 sig = cast(T) 1uL << (precision - 1);
432         }
433         value.exponent     = cast(value.T_exp) exp;
434         value.significand  = cast(value.T_sig) sig;
435         return value;
436     }
437 
438     /// the number of bits in mantissa
439     enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0);
440 
441     /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable
max_10_exp()442     static @property int max_10_exp(){ return cast(int) log10( +max ); }
443 
444     /// maximum int value such that 2<sup>max_exp-1</sup> is representable
445     enum max_exp = exponent_max-bias+((~flags&(Flags.infinity|flags.nan))!=0);
446 
447     /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable
min_10_exp()448     static @property int min_10_exp(){ return cast(int) log10( +min_normal ); }
449 
450     /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value
451     enum min_exp = cast(T_signed_exp)-bias +1+ ((flags&Flags.allowDenorm)!=0);
452 
453     /// Returns: largest representable value that's not infinity
max()454     static @property CustomFloat max()
455     {
456         CustomFloat value;
457         static if (flags & Flags.signed)
458         value.sign        = 0;
459         value.exponent    = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0);
460         value.significand = significand_max;
461         return value;
462     }
463 
464     /// Returns: smallest representable normalized value that's not 0
min_normal()465     static @property CustomFloat min_normal() {
466         CustomFloat value;
467         static if (flags & Flags.signed)
468         value.sign        = 0;
469         value.exponent    = 1;
470         static if (flags&Flags.storeNormalized)
471             value.significand = 0;
472         else
473             value.significand = cast(T_sig) 1uL << (precision - 1);
474         return value;
475     }
476 
477     /// Returns: real part
re()478     @property CustomFloat re() { return this; }
479 
480     /// Returns: imaginary part
im()481     static @property CustomFloat im() { return CustomFloat(0.0f); }
482 
483     /// Initialize from any $(D real) compatible type.
484     this(F)(F input) if (__traits(compiles, cast(real) input ))
485     {
486         this = input;
487     }
488 
489     /// Self assignment
490     void opAssign(F:CustomFloat)(F input)
491     {
492         static if (flags & Flags.signed)
493         sign        = input.sign;
494         exponent    = input.exponent;
495         significand = input.significand;
496     }
497 
498     /// Assigns from any $(D real) compatible type.
499     void opAssign(F)(F input)
500         if (__traits(compiles, cast(real) input))
501     {
502         import std.conv : text;
503 
504         static if (staticIndexOf!(Unqual!F, float, double, real) >= 0)
505             auto value = ToBinary!(Unqual!F)(input);
506         else
507             auto value = ToBinary!(real    )(input);
508 
509         // Assign the sign bit
510         static if (~flags & Flags.signed)
511             assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0),
512                 "Incorrectly signed floating point value assigned to a " ~
513                 typeof(this).stringof ~ " (no sign support).");
514         else
515             sign = value.sign;
516 
517         CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent;
518         CommonType!(T_sig,        value.T_sig       ) sig = value.significand;
519 
520         value.toNormalized(sig,exp);
521         fromNormalized(sig,exp);
522 
523         assert(exp <= exponent_max,    text(typeof(this).stringof ~
524             " exponent too large: "   ,exp," > ",exponent_max,   "\t",input,"\t",sig));
525         assert(sig <= significand_max, text(typeof(this).stringof ~
526             " significand too large: ",sig," > ",significand_max,
527             "\t",input,"\t",exp," ",exponent_max));
528         exponent    = cast(T_exp) exp;
529         significand = cast(T_sig) sig;
530     }
531 
532     /// Fetches the stored value either as a $(D float), $(D double) or $(D real).
533     @property F get(F)()
534         if (staticIndexOf!(Unqual!F, float, double, real) >= 0)
535     {
536         import std.conv : text;
537 
538         ToBinary!F result;
539 
540         static if (flags&Flags.signed)
541             result.sign = sign;
542         else
543             result.sign = (flags&flags.negativeUnsigned) > 0;
544 
545         CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction
546         CommonType!(T_sig,        result.get.T_sig        ) sig = significand;
547 
548         toNormalized(sig,exp);
549         result.fromNormalized(sig,exp);
550         assert(exp <= result.exponent_max,    text("get exponent too large: "   ,exp," > ",result.exponent_max) );
551         assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) );
552         result.exponent     = cast(result.get.T_exp) exp;
553         result.significand  = cast(result.get.T_sig) sig;
554         return result.set;
555     }
556 
557     ///ditto
558     T opCast(T)() if (__traits(compiles, get!T )) { return get!T; }
559 
560     /// Convert the CustomFloat to a real and perform the relavent operator on the result
561     real opUnary(string op)()
562         if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--")
563     {
564         static if (op=="++" || op=="--")
565         {
566             auto result = get!real;
567             this = mixin(op~`result`);
568             return result;
569         }
570         else
571             return mixin(op~`get!real`);
572     }
573 
574     /// ditto
575     real opBinary(string op,T)(T b)
576         if (__traits(compiles, mixin(`get!real`~op~`b`)))
577     {
578         return mixin(`get!real`~op~`b`);
579     }
580 
581     /// ditto
582     real opBinaryRight(string op,T)(T a)
583         if ( __traits(compiles, mixin(`a`~op~`get!real`)) &&
584             !__traits(compiles, mixin(`get!real`~op~`b`)))
585     {
586         return mixin(`a`~op~`get!real`);
587     }
588 
589     /// ditto
590     int opCmp(T)(auto ref T b)
591         if (__traits(compiles, cast(real) b))
592     {
593         auto x = get!real;
594         auto y = cast(real) b;
595         return  (x >= y)-(x <= y);
596     }
597 
598     /// ditto
599     void opOpAssign(string op, T)(auto ref T b)
600         if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`)))
601     {
602         return mixin(`this = this `~op~` cast(real) b`);
603     }
604 
605     /// ditto
toString()606     template toString()
607     {
608         import std.format : FormatSpec, formatValue;
609         // Needs to be a template because of DMD @@BUG@@ 13737.
610         void toString()(scope void delegate(const(char)[]) sink, FormatSpec!char fmt)
611         {
612             sink.formatValue(get!real, fmt);
613         }
614     }
615 }
616 
617 @safe unittest
618 {
619     import std.meta;
620     alias FPTypes =
621         AliasSeq!(
622             CustomFloat!(5, 10),
623             CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed),
624             CustomFloat!(1, 15, CustomFloatFlags.ieee ^ CustomFloatFlags.signed),
625             CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed)
626         );
627 
foreach(F;FPTypes)628     foreach (F; FPTypes)
629     {
630         auto x = F(0.125);
631         assert(x.get!float == 0.125F);
632         assert(x.get!double == 0.125);
633 
634         x -= 0.0625;
635         assert(x.get!float == 0.0625F);
636         assert(x.get!double == 0.0625);
637 
638         x *= 2;
639         assert(x.get!float == 0.125F);
640         assert(x.get!double == 0.125);
641 
642         x /= 4;
643         assert(x.get!float == 0.03125);
644         assert(x.get!double == 0.03125);
645 
646         x = 0.5;
647         x ^^= 4;
648         assert(x.get!float == 1 / 16.0F);
649         assert(x.get!double == 1 / 16.0);
650     }
651 }
652 
653 @system unittest
654 {
655     // @system due to to!string(CustomFloat)
656     import std.conv;
657     CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125);
658     assert(y.to!string == "0.125");
659 }
660 
661 /**
662 Defines the fastest type to use when storing temporaries of a
663 calculation intended to ultimately yield a result of type $(D F)
664 (where $(D F) must be one of $(D float), $(D double), or $(D
665 real)). When doing a multi-step computation, you may want to store
666 intermediate results as $(D FPTemporary!F).
667 
668 The necessity of $(D FPTemporary) stems from the optimized
669 floating-point operations and registers present in virtually all
670 processors. When adding numbers in the example above, the addition may
671 in fact be done in $(D real) precision internally. In that case,
672 storing the intermediate $(D result) in $(D double format) is not only
673 less precise, it is also (surprisingly) slower, because a conversion
674 from $(D real) to $(D double) is performed every pass through the
675 loop. This being a lose-lose situation, $(D FPTemporary!F) has been
676 defined as the $(I fastest) type to use for calculations at precision
677 $(D F). There is no need to define a type for the $(I most accurate)
678 calculations, as that is always $(D real).
679 
680 Finally, there is no guarantee that using $(D FPTemporary!F) will
681 always be fastest, as the speed of floating-point calculations depends
682 on very many factors.
683  */
684 template FPTemporary(F)
685 if (isFloatingPoint!F)
686 {
687     version (X86)
688         alias FPTemporary = real;
689     else
690         alias FPTemporary = Unqual!F;
691 }
692 
693 ///
694 @safe unittest
695 {
696     import std.math : approxEqual;
697 
698     // Average numbers in an array
avg(in double[]a)699     double avg(in double[] a)
700     {
701         if (a.length == 0) return 0;
702         FPTemporary!double result = 0;
703         foreach (e; a) result += e;
704         return result / a.length;
705     }
706 
707     auto a = [1.0, 2.0, 3.0];
708     assert(approxEqual(avg(a), 2));
709 }
710 
711 /**
712 Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a
713 root of the function $(D fun) starting from points $(D [xn_1, x_n])
714 (ideally close to the root). $(D Num) may be $(D float), $(D double),
715 or $(D real).
716 */
secantMethod(alias fun)717 template secantMethod(alias fun)
718 {
719     import std.functional : unaryFun;
720     Num secantMethod(Num)(Num xn_1, Num xn)
721     {
722         auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn;
723         typeof(fxn) fxn_1;
724 
725         xn = xn_1;
726         while (!approxEqual(d, 0) && isFinite(d))
727         {
728             xn_1 = xn;
729             xn -= d;
730             fxn_1 = fxn;
731             fxn = unaryFun!(fun)(xn);
732             d *= -fxn / (fxn - fxn_1);
733         }
734         return xn;
735     }
736 }
737 
738 ///
739 @safe unittest
740 {
741     import std.math : approxEqual, cos;
742 
f(float x)743     float f(float x)
744     {
745         return cos(x) - x*x*x;
746     }
747     auto x = secantMethod!(f)(0f, 1f);
748     assert(approxEqual(x, 0.865474));
749 }
750 
751 @system unittest
752 {
753     // @system because of __gshared stderr
754     scope(failure) stderr.writeln("Failure testing secantMethod");
f(float x)755     float f(float x)
756     {
757         return cos(x) - x*x*x;
758     }
759     immutable x = secantMethod!(f)(0f, 1f);
760     assert(approxEqual(x, 0.865474));
761     auto d = &f;
762     immutable y = secantMethod!(d)(0f, 1f);
763     assert(approxEqual(y, 0.865474));
764 }
765 
766 
767 /**
768  * Return true if a and b have opposite sign.
769  */
oppositeSigns(T1,T2)770 private bool oppositeSigns(T1, T2)(T1 a, T2 b)
771 {
772     return signbit(a) != signbit(b);
773 }
774 
775 public:
776 
777 /**  Find a real root of a real function f(x) via bracketing.
778  *
779  * Given a function `f` and a range `[a .. b]` such that `f(a)`
780  * and `f(b)` have opposite signs or at least one of them equals ±0,
781  * returns the value of `x` in
782  * the range which is closest to a root of `f(x)`.  If `f(x)`
783  * has more than one root in the range, one will be chosen
784  * arbitrarily.  If `f(x)` returns NaN, NaN will be returned;
785  * otherwise, this algorithm is guaranteed to succeed.
786  *
787  * Uses an algorithm based on TOMS748, which uses inverse cubic
788  * interpolation whenever possible, otherwise reverting to parabolic
789  * or secant interpolation. Compared to TOMS748, this implementation
790  * improves worst-case performance by a factor of more than 100, and
791  * typical performance by a factor of 2. For 80-bit reals, most
792  * problems require 8 to 15 calls to `f(x)` to achieve full machine
793  * precision. The worst-case performance (pathological cases) is
794  * approximately twice the number of bits.
795  *
796  * References: "On Enclosing Simple Roots of Nonlinear Equations",
797  * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61,
798  * pp733-744 (1993).  Fortran code available from $(HTTP
799  * www.netlib.org,www.netlib.org) as algorithm TOMS478.
800  *
801  */
802 T findRoot(T, DF, DT)(scope DF f, in T a, in T b,
803     scope DT tolerance) //= (T a, T b) => false)
804 if (
805     isFloatingPoint!T &&
806     is(typeof(tolerance(T.init, T.init)) : bool) &&
807     is(typeof(f(T.init)) == R, R) && isFloatingPoint!R
808     )
809 {
810     immutable fa = f(a);
811     if (fa == 0)
812         return a;
813     immutable fb = f(b);
814     if (fb == 0)
815         return b;
816     immutable r = findRoot(f, a, b, fa, fb, tolerance);
817     // Return the first value if it is smaller or NaN
818     return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1];
819 }
820 
821 ///ditto
findRoot(T,DF)822 T findRoot(T, DF)(scope DF f, in T a, in T b)
823 {
824     return findRoot(f, a, b, (T a, T b) => false);
825 }
826 
827 /** Find root of a real function f(x) by bracketing, allowing the
828  * termination condition to be specified.
829  *
830  * Params:
831  *
832  * f = Function to be analyzed
833  *
834  * ax = Left bound of initial range of `f` known to contain the
835  * root.
836  *
837  * bx = Right bound of initial range of `f` known to contain the
838  * root.
839  *
840  * fax = Value of $(D f(ax)).
841  *
842  * fbx = Value of $(D f(bx)). $(D fax) and $(D fbx) should have opposite signs.
843  * ($(D f(ax)) and $(D f(bx)) are commonly known in advance.)
844  *
845  *
846  * tolerance = Defines an early termination condition. Receives the
847  *             current upper and lower bounds on the root. The
848  *             delegate must return $(D true) when these bounds are
849  *             acceptable. If this function always returns $(D false),
850  *             full machine precision will be achieved.
851  *
852  * Returns:
853  *
854  * A tuple consisting of two ranges. The first two elements are the
855  * range (in `x`) of the root, while the second pair of elements
856  * are the corresponding function values at those points. If an exact
857  * root was found, both of the first two elements will contain the
858  * root, and the second pair of elements will be 0.
859  */
860 Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f, in T ax, in T bx, in R fax, in R fbx,
861     scope DT tolerance) // = (T a, T b) => false)
862 if (
863     isFloatingPoint!T &&
864     is(typeof(tolerance(T.init, T.init)) : bool) &&
865     is(typeof(f(T.init)) == R) && isFloatingPoint!R
866     )
867 in
868 {
869     assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN");
870     assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root.");
871 }
872 body
873 {
874     // Author: Don Clugston. This code is (heavily) modified from TOMS748
875     // (www.netlib.org).  The changes to improve the worst-cast performance are
876     // entirely original.
877 
878     T a, b, d;  // [a .. b] is our current bracket. d is the third best guess.
879     R fa, fb, fd; // Values of f at a, b, d.
880     bool done = false; // Has a root been found?
881 
882     // Allow ax and bx to be provided in reverse order
883     if (ax <= bx)
884     {
885         a = ax; fa = fax;
886         b = bx; fb = fbx;
887     }
888     else
889     {
890         a = bx; fa = fbx;
891         b = ax; fb = fax;
892     }
893 
894     // Test the function at point c; update brackets accordingly
bracket(T c)895     void bracket(T c)
896     {
897         R fc = f(c);
898         if (fc == 0 || fc.isNaN()) // Exact solution, or NaN
899         {
900             a = c;
901             fa = fc;
902             d = c;
903             fd = fc;
904             done = true;
905             return;
906         }
907 
908         // Determine new enclosing interval
909         if (signbit(fa) != signbit(fc))
910         {
911             d = b;
912             fd = fb;
913             b = c;
914             fb = fc;
915         }
916         else
917         {
918             d = a;
919             fd = fa;
920             a = c;
921             fa = fc;
922         }
923     }
924 
925    /* Perform a secant interpolation. If the result would lie on a or b, or if
926      a and b differ so wildly in magnitude that the result would be meaningless,
927      perform a bisection instead.
928     */
secant_interpolate(T a,T b,R fa,R fb)929     static T secant_interpolate(T a, T b, R fa, R fb)
930     {
931         if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b)))
932         {
933             // Catastrophic cancellation
934             if (a == 0)
935                 a = copysign(T(0), b);
936             else if (b == 0)
937                 b = copysign(T(0), a);
938             else if (signbit(a) != signbit(b))
939                 return 0;
940             T c = ieeeMean(a, b);
941             return c;
942         }
943         // avoid overflow
944         if (b - a > T.max)
945             return b / 2 + a / 2;
946         if (fb - fa > R.max)
947             return a - (b - a) / 2;
948         T c = a - (fa / (fb - fa)) * (b - a);
949         if (c == a || c == b)
950             return (a + b) / 2;
951         return c;
952     }
953 
954     /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the
955        quadratic polynomial interpolating f(x) at a, b, and d.
956        Returns:
957          The approximate zero in [a .. b] of the quadratic polynomial.
958     */
newtonQuadratic(int numsteps)959     T newtonQuadratic(int numsteps)
960     {
961         // Find the coefficients of the quadratic polynomial.
962         immutable T a0 = fa;
963         immutable T a1 = (fb - fa)/(b - a);
964         immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a);
965 
966         // Determine the starting point of newton steps.
967         T c = oppositeSigns(a2, fa) ? a  : b;
968 
969         // start the safeguarded newton steps.
970         foreach (int i; 0 .. numsteps)
971         {
972             immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a);
973             immutable T pdc = a1 + a2*((2 * c) - (a + b));
974             if (pdc == 0)
975                 return a - a0 / a1;
976             else
977                 c = c - pc / pdc;
978         }
979         return c;
980     }
981 
982     // On the first iteration we take a secant step:
983     if (fa == 0 || fa.isNaN())
984     {
985         done = true;
986         b = a;
987         fb = fa;
988     }
989     else if (fb == 0 || fb.isNaN())
990     {
991         done = true;
992         a = b;
993         fa = fb;
994     }
995     else
996     {
997         bracket(secant_interpolate(a, b, fa, fb));
998     }
999 
1000     // Starting with the second iteration, higher-order interpolation can
1001     // be used.
1002     int itnum = 1;   // Iteration number
1003     int baditer = 1; // Num bisections to take if an iteration is bad.
1004     T c, e;  // e is our fourth best guess
1005     R fe;
1006 
1007 whileloop:
1008     while (!done && (b != nextUp(a)) && !tolerance(a, b))
1009     {
1010         T a0 = a, b0 = b; // record the brackets
1011 
1012         // Do two higher-order (cubic or parabolic) interpolation steps.
1013         foreach (int QQ; 0 .. 2)
1014         {
1015             // Cubic inverse interpolation requires that
1016             // all four function values fa, fb, fd, and fe are distinct;
1017             // otherwise use quadratic interpolation.
1018             bool distinct = (fa != fb) && (fa != fd) && (fa != fe)
1019                          && (fb != fd) && (fb != fe) && (fd != fe);
1020             // The first time, cubic interpolation is impossible.
1021             if (itnum<2) distinct = false;
1022             bool ok = distinct;
1023             if (distinct)
1024             {
1025                 // Cubic inverse interpolation of f(x) at a, b, d, and e
1026                 immutable q11 = (d - e) * fd / (fe - fd);
1027                 immutable q21 = (b - d) * fb / (fd - fb);
1028                 immutable q31 = (a - b) * fa / (fb - fa);
1029                 immutable d21 = (b - d) * fd / (fd - fb);
1030                 immutable d31 = (a - b) * fb / (fb - fa);
1031 
1032                 immutable q22 = (d21 - q11) * fb / (fe - fb);
1033                 immutable q32 = (d31 - q21) * fa / (fd - fa);
1034                 immutable d32 = (d31 - q21) * fd / (fd - fa);
1035                 immutable q33 = (d32 - q22) * fa / (fe - fa);
1036                 c = a + (q31 + q32 + q33);
1037                 if (c.isNaN() || (c <= a) || (c >= b))
1038                 {
1039                     // DAC: If the interpolation predicts a or b, it's
1040                     // probable that it's the actual root. Only allow this if
1041                     // we're already close to the root.
1042                     if (c == a && a - b != a)
1043                     {
1044                         c = nextUp(a);
1045                     }
1046                     else if (c == b && a - b != -b)
1047                     {
1048                         c = nextDown(b);
1049                     }
1050                     else
1051                     {
1052                         ok = false;
1053                     }
1054                 }
1055             }
1056             if (!ok)
1057             {
1058                 // DAC: Alefeld doesn't explain why the number of newton steps
1059                 // should vary.
1060                 c = newtonQuadratic(distinct ? 3 : 2);
1061                 if (c.isNaN() || (c <= a) || (c >= b))
1062                 {
1063                     // Failure, try a secant step:
1064                     c = secant_interpolate(a, b, fa, fb);
1065                 }
1066             }
1067             ++itnum;
1068             e = d;
1069             fe = fd;
1070             bracket(c);
1071             if (done || ( b == nextUp(a)) || tolerance(a, b))
1072                 break whileloop;
1073             if (itnum == 2)
1074                 continue whileloop;
1075         }
1076 
1077         // Now we take a double-length secant step:
1078         T u;
1079         R fu;
1080         if (fabs(fa) < fabs(fb))
1081         {
1082             u = a;
1083             fu = fa;
1084         }
1085         else
1086         {
1087             u = b;
1088             fu = fb;
1089         }
1090         c = u - 2 * (fu / (fb - fa)) * (b - a);
1091 
1092         // DAC: If the secant predicts a value equal to an endpoint, it's
1093         // probably false.
1094         if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2)
1095         {
1096             if ((a-b) == a || (b-a) == b)
1097             {
1098                 if ((a>0 && b<0) || (a<0 && b>0))
1099                     c = 0;
1100                 else
1101                 {
1102                     if (a == 0)
1103                         c = ieeeMean(copysign(T(0), b), b);
1104                     else if (b == 0)
1105                         c = ieeeMean(copysign(T(0), a), a);
1106                     else
1107                         c = ieeeMean(a, b);
1108                 }
1109             }
1110             else
1111             {
1112                 c = a + (b - a) / 2;
1113             }
1114         }
1115         e = d;
1116         fe = fd;
1117         bracket(c);
1118         if (done || (b == nextUp(a)) || tolerance(a, b))
1119             break;
1120 
1121         // IMPROVE THE WORST-CASE PERFORMANCE
1122         // We must ensure that the bounds reduce by a factor of 2
1123         // in binary space! every iteration. If we haven't achieved this
1124         // yet, or if we don't yet know what the exponent is,
1125         // perform a binary chop.
1126 
1127         if ((a == 0 || b == 0 ||
1128             (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a)))
1129             &&  (b - a) < T(0.25) * (b0 - a0))
1130         {
1131             baditer = 1;
1132             continue;
1133         }
1134 
1135         // DAC: If this happens on consecutive iterations, we probably have a
1136         // pathological function. Perform a number of bisections equal to the
1137         // total number of consecutive bad iterations.
1138 
1139         if ((b - a) < T(0.25) * (b0 - a0))
1140             baditer = 1;
1141         foreach (int QQ; 0 .. baditer)
1142         {
1143             e = d;
1144             fe = fd;
1145 
1146             T w;
1147             if ((a>0 && b<0) || (a<0 && b>0))
1148                 w = 0;
1149             else
1150             {
1151                 T usea = a;
1152                 T useb = b;
1153                 if (a == 0)
1154                     usea = copysign(T(0), b);
1155                 else if (b == 0)
1156                     useb = copysign(T(0), a);
1157                 w = ieeeMean(usea, useb);
1158             }
1159             bracket(w);
1160         }
1161         ++baditer;
1162     }
1163     return Tuple!(T, T, R, R)(a, b, fa, fb);
1164 }
1165 
1166 ///ditto
1167 Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f, in T ax, in T bx, in R fax, in R fbx)
1168 {
1169     return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false);
1170 }
1171 
1172 ///ditto
findRoot(T,R)1173 T findRoot(T, R)(scope R delegate(T) f, in T a, in T b,
1174     scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false)
1175 {
1176     return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance);
1177 }
1178 
1179 @safe nothrow unittest
1180 {
1181     int numProblems = 0;
1182     int numCalls;
1183 
testFindRoot(real delegate (real)@nogc@safe nothrow pure f,real x1,real x2)1184     void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure
1185     {
1186         //numCalls=0;
1187         //++numProblems;
1188         assert(!x1.isNaN() && !x2.isNaN());
1189         assert(signbit(x1) != signbit(x2));
1190         auto result = findRoot(f, x1, x2, f(x1), f(x2),
1191           (real lo, real hi) { return false; });
1192 
1193         auto flo = f(result[0]);
1194         auto fhi = f(result[1]);
1195         if (flo != 0)
1196         {
1197             assert(oppositeSigns(flo, fhi));
1198         }
1199     }
1200 
1201     // Test functions
cubicfn(real x)1202     real cubicfn(real x) @nogc @safe nothrow pure
1203     {
1204         //++numCalls;
1205         if (x>float.max)
1206             x = float.max;
1207         if (x<-double.max)
1208             x = -double.max;
1209         // This has a single real root at -59.286543284815
1210         return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2;
1211     }
1212     // Test a function with more than one root.
multisine(real x)1213     real multisine(real x) { ++numCalls; return sin(x); }
1214     //testFindRoot( &multisine, 6, 90);
1215     //testFindRoot(&cubicfn, -100, 100);
1216     //testFindRoot( &cubicfn, -double.max, real.max);
1217 
1218 
1219 /* Tests from the paper:
1220  * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra,
1221  *   Yixun Shi, Mathematics of Computation 61, pp733-744 (1993).
1222  */
1223     // Parameters common to many alefeld tests.
1224     int n;
1225     real ale_a, ale_b;
1226 
1227     int powercalls = 0;
1228 
power(real x)1229     real power(real x)
1230     {
1231         ++powercalls;
1232         ++numCalls;
1233         return pow(x, n) + double.min_normal;
1234     }
1235     int [] power_nvals = [3, 5, 7, 9, 19, 25];
1236     // Alefeld paper states that pow(x,n) is a very poor case, where bisection
1237     // outperforms his method, and gives total numcalls =
1238     // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit),
1239     // 2624 for brent (6.8/bit)
1240     // ... but that is for double, not real80.
1241     // This poor performance seems mainly due to catastrophic cancellation,
1242     // which is avoided here by the use of ieeeMean().
1243     // I get: 231 (0.48/bit).
1244     // IE this is 10X faster in Alefeld's worst case
1245     numProblems=0;
foreach(k;power_nvals)1246     foreach (k; power_nvals)
1247     {
1248         n = k;
1249         //testFindRoot(&power, -1, 10);
1250     }
1251 
1252     int powerProblems = numProblems;
1253 
1254     // Tests from Alefeld paper
1255 
1256     int [9] alefeldSums;
alefeld0(real x)1257     real alefeld0(real x)
1258     {
1259         ++alefeldSums[0];
1260         ++numCalls;
1261         real q =  sin(x) - x/2;
1262         for (int i=1; i<20; ++i)
1263             q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i));
1264         return q;
1265     }
alefeld1(real x)1266     real alefeld1(real x)
1267     {
1268         ++numCalls;
1269         ++alefeldSums[1];
1270         return ale_a*x + exp(ale_b * x);
1271     }
alefeld2(real x)1272     real alefeld2(real x)
1273     {
1274         ++numCalls;
1275         ++alefeldSums[2];
1276         return pow(x, n) - ale_a;
1277     }
alefeld3(real x)1278     real alefeld3(real x)
1279     {
1280         ++numCalls;
1281         ++alefeldSums[3];
1282         return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2);
1283     }
alefeld4(real x)1284     real alefeld4(real x)
1285     {
1286         ++numCalls;
1287         ++alefeldSums[4];
1288         return x*x - pow(1-x, n);
1289     }
alefeld5(real x)1290     real alefeld5(real x)
1291     {
1292         ++numCalls;
1293         ++alefeldSums[5];
1294         return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4);
1295     }
alefeld6(real x)1296     real alefeld6(real x)
1297     {
1298         ++numCalls;
1299         ++alefeldSums[6];
1300         return exp(-n*x)*(x-1.01L) + pow(x, n);
1301     }
alefeld7(real x)1302     real alefeld7(real x)
1303     {
1304         ++numCalls;
1305         ++alefeldSums[7];
1306         return (n*x-1)/((n-1)*x);
1307     }
1308 
1309     numProblems=0;
1310     //testFindRoot(&alefeld0, PI_2, PI);
1311     for (n=1; n <= 10; ++n)
1312     {
1313         //testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L);
1314     }
1315     ale_a = -40; ale_b = -1;
1316     //testFindRoot(&alefeld1, -9, 31);
1317     ale_a = -100; ale_b = -2;
1318     //testFindRoot(&alefeld1, -9, 31);
1319     ale_a = -200; ale_b = -3;
1320     //testFindRoot(&alefeld1, -9, 31);
1321     int [] nvals_3 = [1, 2, 5, 10, 15, 20];
1322     int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20];
1323     int [] nvals_6 = [1, 5, 10, 15, 20];
1324     int [] nvals_7 = [2, 5, 15, 20];
1325 
1326     for (int i=4; i<12; i+=2)
1327     {
1328        n = i;
1329        ale_a = 0.2;
1330        //testFindRoot(&alefeld2, 0, 5);
1331        ale_a=1;
1332        //testFindRoot(&alefeld2, 0.95, 4.05);
1333        //testFindRoot(&alefeld2, 0, 1.5);
1334     }
foreach(i;nvals_3)1335     foreach (i; nvals_3)
1336     {
1337         n=i;
1338         //testFindRoot(&alefeld3, 0, 1);
1339     }
foreach(i;nvals_3)1340     foreach (i; nvals_3)
1341     {
1342         n=i;
1343         //testFindRoot(&alefeld4, 0, 1);
1344     }
foreach(i;nvals_5)1345     foreach (i; nvals_5)
1346     {
1347         n=i;
1348         //testFindRoot(&alefeld5, 0, 1);
1349     }
foreach(i;nvals_6)1350     foreach (i; nvals_6)
1351     {
1352         n=i;
1353         //testFindRoot(&alefeld6, 0, 1);
1354     }
foreach(i;nvals_7)1355     foreach (i; nvals_7)
1356     {
1357         n=i;
1358         //testFindRoot(&alefeld7, 0.01L, 1);
1359     }
worstcase(real x)1360     real worstcase(real x)
1361     {
1362         ++numCalls;
1363         return x<0.3*real.max? -0.999e-3 : 1.0;
1364     }
1365     //testFindRoot(&worstcase, -real.max, real.max);
1366 
1367     // just check that the double + float cases compile
1368     //findRoot((double x){ return 0.0; }, -double.max, double.max);
1369     //findRoot((float x){ return 0.0f; }, -float.max, float.max);
1370 
1371 /*
1372    int grandtotal=0;
1373    foreach (calls; alefeldSums)
1374    {
1375        grandtotal+=calls;
1376    }
1377    grandtotal-=2*numProblems;
1378    printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n",
1379    grandtotal, (1.0*grandtotal)/numProblems);
1380    powercalls -= 2*powerProblems;
1381    printf("POWER TOTAL = %d avg = %f ", powercalls,
1382         (1.0*powercalls)/powerProblems);
1383 */
1384     //Issue 14231
1385     auto xp = findRoot((float x) => x, 0f, 1f);
1386     auto xn = findRoot((float x) => x, -1f, -0f);
1387 }
1388 
1389 //regression control
1390 @system unittest
1391 {
1392     // @system due to the case in the 2nd line
1393     static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init)));
1394     static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init)));
1395     static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init)));
1396 }
1397 
1398 /++
1399 Find a real minimum of a real function `f(x)` via bracketing.
1400 Given a function `f` and a range `(ax .. bx)`,
1401 returns the value of `x` in the range which is closest to a minimum of `f(x)`.
1402 `f` is never evaluted at the endpoints of `ax` and `bx`.
1403 If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily.
1404 If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned;
1405 otherwise, this algorithm is guaranteed to succeed.
1406 
1407 Params:
1408     f = Function to be analyzed
1409     ax = Left bound of initial range of f known to contain the minimum.
1410     bx = Right bound of initial range of f known to contain the minimum.
1411     relTolerance = Relative tolerance.
1412     absTolerance = Absolute tolerance.
1413 
1414 Preconditions:
1415     `ax` and `bx` shall be finite reals. $(BR)
1416     $(D relTolerance) shall be normal positive real. $(BR)
1417     $(D absTolerance) shall be normal positive real no less then $(D T.epsilon*2).
1418 
1419 Returns:
1420     A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`.
1421 
1422     The method used is a combination of golden section search and
1423 successive parabolic interpolation. Convergence is never much slower
1424 than that for a Fibonacci search.
1425 
1426 References:
1427     "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973)
1428 
1429 See_Also: $(LREF findRoot), $(REF isNormal, std,math)
1430 +/
1431 Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error")
1432 findLocalMin(T, DF)(
1433         scope DF f,
1434         in T ax,
1435         in T bx,
1436         in T relTolerance = sqrt(T.epsilon),
1437         in T absTolerance = sqrt(T.epsilon),
1438         )
1439 if (isFloatingPoint!T
1440     && __traits(compiles, {T _ = DF.init(T.init);}))
1441 in
1442 {
1443     assert(isFinite(ax), "ax is not finite");
1444     assert(isFinite(bx), "bx is not finite");
1445     assert(isNormal(relTolerance), "relTolerance is not normal floating point number");
1446     assert(isNormal(absTolerance), "absTolerance is not normal floating point number");
1447     assert(relTolerance >= 0, "absTolerance is not positive");
1448     assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`");
1449 }
out(result)1450 out (result)
1451 {
1452     assert(isFinite(result.x));
1453 }
1454 body
1455 {
1456     alias R = Unqual!(CommonType!(ReturnType!DF, T));
1457     // c is the squared inverse of the golden ratio
1458     // (3 - sqrt(5))/2
1459     // Value obtained from Wolfram Alpha.
1460     enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L;
1461     enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L;
1462     R tolerance;
1463     T a = ax > bx ? bx : ax;
1464     T b = ax > bx ? ax : bx;
1465     // sequence of declarations suitable for SIMD instructions
1466     T  v = a * cm1 + b * c;
1467     assert(isFinite(v));
1468     R fv = f(v);
1469     if (isNaN(fv) || fv == -T.infinity)
1470     {
1471         return typeof(return)(v, fv, T.init);
1472     }
1473     T  w = v;
1474     R fw = fv;
1475     T  x = v;
1476     R fx = fv;
1477     size_t i;
1478     for (R d = 0, e = 0;;)
1479     {
1480         i++;
1481         T m = (a + b) / 2;
1482         // This fix is not part of the original algorithm
1483         if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R.
1484         {
1485             m = a / 2 + b / 2;
1486             if (!isFinite(m)) // fast-math compiler switch is enabled
1487             {
1488                 //SIMD instructions can be used by compiler, do not reduce declarations
1489                 int a_exp = void;
1490                 int b_exp = void;
1491                 immutable an = frexp(a, a_exp);
1492                 immutable bn = frexp(b, b_exp);
1493                 immutable am = ldexp(an, a_exp-1);
1494                 immutable bm = ldexp(bn, b_exp-1);
1495                 m = am + bm;
1496                 if (!isFinite(m)) // wrong input: constraints are disabled in release mode
1497                 {
1498                     return typeof(return).init;
1499                 }
1500             }
1501         }
1502         tolerance = absTolerance * fabs(x) + relTolerance;
1503         immutable t2 = tolerance * 2;
1504         // check stopping criterion
1505         if (!(fabs(x - m) > t2 - (b - a) / 2))
1506         {
1507             break;
1508         }
1509         R p = 0;
1510         R q = 0;
1511         R r = 0;
1512         // fit parabola
1513         if (fabs(e) > tolerance)
1514         {
1515             immutable  xw =  x -  w;
1516             immutable fxw = fx - fw;
1517             immutable  xv =  x -  v;
1518             immutable fxv = fx - fv;
1519             immutable xwfxv = xw * fxv;
1520             immutable xvfxw = xv * fxw;
1521             p = xv * xvfxw - xw * xwfxv;
1522             q = (xvfxw - xwfxv) * 2;
1523             if (q > 0)
1524                 p = -p;
1525             else
1526                 q = -q;
1527             r = e;
1528             e = d;
1529         }
1530         T u;
1531         // a parabolic-interpolation step
1532         if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x))
1533         {
1534             d = p / q;
1535             u = x + d;
1536             // f must not be evaluated too close to a or b
1537             if (u - a < t2 || b - u < t2)
1538                 d = x < m ? tolerance : -tolerance;
1539         }
1540         // a golden-section step
1541         else
1542         {
1543             e = (x < m ? b : a) - x;
1544             d = c * e;
1545         }
1546         // f must not be evaluated too close to x
1547         u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance);
1548         immutable fu = f(u);
1549         if (isNaN(fu) || fu == -T.infinity)
1550         {
1551             return typeof(return)(u, fu, T.init);
1552         }
1553         //  update  a, b, v, w, and x
1554         if (fu <= fx)
1555         {
1556             u < x ? b : a = x;
1557             v = w; fv = fw;
1558             w = x; fw = fx;
1559             x = u; fx = fu;
1560         }
1561         else
1562         {
1563             u < x ? a : b = u;
1564             if (fu <= fw || w == x)
1565             {
1566                 v = w; fv = fw;
1567                 w = u; fw = fu;
1568             }
1569             else if (fu <= fv || v == x || v == w)
1570             { // do not remove this braces
1571                 v = u; fv = fu;
1572             }
1573         }
1574     }
1575     return typeof(return)(x, fx, tolerance * 3);
1576 }
1577 
1578 ///
1579 @safe unittest
1580 {
1581     import std.math : approxEqual;
1582 
1583     auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7);
1584     assert(ret.x.approxEqual(4.0));
1585     assert(ret.y.approxEqual(0.0));
1586 }
1587 
1588 @safe unittest
1589 {
1590     import std.meta : AliasSeq;
1591     foreach (T; AliasSeq!(double, float, real))
1592     {
1593         {
1594             auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7);
1595             assert(ret.x.approxEqual(T(4)));
1596             assert(ret.y.approxEqual(T(0)));
1597         }
1598         {
1599             auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon);
1600             assert(approxEqual(ret.x, T(1)));
1601             assert(approxEqual(ret.y, T(0)));
1602             assert(ret.error <= 10 * T.epsilon);
1603         }
1604         {
1605             auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon);
1606             assert(!ret.x.isNaN);
1607             assert(ret.y.isNaN);
1608             assert(ret.error.isNaN);
1609         }
1610         {
1611             auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon);
1612             assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal));
1613             assert(ret.x >= 0 && ret.x <= ret.error);
1614         }
1615         {
1616             auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon);
1617             assert(ret.y < -18);
1618             assert(ret.error < 5e-08);
1619             assert(ret.x >= 0 && ret.x <= ret.error);
1620         }
1621         {
1622             auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon);
1623             assert(ret.x.fabs.approxEqual(T(1)));
1624             assert(ret.y.fabs.approxEqual(T(1)));
1625             assert(ret.error.approxEqual(T(0)));
1626         }
1627     }
1628 }
1629 
1630 /**
1631 Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance,
1632 Euclidean distance) between input ranges $(D a) and
1633 $(D b). The two ranges must have the same length. The three-parameter
1634 version stops computation as soon as the distance is greater than or
1635 equal to $(D limit) (this is useful to save computation if a small
1636 distance is sought).
1637  */
1638 CommonType!(ElementType!(Range1), ElementType!(Range2))
1639 euclideanDistance(Range1, Range2)(Range1 a, Range2 b)
1640 if (isInputRange!(Range1) && isInputRange!(Range2))
1641 {
1642     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1643     static if (haveLen) assert(a.length == b.length);
1644     Unqual!(typeof(return)) result = 0;
1645     for (; !a.empty; a.popFront(), b.popFront())
1646     {
1647         immutable t = a.front - b.front;
1648         result += t * t;
1649     }
1650     static if (!haveLen) assert(b.empty);
1651     return sqrt(result);
1652 }
1653 
1654 /// Ditto
1655 CommonType!(ElementType!(Range1), ElementType!(Range2))
1656 euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit)
1657 if (isInputRange!(Range1) && isInputRange!(Range2))
1658 {
1659     limit *= limit;
1660     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1661     static if (haveLen) assert(a.length == b.length);
1662     Unqual!(typeof(return)) result = 0;
1663     for (; ; a.popFront(), b.popFront())
1664     {
1665         if (a.empty)
1666         {
1667             static if (!haveLen) assert(b.empty);
1668             break;
1669         }
1670         immutable t = a.front - b.front;
1671         result += t * t;
1672         if (result >= limit) break;
1673     }
1674     return sqrt(result);
1675 }
1676 
1677 @safe unittest
1678 {
1679     import std.meta : AliasSeq;
1680     foreach (T; AliasSeq!(double, const double, immutable double))
1681     {
1682         T[] a = [ 1.0, 2.0, ];
1683         T[] b = [ 4.0, 6.0, ];
1684         assert(euclideanDistance(a, b) == 5);
1685         assert(euclideanDistance(a, b, 5) == 5);
1686         assert(euclideanDistance(a, b, 4) == 5);
1687         assert(euclideanDistance(a, b, 2) == 3);
1688     }
1689 }
1690 
1691 /**
1692 Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product,
1693 dot product) of input ranges $(D a) and $(D
1694 b). The two ranges must have the same length. If both ranges define
1695 length, the check is done once; otherwise, it is done at each
1696 iteration.
1697  */
1698 CommonType!(ElementType!(Range1), ElementType!(Range2))
1699 dotProduct(Range1, Range2)(Range1 a, Range2 b)
1700 if (isInputRange!(Range1) && isInputRange!(Range2) &&
1701     !(isArray!(Range1) && isArray!(Range2)))
1702 {
1703     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1704     static if (haveLen) assert(a.length == b.length);
1705     Unqual!(typeof(return)) result = 0;
1706     for (; !a.empty; a.popFront(), b.popFront())
1707     {
1708         result += a.front * b.front;
1709     }
1710     static if (!haveLen) assert(b.empty);
1711     return result;
1712 }
1713 
1714 /// Ditto
1715 CommonType!(F1, F2)
1716 dotProduct(F1, F2)(in F1[] avector, in F2[] bvector)
1717 {
1718     immutable n = avector.length;
1719     assert(n == bvector.length);
1720     auto avec = avector.ptr, bvec = bvector.ptr;
1721     Unqual!(typeof(return)) sum0 = 0, sum1 = 0;
1722 
1723     const all_endp = avec + n;
1724     const smallblock_endp = avec + (n & ~3);
1725     const bigblock_endp = avec + (n & ~15);
1726 
1727     for (; avec != bigblock_endp; avec += 16, bvec += 16)
1728     {
1729         sum0 += avec[0] * bvec[0];
1730         sum1 += avec[1] * bvec[1];
1731         sum0 += avec[2] * bvec[2];
1732         sum1 += avec[3] * bvec[3];
1733         sum0 += avec[4] * bvec[4];
1734         sum1 += avec[5] * bvec[5];
1735         sum0 += avec[6] * bvec[6];
1736         sum1 += avec[7] * bvec[7];
1737         sum0 += avec[8] * bvec[8];
1738         sum1 += avec[9] * bvec[9];
1739         sum0 += avec[10] * bvec[10];
1740         sum1 += avec[11] * bvec[11];
1741         sum0 += avec[12] * bvec[12];
1742         sum1 += avec[13] * bvec[13];
1743         sum0 += avec[14] * bvec[14];
1744         sum1 += avec[15] * bvec[15];
1745     }
1746 
1747     for (; avec != smallblock_endp; avec += 4, bvec += 4)
1748     {
1749         sum0 += avec[0] * bvec[0];
1750         sum1 += avec[1] * bvec[1];
1751         sum0 += avec[2] * bvec[2];
1752         sum1 += avec[3] * bvec[3];
1753     }
1754 
1755     sum0 += sum1;
1756 
1757     /* Do trailing portion in naive loop. */
1758     while (avec != all_endp)
1759     {
1760         sum0 += *avec * *bvec;
1761         ++avec;
1762         ++bvec;
1763     }
1764 
1765     return sum0;
1766 }
1767 
1768 @system unittest
1769 {
1770     // @system due to dotProduct and assertCTFEable
1771     import std.exception : assertCTFEable;
1772     import std.meta : AliasSeq;
1773     foreach (T; AliasSeq!(double, const double, immutable double))
1774     {
1775         T[] a = [ 1.0, 2.0, ];
1776         T[] b = [ 4.0, 6.0, ];
1777         assert(dotProduct(a, b) == 16);
1778         assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3);
1779     }
1780 
1781     // Make sure the unrolled loop codepath gets tested.
1782     static const x =
1783         [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18];
1784     static const y =
1785         [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19];
1786     assertCTFEable!({ assert(dotProduct(x, y) == 2280); });
1787 }
1788 
1789 /**
1790 Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity,
1791 cosine similarity) of input ranges $(D a) and $(D
1792 b). The two ranges must have the same length. If both ranges define
1793 length, the check is done once; otherwise, it is done at each
1794 iteration. If either range has all-zero elements, return 0.
1795  */
1796 CommonType!(ElementType!(Range1), ElementType!(Range2))
1797 cosineSimilarity(Range1, Range2)(Range1 a, Range2 b)
1798 if (isInputRange!(Range1) && isInputRange!(Range2))
1799 {
1800     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1801     static if (haveLen) assert(a.length == b.length);
1802     Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0;
1803     for (; !a.empty; a.popFront(), b.popFront())
1804     {
1805         immutable t1 = a.front, t2 = b.front;
1806         norma += t1 * t1;
1807         normb += t2 * t2;
1808         dotprod += t1 * t2;
1809     }
1810     static if (!haveLen) assert(b.empty);
1811     if (norma == 0 || normb == 0) return 0;
1812     return dotprod / sqrt(norma * normb);
1813 }
1814 
1815 @safe unittest
1816 {
1817     import std.meta : AliasSeq;
1818     foreach (T; AliasSeq!(double, const double, immutable double))
1819     {
1820         T[] a = [ 1.0, 2.0, ];
1821         T[] b = [ 4.0, 3.0, ];
1822         assert(approxEqual(
1823                     cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25),
1824                     0.01));
1825     }
1826 }
1827 
1828 /**
1829 Normalizes values in $(D range) by multiplying each element with a
1830 number chosen such that values sum up to $(D sum). If elements in $(D
1831 range) sum to zero, assigns $(D sum / range.length) to
1832 all. Normalization makes sense only if all elements in $(D range) are
1833 positive. $(D normalize) assumes that is the case without checking it.
1834 
1835 Returns: $(D true) if normalization completed normally, $(D false) if
1836 all elements in $(D range) were zero or if $(D range) is empty.
1837  */
1838 bool normalize(R)(R range, ElementType!(R) sum = 1)
1839 if (isForwardRange!(R))
1840 {
1841     ElementType!(R) s = 0;
1842     // Step 1: Compute sum and length of the range
1843     static if (hasLength!(R))
1844     {
1845         const length = range.length;
foreach(e;range)1846         foreach (e; range)
1847         {
1848             s += e;
1849         }
1850     }
1851     else
1852     {
1853         uint length = 0;
foreach(e;range)1854         foreach (e; range)
1855         {
1856             s += e;
1857             ++length;
1858         }
1859     }
1860     // Step 2: perform normalization
1861     if (s == 0)
1862     {
1863         if (length)
1864         {
1865             immutable f = sum / range.length;
1866             foreach (ref e; range) e = f;
1867         }
1868         return false;
1869     }
1870     // The path most traveled
1871     assert(s >= 0);
1872     immutable f = sum / s;
1873     foreach (ref e; range)
1874         e *= f;
1875     return true;
1876 }
1877 
1878 ///
1879 @safe unittest
1880 {
1881     double[] a = [];
1882     assert(!normalize(a));
1883     a = [ 1.0, 3.0 ];
1884     assert(normalize(a));
1885     assert(a == [ 0.25, 0.75 ]);
1886     a = [ 0.0, 0.0 ];
1887     assert(!normalize(a));
1888     assert(a == [ 0.5, 0.5 ]);
1889 }
1890 
1891 /**
1892 Compute the sum of binary logarithms of the input range $(D r).
1893 The error of this method is much smaller than with a naive sum of log2.
1894  */
1895 ElementType!Range sumOfLog2s(Range)(Range r)
1896 if (isInputRange!Range && isFloatingPoint!(ElementType!Range))
1897 {
1898     long exp = 0;
1899     Unqual!(typeof(return)) x = 1;
foreach(e;r)1900     foreach (e; r)
1901     {
1902         if (e < 0)
1903             return typeof(return).nan;
1904         int lexp = void;
1905         x *= frexp(e, lexp);
1906         exp += lexp;
1907         if (x < 0.5)
1908         {
1909             x *= 2;
1910             exp--;
1911         }
1912     }
1913     return exp + log2(x);
1914 }
1915 
1916 ///
1917 @safe unittest
1918 {
1919     import std.math : isNaN;
1920 
1921     assert(sumOfLog2s(new double[0]) == 0);
1922     assert(sumOfLog2s([0.0L]) == -real.infinity);
1923     assert(sumOfLog2s([-0.0L]) == -real.infinity);
1924     assert(sumOfLog2s([2.0L]) == 1);
1925     assert(sumOfLog2s([-2.0L]).isNaN());
1926     assert(sumOfLog2s([real.nan]).isNaN());
1927     assert(sumOfLog2s([-real.nan]).isNaN());
1928     assert(sumOfLog2s([real.infinity]) == real.infinity);
1929     assert(sumOfLog2s([-real.infinity]).isNaN());
1930     assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9);
1931 }
1932 
1933 /**
1934 Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory),
1935 _entropy) of input range $(D r) in bits. This
1936 function assumes (without checking) that the values in $(D r) are all
1937 in $(D [0, 1]). For the entropy to be meaningful, often $(D r) should
1938 be normalized too (i.e., its values should sum to 1). The
1939 two-parameter version stops evaluating as soon as the intermediate
1940 result is greater than or equal to $(D max).
1941  */
1942 ElementType!Range entropy(Range)(Range r)
1943 if (isInputRange!Range)
1944 {
1945     Unqual!(typeof(return)) result = 0.0;
1946     for (;!r.empty; r.popFront)
1947     {
1948         if (!r.front) continue;
1949         result -= r.front * log2(r.front);
1950     }
1951     return result;
1952 }
1953 
1954 /// Ditto
1955 ElementType!Range entropy(Range, F)(Range r, F max)
1956 if (isInputRange!Range &&
1957     !is(CommonType!(ElementType!Range, F) == void))
1958 {
1959     Unqual!(typeof(return)) result = 0.0;
1960     for (;!r.empty; r.popFront)
1961     {
1962         if (!r.front) continue;
1963         result -= r.front * log2(r.front);
1964         if (result >= max) break;
1965     }
1966     return result;
1967 }
1968 
1969 @safe unittest
1970 {
1971     import std.meta : AliasSeq;
1972     foreach (T; AliasSeq!(double, const double, immutable double))
1973     {
1974         T[] p = [ 0.0, 0, 0, 1 ];
1975         assert(entropy(p) == 0);
1976         p = [ 0.25, 0.25, 0.25, 0.25 ];
1977         assert(entropy(p) == 2);
1978         assert(entropy(p, 1) == 1);
1979     }
1980 }
1981 
1982 /**
1983 Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence,
1984 Kullback-Leibler divergence) between input ranges
1985 $(D a) and $(D b), which is the sum $(D ai * log(ai / bi)). The base
1986 of logarithm is 2. The ranges are assumed to contain elements in $(D
1987 [0, 1]). Usually the ranges are normalized probability distributions,
1988 but this is not required or checked by $(D
1989 kullbackLeiblerDivergence). If any element $(D bi) is zero and the
1990 corresponding element $(D ai) nonzero, returns infinity. (Otherwise,
1991 if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is
1992 considered zero.) If the inputs are normalized, the result is
1993 positive.
1994  */
1995 CommonType!(ElementType!Range1, ElementType!Range2)
1996 kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b)
1997 if (isInputRange!(Range1) && isInputRange!(Range2))
1998 {
1999     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2000     static if (haveLen) assert(a.length == b.length);
2001     Unqual!(typeof(return)) result = 0;
2002     for (; !a.empty; a.popFront(), b.popFront())
2003     {
2004         immutable t1 = a.front;
2005         if (t1 == 0) continue;
2006         immutable t2 = b.front;
2007         if (t2 == 0) return result.infinity;
2008         assert(t1 > 0 && t2 > 0);
2009         result += t1 * log2(t1 / t2);
2010     }
2011     static if (!haveLen) assert(b.empty);
2012     return result;
2013 }
2014 
2015 ///
2016 @safe unittest
2017 {
2018     import std.math : approxEqual;
2019 
2020     double[] p = [ 0.0, 0, 0, 1 ];
2021     assert(kullbackLeiblerDivergence(p, p) == 0);
2022     double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ];
2023     assert(kullbackLeiblerDivergence(p1, p1) == 0);
2024     assert(kullbackLeiblerDivergence(p, p1) == 2);
2025     assert(kullbackLeiblerDivergence(p1, p) == double.infinity);
2026     double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ];
2027     assert(approxEqual(kullbackLeiblerDivergence(p1, p2), 0.0719281));
2028     assert(approxEqual(kullbackLeiblerDivergence(p2, p1), 0.0780719));
2029 }
2030 
2031 /**
2032 Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence,
2033 Jensen-Shannon divergence) between $(D a) and $(D
2034 b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 *
2035 bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are
2036 assumed to contain elements in $(D [0, 1]). Usually the ranges are
2037 normalized probability distributions, but this is not required or
2038 checked by $(D jensenShannonDivergence). If the inputs are normalized,
2039 the result is bounded within $(D [0, 1]). The three-parameter version
2040 stops evaluations as soon as the intermediate result is greater than
2041 or equal to $(D limit).
2042  */
2043 CommonType!(ElementType!Range1, ElementType!Range2)
2044 jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b)
2045 if (isInputRange!Range1 && isInputRange!Range2 &&
2046     is(CommonType!(ElementType!Range1, ElementType!Range2)))
2047 {
2048     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2049     static if (haveLen) assert(a.length == b.length);
2050     Unqual!(typeof(return)) result = 0;
2051     for (; !a.empty; a.popFront(), b.popFront())
2052     {
2053         immutable t1 = a.front;
2054         immutable t2 = b.front;
2055         immutable avg = (t1 + t2) / 2;
2056         if (t1 != 0)
2057         {
2058             result += t1 * log2(t1 / avg);
2059         }
2060         if (t2 != 0)
2061         {
2062             result += t2 * log2(t2 / avg);
2063         }
2064     }
2065     static if (!haveLen) assert(b.empty);
2066     return result / 2;
2067 }
2068 
2069 /// Ditto
2070 CommonType!(ElementType!Range1, ElementType!Range2)
2071 jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit)
2072 if (isInputRange!Range1 && isInputRange!Range2 &&
2073     is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init
2074     >= F.init) : bool))
2075 {
2076     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2077     static if (haveLen) assert(a.length == b.length);
2078     Unqual!(typeof(return)) result = 0;
2079     limit *= 2;
2080     for (; !a.empty; a.popFront(), b.popFront())
2081     {
2082         immutable t1 = a.front;
2083         immutable t2 = b.front;
2084         immutable avg = (t1 + t2) / 2;
2085         if (t1 != 0)
2086         {
2087             result += t1 * log2(t1 / avg);
2088         }
2089         if (t2 != 0)
2090         {
2091             result += t2 * log2(t2 / avg);
2092         }
2093         if (result >= limit) break;
2094     }
2095     static if (!haveLen) assert(b.empty);
2096     return result / 2;
2097 }
2098 
2099 ///
2100 @safe unittest
2101 {
2102     import std.math : approxEqual;
2103 
2104     double[] p = [ 0.0, 0, 0, 1 ];
2105     assert(jensenShannonDivergence(p, p) == 0);
2106     double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ];
2107     assert(jensenShannonDivergence(p1, p1) == 0);
2108     assert(approxEqual(jensenShannonDivergence(p1, p), 0.548795));
2109     double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ];
2110     assert(approxEqual(jensenShannonDivergence(p1, p2), 0.0186218));
2111     assert(approxEqual(jensenShannonDivergence(p2, p1), 0.0186218));
2112     assert(approxEqual(jensenShannonDivergence(p2, p1, 0.005), 0.00602366));
2113 }
2114 
2115 /**
2116 The so-called "all-lengths gap-weighted string kernel" computes a
2117 similarity measure between $(D s) and $(D t) based on all of their
2118 common subsequences of all lengths. Gapped subsequences are also
2119 included.
2120 
2121 To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes,
2122 consider first the case $(D lambda = 1) and the strings $(D s =
2123 ["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new",
2124 "world"]). In that case, $(D gapWeightedSimilarity) counts the
2125 following matches:
2126 
2127 $(OL $(LI three matches of length 1, namely $(D "Hello"), $(D "new"),
2128 and $(D "world");) $(LI three matches of length 2, namely ($(D
2129 "Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));)
2130 $(LI one match of length 3, namely ($(D "Hello", "new", "world")).))
2131 
2132 The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of
2133 these matches and adds them up, returning 7.
2134 
2135 ----
2136 string[] s = ["Hello", "brave", "new", "world"];
2137 string[] t = ["Hello", "new", "world"];
2138 assert(gapWeightedSimilarity(s, t, 1) == 7);
2139 ----
2140 
2141 Note how the gaps in matching are simply ignored, for example ($(D
2142 "Hello", "new")) is deemed as good a match as ($(D "new",
2143 "world")). This may be too permissive for some applications. To
2144 eliminate gapped matches entirely, use $(D lambda = 0):
2145 
2146 ----
2147 string[] s = ["Hello", "brave", "new", "world"];
2148 string[] t = ["Hello", "new", "world"];
2149 assert(gapWeightedSimilarity(s, t, 0) == 4);
2150 ----
2151 
2152 The call above eliminated the gapped matches ($(D "Hello", "new")),
2153 ($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the
2154 tally. That leaves only 4 matches.
2155 
2156 The most interesting case is when gapped matches still participate in
2157 the result, but not as strongly as ungapped matches. The result will
2158 be a smooth, fine-grained similarity measure between the input
2159 strings. This is where values of $(D lambda) between 0 and 1 enter
2160 into play: gapped matches are $(I exponentially penalized with the
2161 number of gaps) with base $(D lambda). This means that an ungapped
2162 match adds 1 to the return value; a match with one gap in either
2163 string adds $(D lambda) to the return value; ...; a match with a total
2164 of $(D n) gaps in both strings adds $(D pow(lambda, n)) to the return
2165 value. In the example above, we have 4 matches without gaps, 2 matches
2166 with one gap, and 1 match with three gaps. The latter match is ($(D
2167 "Hello", "world")), which has two gaps in the first string and one gap
2168 in the second string, totaling to three gaps. Summing these up we get
2169 $(D 4 + 2 * lambda + pow(lambda, 3)).
2170 
2171 ----
2172 string[] s = ["Hello", "brave", "new", "world"];
2173 string[] t = ["Hello", "new", "world"];
2174 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125);
2175 ----
2176 
2177 $(D gapWeightedSimilarity) is useful wherever a smooth similarity
2178 measure between sequences allowing for approximate matches is
2179 needed. The examples above are given with words, but any sequences
2180 with elements comparable for equality are allowed, e.g. characters or
2181 numbers. $(D gapWeightedSimilarity) uses a highly optimized dynamic
2182 programming implementation that needs $(D 16 * min(s.length,
2183 t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time
2184 to complete.
2185  */
2186 F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda)
2187 if (isRandomAccessRange!(R1) && hasLength!(R1) &&
2188     isRandomAccessRange!(R2) && hasLength!(R2))
2189 {
2190     import core.exception : onOutOfMemoryError;
2191     import core.stdc.stdlib : malloc, free;
2192     import std.algorithm.mutation : swap;
2193     import std.functional : binaryFun;
2194 
2195     if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda);
2196     if (!t.length) return 0;
2197 
2198     auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length);
2199     if (!dpvi)
2200         onOutOfMemoryError();
2201 
2202     auto dpvi1 = dpvi + t.length;
2203     scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1);
2204     dpvi[0 .. t.length] = 0;
2205     dpvi1[0] = 0;
2206     immutable lambda2 = lambda * lambda;
2207 
2208     F result = 0;
2209     foreach (i; 0 .. s.length)
2210     {
2211         const si = s[i];
2212         for (size_t j = 0;;)
2213         {
2214             F dpsij = void;
2215             if (binaryFun!(comp)(si, t[j]))
2216             {
2217                 dpsij = 1 + dpvi[j];
2218                 result += dpsij;
2219             }
2220             else
2221             {
2222                 dpsij = 0;
2223             }
2224             immutable j1 = j + 1;
2225             if (j1 == t.length) break;
2226             dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) -
2227                         lambda2 * dpvi[j];
2228             j = j1;
2229         }
2230         swap(dpvi, dpvi1);
2231     }
2232     return result;
2233 }
2234 
2235 @system unittest
2236 {
2237     string[] s = ["Hello", "brave", "new", "world"];
2238     string[] t = ["Hello", "new", "world"];
2239     assert(gapWeightedSimilarity(s, t, 1) == 7);
2240     assert(gapWeightedSimilarity(s, t, 0) == 4);
2241     assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125);
2242 }
2243 
2244 /**
2245 The similarity per $(D gapWeightedSimilarity) has an issue in that it
2246 grows with the lengths of the two strings, even though the strings are
2247 not actually very similar. For example, the range $(D ["Hello",
2248 "world"]) is increasingly similar with the range $(D ["Hello",
2249 "world", "world", "world",...]) as more instances of $(D "world") are
2250 appended. To prevent that, $(D gapWeightedSimilarityNormalized)
2251 computes a normalized version of the similarity that is computed as
2252 $(D gapWeightedSimilarity(s, t, lambda) /
2253 sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t,
2254 lambda))). The function $(D gapWeightedSimilarityNormalized) (a
2255 so-called normalized kernel) is bounded in $(D [0, 1]), reaches $(D 0)
2256 only for ranges that don't match in any position, and $(D 1) only for
2257 identical ranges.
2258 
2259 The optional parameters $(D sSelfSim) and $(D tSelfSim) are meant for
2260 avoiding duplicate computation. Many applications may have already
2261 computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D
2262 gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed
2263 as $(D sSelfSim) and $(D tSelfSim), respectively.
2264  */
2265 Select!(isFloatingPoint!(F), F, double)
2266 gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F)
2267         (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init)
2268 if (isRandomAccessRange!(R1) && hasLength!(R1) &&
2269     isRandomAccessRange!(R2) && hasLength!(R2))
2270 {
uncomputed(F n)2271     static bool uncomputed(F n)
2272     {
2273         static if (isFloatingPoint!(F))
2274             return isNaN(n);
2275         else
2276             return n == n.init;
2277     }
2278     if (uncomputed(sSelfSim))
2279         sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda);
2280     if (sSelfSim == 0) return 0;
2281     if (uncomputed(tSelfSim))
2282         tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda);
2283     if (tSelfSim == 0) return 0;
2284 
2285     return gapWeightedSimilarity!(comp)(s, t, lambda) /
2286            sqrt(cast(typeof(return)) sSelfSim * tSelfSim);
2287 }
2288 
2289 ///
2290 @system unittest
2291 {
2292     import std.math : approxEqual, sqrt;
2293 
2294     string[] s = ["Hello", "brave", "new", "world"];
2295     string[] t = ["Hello", "new", "world"];
2296     assert(gapWeightedSimilarity(s, s, 1) == 15);
2297     assert(gapWeightedSimilarity(t, t, 1) == 7);
2298     assert(gapWeightedSimilarity(s, t, 1) == 7);
2299     assert(approxEqual(gapWeightedSimilarityNormalized(s, t, 1),
2300                     7.0 / sqrt(15.0 * 7), 0.01));
2301 }
2302 
2303 /**
2304 Similar to $(D gapWeightedSimilarity), just works in an incremental
2305 manner by first revealing the matches of length 1, then gapped matches
2306 of length 2, and so on. The memory requirement is $(BIGOH s.length *
2307 t.length). The time complexity is $(BIGOH s.length * t.length) time
2308 for computing each step. Continuing on the previous example:
2309 
2310 The implementation is based on the pseudocode in Fig. 4 of the paper
2311 $(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf,
2312 "Efficient Computation of Gapped Substring Kernels on Large Alphabets")
2313 by Rousu et al., with additional algorithmic and systems-level
2314 optimizations.
2315  */
2316 struct GapWeightedSimilarityIncremental(Range, F = double)
2317 if (isRandomAccessRange!(Range) && hasLength!(Range))
2318 {
2319     import core.stdc.stdlib : malloc, realloc, alloca, free;
2320 
2321 private:
2322     Range s, t;
2323     F currentValue = 0;
2324     F* kl;
2325     size_t gram = void;
2326     F lambda = void, lambda2 = void;
2327 
2328 public:
2329 /**
2330 Constructs an object given two ranges $(D s) and $(D t) and a penalty
2331 $(D lambda). Constructor completes in $(BIGOH s.length * t.length)
2332 time and computes all matches of length 1.
2333  */
thisGapWeightedSimilarityIncremental2334     this(Range s, Range t, F lambda)
2335     {
2336         import core.exception : onOutOfMemoryError;
2337 
2338         assert(lambda > 0);
2339         this.gram = 0;
2340         this.lambda = lambda;
2341         this.lambda2 = lambda * lambda; // for efficiency only
2342 
2343         size_t iMin = size_t.max, jMin = size_t.max,
2344             iMax = 0, jMax = 0;
2345         /* initialize */
2346         Tuple!(size_t, size_t) * k0;
2347         size_t k0len;
2348         scope(exit) free(k0);
2349         currentValue = 0;
2350         foreach (i, si; s)
2351         {
2352             foreach (j; 0 .. t.length)
2353             {
2354                 if (si != t[j]) continue;
2355                 k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof);
2356                 with (k0[k0len - 1])
2357                 {
2358                     field[0] = i;
2359                     field[1] = j;
2360                 }
2361                 // Maintain the minimum and maximum i and j
2362                 if (iMin > i) iMin = i;
2363                 if (iMax < i) iMax = i;
2364                 if (jMin > j) jMin = j;
2365                 if (jMax < j) jMax = j;
2366             }
2367         }
2368 
2369         if (iMin > iMax) return;
2370         assert(k0len);
2371 
2372         currentValue = k0len;
2373         // Chop strings down to the useful sizes
2374         s = s[iMin .. iMax + 1];
2375         t = t[jMin .. jMax + 1];
2376         this.s = s;
2377         this.t = t;
2378 
2379         kl = cast(F*) malloc(s.length * t.length * F.sizeof);
2380         if (!kl)
2381             onOutOfMemoryError();
2382 
2383         kl[0 .. s.length * t.length] = 0;
2384         foreach (pos; 0 .. k0len)
2385         {
2386             with (k0[pos])
2387             {
2388                 kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2;
2389             }
2390         }
2391     }
2392 
2393     /**
2394     Returns: $(D this).
2395      */
opSliceGapWeightedSimilarityIncremental2396     ref GapWeightedSimilarityIncremental opSlice()
2397     {
2398         return this;
2399     }
2400 
2401     /**
2402     Computes the match of the popFront length. Completes in $(BIGOH s.length *
2403     t.length) time.
2404      */
popFrontGapWeightedSimilarityIncremental2405     void popFront()
2406     {
2407         import std.algorithm.mutation : swap;
2408 
2409         // This is a large source of optimization: if similarity at
2410         // the gram-1 level was 0, then we can safely assume
2411         // similarity at the gram level is 0 as well.
2412         if (empty) return;
2413 
2414         // Now attempt to match gapped substrings of length `gram'
2415         ++gram;
2416         currentValue = 0;
2417 
2418         auto Si = cast(F*) alloca(t.length * F.sizeof);
2419         Si[0 .. t.length] = 0;
2420         foreach (i; 0 .. s.length)
2421         {
2422             const si = s[i];
2423             F Sij_1 = 0;
2424             F Si_1j_1 = 0;
2425             auto kli = kl + i * t.length;
2426             for (size_t j = 0;;)
2427             {
2428                 const klij = kli[j];
2429                 const Si_1j = Si[j];
2430                 const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1;
2431                 // now update kl and currentValue
2432                 if (si == t[j])
2433                     currentValue += kli[j] = lambda2 * Si_1j_1;
2434                 else
2435                     kli[j] = 0;
2436                 // commit to Si
2437                 Si[j] = tmp;
2438                 if (++j == t.length) break;
2439                 // get ready for the popFront step; virtually increment j,
2440                 // so essentially stuffj_1 <-- stuffj
2441                 Si_1j_1 = Si_1j;
2442                 Sij_1 = tmp;
2443             }
2444         }
2445         currentValue /= pow(lambda, 2 * (gram + 1));
2446 
2447         version (none)
2448         {
2449             Si_1[0 .. t.length] = 0;
2450             kl[0 .. min(t.length, maxPerimeter + 1)] = 0;
2451             foreach (i; 1 .. min(s.length, maxPerimeter + 1))
2452             {
2453                 auto kli = kl + i * t.length;
2454                 assert(s.length > i);
2455                 const si = s[i];
2456                 auto kl_1i_1 = kl_1 + (i - 1) * t.length;
2457                 kli[0] = 0;
2458                 F lastS = 0;
2459                 foreach (j; 1 .. min(maxPerimeter - i + 1, t.length))
2460                 {
2461                     immutable j_1 = j - 1;
2462                     immutable tmp = kl_1i_1[j_1]
2463                         + lambda * (Si_1[j] + lastS)
2464                         - lambda2 * Si_1[j_1];
2465                     kl_1i_1[j_1] = float.nan;
2466                     Si_1[j_1] = lastS;
2467                     lastS = tmp;
2468                     if (si == t[j])
2469                     {
2470                         currentValue += kli[j] = lambda2 * lastS;
2471                     }
2472                     else
2473                     {
2474                         kli[j] = 0;
2475                     }
2476                 }
2477                 Si_1[t.length - 1] = lastS;
2478             }
2479             currentValue /= pow(lambda, 2 * (gram + 1));
2480             // get ready for the popFront computation
2481             swap(kl, kl_1);
2482         }
2483     }
2484 
2485     /**
2486     Returns: The gapped similarity at the current match length (initially
2487     1, grows with each call to $(D popFront)).
2488     */
frontGapWeightedSimilarityIncremental2489     @property F front() { return currentValue; }
2490 
2491     /**
2492     Returns: Whether there are more matches.
2493      */
emptyGapWeightedSimilarityIncremental2494     @property bool empty()
2495     {
2496         if (currentValue) return false;
2497         if (kl)
2498         {
2499             free(kl);
2500             kl = null;
2501         }
2502         return true;
2503     }
2504 }
2505 
2506 /**
2507 Ditto
2508  */
2509 GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F)
2510 (R r1, R r2, F penalty)
2511 {
2512     return typeof(return)(r1, r2, penalty);
2513 }
2514 
2515 ///
2516 @system unittest
2517 {
2518     string[] s = ["Hello", "brave", "new", "world"];
2519     string[] t = ["Hello", "new", "world"];
2520     auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0);
2521     assert(simIter.front == 3); // three 1-length matches
2522     simIter.popFront();
2523     assert(simIter.front == 3); // three 2-length matches
2524     simIter.popFront();
2525     assert(simIter.front == 1); // one 3-length match
2526     simIter.popFront();
2527     assert(simIter.empty);     // no more match
2528 }
2529 
2530 @system unittest
2531 {
2532     import std.conv : text;
2533     string[] s = ["Hello", "brave", "new", "world"];
2534     string[] t = ["Hello", "new", "world"];
2535     auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0);
2536     //foreach (e; simIter) writeln(e);
2537     assert(simIter.front == 3); // three 1-length matches
2538     simIter.popFront();
2539     assert(simIter.front == 3, text(simIter.front)); // three 2-length matches
2540     simIter.popFront();
2541     assert(simIter.front == 1); // one 3-length matches
2542     simIter.popFront();
2543     assert(simIter.empty);     // no more match
2544 
2545     s = ["Hello"];
2546     t = ["bye"];
2547     simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2548     assert(simIter.empty);
2549 
2550     s = ["Hello"];
2551     t = ["Hello"];
2552     simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2553     assert(simIter.front == 1); // one match
2554     simIter.popFront();
2555     assert(simIter.empty);
2556 
2557     s = ["Hello", "world"];
2558     t = ["Hello"];
2559     simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2560     assert(simIter.front == 1); // one match
2561     simIter.popFront();
2562     assert(simIter.empty);
2563 
2564     s = ["Hello", "world"];
2565     t = ["Hello", "yah", "world"];
2566     simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2567     assert(simIter.front == 2); // two 1-gram matches
2568     simIter.popFront();
2569     assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap
2570 }
2571 
2572 @system unittest
2573 {
2574     GapWeightedSimilarityIncremental!(string[]) sim =
2575         GapWeightedSimilarityIncremental!(string[])(
2576             ["nyuk", "I", "have", "no", "chocolate", "giba"],
2577             ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"],
2578             0.5);
2579     double[] witness = [ 7.0, 4.03125, 0, 0 ];
foreach(e;sim)2580     foreach (e; sim)
2581     {
2582         //writeln(e);
2583         assert(e == witness.front);
2584         witness.popFront();
2585     }
2586     witness = [ 3.0, 1.3125, 0.25 ];
2587     sim = GapWeightedSimilarityIncremental!(string[])(
2588         ["I", "have", "no", "chocolate"],
2589         ["I", "have", "some", "chocolate"],
2590         0.5);
foreach(e;sim)2591     foreach (e; sim)
2592     {
2593         //writeln(e);
2594         assert(e == witness.front);
2595         witness.popFront();
2596     }
2597     assert(witness.empty);
2598 }
2599 
2600 /**
2601 Computes the greatest common divisor of $(D a) and $(D b) by using
2602 an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's)
2603 or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm.
2604 
2605 Params:
2606     T = Any numerical type that supports the modulo operator `%`. If
2607         bit-shifting `<<` and `>>` are also supported, Stein's algorithm will
2608         be used; otherwise, Euclid's algorithm is used as _a fallback.
2609 Returns:
2610     The greatest common divisor of the given arguments.
2611  */
2612 T gcd(T)(T a, T b)
2613     if (isIntegral!T)
2614 {
2615     static if (is(T == const) || is(T == immutable))
2616     {
2617         return gcd!(Unqual!T)(a, b);
2618     }
version(DigitalMars)2619     else version (DigitalMars)
2620     {
2621         static if (T.min < 0)
2622         {
2623             assert(a >= 0 && b >= 0);
2624         }
2625         while (b)
2626         {
2627             immutable t = b;
2628             b = a % b;
2629             a = t;
2630         }
2631         return a;
2632     }
2633     else
2634     {
2635         if (a == 0)
2636             return b;
2637         if (b == 0)
2638             return a;
2639 
2640         import core.bitop : bsf;
2641         import std.algorithm.mutation : swap;
2642 
2643         immutable uint shift = bsf(a | b);
2644         a >>= a.bsf;
2645 
2646         do
2647         {
2648             b >>= b.bsf;
2649             if (a > b)
2650                 swap(a, b);
2651             b -= a;
2652         } while (b);
2653 
2654         return a << shift;
2655     }
2656 }
2657 
2658 ///
2659 @safe unittest
2660 {
2661     assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7);
2662     const int a = 5 * 13 * 23 * 23, b = 13 * 59;
2663     assert(gcd(a, b) == 13);
2664 }
2665 
2666 // This overload is for non-builtin numerical types like BigInt or
2667 // user-defined types.
2668 /// ditto
2669 T gcd(T)(T a, T b)
2670     if (!isIntegral!T &&
2671         is(typeof(T.init % T.init)) &&
2672         is(typeof(T.init == 0 || T.init > 0)))
2673 {
2674     import std.algorithm.mutation : swap;
2675 
2676     enum canUseBinaryGcd = is(typeof(() {
2677         T t, u;
2678         t <<= 1;
2679         t >>= 1;
2680         t -= u;
2681         bool b = (t & 1) == 0;
2682         swap(t, u);
2683     }));
2684 
2685     assert(a >= 0 && b >= 0);
2686 
2687     static if (canUseBinaryGcd)
2688     {
2689         uint shift = 0;
2690         while ((a & 1) == 0 && (b & 1) == 0)
2691         {
2692             a >>= 1;
2693             b >>= 1;
2694             shift++;
2695         }
2696 
2697         do
2698         {
2699             assert((a & 1) != 0);
2700             while ((b & 1) == 0)
2701                 b >>= 1;
2702             if (a > b)
2703                 swap(a, b);
2704             b -= a;
2705         } while (b);
2706 
2707         return a << shift;
2708     }
2709     else
2710     {
2711         // The only thing we have is %; fallback to Euclidean algorithm.
2712         while (b != 0)
2713         {
2714             auto t = b;
2715             b = a % b;
2716             a = t;
2717         }
2718         return a;
2719     }
2720 }
2721 
2722 // Issue 7102
2723 @system pure unittest
2724 {
2725     import std.bigint : BigInt;
2726     assert(gcd(BigInt("71_000_000_000_000_000_000"),
2727                BigInt("31_000_000_000_000_000_000")) ==
2728            BigInt("1_000_000_000_000_000_000"));
2729 }
2730 
2731 @safe pure nothrow unittest
2732 {
2733     // A numerical type that only supports % and - (to force gcd implementation
2734     // to use Euclidean algorithm).
2735     struct CrippledInt
2736     {
2737         int impl;
2738         CrippledInt opBinary(string op : "%")(CrippledInt i)
2739         {
2740             return CrippledInt(impl % i.impl);
2741         }
opEqualsCrippledInt2742         int opEquals(CrippledInt i) { return impl == i.impl; }
opEqualsCrippledInt2743         int opEquals(int i) { return impl == i; }
opCmpCrippledInt2744         int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; }
2745     }
2746     assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77));
2747 }
2748 
2749 // This is to make tweaking the speed/size vs. accuracy tradeoff easy,
2750 // though floats seem accurate enough for all practical purposes, since
2751 // they pass the "approxEqual(inverseFft(fft(arr)), arr)" test even for
2752 // size 2 ^^ 22.
2753 private alias lookup_t = float;
2754 
2755 /**A class for performing fast Fourier transforms of power of two sizes.
2756  * This class encapsulates a large amount of state that is reusable when
2757  * performing multiple FFTs of sizes smaller than or equal to that specified
2758  * in the constructor.  This results in substantial speedups when performing
2759  * multiple FFTs with a known maximum size.  However,
2760  * a free function API is provided for convenience if you need to perform a
2761  * one-off FFT.
2762  *
2763  * References:
2764  * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm)
2765  */
2766 final class Fft
2767 {
2768     import core.bitop : bsf;
2769     import std.algorithm.iteration : map;
2770     import std.array : uninitializedArray;
2771 
2772 private:
2773     immutable lookup_t[][] negSinLookup;
2774 
enforceSize(R)2775     void enforceSize(R)(R range) const
2776     {
2777         import std.conv : text;
2778         assert(range.length <= size, text(
2779             "FFT size mismatch.  Expected ", size, ", got ", range.length));
2780     }
2781 
fftImpl(Ret,R)2782     void fftImpl(Ret, R)(Stride!R range, Ret buf) const
2783     in
2784     {
2785         assert(range.length >= 4);
2786         assert(isPowerOf2(range.length));
2787     }
2788     body
2789     {
2790         auto recurseRange = range;
2791         recurseRange.doubleSteps();
2792 
2793         if (buf.length > 4)
2794         {
2795             fftImpl(recurseRange, buf[0..$ / 2]);
2796             recurseRange.popHalf();
2797             fftImpl(recurseRange, buf[$ / 2..$]);
2798         }
2799         else
2800         {
2801             // Do this here instead of in another recursion to save on
2802             // recursion overhead.
2803             slowFourier2(recurseRange, buf[0..$ / 2]);
2804             recurseRange.popHalf();
2805             slowFourier2(recurseRange, buf[$ / 2..$]);
2806         }
2807 
2808         butterfly(buf);
2809     }
2810 
2811     // This algorithm works by performing the even and odd parts of our FFT
2812     // using the "two for the price of one" method mentioned at
2813     // http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521
2814     // by making the odd terms into the imaginary components of our new FFT,
2815     // and then using symmetry to recombine them.
fftImplPureReal(Ret,R)2816     void fftImplPureReal(Ret, R)(R range, Ret buf) const
2817     in
2818     {
2819         assert(range.length >= 4);
2820         assert(isPowerOf2(range.length));
2821     }
2822     body
2823     {
2824         alias E = ElementType!R;
2825 
2826         // Converts odd indices of range to the imaginary components of
2827         // a range half the size.  The even indices become the real components.
2828         static if (isArray!R && isFloatingPoint!E)
2829         {
2830             // Then the memory layout of complex numbers provides a dirt
2831             // cheap way to convert.  This is a common case, so take advantage.
2832             auto oddsImag = cast(Complex!E[]) range;
2833         }
2834         else
2835         {
2836             // General case:  Use a higher order range.  We can assume
2837             // source.length is even because it has to be a power of 2.
2838             static struct OddToImaginary
2839             {
2840                 R source;
2841                 alias C = Complex!(CommonType!(E, typeof(buf[0].re)));
2842 
2843                 @property
2844                 {
frontOddToImaginary2845                     C front()
2846                     {
2847                         return C(source[0], source[1]);
2848                     }
2849 
backOddToImaginary2850                     C back()
2851                     {
2852                         immutable n = source.length;
2853                         return C(source[n - 2], source[n - 1]);
2854                     }
2855 
saveOddToImaginary2856                     typeof(this) save()
2857                     {
2858                         return typeof(this)(source.save);
2859                     }
2860 
emptyOddToImaginary2861                     bool empty()
2862                     {
2863                         return source.empty;
2864                     }
2865 
lengthOddToImaginary2866                     size_t length()
2867                     {
2868                         return source.length / 2;
2869                     }
2870                 }
2871 
popFrontOddToImaginary2872                 void popFront()
2873                 {
2874                     source.popFront();
2875                     source.popFront();
2876                 }
2877 
popBackOddToImaginary2878                 void popBack()
2879                 {
2880                     source.popBack();
2881                     source.popBack();
2882                 }
2883 
opIndexOddToImaginary2884                 C opIndex(size_t index)
2885                 {
2886                     return C(source[index * 2], source[index * 2 + 1]);
2887                 }
2888 
opSliceOddToImaginary2889                 typeof(this) opSlice(size_t lower, size_t upper)
2890                 {
2891                     return typeof(this)(source[lower * 2 .. upper * 2]);
2892                 }
2893             }
2894 
2895             auto oddsImag = OddToImaginary(range);
2896         }
2897 
2898         fft(oddsImag, buf[0..$ / 2]);
2899         auto evenFft = buf[0..$ / 2];
2900         auto oddFft = buf[$ / 2..$];
2901         immutable halfN = evenFft.length;
2902         oddFft[0].re = buf[0].im;
2903         oddFft[0].im = 0;
2904         evenFft[0].im = 0;
2905         // evenFft[0].re is already right b/c it's aliased with buf[0].re.
2906 
2907         foreach (k; 1 .. halfN / 2 + 1)
2908         {
2909             immutable bufk = buf[k];
2910             immutable bufnk = buf[buf.length / 2 - k];
2911             evenFft[k].re = 0.5 * (bufk.re + bufnk.re);
2912             evenFft[halfN - k].re = evenFft[k].re;
2913             evenFft[k].im = 0.5 * (bufk.im - bufnk.im);
2914             evenFft[halfN - k].im = -evenFft[k].im;
2915 
2916             oddFft[k].re = 0.5 * (bufk.im + bufnk.im);
2917             oddFft[halfN - k].re = oddFft[k].re;
2918             oddFft[k].im = 0.5 * (bufnk.re - bufk.re);
2919             oddFft[halfN - k].im = -oddFft[k].im;
2920         }
2921 
2922         butterfly(buf);
2923     }
2924 
butterfly(R)2925     void butterfly(R)(R buf) const
2926     in
2927     {
2928         assert(isPowerOf2(buf.length));
2929     }
2930     body
2931     {
2932         immutable n = buf.length;
2933         immutable localLookup = negSinLookup[bsf(n)];
2934         assert(localLookup.length == n);
2935 
2936         immutable cosMask = n - 1;
2937         immutable cosAdd = n / 4 * 3;
2938 
negSinFromLookup(size_t index)2939         lookup_t negSinFromLookup(size_t index) pure nothrow
2940         {
2941             return localLookup[index];
2942         }
2943 
cosFromLookup(size_t index)2944         lookup_t cosFromLookup(size_t index) pure nothrow
2945         {
2946             // cos is just -sin shifted by PI * 3 / 2.
2947             return localLookup[(index + cosAdd) & cosMask];
2948         }
2949 
2950         immutable halfLen = n / 2;
2951 
2952         // This loop is unrolled and the two iterations are interleaved
2953         // relative to the textbook FFT to increase ILP.  This gives roughly 5%
2954         // speedups on DMD.
2955         for (size_t k = 0; k < halfLen; k += 2)
2956         {
2957             immutable cosTwiddle1 = cosFromLookup(k);
2958             immutable sinTwiddle1 = negSinFromLookup(k);
2959             immutable cosTwiddle2 = cosFromLookup(k + 1);
2960             immutable sinTwiddle2 = negSinFromLookup(k + 1);
2961 
2962             immutable realLower1 = buf[k].re;
2963             immutable imagLower1 = buf[k].im;
2964             immutable realLower2 = buf[k + 1].re;
2965             immutable imagLower2 = buf[k + 1].im;
2966 
2967             immutable upperIndex1 = k + halfLen;
2968             immutable upperIndex2 = upperIndex1 + 1;
2969             immutable realUpper1 = buf[upperIndex1].re;
2970             immutable imagUpper1 = buf[upperIndex1].im;
2971             immutable realUpper2 = buf[upperIndex2].re;
2972             immutable imagUpper2 = buf[upperIndex2].im;
2973 
2974             immutable realAdd1 = cosTwiddle1 * realUpper1
2975                                - sinTwiddle1 * imagUpper1;
2976             immutable imagAdd1 = sinTwiddle1 * realUpper1
2977                                + cosTwiddle1 * imagUpper1;
2978             immutable realAdd2 = cosTwiddle2 * realUpper2
2979                                - sinTwiddle2 * imagUpper2;
2980             immutable imagAdd2 = sinTwiddle2 * realUpper2
2981                                + cosTwiddle2 * imagUpper2;
2982 
2983             buf[k].re += realAdd1;
2984             buf[k].im += imagAdd1;
2985             buf[k + 1].re += realAdd2;
2986             buf[k + 1].im += imagAdd2;
2987 
2988             buf[upperIndex1].re = realLower1 - realAdd1;
2989             buf[upperIndex1].im = imagLower1 - imagAdd1;
2990             buf[upperIndex2].re = realLower2 - realAdd2;
2991             buf[upperIndex2].im = imagLower2 - imagAdd2;
2992         }
2993     }
2994 
2995     // This constructor is used within this module for allocating the
2996     // buffer space elsewhere besides the GC heap.  It's definitely **NOT**
2997     // part of the public API and definitely **IS** subject to change.
2998     //
2999     // Also, this is unsafe because the memSpace buffer will be cast
3000     // to immutable.
this(lookup_t[]memSpace)3001     public this(lookup_t[] memSpace)  // Public b/c of bug 4636.
3002     {
3003         immutable size = memSpace.length / 2;
3004 
3005         /* Create a lookup table of all negative sine values at a resolution of
3006          * size and all smaller power of two resolutions.  This may seem
3007          * inefficient, but having all the lookups be next to each other in
3008          * memory at every level of iteration is a huge win performance-wise.
3009          */
3010         if (size == 0)
3011         {
3012             return;
3013         }
3014 
3015         assert(isPowerOf2(size),
3016             "Can only do FFTs on ranges with a size that is a power of two.");
3017 
3018         auto table = new lookup_t[][bsf(size) + 1];
3019 
3020         table[$ - 1] = memSpace[$ - size..$];
3021         memSpace = memSpace[0 .. size];
3022 
3023         auto lastRow = table[$ - 1];
3024         lastRow[0] = 0;  // -sin(0) == 0.
3025         foreach (ptrdiff_t i; 1 .. size)
3026         {
3027             // The hard coded cases are for improved accuracy and to prevent
3028             // annoying non-zeroness when stuff should be zero.
3029 
3030             if (i == size / 4)
3031                 lastRow[i] = -1;  // -sin(pi / 2) == -1.
3032             else if (i == size / 2)
3033                 lastRow[i] = 0;   // -sin(pi) == 0.
3034             else if (i == size * 3 / 4)
3035                 lastRow[i] = 1;  // -sin(pi * 3 / 2) == 1
3036             else
3037                 lastRow[i] = -sin(i * 2.0L * PI / size);
3038         }
3039 
3040         // Fill in all the other rows with strided versions.
3041         foreach (i; 1 .. table.length - 1)
3042         {
3043             immutable strideLength = size / (2 ^^ i);
3044             auto strided = Stride!(lookup_t[])(lastRow, strideLength);
3045             table[i] = memSpace[$ - strided.length..$];
3046             memSpace = memSpace[0..$ - strided.length];
3047 
3048             size_t copyIndex;
3049             foreach (elem; strided)
3050             {
3051                 table[i][copyIndex++] = elem;
3052             }
3053         }
3054 
3055         negSinLookup = cast(immutable) table;
3056     }
3057 
3058 public:
3059     /**Create an $(D Fft) object for computing fast Fourier transforms of
3060      * power of two sizes of $(D size) or smaller.  $(D size) must be a
3061      * power of two.
3062      */
this(size_t size)3063     this(size_t size)
3064     {
3065         // Allocate all twiddle factor buffers in one contiguous block so that,
3066         // when one is done being used, the next one is next in cache.
3067         auto memSpace = uninitializedArray!(lookup_t[])(2 * size);
3068         this(memSpace);
3069     }
3070 
size()3071     @property size_t size() const
3072     {
3073         return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length;
3074     }
3075 
3076     /**Compute the Fourier transform of range using the $(BIGOH N log N)
3077      * Cooley-Tukey Algorithm.  $(D range) must be a random-access range with
3078      * slicing and a length equal to $(D size) as provided at the construction of
3079      * this object.  The contents of range can be either  numeric types,
3080      * which will be interpreted as pure real values, or complex types with
3081      * properties or members $(D .re) and $(D .im) that can be read.
3082      *
3083      * Note:  Pure real FFTs are automatically detected and the relevant
3084      *        optimizations are performed.
3085      *
3086      * Returns:  An array of complex numbers representing the transformed data in
3087      *           the frequency domain.
3088      *
3089      * Conventions: The exponent is negative and the factor is one,
3090      *              i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ].
3091      */
3092     Complex!F[] fft(F = double, R)(R range) const
3093         if (isFloatingPoint!F && isRandomAccessRange!R)
3094     {
3095         enforceSize(range);
3096         Complex!F[] ret;
3097         if (range.length == 0)
3098         {
3099             return ret;
3100         }
3101 
3102         // Don't waste time initializing the memory for ret.
3103         ret = uninitializedArray!(Complex!F[])(range.length);
3104 
3105         fft(range,  ret);
3106         return ret;
3107     }
3108 
3109     /**Same as the overload, but allows for the results to be stored in a user-
3110      * provided buffer.  The buffer must be of the same length as range, must be
3111      * a random-access range, must have slicing, and must contain elements that are
3112      * complex-like.  This means that they must have a .re and a .im member or
3113      * property that can be both read and written and are floating point numbers.
3114      */
3115     void fft(Ret, R)(R range, Ret buf) const
3116         if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret)
3117     {
3118         assert(buf.length == range.length);
3119         enforceSize(range);
3120 
3121         if (range.length == 0)
3122         {
3123             return;
3124         }
3125         else if (range.length == 1)
3126         {
3127             buf[0] = range[0];
3128             return;
3129         }
3130         else if (range.length == 2)
3131         {
3132             slowFourier2(range, buf);
3133             return;
3134         }
3135         else
3136         {
3137             alias E = ElementType!R;
3138             static if (is(E : real))
3139             {
3140                 return fftImplPureReal(range, buf);
3141             }
3142             else
3143             {
3144                 static if (is(R : Stride!R))
3145                     return fftImpl(range, buf);
3146                 else
3147                     return fftImpl(Stride!R(range, 1), buf);
3148             }
3149         }
3150     }
3151 
3152     /**
3153      * Computes the inverse Fourier transform of a range.  The range must be a
3154      * random access range with slicing, have a length equal to the size
3155      * provided at construction of this object, and contain elements that are
3156      * either of type std.complex.Complex or have essentially
3157      * the same compile-time interface.
3158      *
3159      * Returns:  The time-domain signal.
3160      *
3161      * Conventions: The exponent is positive and the factor is 1/N, i.e.,
3162      *              output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ].
3163      */
3164     Complex!F[] inverseFft(F = double, R)(R range) const
3165         if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F)
3166     {
3167         enforceSize(range);
3168         Complex!F[] ret;
3169         if (range.length == 0)
3170         {
3171             return ret;
3172         }
3173 
3174         // Don't waste time initializing the memory for ret.
3175         ret = uninitializedArray!(Complex!F[])(range.length);
3176 
3177         inverseFft(range, ret);
3178         return ret;
3179     }
3180 
3181     /**
3182      * Inverse FFT that allows a user-supplied buffer to be provided.  The buffer
3183      * must be a random access range with slicing, and its elements
3184      * must be some complex-like type.
3185      */
3186     void inverseFft(Ret, R)(R range, Ret buf) const
3187         if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret)
3188     {
3189         enforceSize(range);
3190 
3191         auto swapped = map!swapRealImag(range);
3192         fft(swapped,  buf);
3193 
3194         immutable lenNeg1 = 1.0 / buf.length;
foreach(ref elem;buf)3195         foreach (ref elem; buf)
3196         {
3197             immutable temp = elem.re * lenNeg1;
3198             elem.re = elem.im * lenNeg1;
3199             elem.im = temp;
3200         }
3201     }
3202 }
3203 
3204 // This mixin creates an Fft object in the scope it's mixed into such that all
3205 // memory owned by the object is deterministically destroyed at the end of that
3206 // scope.
3207 private enum string MakeLocalFft = q{
3208     import core.stdc.stdlib;
3209     import core.exception : onOutOfMemoryError;
3210 
3211     auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof))
3212                      [0 .. 2 * range.length];
3213     if (!lookupBuf.ptr)
3214         onOutOfMemoryError();
3215 
3216     scope(exit) free(cast(void*) lookupBuf.ptr);
3217     auto fftObj = scoped!Fft(lookupBuf);
3218 };
3219 
3220 /**Convenience functions that create an $(D Fft) object, run the FFT or inverse
3221  * FFT and return the result.  Useful for one-off FFTs.
3222  *
3223  * Note:  In addition to convenience, these functions are slightly more
3224  *        efficient than manually creating an Fft object for a single use,
3225  *        as the Fft object is deterministically destroyed before these
3226  *        functions return.
3227  */
3228 Complex!F[] fft(F = double, R)(R range)
3229 {
3230     mixin(MakeLocalFft);
3231     return fftObj.fft!(F, R)(range);
3232 }
3233 
3234 /// ditto
fft(Ret,R)3235 void fft(Ret, R)(R range, Ret buf)
3236 {
3237     mixin(MakeLocalFft);
3238     return fftObj.fft!(Ret, R)(range, buf);
3239 }
3240 
3241 /// ditto
3242 Complex!F[] inverseFft(F = double, R)(R range)
3243 {
3244     mixin(MakeLocalFft);
3245     return fftObj.inverseFft!(F, R)(range);
3246 }
3247 
3248 /// ditto
inverseFft(Ret,R)3249 void inverseFft(Ret, R)(R range, Ret buf)
3250 {
3251     mixin(MakeLocalFft);
3252     return fftObj.inverseFft!(Ret, R)(range, buf);
3253 }
3254 
3255 @system unittest
3256 {
3257     import std.algorithm;
3258     import std.conv;
3259     import std.range;
3260     // Test values from R and Octave.
3261     auto arr = [1,2,3,4,5,6,7,8];
3262     auto fft1 = fft(arr);
3263     assert(approxEqual(map!"a.re"(fft1),
3264         [36.0, -4, -4, -4, -4, -4, -4, -4]));
3265     assert(approxEqual(map!"a.im"(fft1),
3266         [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568]));
3267 
3268     auto fft1Retro = fft(retro(arr));
3269     assert(approxEqual(map!"a.re"(fft1Retro),
3270         [36.0, 4, 4, 4, 4, 4, 4, 4]));
3271     assert(approxEqual(map!"a.im"(fft1Retro),
3272         [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568]));
3273 
3274     auto fft1Float = fft(to!(float[])(arr));
3275     assert(approxEqual(map!"a.re"(fft1), map!"a.re"(fft1Float)));
3276     assert(approxEqual(map!"a.im"(fft1), map!"a.im"(fft1Float)));
3277 
3278     alias C = Complex!float;
3279     auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10),
3280         C(11,12), C(13,14), C(15,16)];
3281     auto fft2 = fft(arr2);
3282     assert(approxEqual(map!"a.re"(fft2),
3283         [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137]));
3284     assert(approxEqual(map!"a.im"(fft2),
3285         [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137]));
3286 
3287     auto inv1 = inverseFft(fft1);
3288     assert(approxEqual(map!"a.re"(inv1), arr));
3289     assert(reduce!max(map!"a.im"(inv1)) < 1e-10);
3290 
3291     auto inv2 = inverseFft(fft2);
3292     assert(approxEqual(map!"a.re"(inv2), map!"a.re"(arr2)));
3293     assert(approxEqual(map!"a.im"(inv2), map!"a.im"(arr2)));
3294 
3295     // FFTs of size 0, 1 and 2 are handled as special cases.  Test them here.
3296     ushort[] empty;
3297     assert(fft(empty) == null);
3298     assert(inverseFft(fft(empty)) == null);
3299 
3300     real[] oneElem = [4.5L];
3301     auto oneFft = fft(oneElem);
3302     assert(oneFft.length == 1);
3303     assert(oneFft[0].re == 4.5L);
3304     assert(oneFft[0].im == 0);
3305 
3306     auto oneInv = inverseFft(oneFft);
3307     assert(oneInv.length == 1);
3308     assert(approxEqual(oneInv[0].re, 4.5));
3309     assert(approxEqual(oneInv[0].im, 0));
3310 
3311     long[2] twoElems = [8, 4];
3312     auto twoFft = fft(twoElems[]);
3313     assert(twoFft.length == 2);
3314     assert(approxEqual(twoFft[0].re, 12));
3315     assert(approxEqual(twoFft[0].im, 0));
3316     assert(approxEqual(twoFft[1].re, 4));
3317     assert(approxEqual(twoFft[1].im, 0));
3318     auto twoInv = inverseFft(twoFft);
3319     assert(approxEqual(twoInv[0].re, 8));
3320     assert(approxEqual(twoInv[0].im, 0));
3321     assert(approxEqual(twoInv[1].re, 4));
3322     assert(approxEqual(twoInv[1].im, 0));
3323 }
3324 
3325 // Swaps the real and imaginary parts of a complex number.  This is useful
3326 // for inverse FFTs.
swapRealImag(C)3327 C swapRealImag(C)(C input)
3328 {
3329     return C(input.im, input.re);
3330 }
3331 
3332 private:
3333 // The reasons I couldn't use std.algorithm were b/c its stride length isn't
3334 // modifiable on the fly and because range has grown some performance hacks
3335 // for powers of 2.
Stride(R)3336 struct Stride(R)
3337 {
3338     import core.bitop : bsf;
3339     Unqual!R range;
3340     size_t _nSteps;
3341     size_t _length;
3342     alias E = ElementType!(R);
3343 
3344     this(R range, size_t nStepsIn)
3345     {
3346         this.range = range;
3347        _nSteps = nStepsIn;
3348        _length = (range.length + _nSteps - 1) / nSteps;
3349     }
3350 
3351     size_t length() const @property
3352     {
3353         return _length;
3354     }
3355 
3356     typeof(this) save() @property
3357     {
3358         auto ret = this;
3359         ret.range = ret.range.save;
3360         return ret;
3361     }
3362 
3363     E opIndex(size_t index)
3364     {
3365         return range[index * _nSteps];
3366     }
3367 
3368     E front() @property
3369     {
3370         return range[0];
3371     }
3372 
3373     void popFront()
3374     {
3375         if (range.length >= _nSteps)
3376         {
3377             range = range[_nSteps .. range.length];
3378             _length--;
3379         }
3380         else
3381         {
3382             range = range[0 .. 0];
3383             _length = 0;
3384         }
3385     }
3386 
3387     // Pops half the range's stride.
3388     void popHalf()
3389     {
3390         range = range[_nSteps / 2 .. range.length];
3391     }
3392 
3393     bool empty() const @property
3394     {
3395         return length == 0;
3396     }
3397 
3398     size_t nSteps() const @property
3399     {
3400         return _nSteps;
3401     }
3402 
3403     void doubleSteps()
3404     {
3405         _nSteps *= 2;
3406         _length /= 2;
3407     }
3408 
3409     size_t nSteps(size_t newVal) @property
3410     {
3411         _nSteps = newVal;
3412 
3413         // Using >> bsf(nSteps) is a few cycles faster than / nSteps.
3414         _length = (range.length + _nSteps - 1)  >> bsf(nSteps);
3415         return newVal;
3416     }
3417 }
3418 
3419 // Hard-coded base case for FFT of size 2.  This is actually a TON faster than
3420 // using a generic slow DFT.  This seems to be the best base case.  (Size 1
3421 // can be coded inline as buf[0] = range[0]).
slowFourier2(Ret,R)3422 void slowFourier2(Ret, R)(R range, Ret buf)
3423 {
3424     assert(range.length == 2);
3425     assert(buf.length == 2);
3426     buf[0] = range[0] + range[1];
3427     buf[1] = range[0] - range[1];
3428 }
3429 
3430 // Hard-coded base case for FFT of size 4.  Doesn't work as well as the size
3431 // 2 case.
slowFourier4(Ret,R)3432 void slowFourier4(Ret, R)(R range, Ret buf)
3433 {
3434     alias C = ElementType!Ret;
3435 
3436     assert(range.length == 4);
3437     assert(buf.length == 4);
3438     buf[0] = range[0] + range[1] + range[2] + range[3];
3439     buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1);
3440     buf[2] = range[0] - range[1] + range[2] - range[3];
3441     buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1);
3442 }
3443 
3444 N roundDownToPowerOf2(N)(N num)
3445 if (isScalarType!N && !isFloatingPoint!N)
3446 {
3447     import core.bitop : bsr;
3448     return num & (cast(N) 1 << bsr(num));
3449 }
3450 
3451 @safe unittest
3452 {
3453     assert(roundDownToPowerOf2(7) == 4);
3454     assert(roundDownToPowerOf2(4) == 4);
3455 }
3456 
isComplexLike(T)3457 template isComplexLike(T)
3458 {
3459     enum bool isComplexLike = is(typeof(T.init.re)) &&
3460         is(typeof(T.init.im));
3461 }
3462 
3463 @safe unittest
3464 {
3465     static assert(isComplexLike!(Complex!double));
3466     static assert(!isComplexLike!(uint));
3467 }
3468