1 // Written in the D programming language.
2 
3 /**
4 This module is a port of a growing fragment of the $(D_PARAM numeric)
5 header in Alexander Stepanov's $(LINK2 https://en.wikipedia.org/wiki/Standard_Template_Library,
6 Standard Template Library), with a few additions.
7 
8 Macros:
9 Copyright: Copyright Andrei Alexandrescu 2008 - 2009.
10 License:   $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0).
11 Authors:   $(HTTP erdani.org, Andrei Alexandrescu),
12                    Don Clugston, Robert Jacques, Ilya Yaroshenko
13 Source:    $(PHOBOSSRC std/numeric.d)
14 */
15 /*
16          Copyright Andrei Alexandrescu 2008 - 2009.
17 Distributed under the Boost Software License, Version 1.0.
18    (See accompanying file LICENSE_1_0.txt or copy at
19          http://www.boost.org/LICENSE_1_0.txt)
20 */
21 module std.numeric;
22 
23 import std.complex;
24 import std.math;
25 import core.math : fabs, ldexp, sin, sqrt;
26 import std.range.primitives;
27 import std.traits;
28 import std.typecons;
29 
30 /// Format flags for CustomFloat.
31 public enum CustomFloatFlags
32 {
33     /// Adds a sign bit to allow for signed numbers.
34     signed = 1,
35 
36     /**
37      * Store values in normalized form by default. The actual precision of the
38      * significand is extended by 1 bit by assuming an implicit leading bit of 1
39      * instead of 0. i.e. `1.nnnn` instead of `0.nnnn`.
40      * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types
41      */
42     storeNormalized = 2,
43 
44     /**
45      * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers,
46      * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0.
47      */
48     allowDenorm = 4,
49 
50     /**
51       * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity,
52       * IEEE754 _infinity) values.
53       */
54     infinity = 8,
55 
56     /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values.
57     nan = 16,
58 
59     /**
60      * If set, select an exponent bias such that max_exp = 1.
61      * i.e. so that the maximum value is >= 1.0 and < 2.0.
62      * Ignored if the exponent bias is manually specified.
63      */
64     probability = 32,
65 
66     /// If set, unsigned custom floats are assumed to be negative.
67     negativeUnsigned = 64,
68 
69     /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers,
70      * IEEE754 denormalized) number.
71      * Requires allowDenorm and storeNormalized.
72      */
73     allowDenormZeroOnly = 128 | allowDenorm | storeNormalized,
74 
75     /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options.
76     ieee = signed | storeNormalized | allowDenorm | infinity | nan ,
77 
78     /// Include none of the above options.
79     none = 0
80 }
81 
CustomFloatParams(uint bits)82 private template CustomFloatParams(uint bits)
83 {
84     enum CustomFloatFlags flags = CustomFloatFlags.ieee
85                 ^ ((bits == 80) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none);
86     static if (bits ==  8) alias CustomFloatParams = CustomFloatParams!( 4,  3, flags);
87     static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10,  5, flags);
88     static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23,  8, flags);
89     static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags);
90     static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags);
91 }
92 
CustomFloatParams(uint precision,uint exponentWidth,CustomFloatFlags flags)93 private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags)
94 {
95     import std.meta : AliasSeq;
96     alias CustomFloatParams =
97         AliasSeq!(
98             precision,
99             exponentWidth,
100             flags,
101             (1 << (exponentWidth - ((flags & flags.probability) == 0)))
102              - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0)
103         ); // ((flags & CustomFloatFlags.probability) == 0)
104 }
105 
106 /**
107  * Allows user code to define custom floating-point formats. These formats are
108  * for storage only; all operations on them are performed by first implicitly
109  * extracting them to `real` first. After the operation is completed the
110  * result can be stored in a custom floating-point value via assignment.
111  */
112 template CustomFloat(uint bits)
113 if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80)
114 {
115     alias CustomFloat = CustomFloat!(CustomFloatParams!(bits));
116 }
117 
118 /// ditto
119 template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee)
120 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0)
121 {
122     alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags));
123 }
124 
125 ///
126 @safe unittest
127 {
128     import std.math.trigonometry : sin, cos;
129 
130     // Define a 16-bit floating point values
131     CustomFloat!16                                x;     // Using the number of bits
132     CustomFloat!(10, 5)                           y;     // Using the precision and exponent width
133     CustomFloat!(10, 5,CustomFloatFlags.ieee)     z;     // Using the precision, exponent width and format flags
134     CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w;     // Using the precision, exponent width, format flags and exponent offset bias
135 
136     // Use the 16-bit floats mostly like normal numbers
137     w = x*y - 1;
138 
139     // Functions calls require conversion
140     z = sin(+x)           + cos(+y);                     // Use unary plus to concisely convert to a real
141     z = sin(x.get!float)  + cos(y.get!float);            // Or use get!T
142     z = sin(cast(float) x) + cos(cast(float) y);           // Or use cast(T) to explicitly convert
143 
144     // Define a 8-bit custom float for storing probabilities
145     alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed );
146     auto p = Probability(0.5);
147 }
148 
149 // Facilitate converting numeric types to custom float
150 private union ToBinary(F)
151 if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real))
152 {
153     F set;
154 
155     // If on Linux or Mac, where 80-bit reals are padded, ignore the
156     // padding.
157     import std.algorithm.comparison : min;
158     CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get;
159 
160     // Convert F to the correct binary type.
opCall(F value)161     static typeof(get) opCall(F value)
162     {
163         ToBinary r;
164         r.set = value;
165         return r.get;
166     }
167     alias get this;
168 }
169 
170 /// ditto
171 struct CustomFloat(uint             precision,  // fraction bits (23 for float)
172                    uint             exponentWidth,  // exponent bits (8 for float)  Exponent width
173                    CustomFloatFlags flags,
174                    uint             bias)
175 if (isCorrectCustomFloat(precision, exponentWidth, flags))
176 {
177     import std.bitmanip : bitfields;
178     import std.meta : staticIndexOf;
179 private:
180     // get the correct unsigned bitfield type to support > 32 bits
uType(uint bits)181     template uType(uint bits)
182     {
183         static if (bits <= size_t.sizeof*8)  alias uType = size_t;
184         else                                alias uType = ulong ;
185     }
186 
187     // get the correct signed   bitfield type to support > 32 bits
sType(uint bits)188     template sType(uint bits)
189     {
190         static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t;
191         else                                    alias sType = long;
192     }
193 
194     alias T_sig = uType!precision;
195     alias T_exp = uType!exponentWidth;
196     alias T_signed_exp = sType!exponentWidth;
197 
198     alias Flags = CustomFloatFlags;
199 
200     // Perform IEEE rounding with round to nearest detection
roundedShift(T,U)201     void roundedShift(T,U)(ref T sig, U shift)
202     {
203         if (shift >= T.sizeof*8)
204         {
205             // avoid illegal shift
206             sig = 0;
207         }
208         else if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1))
209         {
210             // round to even
211             sig >>= shift;
212             sig  += sig & 1;
213         }
214         else
215         {
216             sig >>= shift - 1;
217             sig  += sig & 1;
218             // Perform standard rounding
219             sig >>= 1;
220         }
221     }
222 
223     // Convert the current value to signed exponent, normalized form
toNormalized(T,U)224     void toNormalized(T,U)(ref T sig, ref U exp)
225     {
226         sig = significand;
227         auto shift = (T.sizeof*8) - precision;
228         exp = exponent;
229         static if (flags&(Flags.infinity|Flags.nan))
230         {
231             // Handle inf or nan
232             if (exp == exponent_max)
233             {
234                 exp = exp.max;
235                 sig <<= shift;
236                 static if (flags&Flags.storeNormalized)
237                 {
238                     // Save inf/nan in denormalized format
239                     sig >>= 1;
240                     sig  += cast(T) 1uL << (T.sizeof*8 - 1);
241                 }
242                 return;
243             }
244         }
245         if ((~flags&Flags.storeNormalized) ||
246             // Convert denormalized form to normalized form
247             ((flags&Flags.allowDenorm) && exp == 0))
248         {
249             if (sig > 0)
250             {
251                 import core.bitop : bsr;
252                 auto shift2 = precision - bsr(sig);
253                 exp  -= shift2-1;
254                 shift += shift2;
255             }
256             else                                // value = 0.0
257             {
258                 exp = exp.min;
259                 return;
260             }
261         }
262         sig <<= shift;
263         exp -= bias;
264     }
265 
266     // Set the current value from signed exponent, normalized form
fromNormalized(T,U)267     void fromNormalized(T,U)(ref T sig, ref U exp)
268     {
269         auto shift = (T.sizeof*8) - precision;
270         if (exp == exp.max)
271         {
272             // infinity or nan
273             exp = exponent_max;
274             static if (flags & Flags.storeNormalized)
275                 sig <<= 1;
276 
277             // convert back to normalized form
278             static if (~flags & Flags.infinity)
279                 // No infinity support?
280                 assert(sig != 0, "Infinity floating point value assigned to a "
281                         ~ typeof(this).stringof ~ " (no infinity support).");
282 
283             static if (~flags & Flags.nan)  // No NaN support?
284                 assert(sig == 0, "NaN floating point value assigned to a " ~
285                         typeof(this).stringof ~ " (no nan support).");
286             sig >>= shift;
287             return;
288         }
289         if (exp == exp.min)     // 0.0
290         {
291              exp = 0;
292              sig = 0;
293              return;
294         }
295 
296         exp += bias;
297         if (exp <= 0)
298         {
299             static if ((flags&Flags.allowDenorm) ||
300                        // Convert from normalized form to denormalized
301                        (~flags&Flags.storeNormalized))
302             {
303                 shift += -exp;
304                 roundedShift(sig,1);
305                 sig   += cast(T) 1uL << (T.sizeof*8 - 1);
306                 // Add the leading 1
307                 exp    = 0;
308             }
309             else
310                 assert((flags&Flags.storeNormalized) && exp == 0,
311                     "Underflow occured assigning to a " ~
312                     typeof(this).stringof ~ " (no denormal support).");
313         }
314         else
315         {
316             static if (~flags&Flags.storeNormalized)
317             {
318                 // Convert from normalized form to denormalized
319                 roundedShift(sig,1);
320                 sig  += cast(T) 1uL << (T.sizeof*8 - 1);
321                 // Add the leading 1
322             }
323         }
324 
325         if (shift > 0)
326             roundedShift(sig,shift);
327         if (sig > significand_max)
328         {
329             // handle significand overflow (should only be 1 bit)
330             static if (~flags&Flags.storeNormalized)
331             {
332                 sig >>= 1;
333             }
334             else
335                 sig &= significand_max;
336             exp++;
337         }
338         static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly)
339         {
340             // disallow non-zero denormals
341             if (exp == 0)
342             {
343                 sig <<= 1;
344                 if (sig > significand_max && (sig&significand_max) > 0)
345                     // Check and round to even
346                     exp++;
347                 sig = 0;
348             }
349         }
350 
351         if (exp >= exponent_max)
352         {
353             static if (flags&(Flags.infinity|Flags.nan))
354             {
355                 sig         = 0;
356                 exp         = exponent_max;
357                 static if (~flags&(Flags.infinity))
358                     assert(0, "Overflow occured assigning to a " ~
359                         typeof(this).stringof ~ " (no infinity support).");
360             }
361             else
362                 assert(exp == exponent_max, "Overflow occured assigning to a "
363                     ~ typeof(this).stringof ~ " (no infinity support).");
364         }
365     }
366 
367 public:
368     static if (precision == 64) // CustomFloat!80 support hack
369     {
370         ulong significand;
371         enum ulong significand_max = ulong.max;
372         mixin(bitfields!(
373             T_exp , "exponent", exponentWidth,
374             bool  , "sign"    , flags & flags.signed ));
375     }
376     else
377     {
378         mixin(bitfields!(
379             T_sig, "significand", precision,
380             T_exp, "exponent"   , exponentWidth,
381             bool , "sign"       , flags & flags.signed ));
382     }
383 
384     /// Returns: infinity value
385     static if (flags & Flags.infinity)
infinity()386         static @property CustomFloat infinity()
387         {
388             CustomFloat value;
389             static if (flags & Flags.signed)
390                 value.sign          = 0;
391             value.significand   = 0;
392             value.exponent      = exponent_max;
393             return value;
394         }
395 
396     /// Returns: NaN value
397     static if (flags & Flags.nan)
nan()398         static @property CustomFloat nan()
399         {
400             CustomFloat value;
401             static if (flags & Flags.signed)
402                 value.sign          = 0;
403             value.significand   = cast(typeof(significand_max)) 1L << (precision-1);
404             value.exponent      = exponent_max;
405             return value;
406         }
407 
408     /// Returns: number of decimal digits of precision
dig()409     static @property size_t dig()
410     {
411         auto shiftcnt = precision - ((flags&Flags.storeNormalized) == 0);
412         return shiftcnt == 64 ? 19 : cast(size_t) log10(1uL << shiftcnt);
413     }
414 
415     /// Returns: smallest increment to the value 1
epsilon()416     static @property CustomFloat epsilon()
417     {
418         CustomFloat one = CustomFloat(1);
419         CustomFloat onePlusEpsilon = one;
420         onePlusEpsilon.significand = onePlusEpsilon.significand | 1; // |= does not work here
421 
422         return CustomFloat(onePlusEpsilon - one);
423     }
424 
425     /// the number of bits in mantissa
426     enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0);
427 
428     /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable
max_10_exp()429     static @property int max_10_exp(){ return cast(int) log10( +max ); }
430 
431     /// maximum int value such that 2<sup>max_exp-1</sup> is representable
432     enum max_exp = exponent_max - bias - ((flags & (Flags.infinity | Flags.nan)) != 0) + 1;
433 
434     /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable
min_10_exp()435     static @property int min_10_exp(){ return cast(int) log10( +min_normal ); }
436 
437     /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value
438     enum min_exp = cast(T_signed_exp) -(cast(long) bias) + 1 + ((flags & Flags.allowDenorm) != 0);
439 
440     /// Returns: largest representable value that's not infinity
max()441     static @property CustomFloat max()
442     {
443         CustomFloat value;
444         static if (flags & Flags.signed)
445             value.sign        = 0;
446         value.exponent    = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0);
447         value.significand = significand_max;
448         return value;
449     }
450 
451     /// Returns: smallest representable normalized value that's not 0
min_normal()452     static @property CustomFloat min_normal()
453     {
454         CustomFloat value;
455         static if (flags & Flags.signed)
456             value.sign = 0;
457         value.exponent = (flags & Flags.allowDenorm) != 0;
458         static if (flags & Flags.storeNormalized)
459             value.significand = 0;
460         else
461             value.significand = cast(T_sig) 1uL << (precision - 1);
462         return value;
463     }
464 
465     /// Returns: real part
re()466     @property CustomFloat re() { return this; }
467 
468     /// Returns: imaginary part
im()469     static @property CustomFloat im() { return CustomFloat(0.0f); }
470 
471     /// Initialize from any `real` compatible type.
472     this(F)(F input) if (__traits(compiles, cast(real) input ))
473     {
474         this = input;
475     }
476 
477     /// Self assignment
478     void opAssign(F:CustomFloat)(F input)
479     {
480         static if (flags & Flags.signed)
481             sign        = input.sign;
482         exponent    = input.exponent;
483         significand = input.significand;
484     }
485 
486     /// Assigns from any `real` compatible type.
487     void opAssign(F)(F input)
488         if (__traits(compiles, cast(real) input))
489     {
490         import std.conv : text;
491 
492         static if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0)
493             auto value = ToBinary!(Unqual!F)(input);
494         else
495             auto value = ToBinary!(real    )(input);
496 
497         // Assign the sign bit
498         static if (~flags & Flags.signed)
499             assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0),
500                 "Incorrectly signed floating point value assigned to a " ~
501                 typeof(this).stringof ~ " (no sign support).");
502         else
503             sign = value.sign;
504 
505         CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent;
506         CommonType!(T_sig,        value.T_sig       ) sig = value.significand;
507 
508         value.toNormalized(sig,exp);
509         fromNormalized(sig,exp);
510 
511         assert(exp <= exponent_max,    text(typeof(this).stringof ~
512             " exponent too large: "   ,exp," > ",exponent_max,   "\t",input,"\t",sig));
513         assert(sig <= significand_max, text(typeof(this).stringof ~
514             " significand too large: ",sig," > ",significand_max,
515             "\t",input,"\t",exp," ",exponent_max));
516         exponent    = cast(T_exp) exp;
517         significand = cast(T_sig) sig;
518     }
519 
520     /// Fetches the stored value either as a `float`, `double` or `real`.
521     @property F get(F)()
522         if (staticIndexOf!(immutable F, immutable float, immutable double, immutable real) >= 0)
523     {
524         import std.conv : text;
525 
526         ToBinary!F result;
527 
528         static if (flags&Flags.signed)
529             result.sign = sign;
530         else
531             result.sign = (flags&flags.negativeUnsigned) > 0;
532 
533         CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction
534         CommonType!(T_sig,        result.get.T_sig        ) sig = significand;
535 
536         toNormalized(sig,exp);
537         result.fromNormalized(sig,exp);
538         assert(exp <= result.exponent_max,    text("get exponent too large: "   ,exp," > ",result.exponent_max) );
539         assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) );
540         result.exponent     = cast(result.get.T_exp) exp;
541         result.significand  = cast(result.get.T_sig) sig;
542         return result.set;
543     }
544 
545     ///ditto
546     alias opCast = get;
547 
548     /// Convert the CustomFloat to a real and perform the relevant operator on the result
549     real opUnary(string op)()
550         if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--")
551     {
552         static if (op=="++" || op=="--")
553         {
554             auto result = get!real;
555             this = mixin(op~`result`);
556             return result;
557         }
558         else
559             return mixin(op~`get!real`);
560     }
561 
562     /// ditto
563     // Define an opBinary `CustomFloat op CustomFloat` so that those below
564     // do not match equally, which is disallowed by the spec:
565     // https://dlang.org/spec/operatoroverloading.html#binary
566     real opBinary(string op,T)(T b)
567          if (__traits(compiles, mixin(`get!real`~op~`b.get!real`)))
568      {
569          return mixin(`get!real`~op~`b.get!real`);
570      }
571 
572     /// ditto
573     real opBinary(string op,T)(T b)
574         if ( __traits(compiles, mixin(`get!real`~op~`b`)) &&
575             !__traits(compiles, mixin(`get!real`~op~`b.get!real`)))
576     {
577         return mixin(`get!real`~op~`b`);
578     }
579 
580     /// ditto
581     real opBinaryRight(string op,T)(T a)
582         if ( __traits(compiles, mixin(`a`~op~`get!real`)) &&
583             !__traits(compiles, mixin(`get!real`~op~`b`)) &&
584             !__traits(compiles, mixin(`get!real`~op~`b.get!real`)))
585     {
586         return mixin(`a`~op~`get!real`);
587     }
588 
589     /// ditto
590     int opCmp(T)(auto ref T b)
591         if (__traits(compiles, cast(real) b))
592     {
593         auto x = get!real;
594         auto y = cast(real) b;
595         return  (x >= y)-(x <= y);
596     }
597 
598     /// ditto
599     void opOpAssign(string op, T)(auto ref T b)
600         if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`)))
601     {
602         return mixin(`this = this `~op~` cast(real) b`);
603     }
604 
605     /// ditto
toString()606     template toString()
607     {
608         import std.format.spec : FormatSpec;
609         import std.format.write : formatValue;
610         // Needs to be a template because of https://issues.dlang.org/show_bug.cgi?id=13737.
611         void toString()(scope void delegate(const(char)[]) sink, scope const ref FormatSpec!char fmt)
612         {
613             sink.formatValue(get!real, fmt);
614         }
615     }
616 }
617 
618 @safe unittest
619 {
620     import std.meta;
621     alias FPTypes =
622         AliasSeq!(
623             CustomFloat!(5, 10),
624             CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed),
625             CustomFloat!(1, 7, CustomFloatFlags.ieee ^ CustomFloatFlags.signed),
626             CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed)
627         );
628 
foreach(F;FPTypes)629     foreach (F; FPTypes)
630     {
631         auto x = F(0.125);
632         assert(x.get!float == 0.125F);
633         assert(x.get!double == 0.125);
634 
635         x -= 0.0625;
636         assert(x.get!float == 0.0625F);
637         assert(x.get!double == 0.0625);
638 
639         x *= 2;
640         assert(x.get!float == 0.125F);
641         assert(x.get!double == 0.125);
642 
643         x /= 4;
644         assert(x.get!float == 0.03125);
645         assert(x.get!double == 0.03125);
646 
647         x = 0.5;
648         x ^^= 4;
649         assert(x.get!float == 1 / 16.0F);
650         assert(x.get!double == 1 / 16.0);
651     }
652 }
653 
654 @system unittest
655 {
656     // @system due to to!string(CustomFloat)
657     import std.conv;
658     CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125);
659     assert(y.to!string == "0.125");
660 }
661 
662 @safe unittest
663 {
664     alias cf = CustomFloat!(5, 2);
665 
666     auto a = cf.infinity;
667     assert(a.sign == 0);
668     assert(a.exponent == 3);
669     assert(a.significand == 0);
670 
671     auto b = cf.nan;
672     assert(b.exponent == 3);
673     assert(b.significand != 0);
674 
675     assert(cf.dig == 1);
676 
677     auto c = cf.epsilon;
678     assert(c.sign == 0);
679     assert(c.exponent == 0);
680     assert(c.significand == 1);
681 
682     assert(cf.mant_dig == 6);
683 
684     assert(cf.max_10_exp == 0);
685     assert(cf.max_exp == 2);
686     assert(cf.min_10_exp == 0);
687     assert(cf.min_exp == 1);
688 
689     auto d = cf.max;
690     assert(d.sign == 0);
691     assert(d.exponent == 2);
692     assert(d.significand == 31);
693 
694     auto e = cf.min_normal;
695     assert(e.sign == 0);
696     assert(e.exponent == 1);
697     assert(e.significand == 0);
698 
699     assert(e.re == e);
700     assert(e.im == cf(0.0));
701 }
702 
703 // check whether CustomFloats identical to float/double behave like float/double
704 @safe unittest
705 {
706     import std.conv : to;
707 
708     alias myFloat = CustomFloat!(23, 8);
709 
710     static assert(myFloat.dig == float.dig);
711     static assert(myFloat.mant_dig == float.mant_dig);
712     assert(myFloat.max_10_exp == float.max_10_exp);
713     static assert(myFloat.max_exp == float.max_exp);
714     assert(myFloat.min_10_exp == float.min_10_exp);
715     static assert(myFloat.min_exp == float.min_exp);
716     assert(to!float(myFloat.epsilon) == float.epsilon);
717     assert(to!float(myFloat.max) == float.max);
718     assert(to!float(myFloat.min_normal) == float.min_normal);
719 
720     alias myDouble = CustomFloat!(52, 11);
721 
722     static assert(myDouble.dig == double.dig);
723     static assert(myDouble.mant_dig == double.mant_dig);
724     assert(myDouble.max_10_exp == double.max_10_exp);
725     static assert(myDouble.max_exp == double.max_exp);
726     assert(myDouble.min_10_exp == double.min_10_exp);
727     static assert(myDouble.min_exp == double.min_exp);
728     assert(to!double(myDouble.epsilon) == double.epsilon);
729     assert(to!double(myDouble.max) == double.max);
730     assert(to!double(myDouble.min_normal) == double.min_normal);
731 }
732 
733 // testing .dig
734 @safe unittest
735 {
736     static assert(CustomFloat!(1, 6).dig == 0);
737     static assert(CustomFloat!(9, 6).dig == 2);
738     static assert(CustomFloat!(10, 5).dig == 3);
739     static assert(CustomFloat!(10, 6, CustomFloatFlags.none).dig == 2);
740     static assert(CustomFloat!(11, 5, CustomFloatFlags.none).dig == 3);
741     static assert(CustomFloat!(64, 7).dig == 19);
742 }
743 
744 // testing .mant_dig
745 @safe unittest
746 {
747     static assert(CustomFloat!(10, 5).mant_dig == 11);
748     static assert(CustomFloat!(10, 6, CustomFloatFlags.none).mant_dig == 10);
749 }
750 
751 // testing .max_exp
752 @safe unittest
753 {
754     static assert(CustomFloat!(1, 6).max_exp == 2^^5);
755     static assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_exp == 2^^5);
756     static assert(CustomFloat!(5, 10).max_exp == 2^^9);
757     static assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_exp == 2^^9);
758     static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_exp == 2^^5);
759     static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_exp == 2^^9);
760 }
761 
762 // testing .min_exp
763 @safe unittest
764 {
765     static assert(CustomFloat!(1, 6).min_exp == -2^^5+3);
766     static assert(CustomFloat!(5, 10).min_exp == -2^^9+3);
767     static assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_exp == -2^^5+1);
768     static assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_exp == -2^^9+1);
769     static assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_exp == -2^^5+2);
770     static assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_exp == -2^^9+2);
771     static assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_exp == -2^^5+2);
772     static assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_exp == -2^^9+2);
773 }
774 
775 // testing .max_10_exp
776 @safe unittest
777 {
778     assert(CustomFloat!(1, 6).max_10_exp == 9);
779     assert(CustomFloat!(5, 10).max_10_exp == 154);
780     assert(CustomFloat!(2, 6, CustomFloatFlags.none).max_10_exp == 9);
781     assert(CustomFloat!(6, 10, CustomFloatFlags.none).max_10_exp == 154);
782     assert(CustomFloat!(2, 6, CustomFloatFlags.nan).max_10_exp == 9);
783     assert(CustomFloat!(6, 10, CustomFloatFlags.nan).max_10_exp == 154);
784 }
785 
786 // testing .min_10_exp
787 @safe unittest
788 {
789     assert(CustomFloat!(1, 6).min_10_exp == -9);
790     assert(CustomFloat!(5, 10).min_10_exp == -153);
791     assert(CustomFloat!(2, 6, CustomFloatFlags.none).min_10_exp == -9);
792     assert(CustomFloat!(6, 10, CustomFloatFlags.none).min_10_exp == -154);
793     assert(CustomFloat!(2, 6, CustomFloatFlags.nan).min_10_exp == -9);
794     assert(CustomFloat!(6, 10, CustomFloatFlags.nan).min_10_exp == -153);
795     assert(CustomFloat!(2, 6, CustomFloatFlags.allowDenorm).min_10_exp == -9);
796     assert(CustomFloat!(6, 10, CustomFloatFlags.allowDenorm).min_10_exp == -153);
797 }
798 
799 // testing .epsilon
800 @safe unittest
801 {
802     assert(CustomFloat!(1,6).epsilon.sign == 0);
803     assert(CustomFloat!(1,6).epsilon.exponent == 30);
804     assert(CustomFloat!(1,6).epsilon.significand == 0);
805     assert(CustomFloat!(2,5).epsilon.sign == 0);
806     assert(CustomFloat!(2,5).epsilon.exponent == 13);
807     assert(CustomFloat!(2,5).epsilon.significand == 0);
808     assert(CustomFloat!(3,4).epsilon.sign == 0);
809     assert(CustomFloat!(3,4).epsilon.exponent == 4);
810     assert(CustomFloat!(3,4).epsilon.significand == 0);
811     // the following epsilons are only available, when denormalized numbers are allowed:
812     assert(CustomFloat!(4,3).epsilon.sign == 0);
813     assert(CustomFloat!(4,3).epsilon.exponent == 0);
814     assert(CustomFloat!(4,3).epsilon.significand == 4);
815     assert(CustomFloat!(5,2).epsilon.sign == 0);
816     assert(CustomFloat!(5,2).epsilon.exponent == 0);
817     assert(CustomFloat!(5,2).epsilon.significand == 1);
818 }
819 
820 // testing .max
821 @safe unittest
822 {
823     static assert(CustomFloat!(5,2).max.sign == 0);
824     static assert(CustomFloat!(5,2).max.exponent == 2);
825     static assert(CustomFloat!(5,2).max.significand == 31);
826     static assert(CustomFloat!(4,3).max.sign == 0);
827     static assert(CustomFloat!(4,3).max.exponent == 6);
828     static assert(CustomFloat!(4,3).max.significand == 15);
829     static assert(CustomFloat!(3,4).max.sign == 0);
830     static assert(CustomFloat!(3,4).max.exponent == 14);
831     static assert(CustomFloat!(3,4).max.significand == 7);
832     static assert(CustomFloat!(2,5).max.sign == 0);
833     static assert(CustomFloat!(2,5).max.exponent == 30);
834     static assert(CustomFloat!(2,5).max.significand == 3);
835     static assert(CustomFloat!(1,6).max.sign == 0);
836     static assert(CustomFloat!(1,6).max.exponent == 62);
837     static assert(CustomFloat!(1,6).max.significand == 1);
838     static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.exponent == 31);
839     static assert(CustomFloat!(3,5, CustomFloatFlags.none).max.significand == 7);
840 }
841 
842 // testing .min_normal
843 @safe unittest
844 {
845     static assert(CustomFloat!(5,2).min_normal.sign == 0);
846     static assert(CustomFloat!(5,2).min_normal.exponent == 1);
847     static assert(CustomFloat!(5,2).min_normal.significand == 0);
848     static assert(CustomFloat!(4,3).min_normal.sign == 0);
849     static assert(CustomFloat!(4,3).min_normal.exponent == 1);
850     static assert(CustomFloat!(4,3).min_normal.significand == 0);
851     static assert(CustomFloat!(3,4).min_normal.sign == 0);
852     static assert(CustomFloat!(3,4).min_normal.exponent == 1);
853     static assert(CustomFloat!(3,4).min_normal.significand == 0);
854     static assert(CustomFloat!(2,5).min_normal.sign == 0);
855     static assert(CustomFloat!(2,5).min_normal.exponent == 1);
856     static assert(CustomFloat!(2,5).min_normal.significand == 0);
857     static assert(CustomFloat!(1,6).min_normal.sign == 0);
858     static assert(CustomFloat!(1,6).min_normal.exponent == 1);
859     static assert(CustomFloat!(1,6).min_normal.significand == 0);
860     static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.exponent == 0);
861     static assert(CustomFloat!(3,5, CustomFloatFlags.none).min_normal.significand == 4);
862 }
863 
864 @safe unittest
865 {
866     import std.math.traits : isNaN;
867 
868     alias cf = CustomFloat!(5, 2);
869 
870     auto f = cf.nan.get!float();
871     assert(isNaN(f));
872 
873     cf a;
874     a = real.max;
875     assert(a == cf.infinity);
876 
877     a = 0.015625;
878     assert(a.exponent == 0);
879     assert(a.significand == 0);
880 
881     a = 0.984375;
882     assert(a.exponent == 1);
883     assert(a.significand == 0);
884 }
885 
886 @system unittest
887 {
888     import std.exception : assertThrown;
889     import core.exception : AssertError;
890 
891     alias cf = CustomFloat!(3, 5, CustomFloatFlags.none);
892 
893     cf a;
894     assertThrown!AssertError(a = real.max);
895 }
896 
897 @system unittest
898 {
899     import std.exception : assertThrown;
900     import core.exception : AssertError;
901 
902     alias cf = CustomFloat!(3, 5, CustomFloatFlags.nan);
903 
904     cf a;
905     assertThrown!AssertError(a = real.max);
906 }
907 
908 @system unittest
909 {
910     import std.exception : assertThrown;
911     import core.exception : AssertError;
912 
913     alias cf = CustomFloat!(24, 8, CustomFloatFlags.none);
914 
915     cf a;
916     assertThrown!AssertError(a = float.infinity);
917 }
918 
isCorrectCustomFloat(uint precision,uint exponentWidth,CustomFloatFlags flags)919 private bool isCorrectCustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags) @safe pure nothrow @nogc
920 {
921     // Restrictions from bitfield
922     // due to CustomFloat!80 support hack precision with 64 bits is handled specially
923     auto length = (flags & flags.signed) + exponentWidth + ((precision == 64) ? 0 : precision);
924     if (length != 8 && length != 16 && length != 32 && length != 64) return false;
925 
926     // mantissa needs to fit into real mantissa
927     if (precision > real.mant_dig - 1 && precision != 64) return false;
928 
929     // exponent needs to fit into real exponent
930     if (1L << exponentWidth - 1 > real.max_exp) return false;
931 
932     // mantissa should have at least one bit
933     if (precision == 0) return false;
934 
935     // exponent should have at least one bit, in some cases two
936     if (exponentWidth <= ((flags & (flags.allowDenorm | flags.infinity | flags.nan)) != 0)) return false;
937 
938     return true;
939 }
940 
941 @safe pure nothrow @nogc unittest
942 {
943     assert(isCorrectCustomFloat(3,4,CustomFloatFlags.ieee));
944     assert(isCorrectCustomFloat(3,5,CustomFloatFlags.none));
945     assert(!isCorrectCustomFloat(3,3,CustomFloatFlags.ieee));
946     assert(isCorrectCustomFloat(64,7,CustomFloatFlags.ieee));
947     assert(!isCorrectCustomFloat(64,4,CustomFloatFlags.ieee));
948     assert(!isCorrectCustomFloat(508,3,CustomFloatFlags.ieee));
949     assert(!isCorrectCustomFloat(3,100,CustomFloatFlags.ieee));
950     assert(!isCorrectCustomFloat(0,7,CustomFloatFlags.ieee));
951     assert(!isCorrectCustomFloat(6,1,CustomFloatFlags.ieee));
952     assert(isCorrectCustomFloat(7,1,CustomFloatFlags.none));
953     assert(!isCorrectCustomFloat(8,0,CustomFloatFlags.none));
954 }
955 
956 /**
957 Defines the fastest type to use when storing temporaries of a
958 calculation intended to ultimately yield a result of type `F`
959 (where `F` must be one of `float`, `double`, or $(D
960 real)). When doing a multi-step computation, you may want to store
961 intermediate results as `FPTemporary!F`.
962 
963 The necessity of `FPTemporary` stems from the optimized
964 floating-point operations and registers present in virtually all
965 processors. When adding numbers in the example above, the addition may
966 in fact be done in `real` precision internally. In that case,
967 storing the intermediate `result` in $(D double format) is not only
968 less precise, it is also (surprisingly) slower, because a conversion
969 from `real` to `double` is performed every pass through the
970 loop. This being a lose-lose situation, `FPTemporary!F` has been
971 defined as the $(I fastest) type to use for calculations at precision
972 `F`. There is no need to define a type for the $(I most accurate)
973 calculations, as that is always `real`.
974 
975 Finally, there is no guarantee that using `FPTemporary!F` will
976 always be fastest, as the speed of floating-point calculations depends
977 on very many factors.
978  */
979 template FPTemporary(F)
980 if (isFloatingPoint!F)
981 {
982     version (X86)
983         alias FPTemporary = real;
984     else
985         alias FPTemporary = Unqual!F;
986 }
987 
988 ///
989 @safe unittest
990 {
991     import std.math.operations : isClose;
992 
993     // Average numbers in an array
avg(in double[]a)994     double avg(in double[] a)
995     {
996         if (a.length == 0) return 0;
997         FPTemporary!double result = 0;
998         foreach (e; a) result += e;
999         return result / a.length;
1000     }
1001 
1002     auto a = [1.0, 2.0, 3.0];
1003     assert(isClose(avg(a), 2));
1004 }
1005 
1006 /**
1007 Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a
1008 root of the function `fun` starting from points $(D [xn_1, x_n])
1009 (ideally close to the root). `Num` may be `float`, `double`,
1010 or `real`.
1011 */
secantMethod(alias fun)1012 template secantMethod(alias fun)
1013 {
1014     import std.functional : unaryFun;
1015     Num secantMethod(Num)(Num xn_1, Num xn)
1016     {
1017         auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn;
1018         typeof(fxn) fxn_1;
1019 
1020         xn = xn_1;
1021         while (!isClose(d, 0, 0.0, 1e-5) && isFinite(d))
1022         {
1023             xn_1 = xn;
1024             xn -= d;
1025             fxn_1 = fxn;
1026             fxn = unaryFun!(fun)(xn);
1027             d *= -fxn / (fxn - fxn_1);
1028         }
1029         return xn;
1030     }
1031 }
1032 
1033 ///
1034 @safe unittest
1035 {
1036     import std.math.operations : isClose;
1037     import std.math.trigonometry : cos;
1038 
f(float x)1039     float f(float x)
1040     {
1041         return cos(x) - x*x*x;
1042     }
1043     auto x = secantMethod!(f)(0f, 1f);
1044     assert(isClose(x, 0.865474));
1045 }
1046 
1047 @system unittest
1048 {
1049     // @system because of __gshared stderr
1050     import std.stdio;
1051     scope(failure) stderr.writeln("Failure testing secantMethod");
f(float x)1052     float f(float x)
1053     {
1054         return cos(x) - x*x*x;
1055     }
1056     immutable x = secantMethod!(f)(0f, 1f);
1057     assert(isClose(x, 0.865474));
1058     auto d = &f;
1059     immutable y = secantMethod!(d)(0f, 1f);
1060     assert(isClose(y, 0.865474));
1061 }
1062 
1063 
1064 /**
1065  * Return true if a and b have opposite sign.
1066  */
oppositeSigns(T1,T2)1067 private bool oppositeSigns(T1, T2)(T1 a, T2 b)
1068 {
1069     return signbit(a) != signbit(b);
1070 }
1071 
1072 public:
1073 
1074 /**  Find a real root of a real function f(x) via bracketing.
1075  *
1076  * Given a function `f` and a range `[a .. b]` such that `f(a)`
1077  * and `f(b)` have opposite signs or at least one of them equals ±0,
1078  * returns the value of `x` in
1079  * the range which is closest to a root of `f(x)`.  If `f(x)`
1080  * has more than one root in the range, one will be chosen
1081  * arbitrarily.  If `f(x)` returns NaN, NaN will be returned;
1082  * otherwise, this algorithm is guaranteed to succeed.
1083  *
1084  * Uses an algorithm based on TOMS748, which uses inverse cubic
1085  * interpolation whenever possible, otherwise reverting to parabolic
1086  * or secant interpolation. Compared to TOMS748, this implementation
1087  * improves worst-case performance by a factor of more than 100, and
1088  * typical performance by a factor of 2. For 80-bit reals, most
1089  * problems require 8 to 15 calls to `f(x)` to achieve full machine
1090  * precision. The worst-case performance (pathological cases) is
1091  * approximately twice the number of bits.
1092  *
1093  * References: "On Enclosing Simple Roots of Nonlinear Equations",
1094  * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61,
1095  * pp733-744 (1993).  Fortran code available from $(HTTP
1096  * www.netlib.org,www.netlib.org) as algorithm TOMS478.
1097  *
1098  */
1099 T findRoot(T, DF, DT)(scope DF f, const T a, const T b,
1100     scope DT tolerance) //= (T a, T b) => false)
1101 if (
1102     isFloatingPoint!T &&
1103     is(typeof(tolerance(T.init, T.init)) : bool) &&
1104     is(typeof(f(T.init)) == R, R) && isFloatingPoint!R
1105     )
1106 {
1107     immutable fa = f(a);
1108     if (fa == 0)
1109         return a;
1110     immutable fb = f(b);
1111     if (fb == 0)
1112         return b;
1113     immutable r = findRoot(f, a, b, fa, fb, tolerance);
1114     // Return the first value if it is smaller or NaN
1115     return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1];
1116 }
1117 
1118 ///ditto
findRoot(T,DF)1119 T findRoot(T, DF)(scope DF f, const T a, const T b)
1120 {
1121     return findRoot(f, a, b, (T a, T b) => false);
1122 }
1123 
1124 /** Find root of a real function f(x) by bracketing, allowing the
1125  * termination condition to be specified.
1126  *
1127  * Params:
1128  *
1129  * f = Function to be analyzed
1130  *
1131  * ax = Left bound of initial range of `f` known to contain the
1132  * root.
1133  *
1134  * bx = Right bound of initial range of `f` known to contain the
1135  * root.
1136  *
1137  * fax = Value of `f(ax)`.
1138  *
1139  * fbx = Value of `f(bx)`. `fax` and `fbx` should have opposite signs.
1140  * (`f(ax)` and `f(bx)` are commonly known in advance.)
1141  *
1142  *
1143  * tolerance = Defines an early termination condition. Receives the
1144  *             current upper and lower bounds on the root. The
1145  *             delegate must return `true` when these bounds are
1146  *             acceptable. If this function always returns `false`,
1147  *             full machine precision will be achieved.
1148  *
1149  * Returns:
1150  *
1151  * A tuple consisting of two ranges. The first two elements are the
1152  * range (in `x`) of the root, while the second pair of elements
1153  * are the corresponding function values at those points. If an exact
1154  * root was found, both of the first two elements will contain the
1155  * root, and the second pair of elements will be 0.
1156  */
1157 Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f,
1158     const T ax, const T bx, const R fax, const R fbx,
1159     scope DT tolerance) // = (T a, T b) => false)
1160 if (
1161     isFloatingPoint!T &&
1162     is(typeof(tolerance(T.init, T.init)) : bool) &&
1163     is(typeof(f(T.init)) == R) && isFloatingPoint!R
1164     )
1165 in
1166 {
1167     assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN");
1168     assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root.");
1169 }
1170 do
1171 {
1172     // Author: Don Clugston. This code is (heavily) modified from TOMS748
1173     // (www.netlib.org).  The changes to improve the worst-cast performance are
1174     // entirely original.
1175 
1176     T a, b, d;  // [a .. b] is our current bracket. d is the third best guess.
1177     R fa, fb, fd; // Values of f at a, b, d.
1178     bool done = false; // Has a root been found?
1179 
1180     // Allow ax and bx to be provided in reverse order
1181     if (ax <= bx)
1182     {
1183         a = ax; fa = fax;
1184         b = bx; fb = fbx;
1185     }
1186     else
1187     {
1188         a = bx; fa = fbx;
1189         b = ax; fb = fax;
1190     }
1191 
1192     // Test the function at point c; update brackets accordingly
bracket(T c)1193     void bracket(T c)
1194     {
1195         R fc = f(c);
1196         if (fc == 0 || fc.isNaN()) // Exact solution, or NaN
1197         {
1198             a = c;
1199             fa = fc;
1200             d = c;
1201             fd = fc;
1202             done = true;
1203             return;
1204         }
1205 
1206         // Determine new enclosing interval
1207         if (signbit(fa) != signbit(fc))
1208         {
1209             d = b;
1210             fd = fb;
1211             b = c;
1212             fb = fc;
1213         }
1214         else
1215         {
1216             d = a;
1217             fd = fa;
1218             a = c;
1219             fa = fc;
1220         }
1221     }
1222 
1223    /* Perform a secant interpolation. If the result would lie on a or b, or if
1224      a and b differ so wildly in magnitude that the result would be meaningless,
1225      perform a bisection instead.
1226     */
secant_interpolate(T a,T b,R fa,R fb)1227     static T secant_interpolate(T a, T b, R fa, R fb)
1228     {
1229         if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b)))
1230         {
1231             // Catastrophic cancellation
1232             if (a == 0)
1233                 a = copysign(T(0), b);
1234             else if (b == 0)
1235                 b = copysign(T(0), a);
1236             else if (signbit(a) != signbit(b))
1237                 return 0;
1238             T c = ieeeMean(a, b);
1239             return c;
1240         }
1241         // avoid overflow
1242         if (b - a > T.max)
1243             return b / 2 + a / 2;
1244         if (fb - fa > R.max)
1245             return a - (b - a) / 2;
1246         T c = a - (fa / (fb - fa)) * (b - a);
1247         if (c == a || c == b)
1248             return (a + b) / 2;
1249         return c;
1250     }
1251 
1252     /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the
1253        quadratic polynomial interpolating f(x) at a, b, and d.
1254        Returns:
1255          The approximate zero in [a .. b] of the quadratic polynomial.
1256     */
newtonQuadratic(int numsteps)1257     T newtonQuadratic(int numsteps)
1258     {
1259         // Find the coefficients of the quadratic polynomial.
1260         immutable T a0 = fa;
1261         immutable T a1 = (fb - fa)/(b - a);
1262         immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a);
1263 
1264         // Determine the starting point of newton steps.
1265         T c = oppositeSigns(a2, fa) ? a  : b;
1266 
1267         // start the safeguarded newton steps.
1268         foreach (int i; 0 .. numsteps)
1269         {
1270             immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a);
1271             immutable T pdc = a1 + a2*((2 * c) - (a + b));
1272             if (pdc == 0)
1273                 return a - a0 / a1;
1274             else
1275                 c = c - pc / pdc;
1276         }
1277         return c;
1278     }
1279 
1280     // On the first iteration we take a secant step:
1281     if (fa == 0 || fa.isNaN())
1282     {
1283         done = true;
1284         b = a;
1285         fb = fa;
1286     }
1287     else if (fb == 0 || fb.isNaN())
1288     {
1289         done = true;
1290         a = b;
1291         fa = fb;
1292     }
1293     else
1294     {
1295         bracket(secant_interpolate(a, b, fa, fb));
1296     }
1297 
1298     // Starting with the second iteration, higher-order interpolation can
1299     // be used.
1300     int itnum = 1;   // Iteration number
1301     int baditer = 1; // Num bisections to take if an iteration is bad.
1302     T c, e;  // e is our fourth best guess
1303     R fe;
1304 
1305 whileloop:
1306     while (!done && (b != nextUp(a)) && !tolerance(a, b))
1307     {
1308         T a0 = a, b0 = b; // record the brackets
1309 
1310         // Do two higher-order (cubic or parabolic) interpolation steps.
1311         foreach (int QQ; 0 .. 2)
1312         {
1313             // Cubic inverse interpolation requires that
1314             // all four function values fa, fb, fd, and fe are distinct;
1315             // otherwise use quadratic interpolation.
1316             bool distinct = (fa != fb) && (fa != fd) && (fa != fe)
1317                          && (fb != fd) && (fb != fe) && (fd != fe);
1318             // The first time, cubic interpolation is impossible.
1319             if (itnum<2) distinct = false;
1320             bool ok = distinct;
1321             if (distinct)
1322             {
1323                 // Cubic inverse interpolation of f(x) at a, b, d, and e
1324                 immutable q11 = (d - e) * fd / (fe - fd);
1325                 immutable q21 = (b - d) * fb / (fd - fb);
1326                 immutable q31 = (a - b) * fa / (fb - fa);
1327                 immutable d21 = (b - d) * fd / (fd - fb);
1328                 immutable d31 = (a - b) * fb / (fb - fa);
1329 
1330                 immutable q22 = (d21 - q11) * fb / (fe - fb);
1331                 immutable q32 = (d31 - q21) * fa / (fd - fa);
1332                 immutable d32 = (d31 - q21) * fd / (fd - fa);
1333                 immutable q33 = (d32 - q22) * fa / (fe - fa);
1334                 c = a + (q31 + q32 + q33);
1335                 if (c.isNaN() || (c <= a) || (c >= b))
1336                 {
1337                     // DAC: If the interpolation predicts a or b, it's
1338                     // probable that it's the actual root. Only allow this if
1339                     // we're already close to the root.
1340                     if (c == a && a - b != a)
1341                     {
1342                         c = nextUp(a);
1343                     }
1344                     else if (c == b && a - b != -b)
1345                     {
1346                         c = nextDown(b);
1347                     }
1348                     else
1349                     {
1350                         ok = false;
1351                     }
1352                 }
1353             }
1354             if (!ok)
1355             {
1356                 // DAC: Alefeld doesn't explain why the number of newton steps
1357                 // should vary.
1358                 c = newtonQuadratic(distinct ? 3 : 2);
1359                 if (c.isNaN() || (c <= a) || (c >= b))
1360                 {
1361                     // Failure, try a secant step:
1362                     c = secant_interpolate(a, b, fa, fb);
1363                 }
1364             }
1365             ++itnum;
1366             e = d;
1367             fe = fd;
1368             bracket(c);
1369             if (done || ( b == nextUp(a)) || tolerance(a, b))
1370                 break whileloop;
1371             if (itnum == 2)
1372                 continue whileloop;
1373         }
1374 
1375         // Now we take a double-length secant step:
1376         T u;
1377         R fu;
1378         if (fabs(fa) < fabs(fb))
1379         {
1380             u = a;
1381             fu = fa;
1382         }
1383         else
1384         {
1385             u = b;
1386             fu = fb;
1387         }
1388         c = u - 2 * (fu / (fb - fa)) * (b - a);
1389 
1390         // DAC: If the secant predicts a value equal to an endpoint, it's
1391         // probably false.
1392         if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2)
1393         {
1394             if ((a-b) == a || (b-a) == b)
1395             {
1396                 if ((a>0 && b<0) || (a<0 && b>0))
1397                     c = 0;
1398                 else
1399                 {
1400                     if (a == 0)
1401                         c = ieeeMean(copysign(T(0), b), b);
1402                     else if (b == 0)
1403                         c = ieeeMean(copysign(T(0), a), a);
1404                     else
1405                         c = ieeeMean(a, b);
1406                 }
1407             }
1408             else
1409             {
1410                 c = a + (b - a) / 2;
1411             }
1412         }
1413         e = d;
1414         fe = fd;
1415         bracket(c);
1416         if (done || (b == nextUp(a)) || tolerance(a, b))
1417             break;
1418 
1419         // IMPROVE THE WORST-CASE PERFORMANCE
1420         // We must ensure that the bounds reduce by a factor of 2
1421         // in binary space! every iteration. If we haven't achieved this
1422         // yet, or if we don't yet know what the exponent is,
1423         // perform a binary chop.
1424 
1425         if ((a == 0 || b == 0 ||
1426             (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a)))
1427             &&  (b - a) < T(0.25) * (b0 - a0))
1428         {
1429             baditer = 1;
1430             continue;
1431         }
1432 
1433         // DAC: If this happens on consecutive iterations, we probably have a
1434         // pathological function. Perform a number of bisections equal to the
1435         // total number of consecutive bad iterations.
1436 
1437         if ((b - a) < T(0.25) * (b0 - a0))
1438             baditer = 1;
1439         foreach (int QQ; 0 .. baditer)
1440         {
1441             e = d;
1442             fe = fd;
1443 
1444             T w;
1445             if ((a>0 && b<0) || (a<0 && b>0))
1446                 w = 0;
1447             else
1448             {
1449                 T usea = a;
1450                 T useb = b;
1451                 if (a == 0)
1452                     usea = copysign(T(0), b);
1453                 else if (b == 0)
1454                     useb = copysign(T(0), a);
1455                 w = ieeeMean(usea, useb);
1456             }
1457             bracket(w);
1458         }
1459         ++baditer;
1460     }
1461     return Tuple!(T, T, R, R)(a, b, fa, fb);
1462 }
1463 
1464 ///ditto
1465 Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f,
1466     const T ax, const T bx, const R fax, const R fbx)
1467 {
1468     return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false);
1469 }
1470 
1471 ///ditto
findRoot(T,R)1472 T findRoot(T, R)(scope R delegate(T) f, const T a, const T b,
1473     scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false)
1474 {
1475     return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance);
1476 }
1477 
1478 @safe nothrow unittest
1479 {
1480     int numProblems = 0;
1481     int numCalls;
1482 
testFindRoot(real delegate (real)@nogc@safe nothrow pure f,real x1,real x2)1483     void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure
1484     {
1485         //numCalls=0;
1486         //++numProblems;
1487         assert(!x1.isNaN() && !x2.isNaN());
1488         assert(signbit(f(x1)) != signbit(f(x2)));
1489         auto result = findRoot(f, x1, x2, f(x1), f(x2),
1490           (real lo, real hi) { return false; });
1491 
1492         auto flo = f(result[0]);
1493         auto fhi = f(result[1]);
1494         if (flo != 0)
1495         {
1496             assert(oppositeSigns(flo, fhi));
1497         }
1498     }
1499 
1500     // Test functions
cubicfn(real x)1501     real cubicfn(real x) @nogc @safe nothrow pure
1502     {
1503         //++numCalls;
1504         if (x>float.max)
1505             x = float.max;
1506         if (x<-float.max)
1507             x = -float.max;
1508         // This has a single real root at -59.286543284815
1509         return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2;
1510     }
1511     // Test a function with more than one root.
multisine(real x)1512     real multisine(real x) { ++numCalls; return sin(x); }
1513     testFindRoot( &multisine, 6, 90);
1514     testFindRoot(&cubicfn, -100, 100);
1515     testFindRoot( &cubicfn, -double.max, real.max);
1516 
1517 
1518 /* Tests from the paper:
1519  * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra,
1520  *   Yixun Shi, Mathematics of Computation 61, pp733-744 (1993).
1521  */
1522     // Parameters common to many alefeld tests.
1523     int n;
1524     real ale_a, ale_b;
1525 
1526     int powercalls = 0;
1527 
power(real x)1528     real power(real x)
1529     {
1530         ++powercalls;
1531         ++numCalls;
1532         return pow(x, n) + double.min_normal;
1533     }
1534     int [] power_nvals = [3, 5, 7, 9, 19, 25];
1535     // Alefeld paper states that pow(x,n) is a very poor case, where bisection
1536     // outperforms his method, and gives total numcalls =
1537     // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit),
1538     // 2624 for brent (6.8/bit)
1539     // ... but that is for double, not real80.
1540     // This poor performance seems mainly due to catastrophic cancellation,
1541     // which is avoided here by the use of ieeeMean().
1542     // I get: 231 (0.48/bit).
1543     // IE this is 10X faster in Alefeld's worst case
1544     numProblems=0;
foreach(k;power_nvals)1545     foreach (k; power_nvals)
1546     {
1547         n = k;
1548         testFindRoot(&power, -1, 10);
1549     }
1550 
1551     int powerProblems = numProblems;
1552 
1553     // Tests from Alefeld paper
1554 
1555     int [9] alefeldSums;
alefeld0(real x)1556     real alefeld0(real x)
1557     {
1558         ++alefeldSums[0];
1559         ++numCalls;
1560         real q =  sin(x) - x/2;
1561         for (int i=1; i<20; ++i)
1562             q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i));
1563         return q;
1564     }
alefeld1(real x)1565     real alefeld1(real x)
1566     {
1567         ++numCalls;
1568         ++alefeldSums[1];
1569         return ale_a*x + exp(ale_b * x);
1570     }
alefeld2(real x)1571     real alefeld2(real x)
1572     {
1573         ++numCalls;
1574         ++alefeldSums[2];
1575         return pow(x, n) - ale_a;
1576     }
alefeld3(real x)1577     real alefeld3(real x)
1578     {
1579         ++numCalls;
1580         ++alefeldSums[3];
1581         return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2);
1582     }
alefeld4(real x)1583     real alefeld4(real x)
1584     {
1585         ++numCalls;
1586         ++alefeldSums[4];
1587         return x*x - pow(1-x, n);
1588     }
alefeld5(real x)1589     real alefeld5(real x)
1590     {
1591         ++numCalls;
1592         ++alefeldSums[5];
1593         return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4);
1594     }
alefeld6(real x)1595     real alefeld6(real x)
1596     {
1597         ++numCalls;
1598         ++alefeldSums[6];
1599         return exp(-n*x)*(x-1.01L) + pow(x, n);
1600     }
alefeld7(real x)1601     real alefeld7(real x)
1602     {
1603         ++numCalls;
1604         ++alefeldSums[7];
1605         return (n*x-1)/((n-1)*x);
1606     }
1607 
1608     numProblems=0;
1609     testFindRoot(&alefeld0, PI_2, PI);
1610     for (n=1; n <= 10; ++n)
1611     {
1612         testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L);
1613     }
1614     ale_a = -40; ale_b = -1;
1615     testFindRoot(&alefeld1, -9, 31);
1616     ale_a = -100; ale_b = -2;
1617     testFindRoot(&alefeld1, -9, 31);
1618     ale_a = -200; ale_b = -3;
1619     testFindRoot(&alefeld1, -9, 31);
1620     int [] nvals_3 = [1, 2, 5, 10, 15, 20];
1621     int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20];
1622     int [] nvals_6 = [1, 5, 10, 15, 20];
1623     int [] nvals_7 = [2, 5, 15, 20];
1624 
1625     for (int i=4; i<12; i+=2)
1626     {
1627         n = i;
1628         ale_a = 0.2;
1629         testFindRoot(&alefeld2, 0, 5);
1630         ale_a=1;
1631         testFindRoot(&alefeld2, 0.95, 4.05);
1632         testFindRoot(&alefeld2, 0, 1.5);
1633     }
foreach(i;nvals_3)1634     foreach (i; nvals_3)
1635     {
1636         n=i;
1637         testFindRoot(&alefeld3, 0, 1);
1638     }
foreach(i;nvals_3)1639     foreach (i; nvals_3)
1640     {
1641         n=i;
1642         testFindRoot(&alefeld4, 0, 1);
1643     }
foreach(i;nvals_5)1644     foreach (i; nvals_5)
1645     {
1646         n=i;
1647         testFindRoot(&alefeld5, 0, 1);
1648     }
foreach(i;nvals_6)1649     foreach (i; nvals_6)
1650     {
1651         n=i;
1652         testFindRoot(&alefeld6, 0, 1);
1653     }
foreach(i;nvals_7)1654     foreach (i; nvals_7)
1655     {
1656         n=i;
1657         testFindRoot(&alefeld7, 0.01L, 1);
1658     }
worstcase(real x)1659     real worstcase(real x)
1660     {
1661         ++numCalls;
1662         return x<0.3*real.max? -0.999e-3 : 1.0;
1663     }
1664     testFindRoot(&worstcase, -real.max, real.max);
1665 
1666     // just check that the double + float cases compile
1667     findRoot((double x){ return 0.0; }, -double.max, double.max);
1668     findRoot((float x){ return 0.0f; }, -float.max, float.max);
1669 
1670 /*
1671    int grandtotal=0;
1672    foreach (calls; alefeldSums)
1673    {
1674        grandtotal+=calls;
1675    }
1676    grandtotal-=2*numProblems;
1677    printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n",
1678    grandtotal, (1.0*grandtotal)/numProblems);
1679    powercalls -= 2*powerProblems;
1680    printf("POWER TOTAL = %d avg = %f ", powercalls,
1681         (1.0*powercalls)/powerProblems);
1682 */
1683     // https://issues.dlang.org/show_bug.cgi?id=14231
1684     auto xp = findRoot((float x) => x, 0f, 1f);
1685     auto xn = findRoot((float x) => x, -1f, -0f);
1686 }
1687 
1688 //regression control
1689 @system unittest
1690 {
1691     // @system due to the case in the 2nd line
1692     static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init)));
1693     static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init)));
1694     static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init)));
1695 }
1696 
1697 /++
1698 Find a real minimum of a real function `f(x)` via bracketing.
1699 Given a function `f` and a range `(ax .. bx)`,
1700 returns the value of `x` in the range which is closest to a minimum of `f(x)`.
1701 `f` is never evaluted at the endpoints of `ax` and `bx`.
1702 If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily.
1703 If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned;
1704 otherwise, this algorithm is guaranteed to succeed.
1705 
1706 Params:
1707     f = Function to be analyzed
1708     ax = Left bound of initial range of f known to contain the minimum.
1709     bx = Right bound of initial range of f known to contain the minimum.
1710     relTolerance = Relative tolerance.
1711     absTolerance = Absolute tolerance.
1712 
1713 Preconditions:
1714     `ax` and `bx` shall be finite reals. $(BR)
1715     `relTolerance` shall be normal positive real. $(BR)
1716     `absTolerance` shall be normal positive real no less then `T.epsilon*2`.
1717 
1718 Returns:
1719     A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`.
1720 
1721     The method used is a combination of golden section search and
1722 successive parabolic interpolation. Convergence is never much slower
1723 than that for a Fibonacci search.
1724 
1725 References:
1726     "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973)
1727 
1728 See_Also: $(LREF findRoot), $(REF isNormal, std,math)
1729 +/
1730 Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error")
1731 findLocalMin(T, DF)(
1732         scope DF f,
1733         const T ax,
1734         const T bx,
1735         const T relTolerance = sqrt(T.epsilon),
1736         const T absTolerance = sqrt(T.epsilon),
1737         )
1738 if (isFloatingPoint!T
1739     && __traits(compiles, {T _ = DF.init(T.init);}))
1740 in
1741 {
1742     assert(isFinite(ax), "ax is not finite");
1743     assert(isFinite(bx), "bx is not finite");
1744     assert(isNormal(relTolerance), "relTolerance is not normal floating point number");
1745     assert(isNormal(absTolerance), "absTolerance is not normal floating point number");
1746     assert(relTolerance >= 0, "absTolerance is not positive");
1747     assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`");
1748 }
out(result)1749 out (result)
1750 {
1751     assert(isFinite(result.x));
1752 }
1753 do
1754 {
1755     alias R = Unqual!(CommonType!(ReturnType!DF, T));
1756     // c is the squared inverse of the golden ratio
1757     // (3 - sqrt(5))/2
1758     // Value obtained from Wolfram Alpha.
1759     enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L;
1760     enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L;
1761     R tolerance;
1762     T a = ax > bx ? bx : ax;
1763     T b = ax > bx ? ax : bx;
1764     // sequence of declarations suitable for SIMD instructions
1765     T  v = a * cm1 + b * c;
1766     assert(isFinite(v));
1767     R fv = f(v);
1768     if (isNaN(fv) || fv == -T.infinity)
1769     {
1770         return typeof(return)(v, fv, T.init);
1771     }
1772     T  w = v;
1773     R fw = fv;
1774     T  x = v;
1775     R fx = fv;
1776     size_t i;
1777     for (R d = 0, e = 0;;)
1778     {
1779         i++;
1780         T m = (a + b) / 2;
1781         // This fix is not part of the original algorithm
1782         if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R.
1783         {
1784             m = a / 2 + b / 2;
1785             if (!isFinite(m)) // fast-math compiler switch is enabled
1786             {
1787                 //SIMD instructions can be used by compiler, do not reduce declarations
1788                 int a_exp = void;
1789                 int b_exp = void;
1790                 immutable an = frexp(a, a_exp);
1791                 immutable bn = frexp(b, b_exp);
1792                 immutable am = ldexp(an, a_exp-1);
1793                 immutable bm = ldexp(bn, b_exp-1);
1794                 m = am + bm;
1795                 if (!isFinite(m)) // wrong input: constraints are disabled in release mode
1796                 {
1797                     return typeof(return).init;
1798                 }
1799             }
1800         }
1801         tolerance = absTolerance * fabs(x) + relTolerance;
1802         immutable t2 = tolerance * 2;
1803         // check stopping criterion
1804         if (!(fabs(x - m) > t2 - (b - a) / 2))
1805         {
1806             break;
1807         }
1808         R p = 0;
1809         R q = 0;
1810         R r = 0;
1811         // fit parabola
1812         if (fabs(e) > tolerance)
1813         {
1814             immutable  xw =  x -  w;
1815             immutable fxw = fx - fw;
1816             immutable  xv =  x -  v;
1817             immutable fxv = fx - fv;
1818             immutable xwfxv = xw * fxv;
1819             immutable xvfxw = xv * fxw;
1820             p = xv * xvfxw - xw * xwfxv;
1821             q = (xvfxw - xwfxv) * 2;
1822             if (q > 0)
1823                 p = -p;
1824             else
1825                 q = -q;
1826             r = e;
1827             e = d;
1828         }
1829         T u;
1830         // a parabolic-interpolation step
1831         if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x))
1832         {
1833             d = p / q;
1834             u = x + d;
1835             // f must not be evaluated too close to a or b
1836             if (u - a < t2 || b - u < t2)
1837                 d = x < m ? tolerance : -tolerance;
1838         }
1839         // a golden-section step
1840         else
1841         {
1842             e = (x < m ? b : a) - x;
1843             d = c * e;
1844         }
1845         // f must not be evaluated too close to x
1846         u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance);
1847         immutable fu = f(u);
1848         if (isNaN(fu) || fu == -T.infinity)
1849         {
1850             return typeof(return)(u, fu, T.init);
1851         }
1852         //  update  a, b, v, w, and x
1853         if (fu <= fx)
1854         {
1855             (u < x ? b : a) = x;
1856             v = w; fv = fw;
1857             w = x; fw = fx;
1858             x = u; fx = fu;
1859         }
1860         else
1861         {
1862             (u < x ? a : b) = u;
1863             if (fu <= fw || w == x)
1864             {
1865                 v = w; fv = fw;
1866                 w = u; fw = fu;
1867             }
1868             else if (fu <= fv || v == x || v == w)
1869             { // do not remove this braces
1870                 v = u; fv = fu;
1871             }
1872         }
1873     }
1874     return typeof(return)(x, fx, tolerance * 3);
1875 }
1876 
1877 ///
1878 @safe unittest
1879 {
1880     import std.math.operations : isClose;
1881 
1882     auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7);
1883     assert(ret.x.isClose(4.0));
1884     assert(ret.y.isClose(0.0, 0.0, 1e-10));
1885 }
1886 
1887 @safe unittest
1888 {
1889     import std.meta : AliasSeq;
1890     static foreach (T; AliasSeq!(double, float, real))
1891     {
1892         {
1893             auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7);
1894             assert(ret.x.isClose(T(4)));
1895             assert(ret.y.isClose(T(0), 0.0, T.epsilon));
1896         }
1897         {
1898             auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon);
1899             assert(isClose(ret.x, T(1)));
1900             assert(isClose(ret.y, T(0), 0.0, T.epsilon));
1901             assert(ret.error <= 10 * T.epsilon);
1902         }
1903         {
1904             auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon);
1905             assert(!ret.x.isNaN);
1906             assert(ret.y.isNaN);
1907             assert(ret.error.isNaN);
1908         }
1909         {
1910             auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon);
1911             assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal));
1912             assert(ret.x >= 0 && ret.x <= ret.error);
1913         }
1914         {
1915             auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon);
1916             assert(ret.y < -18);
1917             assert(ret.error < 5e-08);
1918             assert(ret.x >= 0 && ret.x <= ret.error);
1919         }
1920         {
1921             auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon);
1922             assert(ret.x.fabs.isClose(T(1)));
1923             assert(ret.y.fabs.isClose(T(1)));
1924             assert(ret.error.isClose(T(0), 0.0, 100*T.epsilon));
1925         }
1926     }
1927 }
1928 
1929 /**
1930 Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance,
1931 Euclidean distance) between input ranges `a` and
1932 `b`. The two ranges must have the same length. The three-parameter
1933 version stops computation as soon as the distance is greater than or
1934 equal to `limit` (this is useful to save computation if a small
1935 distance is sought).
1936  */
1937 CommonType!(ElementType!(Range1), ElementType!(Range2))
1938 euclideanDistance(Range1, Range2)(Range1 a, Range2 b)
1939 if (isInputRange!(Range1) && isInputRange!(Range2))
1940 {
1941     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1942     static if (haveLen) assert(a.length == b.length);
1943     Unqual!(typeof(return)) result = 0;
1944     for (; !a.empty; a.popFront(), b.popFront())
1945     {
1946         immutable t = a.front - b.front;
1947         result += t * t;
1948     }
1949     static if (!haveLen) assert(b.empty);
1950     return sqrt(result);
1951 }
1952 
1953 /// Ditto
1954 CommonType!(ElementType!(Range1), ElementType!(Range2))
1955 euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit)
1956 if (isInputRange!(Range1) && isInputRange!(Range2))
1957 {
1958     limit *= limit;
1959     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1960     static if (haveLen) assert(a.length == b.length);
1961     Unqual!(typeof(return)) result = 0;
1962     for (; ; a.popFront(), b.popFront())
1963     {
1964         if (a.empty)
1965         {
1966             static if (!haveLen) assert(b.empty);
1967             break;
1968         }
1969         immutable t = a.front - b.front;
1970         result += t * t;
1971         if (result >= limit) break;
1972     }
1973     return sqrt(result);
1974 }
1975 
1976 @safe unittest
1977 {
1978     import std.meta : AliasSeq;
1979     static foreach (T; AliasSeq!(double, const double, immutable double))
1980     {{
1981         T[] a = [ 1.0, 2.0, ];
1982         T[] b = [ 4.0, 6.0, ];
1983         assert(euclideanDistance(a, b) == 5);
1984         assert(euclideanDistance(a, b, 6) == 5);
1985         assert(euclideanDistance(a, b, 5) == 5);
1986         assert(euclideanDistance(a, b, 4) == 5);
1987         assert(euclideanDistance(a, b, 2) == 3);
1988     }}
1989 }
1990 
1991 /**
1992 Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product,
1993 dot product) of input ranges `a` and $(D
1994 b). The two ranges must have the same length. If both ranges define
1995 length, the check is done once; otherwise, it is done at each
1996 iteration.
1997  */
1998 CommonType!(ElementType!(Range1), ElementType!(Range2))
1999 dotProduct(Range1, Range2)(Range1 a, Range2 b)
2000 if (isInputRange!(Range1) && isInputRange!(Range2) &&
2001     !(isArray!(Range1) && isArray!(Range2)))
2002 {
2003     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2004     static if (haveLen) assert(a.length == b.length);
2005     Unqual!(typeof(return)) result = 0;
2006     for (; !a.empty; a.popFront(), b.popFront())
2007     {
2008         result += a.front * b.front;
2009     }
2010     static if (!haveLen) assert(b.empty);
2011     return result;
2012 }
2013 
2014 /// Ditto
2015 CommonType!(F1, F2)
2016 dotProduct(F1, F2)(in F1[] avector, in F2[] bvector)
2017 {
2018     immutable n = avector.length;
2019     assert(n == bvector.length);
2020     auto avec = avector.ptr, bvec = bvector.ptr;
2021     Unqual!(typeof(return)) sum0 = 0, sum1 = 0;
2022 
2023     const all_endp = avec + n;
2024     const smallblock_endp = avec + (n & ~3);
2025     const bigblock_endp = avec + (n & ~15);
2026 
2027     for (; avec != bigblock_endp; avec += 16, bvec += 16)
2028     {
2029         sum0 += avec[0] * bvec[0];
2030         sum1 += avec[1] * bvec[1];
2031         sum0 += avec[2] * bvec[2];
2032         sum1 += avec[3] * bvec[3];
2033         sum0 += avec[4] * bvec[4];
2034         sum1 += avec[5] * bvec[5];
2035         sum0 += avec[6] * bvec[6];
2036         sum1 += avec[7] * bvec[7];
2037         sum0 += avec[8] * bvec[8];
2038         sum1 += avec[9] * bvec[9];
2039         sum0 += avec[10] * bvec[10];
2040         sum1 += avec[11] * bvec[11];
2041         sum0 += avec[12] * bvec[12];
2042         sum1 += avec[13] * bvec[13];
2043         sum0 += avec[14] * bvec[14];
2044         sum1 += avec[15] * bvec[15];
2045     }
2046 
2047     for (; avec != smallblock_endp; avec += 4, bvec += 4)
2048     {
2049         sum0 += avec[0] * bvec[0];
2050         sum1 += avec[1] * bvec[1];
2051         sum0 += avec[2] * bvec[2];
2052         sum1 += avec[3] * bvec[3];
2053     }
2054 
2055     sum0 += sum1;
2056 
2057     /* Do trailing portion in naive loop. */
2058     while (avec != all_endp)
2059     {
2060         sum0 += *avec * *bvec;
2061         ++avec;
2062         ++bvec;
2063     }
2064 
2065     return sum0;
2066 }
2067 
2068 /// ditto
2069 F dotProduct(F, uint N)(const ref scope F[N] a, const ref scope F[N] b)
2070 if (N <= 16)
2071 {
2072     F sum0 = 0;
2073     F sum1 = 0;
2074     static foreach (i; 0 .. N / 2)
2075     {
2076         sum0 += a[i*2] * b[i*2];
2077         sum1 += a[i*2+1] * b[i*2+1];
2078     }
2079     static if (N % 2 == 1)
2080     {
2081         sum0 += a[N-1] * b[N-1];
2082     }
2083     return sum0 + sum1;
2084 }
2085 
2086 @system unittest
2087 {
2088     // @system due to dotProduct and assertCTFEable
2089     import std.exception : assertCTFEable;
2090     import std.meta : AliasSeq;
2091     static foreach (T; AliasSeq!(double, const double, immutable double))
2092     {{
2093         T[] a = [ 1.0, 2.0, ];
2094         T[] b = [ 4.0, 6.0, ];
2095         assert(dotProduct(a, b) == 16);
2096         assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3);
2097         // Test with fixed-length arrays.
2098         T[2] c = [ 1.0, 2.0, ];
2099         T[2] d = [ 4.0, 6.0, ];
2100         assert(dotProduct(c, d) == 16);
2101         T[3] e = [1,  3, -5];
2102         T[3] f = [4, -2, -1];
2103         assert(dotProduct(e, f) == 3);
2104     }}
2105 
2106     // Make sure the unrolled loop codepath gets tested.
2107     static const x =
2108         [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22];
2109     static const y =
2110         [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23];
2111     assertCTFEable!({ assert(dotProduct(x, y) == 4048); });
2112 }
2113 
2114 /**
2115 Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity,
2116 cosine similarity) of input ranges `a` and $(D
2117 b). The two ranges must have the same length. If both ranges define
2118 length, the check is done once; otherwise, it is done at each
2119 iteration. If either range has all-zero elements, return 0.
2120  */
2121 CommonType!(ElementType!(Range1), ElementType!(Range2))
2122 cosineSimilarity(Range1, Range2)(Range1 a, Range2 b)
2123 if (isInputRange!(Range1) && isInputRange!(Range2))
2124 {
2125     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2126     static if (haveLen) assert(a.length == b.length);
2127     Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0;
2128     for (; !a.empty; a.popFront(), b.popFront())
2129     {
2130         immutable t1 = a.front, t2 = b.front;
2131         norma += t1 * t1;
2132         normb += t2 * t2;
2133         dotprod += t1 * t2;
2134     }
2135     static if (!haveLen) assert(b.empty);
2136     if (norma == 0 || normb == 0) return 0;
2137     return dotprod / sqrt(norma * normb);
2138 }
2139 
2140 @safe unittest
2141 {
2142     import std.meta : AliasSeq;
2143     static foreach (T; AliasSeq!(double, const double, immutable double))
2144     {{
2145         T[] a = [ 1.0, 2.0, ];
2146         T[] b = [ 4.0, 3.0, ];
2147         assert(isClose(
2148                     cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25),
2149                     0.01));
2150     }}
2151 }
2152 
2153 /**
2154 Normalizes values in `range` by multiplying each element with a
2155 number chosen such that values sum up to `sum`. If elements in $(D
2156 range) sum to zero, assigns $(D sum / range.length) to
2157 all. Normalization makes sense only if all elements in `range` are
2158 positive. `normalize` assumes that is the case without checking it.
2159 
2160 Returns: `true` if normalization completed normally, `false` if
2161 all elements in `range` were zero or if `range` is empty.
2162  */
2163 bool normalize(R)(R range, ElementType!(R) sum = 1)
2164 if (isForwardRange!(R))
2165 {
2166     ElementType!(R) s = 0;
2167     // Step 1: Compute sum and length of the range
2168     static if (hasLength!(R))
2169     {
2170         const length = range.length;
foreach(e;range)2171         foreach (e; range)
2172         {
2173             s += e;
2174         }
2175     }
2176     else
2177     {
2178         uint length = 0;
foreach(e;range)2179         foreach (e; range)
2180         {
2181             s += e;
2182             ++length;
2183         }
2184     }
2185     // Step 2: perform normalization
2186     if (s == 0)
2187     {
2188         if (length)
2189         {
2190             immutable f = sum / range.length;
2191             foreach (ref e; range) e = f;
2192         }
2193         return false;
2194     }
2195     // The path most traveled
2196     assert(s >= 0);
2197     immutable f = sum / s;
2198     foreach (ref e; range)
2199         e *= f;
2200     return true;
2201 }
2202 
2203 ///
2204 @safe unittest
2205 {
2206     double[] a = [];
2207     assert(!normalize(a));
2208     a = [ 1.0, 3.0 ];
2209     assert(normalize(a));
2210     assert(a == [ 0.25, 0.75 ]);
2211     assert(normalize!(typeof(a))(a, 50)); // a = [12.5, 37.5]
2212     a = [ 0.0, 0.0 ];
2213     assert(!normalize(a));
2214     assert(a == [ 0.5, 0.5 ]);
2215 }
2216 
2217 /**
2218 Compute the sum of binary logarithms of the input range `r`.
2219 The error of this method is much smaller than with a naive sum of log2.
2220  */
2221 ElementType!Range sumOfLog2s(Range)(Range r)
2222 if (isInputRange!Range && isFloatingPoint!(ElementType!Range))
2223 {
2224     long exp = 0;
2225     Unqual!(typeof(return)) x = 1;
foreach(e;r)2226     foreach (e; r)
2227     {
2228         if (e < 0)
2229             return typeof(return).nan;
2230         int lexp = void;
2231         x *= frexp(e, lexp);
2232         exp += lexp;
2233         if (x < 0.5)
2234         {
2235             x *= 2;
2236             exp--;
2237         }
2238     }
2239     return exp + log2(x);
2240 }
2241 
2242 ///
2243 @safe unittest
2244 {
2245     import std.math.traits : isNaN;
2246 
2247     assert(sumOfLog2s(new double[0]) == 0);
2248     assert(sumOfLog2s([0.0L]) == -real.infinity);
2249     assert(sumOfLog2s([-0.0L]) == -real.infinity);
2250     assert(sumOfLog2s([2.0L]) == 1);
2251     assert(sumOfLog2s([-2.0L]).isNaN());
2252     assert(sumOfLog2s([real.nan]).isNaN());
2253     assert(sumOfLog2s([-real.nan]).isNaN());
2254     assert(sumOfLog2s([real.infinity]) == real.infinity);
2255     assert(sumOfLog2s([-real.infinity]).isNaN());
2256     assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9);
2257 }
2258 
2259 /**
2260 Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory),
2261 _entropy) of input range `r` in bits. This
2262 function assumes (without checking) that the values in `r` are all
2263 in $(D [0, 1]). For the entropy to be meaningful, often `r` should
2264 be normalized too (i.e., its values should sum to 1). The
2265 two-parameter version stops evaluating as soon as the intermediate
2266 result is greater than or equal to `max`.
2267  */
2268 ElementType!Range entropy(Range)(Range r)
2269 if (isInputRange!Range)
2270 {
2271     Unqual!(typeof(return)) result = 0.0;
2272     for (;!r.empty; r.popFront)
2273     {
2274         if (!r.front) continue;
2275         result -= r.front * log2(r.front);
2276     }
2277     return result;
2278 }
2279 
2280 /// Ditto
2281 ElementType!Range entropy(Range, F)(Range r, F max)
2282 if (isInputRange!Range &&
2283     !is(CommonType!(ElementType!Range, F) == void))
2284 {
2285     Unqual!(typeof(return)) result = 0.0;
2286     for (;!r.empty; r.popFront)
2287     {
2288         if (!r.front) continue;
2289         result -= r.front * log2(r.front);
2290         if (result >= max) break;
2291     }
2292     return result;
2293 }
2294 
2295 @safe unittest
2296 {
2297     import std.meta : AliasSeq;
2298     static foreach (T; AliasSeq!(double, const double, immutable double))
2299     {{
2300         T[] p = [ 0.0, 0, 0, 1 ];
2301         assert(entropy(p) == 0);
2302         p = [ 0.25, 0.25, 0.25, 0.25 ];
2303         assert(entropy(p) == 2);
2304         assert(entropy(p, 1) == 1);
2305     }}
2306 }
2307 
2308 /**
2309 Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence,
2310 Kullback-Leibler divergence) between input ranges
2311 `a` and `b`, which is the sum $(D ai * log(ai / bi)). The base
2312 of logarithm is 2. The ranges are assumed to contain elements in $(D
2313 [0, 1]). Usually the ranges are normalized probability distributions,
2314 but this is not required or checked by $(D
2315 kullbackLeiblerDivergence). If any element `bi` is zero and the
2316 corresponding element `ai` nonzero, returns infinity. (Otherwise,
2317 if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is
2318 considered zero.) If the inputs are normalized, the result is
2319 positive.
2320  */
2321 CommonType!(ElementType!Range1, ElementType!Range2)
2322 kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b)
2323 if (isInputRange!(Range1) && isInputRange!(Range2))
2324 {
2325     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2326     static if (haveLen) assert(a.length == b.length);
2327     Unqual!(typeof(return)) result = 0;
2328     for (; !a.empty; a.popFront(), b.popFront())
2329     {
2330         immutable t1 = a.front;
2331         if (t1 == 0) continue;
2332         immutable t2 = b.front;
2333         if (t2 == 0) return result.infinity;
2334         assert(t1 > 0 && t2 > 0);
2335         result += t1 * log2(t1 / t2);
2336     }
2337     static if (!haveLen) assert(b.empty);
2338     return result;
2339 }
2340 
2341 ///
2342 @safe unittest
2343 {
2344     import std.math.operations : isClose;
2345 
2346     double[] p = [ 0.0, 0, 0, 1 ];
2347     assert(kullbackLeiblerDivergence(p, p) == 0);
2348     double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ];
2349     assert(kullbackLeiblerDivergence(p1, p1) == 0);
2350     assert(kullbackLeiblerDivergence(p, p1) == 2);
2351     assert(kullbackLeiblerDivergence(p1, p) == double.infinity);
2352     double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ];
2353     assert(isClose(kullbackLeiblerDivergence(p1, p2), 0.0719281, 1e-5));
2354     assert(isClose(kullbackLeiblerDivergence(p2, p1), 0.0780719, 1e-5));
2355 }
2356 
2357 /**
2358 Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence,
2359 Jensen-Shannon divergence) between `a` and $(D
2360 b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 *
2361 bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are
2362 assumed to contain elements in $(D [0, 1]). Usually the ranges are
2363 normalized probability distributions, but this is not required or
2364 checked by `jensenShannonDivergence`. If the inputs are normalized,
2365 the result is bounded within $(D [0, 1]). The three-parameter version
2366 stops evaluations as soon as the intermediate result is greater than
2367 or equal to `limit`.
2368  */
2369 CommonType!(ElementType!Range1, ElementType!Range2)
2370 jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b)
2371 if (isInputRange!Range1 && isInputRange!Range2 &&
2372     is(CommonType!(ElementType!Range1, ElementType!Range2)))
2373 {
2374     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2375     static if (haveLen) assert(a.length == b.length);
2376     Unqual!(typeof(return)) result = 0;
2377     for (; !a.empty; a.popFront(), b.popFront())
2378     {
2379         immutable t1 = a.front;
2380         immutable t2 = b.front;
2381         immutable avg = (t1 + t2) / 2;
2382         if (t1 != 0)
2383         {
2384             result += t1 * log2(t1 / avg);
2385         }
2386         if (t2 != 0)
2387         {
2388             result += t2 * log2(t2 / avg);
2389         }
2390     }
2391     static if (!haveLen) assert(b.empty);
2392     return result / 2;
2393 }
2394 
2395 /// Ditto
2396 CommonType!(ElementType!Range1, ElementType!Range2)
2397 jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit)
2398 if (isInputRange!Range1 && isInputRange!Range2 &&
2399     is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init
2400     >= F.init) : bool))
2401 {
2402     enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2403     static if (haveLen) assert(a.length == b.length);
2404     Unqual!(typeof(return)) result = 0;
2405     limit *= 2;
2406     for (; !a.empty; a.popFront(), b.popFront())
2407     {
2408         immutable t1 = a.front;
2409         immutable t2 = b.front;
2410         immutable avg = (t1 + t2) / 2;
2411         if (t1 != 0)
2412         {
2413             result += t1 * log2(t1 / avg);
2414         }
2415         if (t2 != 0)
2416         {
2417             result += t2 * log2(t2 / avg);
2418         }
2419         if (result >= limit) break;
2420     }
2421     static if (!haveLen) assert(b.empty);
2422     return result / 2;
2423 }
2424 
2425 ///
2426 @safe unittest
2427 {
2428     import std.math.operations : isClose;
2429 
2430     double[] p = [ 0.0, 0, 0, 1 ];
2431     assert(jensenShannonDivergence(p, p) == 0);
2432     double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ];
2433     assert(jensenShannonDivergence(p1, p1) == 0);
2434     assert(isClose(jensenShannonDivergence(p1, p), 0.548795, 1e-5));
2435     double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ];
2436     assert(isClose(jensenShannonDivergence(p1, p2), 0.0186218, 1e-5));
2437     assert(isClose(jensenShannonDivergence(p2, p1), 0.0186218, 1e-5));
2438     assert(isClose(jensenShannonDivergence(p2, p1, 0.005), 0.00602366, 1e-5));
2439 }
2440 
2441 /**
2442 The so-called "all-lengths gap-weighted string kernel" computes a
2443 similarity measure between `s` and `t` based on all of their
2444 common subsequences of all lengths. Gapped subsequences are also
2445 included.
2446 
2447 To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes,
2448 consider first the case $(D lambda = 1) and the strings $(D s =
2449 ["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new",
2450 "world"]). In that case, `gapWeightedSimilarity` counts the
2451 following matches:
2452 
2453 $(OL $(LI three matches of length 1, namely `"Hello"`, `"new"`,
2454 and `"world"`;) $(LI three matches of length 2, namely ($(D
2455 "Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));)
2456 $(LI one match of length 3, namely ($(D "Hello", "new", "world")).))
2457 
2458 The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of
2459 these matches and adds them up, returning 7.
2460 
2461 ----
2462 string[] s = ["Hello", "brave", "new", "world"];
2463 string[] t = ["Hello", "new", "world"];
2464 assert(gapWeightedSimilarity(s, t, 1) == 7);
2465 ----
2466 
2467 Note how the gaps in matching are simply ignored, for example ($(D
2468 "Hello", "new")) is deemed as good a match as ($(D "new",
2469 "world")). This may be too permissive for some applications. To
2470 eliminate gapped matches entirely, use $(D lambda = 0):
2471 
2472 ----
2473 string[] s = ["Hello", "brave", "new", "world"];
2474 string[] t = ["Hello", "new", "world"];
2475 assert(gapWeightedSimilarity(s, t, 0) == 4);
2476 ----
2477 
2478 The call above eliminated the gapped matches ($(D "Hello", "new")),
2479 ($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the
2480 tally. That leaves only 4 matches.
2481 
2482 The most interesting case is when gapped matches still participate in
2483 the result, but not as strongly as ungapped matches. The result will
2484 be a smooth, fine-grained similarity measure between the input
2485 strings. This is where values of `lambda` between 0 and 1 enter
2486 into play: gapped matches are $(I exponentially penalized with the
2487 number of gaps) with base `lambda`. This means that an ungapped
2488 match adds 1 to the return value; a match with one gap in either
2489 string adds `lambda` to the return value; ...; a match with a total
2490 of `n` gaps in both strings adds $(D pow(lambda, n)) to the return
2491 value. In the example above, we have 4 matches without gaps, 2 matches
2492 with one gap, and 1 match with three gaps. The latter match is ($(D
2493 "Hello", "world")), which has two gaps in the first string and one gap
2494 in the second string, totaling to three gaps. Summing these up we get
2495 $(D 4 + 2 * lambda + pow(lambda, 3)).
2496 
2497 ----
2498 string[] s = ["Hello", "brave", "new", "world"];
2499 string[] t = ["Hello", "new", "world"];
2500 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125);
2501 ----
2502 
2503 `gapWeightedSimilarity` is useful wherever a smooth similarity
2504 measure between sequences allowing for approximate matches is
2505 needed. The examples above are given with words, but any sequences
2506 with elements comparable for equality are allowed, e.g. characters or
2507 numbers. `gapWeightedSimilarity` uses a highly optimized dynamic
2508 programming implementation that needs $(D 16 * min(s.length,
2509 t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time
2510 to complete.
2511  */
2512 F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda)
2513 if (isRandomAccessRange!(R1) && hasLength!(R1) &&
2514     isRandomAccessRange!(R2) && hasLength!(R2))
2515 {
2516     import core.exception : onOutOfMemoryError;
2517     import core.stdc.stdlib : malloc, free;
2518     import std.algorithm.mutation : swap;
2519     import std.functional : binaryFun;
2520 
2521     if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda);
2522     if (!t.length) return 0;
2523 
2524     auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length);
2525     if (!dpvi)
2526         onOutOfMemoryError();
2527 
2528     auto dpvi1 = dpvi + t.length;
2529     scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1);
2530     dpvi[0 .. t.length] = 0;
2531     dpvi1[0] = 0;
2532     immutable lambda2 = lambda * lambda;
2533 
2534     F result = 0;
2535     foreach (i; 0 .. s.length)
2536     {
2537         const si = s[i];
2538         for (size_t j = 0;;)
2539         {
2540             F dpsij = void;
2541             if (binaryFun!(comp)(si, t[j]))
2542             {
2543                 dpsij = 1 + dpvi[j];
2544                 result += dpsij;
2545             }
2546             else
2547             {
2548                 dpsij = 0;
2549             }
2550             immutable j1 = j + 1;
2551             if (j1 == t.length) break;
2552             dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) -
2553                         lambda2 * dpvi[j];
2554             j = j1;
2555         }
2556         swap(dpvi, dpvi1);
2557     }
2558     return result;
2559 }
2560 
2561 @system unittest
2562 {
2563     string[] s = ["Hello", "brave", "new", "world"];
2564     string[] t = ["Hello", "new", "world"];
2565     assert(gapWeightedSimilarity(s, t, 1) == 7);
2566     assert(gapWeightedSimilarity(s, t, 0) == 4);
2567     assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125);
2568 }
2569 
2570 /**
2571 The similarity per `gapWeightedSimilarity` has an issue in that it
2572 grows with the lengths of the two strings, even though the strings are
2573 not actually very similar. For example, the range $(D ["Hello",
2574 "world"]) is increasingly similar with the range $(D ["Hello",
2575 "world", "world", "world",...]) as more instances of `"world"` are
2576 appended. To prevent that, `gapWeightedSimilarityNormalized`
2577 computes a normalized version of the similarity that is computed as
2578 $(D gapWeightedSimilarity(s, t, lambda) /
2579 sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t,
2580 lambda))). The function `gapWeightedSimilarityNormalized` (a
2581 so-called normalized kernel) is bounded in $(D [0, 1]), reaches `0`
2582 only for ranges that don't match in any position, and `1` only for
2583 identical ranges.
2584 
2585 The optional parameters `sSelfSim` and `tSelfSim` are meant for
2586 avoiding duplicate computation. Many applications may have already
2587 computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D
2588 gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed
2589 as `sSelfSim` and `tSelfSim`, respectively.
2590  */
2591 Select!(isFloatingPoint!(F), F, double)
2592 gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F)
2593         (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init)
2594 if (isRandomAccessRange!(R1) && hasLength!(R1) &&
2595     isRandomAccessRange!(R2) && hasLength!(R2))
2596 {
uncomputed(F n)2597     static bool uncomputed(F n)
2598     {
2599         static if (isFloatingPoint!(F))
2600             return isNaN(n);
2601         else
2602             return n == n.init;
2603     }
2604     if (uncomputed(sSelfSim))
2605         sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda);
2606     if (sSelfSim == 0) return 0;
2607     if (uncomputed(tSelfSim))
2608         tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda);
2609     if (tSelfSim == 0) return 0;
2610 
2611     return gapWeightedSimilarity!(comp)(s, t, lambda) /
2612            sqrt(cast(typeof(return)) sSelfSim * tSelfSim);
2613 }
2614 
2615 ///
2616 @system unittest
2617 {
2618     import std.math.operations : isClose;
2619     import std.math.algebraic : sqrt;
2620 
2621     string[] s = ["Hello", "brave", "new", "world"];
2622     string[] t = ["Hello", "new", "world"];
2623     assert(gapWeightedSimilarity(s, s, 1) == 15);
2624     assert(gapWeightedSimilarity(t, t, 1) == 7);
2625     assert(gapWeightedSimilarity(s, t, 1) == 7);
2626     assert(isClose(gapWeightedSimilarityNormalized(s, t, 1),
2627                     7.0 / sqrt(15.0 * 7), 0.01));
2628 }
2629 
2630 /**
2631 Similar to `gapWeightedSimilarity`, just works in an incremental
2632 manner by first revealing the matches of length 1, then gapped matches
2633 of length 2, and so on. The memory requirement is $(BIGOH s.length *
2634 t.length). The time complexity is $(BIGOH s.length * t.length) time
2635 for computing each step. Continuing on the previous example:
2636 
2637 The implementation is based on the pseudocode in Fig. 4 of the paper
2638 $(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf,
2639 "Efficient Computation of Gapped Substring Kernels on Large Alphabets")
2640 by Rousu et al., with additional algorithmic and systems-level
2641 optimizations.
2642  */
2643 struct GapWeightedSimilarityIncremental(Range, F = double)
2644 if (isRandomAccessRange!(Range) && hasLength!(Range))
2645 {
2646     import core.stdc.stdlib : malloc, realloc, alloca, free;
2647 
2648 private:
2649     Range s, t;
2650     F currentValue = 0;
2651     F* kl;
2652     size_t gram = void;
2653     F lambda = void, lambda2 = void;
2654 
2655 public:
2656 /**
2657 Constructs an object given two ranges `s` and `t` and a penalty
2658 `lambda`. Constructor completes in $(BIGOH s.length * t.length)
2659 time and computes all matches of length 1.
2660  */
thisGapWeightedSimilarityIncremental2661     this(Range s, Range t, F lambda)
2662     {
2663         import core.exception : onOutOfMemoryError;
2664 
2665         assert(lambda > 0);
2666         this.gram = 0;
2667         this.lambda = lambda;
2668         this.lambda2 = lambda * lambda; // for efficiency only
2669 
2670         size_t iMin = size_t.max, jMin = size_t.max,
2671             iMax = 0, jMax = 0;
2672         /* initialize */
2673         Tuple!(size_t, size_t) * k0;
2674         size_t k0len;
2675         scope(exit) free(k0);
2676         currentValue = 0;
2677         foreach (i, si; s)
2678         {
2679             foreach (j; 0 .. t.length)
2680             {
2681                 if (si != t[j]) continue;
2682                 k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof);
2683                 with (k0[k0len - 1])
2684                 {
2685                     field[0] = i;
2686                     field[1] = j;
2687                 }
2688                 // Maintain the minimum and maximum i and j
2689                 if (iMin > i) iMin = i;
2690                 if (iMax < i) iMax = i;
2691                 if (jMin > j) jMin = j;
2692                 if (jMax < j) jMax = j;
2693             }
2694         }
2695 
2696         if (iMin > iMax) return;
2697         assert(k0len);
2698 
2699         currentValue = k0len;
2700         // Chop strings down to the useful sizes
2701         s = s[iMin .. iMax + 1];
2702         t = t[jMin .. jMax + 1];
2703         this.s = s;
2704         this.t = t;
2705 
2706         kl = cast(F*) malloc(s.length * t.length * F.sizeof);
2707         if (!kl)
2708             onOutOfMemoryError();
2709 
2710         kl[0 .. s.length * t.length] = 0;
2711         foreach (pos; 0 .. k0len)
2712         {
2713             with (k0[pos])
2714             {
2715                 kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2;
2716             }
2717         }
2718     }
2719 
2720     /**
2721     Returns: `this`.
2722      */
opSliceGapWeightedSimilarityIncremental2723     ref GapWeightedSimilarityIncremental opSlice()
2724     {
2725         return this;
2726     }
2727 
2728     /**
2729     Computes the match of the popFront length. Completes in $(BIGOH s.length *
2730     t.length) time.
2731      */
popFrontGapWeightedSimilarityIncremental2732     void popFront()
2733     {
2734         import std.algorithm.mutation : swap;
2735 
2736         // This is a large source of optimization: if similarity at
2737         // the gram-1 level was 0, then we can safely assume
2738         // similarity at the gram level is 0 as well.
2739         if (empty) return;
2740 
2741         // Now attempt to match gapped substrings of length `gram'
2742         ++gram;
2743         currentValue = 0;
2744 
2745         auto Si = cast(F*) alloca(t.length * F.sizeof);
2746         Si[0 .. t.length] = 0;
2747         foreach (i; 0 .. s.length)
2748         {
2749             const si = s[i];
2750             F Sij_1 = 0;
2751             F Si_1j_1 = 0;
2752             auto kli = kl + i * t.length;
2753             for (size_t j = 0;;)
2754             {
2755                 const klij = kli[j];
2756                 const Si_1j = Si[j];
2757                 const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1;
2758                 // now update kl and currentValue
2759                 if (si == t[j])
2760                     currentValue += kli[j] = lambda2 * Si_1j_1;
2761                 else
2762                     kli[j] = 0;
2763                 // commit to Si
2764                 Si[j] = tmp;
2765                 if (++j == t.length) break;
2766                 // get ready for the popFront step; virtually increment j,
2767                 // so essentially stuffj_1 <-- stuffj
2768                 Si_1j_1 = Si_1j;
2769                 Sij_1 = tmp;
2770             }
2771         }
2772         currentValue /= pow(lambda, 2 * (gram + 1));
2773 
2774         version (none)
2775         {
2776             Si_1[0 .. t.length] = 0;
2777             kl[0 .. min(t.length, maxPerimeter + 1)] = 0;
2778             foreach (i; 1 .. min(s.length, maxPerimeter + 1))
2779             {
2780                 auto kli = kl + i * t.length;
2781                 assert(s.length > i);
2782                 const si = s[i];
2783                 auto kl_1i_1 = kl_1 + (i - 1) * t.length;
2784                 kli[0] = 0;
2785                 F lastS = 0;
2786                 foreach (j; 1 .. min(maxPerimeter - i + 1, t.length))
2787                 {
2788                     immutable j_1 = j - 1;
2789                     immutable tmp = kl_1i_1[j_1]
2790                         + lambda * (Si_1[j] + lastS)
2791                         - lambda2 * Si_1[j_1];
2792                     kl_1i_1[j_1] = float.nan;
2793                     Si_1[j_1] = lastS;
2794                     lastS = tmp;
2795                     if (si == t[j])
2796                     {
2797                         currentValue += kli[j] = lambda2 * lastS;
2798                     }
2799                     else
2800                     {
2801                         kli[j] = 0;
2802                     }
2803                 }
2804                 Si_1[t.length - 1] = lastS;
2805             }
2806             currentValue /= pow(lambda, 2 * (gram + 1));
2807             // get ready for the popFront computation
2808             swap(kl, kl_1);
2809         }
2810     }
2811 
2812     /**
2813     Returns: The gapped similarity at the current match length (initially
2814     1, grows with each call to `popFront`).
2815     */
frontGapWeightedSimilarityIncremental2816     @property F front() { return currentValue; }
2817 
2818     /**
2819     Returns: Whether there are more matches.
2820      */
emptyGapWeightedSimilarityIncremental2821     @property bool empty()
2822     {
2823         if (currentValue) return false;
2824         if (kl)
2825         {
2826             free(kl);
2827             kl = null;
2828         }
2829         return true;
2830     }
2831 }
2832 
2833 /**
2834 Ditto
2835  */
2836 GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F)
2837 (R r1, R r2, F penalty)
2838 {
2839     return typeof(return)(r1, r2, penalty);
2840 }
2841 
2842 ///
2843 @system unittest
2844 {
2845     string[] s = ["Hello", "brave", "new", "world"];
2846     string[] t = ["Hello", "new", "world"];
2847     auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0);
2848     assert(simIter.front == 3); // three 1-length matches
2849     simIter.popFront();
2850     assert(simIter.front == 3); // three 2-length matches
2851     simIter.popFront();
2852     assert(simIter.front == 1); // one 3-length match
2853     simIter.popFront();
2854     assert(simIter.empty);     // no more match
2855 }
2856 
2857 @system unittest
2858 {
2859     import std.conv : text;
2860     string[] s = ["Hello", "brave", "new", "world"];
2861     string[] t = ["Hello", "new", "world"];
2862     auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0);
2863     //foreach (e; simIter) writeln(e);
2864     assert(simIter.front == 3); // three 1-length matches
2865     simIter.popFront();
2866     assert(simIter.front == 3, text(simIter.front)); // three 2-length matches
2867     simIter.popFront();
2868     assert(simIter.front == 1); // one 3-length matches
2869     simIter.popFront();
2870     assert(simIter.empty);     // no more match
2871 
2872     s = ["Hello"];
2873     t = ["bye"];
2874     simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2875     assert(simIter.empty);
2876 
2877     s = ["Hello"];
2878     t = ["Hello"];
2879     simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2880     assert(simIter.front == 1); // one match
2881     simIter.popFront();
2882     assert(simIter.empty);
2883 
2884     s = ["Hello", "world"];
2885     t = ["Hello"];
2886     simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2887     assert(simIter.front == 1); // one match
2888     simIter.popFront();
2889     assert(simIter.empty);
2890 
2891     s = ["Hello", "world"];
2892     t = ["Hello", "yah", "world"];
2893     simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2894     assert(simIter.front == 2); // two 1-gram matches
2895     simIter.popFront();
2896     assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap
2897 }
2898 
2899 @system unittest
2900 {
2901     GapWeightedSimilarityIncremental!(string[]) sim =
2902         GapWeightedSimilarityIncremental!(string[])(
2903             ["nyuk", "I", "have", "no", "chocolate", "giba"],
2904             ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"],
2905             0.5);
2906     double[] witness = [ 7.0, 4.03125, 0, 0 ];
foreach(e;sim)2907     foreach (e; sim)
2908     {
2909         //writeln(e);
2910         assert(e == witness.front);
2911         witness.popFront();
2912     }
2913     witness = [ 3.0, 1.3125, 0.25 ];
2914     sim = GapWeightedSimilarityIncremental!(string[])(
2915         ["I", "have", "no", "chocolate"],
2916         ["I", "have", "some", "chocolate"],
2917         0.5);
foreach(e;sim)2918     foreach (e; sim)
2919     {
2920         //writeln(e);
2921         assert(e == witness.front);
2922         witness.popFront();
2923     }
2924     assert(witness.empty);
2925 }
2926 
2927 /**
2928 Computes the greatest common divisor of `a` and `b` by using
2929 an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's)
2930 or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm.
2931 
2932 Params:
2933     a = Integer value of any numerical type that supports the modulo operator `%`.
2934         If bit-shifting `<<` and `>>` are also supported, Stein's algorithm will
2935         be used; otherwise, Euclid's algorithm is used as _a fallback.
2936     b = Integer value of any equivalent numerical type.
2937 
2938 Returns:
2939     The greatest common divisor of the given arguments.
2940  */
2941 typeof(Unqual!(T).init % Unqual!(U).init) gcd(T, U)(T a, U b)
2942 if (isIntegral!T && isIntegral!U)
2943 {
2944     // Operate on a common type between the two arguments.
2945     alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U));
2946 
2947     // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined.
2948     static if (is(T : immutable short) || is(T : immutable byte))
2949         UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a);
2950     else
2951         UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a);
2952 
2953     static if (is(U : immutable short) || is(U : immutable byte))
2954         UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b);
2955     else
2956         UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b);
2957 
2958     // Special cases.
2959     if (ax == 0)
2960         return bx;
2961     if (bx == 0)
2962         return ax;
2963 
2964     return gcdImpl(ax, bx);
2965 }
2966 
2967 private typeof(T.init % T.init) gcdImpl(T)(T a, T b)
2968 if (isIntegral!T)
2969 {
2970     pragma(inline, true);
2971     import core.bitop : bsf;
2972     import std.algorithm.mutation : swap;
2973 
2974     immutable uint shift = bsf(a | b);
2975     a >>= a.bsf;
2976     do
2977     {
2978         b >>= b.bsf;
2979         if (a > b)
2980             swap(a, b);
2981         b -= a;
2982     } while (b);
2983 
2984     return a << shift;
2985 }
2986 
2987 ///
2988 @safe unittest
2989 {
2990     assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7);
2991     const int a = 5 * 13 * 23 * 23, b = 13 * 59;
2992     assert(gcd(a, b) == 13);
2993 }
2994 
2995 @safe unittest
2996 {
2997     import std.meta : AliasSeq;
2998     static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
2999                                  const byte, const short, const int, const long,
3000                                  immutable ubyte, immutable ushort, immutable uint, immutable ulong))
3001     {
3002         static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
3003                                      const ubyte, const ushort, const uint, const ulong,
3004                                      immutable byte, immutable short, immutable int, immutable long))
3005         {
3006             // Signed and unsigned tests.
3007             static if (T.max > byte.max && U.max > byte.max)
3008                 assert(gcd(T(200), U(200)) == 200);
3009             static if (T.max > ubyte.max)
3010             {
3011                 assert(gcd(T(2000), U(20))  == 20);
3012                 assert(gcd(T(2011), U(17))  == 1);
3013             }
3014             static if (T.max > ubyte.max && U.max > ubyte.max)
3015                 assert(gcd(T(1071), U(462)) == 21);
3016 
3017             assert(gcd(T(0),   U(13))  == 13);
3018             assert(gcd(T(29),  U(0))   == 29);
3019             assert(gcd(T(0),   U(0))   == 0);
3020             assert(gcd(T(1),   U(2))   == 1);
3021             assert(gcd(T(9),   U(6))   == 3);
3022             assert(gcd(T(3),   U(4))   == 1);
3023             assert(gcd(T(32),  U(24))  == 8);
3024             assert(gcd(T(5),   U(6))   == 1);
3025             assert(gcd(T(54),  U(36))  == 18);
3026 
3027             // Int and Long tests.
3028             static if (T.max > short.max && U.max > short.max)
3029                 assert(gcd(T(46391), U(62527)) == 2017);
3030             static if (T.max > ushort.max && U.max > ushort.max)
3031                 assert(gcd(T(63245986), U(39088169)) == 1);
3032             static if (T.max > uint.max && U.max > uint.max)
3033             {
3034                 assert(gcd(T(77160074263), U(47687519812)) == 1);
3035                 assert(gcd(T(77160074264), U(47687519812)) == 4);
3036             }
3037 
3038             // Negative tests.
3039             static if (T.min < 0)
3040             {
3041                 assert(gcd(T(-21), U(28)) == 7);
3042                 assert(gcd(T(-3),  U(4))  == 1);
3043             }
3044             static if (U.min < 0)
3045             {
3046                 assert(gcd(T(1),  U(-2))  == 1);
3047                 assert(gcd(T(33), U(-44)) == 11);
3048             }
3049             static if (T.min < 0 && U.min < 0)
3050             {
3051                 assert(gcd(T(-5),  U(-6))  == 1);
3052                 assert(gcd(T(-50), U(-60)) == 10);
3053             }
3054         }
3055     }
3056 }
3057 
3058 // https://issues.dlang.org/show_bug.cgi?id=21834
3059 @safe unittest
3060 {
3061     assert(gcd(-120, 10U) == 10);
3062     assert(gcd(120U, -10) == 10);
3063     assert(gcd(int.min, 0L) == 1L + int.max);
3064     assert(gcd(0L, int.min) == 1L + int.max);
3065     assert(gcd(int.min, 0L + int.min) == 1L + int.max);
3066     assert(gcd(int.min, 1L + int.max) == 1L + int.max);
3067     assert(gcd(short.min, 1U + short.max) == 1U + short.max);
3068 }
3069 
3070 // This overload is for non-builtin numerical types like BigInt or
3071 // user-defined types.
3072 /// ditto
3073 auto gcd(T)(T a, T b)
3074 if (!isIntegral!T &&
3075         is(typeof(T.init % T.init)) &&
3076         is(typeof(T.init == 0 || T.init > 0)))
3077 {
3078     static if (!is(T == Unqual!T))
3079     {
3080         return gcd!(Unqual!T)(a, b);
3081     }
3082     else
3083     {
3084         // Ensure arguments are unsigned.
3085         a = a >= 0 ? a : -a;
3086         b = b >= 0 ? b : -b;
3087 
3088         // Special cases.
3089         if (a == 0)
3090             return b;
3091         if (b == 0)
3092             return a;
3093 
3094         return gcdImpl(a, b);
3095     }
3096 }
3097 
3098 private auto gcdImpl(T)(T a, T b)
3099 if (!isIntegral!T)
3100 {
3101     pragma(inline, true);
3102     import std.algorithm.mutation : swap;
3103     enum canUseBinaryGcd = is(typeof(() {
3104         T t, u;
3105         t <<= 1;
3106         t >>= 1;
3107         t -= u;
3108         bool b = (t & 1) == 0;
3109         swap(t, u);
3110     }));
3111 
3112     static if (canUseBinaryGcd)
3113     {
3114         uint shift = 0;
3115         while ((a & 1) == 0 && (b & 1) == 0)
3116         {
3117             a >>= 1;
3118             b >>= 1;
3119             shift++;
3120         }
3121 
3122         if ((a & 1) == 0) swap(a, b);
3123 
3124         do
3125         {
3126             assert((a & 1) != 0);
3127             while ((b & 1) == 0)
3128                 b >>= 1;
3129             if (a > b)
3130                 swap(a, b);
3131             b -= a;
3132         } while (b);
3133 
3134         return a << shift;
3135     }
3136     else
3137     {
3138         // The only thing we have is %; fallback to Euclidean algorithm.
3139         while (b != 0)
3140         {
3141             auto t = b;
3142             b = a % b;
3143             a = t;
3144         }
3145         return a;
3146     }
3147 }
3148 
3149 // https://issues.dlang.org/show_bug.cgi?id=7102
3150 @system pure unittest
3151 {
3152     import std.bigint : BigInt;
3153     assert(gcd(BigInt("71_000_000_000_000_000_000"),
3154                BigInt("31_000_000_000_000_000_000")) ==
3155            BigInt("1_000_000_000_000_000_000"));
3156 
3157     assert(gcd(BigInt(0), BigInt(1234567)) == BigInt(1234567));
3158     assert(gcd(BigInt(1234567), BigInt(0)) == BigInt(1234567));
3159 }
3160 
3161 @safe pure nothrow unittest
3162 {
3163     // A numerical type that only supports % and - (to force gcd implementation
3164     // to use Euclidean algorithm).
3165     struct CrippledInt
3166     {
3167         int impl;
3168         CrippledInt opBinary(string op : "%")(CrippledInt i)
3169         {
3170             return CrippledInt(impl % i.impl);
3171         }
3172         CrippledInt opUnary(string op : "-")()
3173         {
3174             return CrippledInt(-impl);
3175         }
opEqualsCrippledInt3176         int opEquals(CrippledInt i) { return impl == i.impl; }
opEqualsCrippledInt3177         int opEquals(int i) { return impl == i; }
opCmpCrippledInt3178         int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; }
3179     }
3180     assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77));
3181     assert(gcd(CrippledInt(-120), CrippledInt(10U)) == CrippledInt(10));
3182     assert(gcd(CrippledInt(120U), CrippledInt(-10)) == CrippledInt(10));
3183 }
3184 
3185 // https://issues.dlang.org/show_bug.cgi?id=19514
3186 @system pure unittest
3187 {
3188     import std.bigint : BigInt;
3189     assert(gcd(BigInt(2), BigInt(1)) == BigInt(1));
3190 }
3191 
3192 // Issue 20924
3193 @safe unittest
3194 {
3195     import std.bigint : BigInt;
3196     const a = BigInt("123143238472389492934020");
3197     const b = BigInt("902380489324729338420924");
3198     assert(__traits(compiles, gcd(a, b)));
3199 }
3200 
3201 // https://issues.dlang.org/show_bug.cgi?id=21834
3202 @safe unittest
3203 {
3204     import std.bigint : BigInt;
3205     assert(gcd(BigInt(-120), BigInt(10U)) == BigInt(10));
3206     assert(gcd(BigInt(120U), BigInt(-10)) == BigInt(10));
3207     assert(gcd(BigInt(int.min), BigInt(0L)) == BigInt(1L + int.max));
3208     assert(gcd(BigInt(0L), BigInt(int.min)) == BigInt(1L + int.max));
3209     assert(gcd(BigInt(int.min), BigInt(0L + int.min)) == BigInt(1L + int.max));
3210     assert(gcd(BigInt(int.min), BigInt(1L + int.max)) == BigInt(1L + int.max));
3211     assert(gcd(BigInt(short.min), BigInt(1U + short.max)) == BigInt(1U + short.max));
3212 }
3213 
3214 
3215 /**
3216 Computes the least common multiple of `a` and `b`.
3217 Arguments are the same as $(MYREF gcd).
3218 
3219 Returns:
3220     The least common multiple of the given arguments.
3221  */
3222 typeof(Unqual!(T).init % Unqual!(U).init) lcm(T, U)(T a, U b)
3223 if (isIntegral!T && isIntegral!U)
3224 {
3225     // Operate on a common type between the two arguments.
3226     alias UCT = Unsigned!(CommonType!(Unqual!T, Unqual!U));
3227 
3228     // `std.math.abs` doesn't support unsigned integers, and `T.min` is undefined.
3229     static if (is(T : immutable short) || is(T : immutable byte))
3230         UCT ax = (isUnsigned!T || a >= 0) ? a : cast(UCT) -int(a);
3231     else
3232         UCT ax = (isUnsigned!T || a >= 0) ? a : -UCT(a);
3233 
3234     static if (is(U : immutable short) || is(U : immutable byte))
3235         UCT bx = (isUnsigned!U || b >= 0) ? b : cast(UCT) -int(b);
3236     else
3237         UCT bx = (isUnsigned!U || b >= 0) ? b : -UCT(b);
3238 
3239     // Special cases.
3240     if (ax == 0)
3241         return ax;
3242     if (bx == 0)
3243         return bx;
3244 
3245     return (ax / gcdImpl(ax, bx)) * bx;
3246 }
3247 
3248 ///
3249 @safe unittest
3250 {
3251     assert(lcm(1, 2) == 2);
3252     assert(lcm(3, 4) == 12);
3253     assert(lcm(5, 6) == 30);
3254 }
3255 
3256 @safe unittest
3257 {
3258     import std.meta : AliasSeq;
3259     static foreach (T; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
3260                                  const byte, const short, const int, const long,
3261                                  immutable ubyte, immutable ushort, immutable uint, immutable ulong))
3262     {
3263         static foreach (U; AliasSeq!(byte, ubyte, short, ushort, int, uint, long, ulong,
3264                                      const ubyte, const ushort, const uint, const ulong,
3265                                      immutable byte, immutable short, immutable int, immutable long))
3266         {
3267             assert(lcm(T(21), U(6))  == 42);
3268             assert(lcm(T(41), U(0))  == 0);
3269             assert(lcm(T(0),  U(7))  == 0);
3270             assert(lcm(T(0),  U(0))  == 0);
3271             assert(lcm(T(1U), U(2))  == 2);
3272             assert(lcm(T(3),  U(4U)) == 12);
3273             assert(lcm(T(5U), U(6U)) == 30);
3274             static if (T.min < 0)
3275                 assert(lcm(T(-42), U(21U)) == 42);
3276         }
3277     }
3278 }
3279 
3280 /// ditto
3281 auto lcm(T)(T a, T b)
3282 if (!isIntegral!T &&
3283         is(typeof(T.init % T.init)) &&
3284         is(typeof(T.init == 0 || T.init > 0)))
3285 {
3286     // Ensure arguments are unsigned.
3287     a = a >= 0 ? a : -a;
3288     b = b >= 0 ? b : -b;
3289 
3290     // Special cases.
3291     if (a == 0)
3292         return a;
3293     if (b == 0)
3294         return b;
3295 
3296     return (a / gcdImpl(a, b)) * b;
3297 }
3298 
3299 @safe unittest
3300 {
3301     import std.bigint : BigInt;
3302     assert(lcm(BigInt(21),  BigInt(6))   == BigInt(42));
3303     assert(lcm(BigInt(41),  BigInt(0))   == BigInt(0));
3304     assert(lcm(BigInt(0),   BigInt(7))   == BigInt(0));
3305     assert(lcm(BigInt(0),   BigInt(0))   == BigInt(0));
3306     assert(lcm(BigInt(1U),  BigInt(2))   == BigInt(2));
3307     assert(lcm(BigInt(3),   BigInt(4U))  == BigInt(12));
3308     assert(lcm(BigInt(5U),  BigInt(6U))  == BigInt(30));
3309     assert(lcm(BigInt(-42), BigInt(21U)) == BigInt(42));
3310 }
3311 
3312 // This is to make tweaking the speed/size vs. accuracy tradeoff easy,
3313 // though floats seem accurate enough for all practical purposes, since
3314 // they pass the "isClose(inverseFft(fft(arr)), arr)" test even for
3315 // size 2 ^^ 22.
3316 private alias lookup_t = float;
3317 
3318 /**A class for performing fast Fourier transforms of power of two sizes.
3319  * This class encapsulates a large amount of state that is reusable when
3320  * performing multiple FFTs of sizes smaller than or equal to that specified
3321  * in the constructor.  This results in substantial speedups when performing
3322  * multiple FFTs with a known maximum size.  However,
3323  * a free function API is provided for convenience if you need to perform a
3324  * one-off FFT.
3325  *
3326  * References:
3327  * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm)
3328  */
3329 final class Fft
3330 {
3331     import core.bitop : bsf;
3332     import std.algorithm.iteration : map;
3333     import std.array : uninitializedArray;
3334 
3335 private:
3336     immutable lookup_t[][] negSinLookup;
3337 
enforceSize(R)3338     void enforceSize(R)(R range) const
3339     {
3340         import std.conv : text;
3341         assert(range.length <= size, text(
3342             "FFT size mismatch.  Expected ", size, ", got ", range.length));
3343     }
3344 
fftImpl(Ret,R)3345     void fftImpl(Ret, R)(Stride!R range, Ret buf) const
3346     in
3347     {
3348         assert(range.length >= 4);
3349         assert(isPowerOf2(range.length));
3350     }
3351     do
3352     {
3353         auto recurseRange = range;
3354         recurseRange.doubleSteps();
3355 
3356         if (buf.length > 4)
3357         {
3358             fftImpl(recurseRange, buf[0..$ / 2]);
3359             recurseRange.popHalf();
3360             fftImpl(recurseRange, buf[$ / 2..$]);
3361         }
3362         else
3363         {
3364             // Do this here instead of in another recursion to save on
3365             // recursion overhead.
3366             slowFourier2(recurseRange, buf[0..$ / 2]);
3367             recurseRange.popHalf();
3368             slowFourier2(recurseRange, buf[$ / 2..$]);
3369         }
3370 
3371         butterfly(buf);
3372     }
3373 
3374     // This algorithm works by performing the even and odd parts of our FFT
3375     // using the "two for the price of one" method mentioned at
3376     // http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521
3377     // by making the odd terms into the imaginary components of our new FFT,
3378     // and then using symmetry to recombine them.
fftImplPureReal(Ret,R)3379     void fftImplPureReal(Ret, R)(R range, Ret buf) const
3380     in
3381     {
3382         assert(range.length >= 4);
3383         assert(isPowerOf2(range.length));
3384     }
3385     do
3386     {
3387         alias E = ElementType!R;
3388 
3389         // Converts odd indices of range to the imaginary components of
3390         // a range half the size.  The even indices become the real components.
3391         static if (isArray!R && isFloatingPoint!E)
3392         {
3393             // Then the memory layout of complex numbers provides a dirt
3394             // cheap way to convert.  This is a common case, so take advantage.
3395             auto oddsImag = cast(Complex!E[]) range;
3396         }
3397         else
3398         {
3399             // General case:  Use a higher order range.  We can assume
3400             // source.length is even because it has to be a power of 2.
3401             static struct OddToImaginary
3402             {
3403                 R source;
3404                 alias C = Complex!(CommonType!(E, typeof(buf[0].re)));
3405 
3406                 @property
3407                 {
frontOddToImaginary3408                     C front()
3409                     {
3410                         return C(source[0], source[1]);
3411                     }
3412 
backOddToImaginary3413                     C back()
3414                     {
3415                         immutable n = source.length;
3416                         return C(source[n - 2], source[n - 1]);
3417                     }
3418 
saveOddToImaginary3419                     typeof(this) save()
3420                     {
3421                         return typeof(this)(source.save);
3422                     }
3423 
emptyOddToImaginary3424                     bool empty()
3425                     {
3426                         return source.empty;
3427                     }
3428 
lengthOddToImaginary3429                     size_t length()
3430                     {
3431                         return source.length / 2;
3432                     }
3433                 }
3434 
popFrontOddToImaginary3435                 void popFront()
3436                 {
3437                     source.popFront();
3438                     source.popFront();
3439                 }
3440 
popBackOddToImaginary3441                 void popBack()
3442                 {
3443                     source.popBack();
3444                     source.popBack();
3445                 }
3446 
opIndexOddToImaginary3447                 C opIndex(size_t index)
3448                 {
3449                     return C(source[index * 2], source[index * 2 + 1]);
3450                 }
3451 
opSliceOddToImaginary3452                 typeof(this) opSlice(size_t lower, size_t upper)
3453                 {
3454                     return typeof(this)(source[lower * 2 .. upper * 2]);
3455                 }
3456             }
3457 
3458             auto oddsImag = OddToImaginary(range);
3459         }
3460 
3461         fft(oddsImag, buf[0..$ / 2]);
3462         auto evenFft = buf[0..$ / 2];
3463         auto oddFft = buf[$ / 2..$];
3464         immutable halfN = evenFft.length;
3465         oddFft[0].re = buf[0].im;
3466         oddFft[0].im = 0;
3467         evenFft[0].im = 0;
3468         // evenFft[0].re is already right b/c it's aliased with buf[0].re.
3469 
3470         foreach (k; 1 .. halfN / 2 + 1)
3471         {
3472             immutable bufk = buf[k];
3473             immutable bufnk = buf[buf.length / 2 - k];
3474             evenFft[k].re = 0.5 * (bufk.re + bufnk.re);
3475             evenFft[halfN - k].re = evenFft[k].re;
3476             evenFft[k].im = 0.5 * (bufk.im - bufnk.im);
3477             evenFft[halfN - k].im = -evenFft[k].im;
3478 
3479             oddFft[k].re = 0.5 * (bufk.im + bufnk.im);
3480             oddFft[halfN - k].re = oddFft[k].re;
3481             oddFft[k].im = 0.5 * (bufnk.re - bufk.re);
3482             oddFft[halfN - k].im = -oddFft[k].im;
3483         }
3484 
3485         butterfly(buf);
3486     }
3487 
butterfly(R)3488     void butterfly(R)(R buf) const
3489     in
3490     {
3491         assert(isPowerOf2(buf.length));
3492     }
3493     do
3494     {
3495         immutable n = buf.length;
3496         immutable localLookup = negSinLookup[bsf(n)];
3497         assert(localLookup.length == n);
3498 
3499         immutable cosMask = n - 1;
3500         immutable cosAdd = n / 4 * 3;
3501 
negSinFromLookup(size_t index)3502         lookup_t negSinFromLookup(size_t index) pure nothrow
3503         {
3504             return localLookup[index];
3505         }
3506 
cosFromLookup(size_t index)3507         lookup_t cosFromLookup(size_t index) pure nothrow
3508         {
3509             // cos is just -sin shifted by PI * 3 / 2.
3510             return localLookup[(index + cosAdd) & cosMask];
3511         }
3512 
3513         immutable halfLen = n / 2;
3514 
3515         // This loop is unrolled and the two iterations are interleaved
3516         // relative to the textbook FFT to increase ILP.  This gives roughly 5%
3517         // speedups on DMD.
3518         for (size_t k = 0; k < halfLen; k += 2)
3519         {
3520             immutable cosTwiddle1 = cosFromLookup(k);
3521             immutable sinTwiddle1 = negSinFromLookup(k);
3522             immutable cosTwiddle2 = cosFromLookup(k + 1);
3523             immutable sinTwiddle2 = negSinFromLookup(k + 1);
3524 
3525             immutable realLower1 = buf[k].re;
3526             immutable imagLower1 = buf[k].im;
3527             immutable realLower2 = buf[k + 1].re;
3528             immutable imagLower2 = buf[k + 1].im;
3529 
3530             immutable upperIndex1 = k + halfLen;
3531             immutable upperIndex2 = upperIndex1 + 1;
3532             immutable realUpper1 = buf[upperIndex1].re;
3533             immutable imagUpper1 = buf[upperIndex1].im;
3534             immutable realUpper2 = buf[upperIndex2].re;
3535             immutable imagUpper2 = buf[upperIndex2].im;
3536 
3537             immutable realAdd1 = cosTwiddle1 * realUpper1
3538                                - sinTwiddle1 * imagUpper1;
3539             immutable imagAdd1 = sinTwiddle1 * realUpper1
3540                                + cosTwiddle1 * imagUpper1;
3541             immutable realAdd2 = cosTwiddle2 * realUpper2
3542                                - sinTwiddle2 * imagUpper2;
3543             immutable imagAdd2 = sinTwiddle2 * realUpper2
3544                                + cosTwiddle2 * imagUpper2;
3545 
3546             buf[k].re += realAdd1;
3547             buf[k].im += imagAdd1;
3548             buf[k + 1].re += realAdd2;
3549             buf[k + 1].im += imagAdd2;
3550 
3551             buf[upperIndex1].re = realLower1 - realAdd1;
3552             buf[upperIndex1].im = imagLower1 - imagAdd1;
3553             buf[upperIndex2].re = realLower2 - realAdd2;
3554             buf[upperIndex2].im = imagLower2 - imagAdd2;
3555         }
3556     }
3557 
3558     // This constructor is used within this module for allocating the
3559     // buffer space elsewhere besides the GC heap.  It's definitely **NOT**
3560     // part of the public API and definitely **IS** subject to change.
3561     //
3562     // Also, this is unsafe because the memSpace buffer will be cast
3563     // to immutable.
3564     //
3565     // Public b/c of https://issues.dlang.org/show_bug.cgi?id=4636.
this(lookup_t[]memSpace)3566     public this(lookup_t[] memSpace)
3567     {
3568         immutable size = memSpace.length / 2;
3569 
3570         /* Create a lookup table of all negative sine values at a resolution of
3571          * size and all smaller power of two resolutions.  This may seem
3572          * inefficient, but having all the lookups be next to each other in
3573          * memory at every level of iteration is a huge win performance-wise.
3574          */
3575         if (size == 0)
3576         {
3577             return;
3578         }
3579 
3580         assert(isPowerOf2(size),
3581             "Can only do FFTs on ranges with a size that is a power of two.");
3582 
3583         auto table = new lookup_t[][bsf(size) + 1];
3584 
3585         table[$ - 1] = memSpace[$ - size..$];
3586         memSpace = memSpace[0 .. size];
3587 
3588         auto lastRow = table[$ - 1];
3589         lastRow[0] = 0;  // -sin(0) == 0.
3590         foreach (ptrdiff_t i; 1 .. size)
3591         {
3592             // The hard coded cases are for improved accuracy and to prevent
3593             // annoying non-zeroness when stuff should be zero.
3594 
3595             if (i == size / 4)
3596                 lastRow[i] = -1;  // -sin(pi / 2) == -1.
3597             else if (i == size / 2)
3598                 lastRow[i] = 0;   // -sin(pi) == 0.
3599             else if (i == size * 3 / 4)
3600                 lastRow[i] = 1;  // -sin(pi * 3 / 2) == 1
3601             else
3602                 lastRow[i] = -sin(i * 2.0L * PI / size);
3603         }
3604 
3605         // Fill in all the other rows with strided versions.
3606         foreach (i; 1 .. table.length - 1)
3607         {
3608             immutable strideLength = size / (2 ^^ i);
3609             auto strided = Stride!(lookup_t[])(lastRow, strideLength);
3610             table[i] = memSpace[$ - strided.length..$];
3611             memSpace = memSpace[0..$ - strided.length];
3612 
3613             size_t copyIndex;
3614             foreach (elem; strided)
3615             {
3616                 table[i][copyIndex++] = elem;
3617             }
3618         }
3619 
3620         negSinLookup = cast(immutable) table;
3621     }
3622 
3623 public:
3624     /**Create an `Fft` object for computing fast Fourier transforms of
3625      * power of two sizes of `size` or smaller.  `size` must be a
3626      * power of two.
3627      */
this(size_t size)3628     this(size_t size)
3629     {
3630         // Allocate all twiddle factor buffers in one contiguous block so that,
3631         // when one is done being used, the next one is next in cache.
3632         auto memSpace = uninitializedArray!(lookup_t[])(2 * size);
3633         this(memSpace);
3634     }
3635 
size()3636     @property size_t size() const
3637     {
3638         return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length;
3639     }
3640 
3641     /**Compute the Fourier transform of range using the $(BIGOH N log N)
3642      * Cooley-Tukey Algorithm.  `range` must be a random-access range with
3643      * slicing and a length equal to `size` as provided at the construction of
3644      * this object.  The contents of range can be either  numeric types,
3645      * which will be interpreted as pure real values, or complex types with
3646      * properties or members `.re` and `.im` that can be read.
3647      *
3648      * Note:  Pure real FFTs are automatically detected and the relevant
3649      *        optimizations are performed.
3650      *
3651      * Returns:  An array of complex numbers representing the transformed data in
3652      *           the frequency domain.
3653      *
3654      * Conventions: The exponent is negative and the factor is one,
3655      *              i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ].
3656      */
3657     Complex!F[] fft(F = double, R)(R range) const
3658         if (isFloatingPoint!F && isRandomAccessRange!R)
3659     {
3660         enforceSize(range);
3661         Complex!F[] ret;
3662         if (range.length == 0)
3663         {
3664             return ret;
3665         }
3666 
3667         // Don't waste time initializing the memory for ret.
3668         ret = uninitializedArray!(Complex!F[])(range.length);
3669 
3670         fft(range,  ret);
3671         return ret;
3672     }
3673 
3674     /**Same as the overload, but allows for the results to be stored in a user-
3675      * provided buffer.  The buffer must be of the same length as range, must be
3676      * a random-access range, must have slicing, and must contain elements that are
3677      * complex-like.  This means that they must have a .re and a .im member or
3678      * property that can be both read and written and are floating point numbers.
3679      */
3680     void fft(Ret, R)(R range, Ret buf) const
3681         if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret)
3682     {
3683         assert(buf.length == range.length);
3684         enforceSize(range);
3685 
3686         if (range.length == 0)
3687         {
3688             return;
3689         }
3690         else if (range.length == 1)
3691         {
3692             buf[0] = range[0];
3693             return;
3694         }
3695         else if (range.length == 2)
3696         {
3697             slowFourier2(range, buf);
3698             return;
3699         }
3700         else
3701         {
3702             alias E = ElementType!R;
3703             static if (is(E : real))
3704             {
3705                 return fftImplPureReal(range, buf);
3706             }
3707             else
3708             {
3709                 static if (is(R : Stride!R))
3710                     return fftImpl(range, buf);
3711                 else
3712                     return fftImpl(Stride!R(range, 1), buf);
3713             }
3714         }
3715     }
3716 
3717     /**
3718      * Computes the inverse Fourier transform of a range.  The range must be a
3719      * random access range with slicing, have a length equal to the size
3720      * provided at construction of this object, and contain elements that are
3721      * either of type std.complex.Complex or have essentially
3722      * the same compile-time interface.
3723      *
3724      * Returns:  The time-domain signal.
3725      *
3726      * Conventions: The exponent is positive and the factor is 1/N, i.e.,
3727      *              output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ].
3728      */
3729     Complex!F[] inverseFft(F = double, R)(R range) const
3730         if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F)
3731     {
3732         enforceSize(range);
3733         Complex!F[] ret;
3734         if (range.length == 0)
3735         {
3736             return ret;
3737         }
3738 
3739         // Don't waste time initializing the memory for ret.
3740         ret = uninitializedArray!(Complex!F[])(range.length);
3741 
3742         inverseFft(range, ret);
3743         return ret;
3744     }
3745 
3746     /**
3747      * Inverse FFT that allows a user-supplied buffer to be provided.  The buffer
3748      * must be a random access range with slicing, and its elements
3749      * must be some complex-like type.
3750      */
3751     void inverseFft(Ret, R)(R range, Ret buf) const
3752         if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret)
3753     {
3754         enforceSize(range);
3755 
3756         auto swapped = map!swapRealImag(range);
3757         fft(swapped,  buf);
3758 
3759         immutable lenNeg1 = 1.0 / buf.length;
foreach(ref elem;buf)3760         foreach (ref elem; buf)
3761         {
3762             immutable temp = elem.re * lenNeg1;
3763             elem.re = elem.im * lenNeg1;
3764             elem.im = temp;
3765         }
3766     }
3767 }
3768 
3769 // This mixin creates an Fft object in the scope it's mixed into such that all
3770 // memory owned by the object is deterministically destroyed at the end of that
3771 // scope.
3772 private enum string MakeLocalFft = q{
3773     import core.stdc.stdlib;
3774     import core.exception : onOutOfMemoryError;
3775 
3776     auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof))
3777                      [0 .. 2 * range.length];
3778     if (!lookupBuf.ptr)
3779         onOutOfMemoryError();
3780 
3781     scope(exit) free(cast(void*) lookupBuf.ptr);
3782     auto fftObj = scoped!Fft(lookupBuf);
3783 };
3784 
3785 /**Convenience functions that create an `Fft` object, run the FFT or inverse
3786  * FFT and return the result.  Useful for one-off FFTs.
3787  *
3788  * Note:  In addition to convenience, these functions are slightly more
3789  *        efficient than manually creating an Fft object for a single use,
3790  *        as the Fft object is deterministically destroyed before these
3791  *        functions return.
3792  */
3793 Complex!F[] fft(F = double, R)(R range)
3794 {
3795     mixin(MakeLocalFft);
3796     return fftObj.fft!(F, R)(range);
3797 }
3798 
3799 /// ditto
fft(Ret,R)3800 void fft(Ret, R)(R range, Ret buf)
3801 {
3802     mixin(MakeLocalFft);
3803     return fftObj.fft!(Ret, R)(range, buf);
3804 }
3805 
3806 /// ditto
3807 Complex!F[] inverseFft(F = double, R)(R range)
3808 {
3809     mixin(MakeLocalFft);
3810     return fftObj.inverseFft!(F, R)(range);
3811 }
3812 
3813 /// ditto
inverseFft(Ret,R)3814 void inverseFft(Ret, R)(R range, Ret buf)
3815 {
3816     mixin(MakeLocalFft);
3817     return fftObj.inverseFft!(Ret, R)(range, buf);
3818 }
3819 
3820 @system unittest
3821 {
3822     import std.algorithm;
3823     import std.conv;
3824     import std.range;
3825     // Test values from R and Octave.
3826     auto arr = [1,2,3,4,5,6,7,8];
3827     auto fft1 = fft(arr);
3828     assert(isClose(map!"a.re"(fft1),
3829         [36.0, -4, -4, -4, -4, -4, -4, -4], 1e-4));
3830     assert(isClose(map!"a.im"(fft1),
3831         [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568], 1e-4));
3832 
3833     auto fft1Retro = fft(retro(arr));
3834     assert(isClose(map!"a.re"(fft1Retro),
3835         [36.0, 4, 4, 4, 4, 4, 4, 4], 1e-4));
3836     assert(isClose(map!"a.im"(fft1Retro),
3837         [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568], 1e-4));
3838 
3839     auto fft1Float = fft(to!(float[])(arr));
3840     assert(isClose(map!"a.re"(fft1), map!"a.re"(fft1Float)));
3841     assert(isClose(map!"a.im"(fft1), map!"a.im"(fft1Float)));
3842 
3843     alias C = Complex!float;
3844     auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10),
3845         C(11,12), C(13,14), C(15,16)];
3846     auto fft2 = fft(arr2);
3847     assert(isClose(map!"a.re"(fft2),
3848         [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137], 1e-4));
3849     assert(isClose(map!"a.im"(fft2),
3850         [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137], 1e-4));
3851 
3852     auto inv1 = inverseFft(fft1);
3853     assert(isClose(map!"a.re"(inv1), arr, 1e-6));
3854     assert(reduce!max(map!"a.im"(inv1)) < 1e-10);
3855 
3856     auto inv2 = inverseFft(fft2);
3857     assert(isClose(map!"a.re"(inv2), map!"a.re"(arr2)));
3858     assert(isClose(map!"a.im"(inv2), map!"a.im"(arr2)));
3859 
3860     // FFTs of size 0, 1 and 2 are handled as special cases.  Test them here.
3861     ushort[] empty;
3862     assert(fft(empty) == null);
3863     assert(inverseFft(fft(empty)) == null);
3864 
3865     real[] oneElem = [4.5L];
3866     auto oneFft = fft(oneElem);
3867     assert(oneFft.length == 1);
3868     assert(oneFft[0].re == 4.5L);
3869     assert(oneFft[0].im == 0);
3870 
3871     auto oneInv = inverseFft(oneFft);
3872     assert(oneInv.length == 1);
3873     assert(isClose(oneInv[0].re, 4.5));
3874     assert(isClose(oneInv[0].im, 0, 0.0, 1e-10));
3875 
3876     long[2] twoElems = [8, 4];
3877     auto twoFft = fft(twoElems[]);
3878     assert(twoFft.length == 2);
3879     assert(isClose(twoFft[0].re, 12));
3880     assert(isClose(twoFft[0].im, 0, 0.0, 1e-10));
3881     assert(isClose(twoFft[1].re, 4));
3882     assert(isClose(twoFft[1].im, 0, 0.0, 1e-10));
3883     auto twoInv = inverseFft(twoFft);
3884     assert(isClose(twoInv[0].re, 8));
3885     assert(isClose(twoInv[0].im, 0, 0.0, 1e-10));
3886     assert(isClose(twoInv[1].re, 4));
3887     assert(isClose(twoInv[1].im, 0, 0.0, 1e-10));
3888 }
3889 
3890 // Swaps the real and imaginary parts of a complex number.  This is useful
3891 // for inverse FFTs.
swapRealImag(C)3892 C swapRealImag(C)(C input)
3893 {
3894     return C(input.im, input.re);
3895 }
3896 
3897 /** This function transforms `decimal` value into a value in the factorial number
3898 system stored in `fac`.
3899 
3900 A factorial number is constructed as:
3901 $(D fac[0] * 0! + fac[1] * 1! + ... fac[20] * 20!)
3902 
3903 Params:
3904     decimal = The decimal value to convert into the factorial number system.
3905     fac = The array to store the factorial number. The array is of size 21 as
3906         `ulong.max` requires 21 digits in the factorial number system.
3907 Returns:
3908     A variable storing the number of digits of the factorial number stored in
3909     `fac`.
3910 */
decimalToFactorial(ulong decimal,ref ubyte[21]fac)3911 size_t decimalToFactorial(ulong decimal, ref ubyte[21] fac)
3912         @safe pure nothrow @nogc
3913 {
3914     import std.algorithm.mutation : reverse;
3915     size_t idx;
3916 
3917     for (ulong i = 1; decimal != 0; ++i)
3918     {
3919         auto temp = decimal % i;
3920         decimal /= i;
3921         fac[idx++] = cast(ubyte)(temp);
3922     }
3923 
3924     if (idx == 0)
3925     {
3926         fac[idx++] = cast(ubyte) 0;
3927     }
3928 
3929     reverse(fac[0 .. idx]);
3930 
3931     // first digit of the number in factorial will always be zero
3932     assert(fac[idx - 1] == 0);
3933 
3934     return idx;
3935 }
3936 
3937 ///
3938 @safe pure @nogc unittest
3939 {
3940     ubyte[21] fac;
3941     size_t idx = decimalToFactorial(2982, fac);
3942 
3943     assert(fac[0] == 4);
3944     assert(fac[1] == 0);
3945     assert(fac[2] == 4);
3946     assert(fac[3] == 1);
3947     assert(fac[4] == 0);
3948     assert(fac[5] == 0);
3949     assert(fac[6] == 0);
3950 }
3951 
3952 @safe pure unittest
3953 {
3954     ubyte[21] fac;
3955     size_t idx = decimalToFactorial(0UL, fac);
3956     assert(idx == 1);
3957     assert(fac[0] == 0);
3958 
3959     fac[] = 0;
3960     idx = 0;
3961     idx = decimalToFactorial(ulong.max, fac);
3962     assert(idx == 21);
3963     auto t = [7, 11, 12, 4, 3, 15, 3, 5, 3, 5, 0, 8, 3, 5, 0, 0, 0, 2, 1, 1, 0];
foreach(i,it;fac[0..21])3964     foreach (i, it; fac[0 .. 21])
3965     {
3966         assert(it == t[i]);
3967     }
3968 
3969     fac[] = 0;
3970     idx = decimalToFactorial(2982, fac);
3971 
3972     assert(idx == 7);
3973     t = [4, 0, 4, 1, 0, 0, 0];
foreach(i,it;fac[0..idx])3974     foreach (i, it; fac[0 .. idx])
3975     {
3976         assert(it == t[i]);
3977     }
3978 }
3979 
3980 private:
3981 // The reasons I couldn't use std.algorithm were b/c its stride length isn't
3982 // modifiable on the fly and because range has grown some performance hacks
3983 // for powers of 2.
Stride(R)3984 struct Stride(R)
3985 {
3986     import core.bitop : bsf;
3987     Unqual!R range;
3988     size_t _nSteps;
3989     size_t _length;
3990     alias E = ElementType!(R);
3991 
3992     this(R range, size_t nStepsIn)
3993     {
3994         this.range = range;
3995        _nSteps = nStepsIn;
3996        _length = (range.length + _nSteps - 1) / nSteps;
3997     }
3998 
3999     size_t length() const @property
4000     {
4001         return _length;
4002     }
4003 
4004     typeof(this) save() @property
4005     {
4006         auto ret = this;
4007         ret.range = ret.range.save;
4008         return ret;
4009     }
4010 
4011     E opIndex(size_t index)
4012     {
4013         return range[index * _nSteps];
4014     }
4015 
4016     E front() @property
4017     {
4018         return range[0];
4019     }
4020 
4021     void popFront()
4022     {
4023         if (range.length >= _nSteps)
4024         {
4025             range = range[_nSteps .. range.length];
4026             _length--;
4027         }
4028         else
4029         {
4030             range = range[0 .. 0];
4031             _length = 0;
4032         }
4033     }
4034 
4035     // Pops half the range's stride.
4036     void popHalf()
4037     {
4038         range = range[_nSteps / 2 .. range.length];
4039     }
4040 
4041     bool empty() const @property
4042     {
4043         return length == 0;
4044     }
4045 
4046     size_t nSteps() const @property
4047     {
4048         return _nSteps;
4049     }
4050 
4051     void doubleSteps()
4052     {
4053         _nSteps *= 2;
4054         _length /= 2;
4055     }
4056 
4057     size_t nSteps(size_t newVal) @property
4058     {
4059         _nSteps = newVal;
4060 
4061         // Using >> bsf(nSteps) is a few cycles faster than / nSteps.
4062         _length = (range.length + _nSteps - 1)  >> bsf(nSteps);
4063         return newVal;
4064     }
4065 }
4066 
4067 // Hard-coded base case for FFT of size 2.  This is actually a TON faster than
4068 // using a generic slow DFT.  This seems to be the best base case.  (Size 1
4069 // can be coded inline as buf[0] = range[0]).
slowFourier2(Ret,R)4070 void slowFourier2(Ret, R)(R range, Ret buf)
4071 {
4072     assert(range.length == 2);
4073     assert(buf.length == 2);
4074     buf[0] = range[0] + range[1];
4075     buf[1] = range[0] - range[1];
4076 }
4077 
4078 // Hard-coded base case for FFT of size 4.  Doesn't work as well as the size
4079 // 2 case.
slowFourier4(Ret,R)4080 void slowFourier4(Ret, R)(R range, Ret buf)
4081 {
4082     alias C = ElementType!Ret;
4083 
4084     assert(range.length == 4);
4085     assert(buf.length == 4);
4086     buf[0] = range[0] + range[1] + range[2] + range[3];
4087     buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1);
4088     buf[2] = range[0] - range[1] + range[2] - range[3];
4089     buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1);
4090 }
4091 
4092 N roundDownToPowerOf2(N)(N num)
4093 if (isScalarType!N && !isFloatingPoint!N)
4094 {
4095     import core.bitop : bsr;
4096     return num & (cast(N) 1 << bsr(num));
4097 }
4098 
4099 @safe unittest
4100 {
4101     assert(roundDownToPowerOf2(7) == 4);
4102     assert(roundDownToPowerOf2(4) == 4);
4103 }
4104 
isComplexLike(T)4105 template isComplexLike(T)
4106 {
4107     enum bool isComplexLike = is(typeof(T.init.re)) &&
4108         is(typeof(T.init.im));
4109 }
4110 
4111 @safe unittest
4112 {
4113     static assert(isComplexLike!(Complex!double));
4114     static assert(!isComplexLike!(uint));
4115 }
4116