1 // Written in the D programming language.
2
3 /**
4 This module is a port of a growing fragment of the $(D_PARAM numeric)
5 header in Alexander Stepanov's $(LINK2 http://sgi.com/tech/stl,
6 Standard Template Library), with a few additions.
7
8 Macros:
9 Copyright: Copyright Andrei Alexandrescu 2008 - 2009.
10 License: $(HTTP www.boost.org/LICENSE_1_0.txt, Boost License 1.0).
11 Authors: $(HTTP erdani.org, Andrei Alexandrescu),
12 Don Clugston, Robert Jacques, Ilya Yaroshenko
13 Source: $(PHOBOSSRC std/_numeric.d)
14 */
15 /*
16 Copyright Andrei Alexandrescu 2008 - 2009.
17 Distributed under the Boost Software License, Version 1.0.
18 (See accompanying file LICENSE_1_0.txt or copy at
19 http://www.boost.org/LICENSE_1_0.txt)
20 */
21 module std.numeric;
22
23 import std.complex;
24 import std.math;
25 import std.range.primitives;
26 import std.traits;
27 import std.typecons;
28
version(unittest)29 version (unittest)
30 {
31 import std.stdio;
32 }
33 /// Format flags for CustomFloat.
34 public enum CustomFloatFlags
35 {
36 /// Adds a sign bit to allow for signed numbers.
37 signed = 1,
38
39 /**
40 * Store values in normalized form by default. The actual precision of the
41 * significand is extended by 1 bit by assuming an implicit leading bit of 1
42 * instead of 0. i.e. $(D 1.nnnn) instead of $(D 0.nnnn).
43 * True for all $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEE754) types
44 */
45 storeNormalized = 2,
46
47 /**
48 * Stores the significand in $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers,
49 * IEEE754 denormalized) form when the exponent is 0. Required to express the value 0.
50 */
51 allowDenorm = 4,
52
53 /**
54 * Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Positive_and_negative_infinity,
55 * IEEE754 _infinity) values.
56 */
57 infinity = 8,
58
59 /// Allows the storage of $(LINK2 https://en.wikipedia.org/wiki/NaN, IEEE754 Not a Number) values.
60 nan = 16,
61
62 /**
63 * If set, select an exponent bias such that max_exp = 1.
64 * i.e. so that the maximum value is >= 1.0 and < 2.0.
65 * Ignored if the exponent bias is manually specified.
66 */
67 probability = 32,
68
69 /// If set, unsigned custom floats are assumed to be negative.
70 negativeUnsigned = 64,
71
72 /**If set, 0 is the only allowed $(LINK2 https://en.wikipedia.org/wiki/IEEE_754-1985#Denormalized_numbers,
73 * IEEE754 denormalized) number.
74 * Requires allowDenorm and storeNormalized.
75 */
76 allowDenormZeroOnly = 128 | allowDenorm | storeNormalized,
77
78 /// Include _all of the $(LINK2 https://en.wikipedia.org/wiki/IEEE_floating_point, IEEE754) options.
79 ieee = signed | storeNormalized | allowDenorm | infinity | nan ,
80
81 /// Include none of the above options.
82 none = 0
83 }
84
CustomFloatParams(uint bits)85 private template CustomFloatParams(uint bits)
86 {
87 enum CustomFloatFlags flags = CustomFloatFlags.ieee
88 ^ ((bits == 80) ? CustomFloatFlags.storeNormalized : CustomFloatFlags.none);
89 static if (bits == 8) alias CustomFloatParams = CustomFloatParams!( 4, 3, flags);
90 static if (bits == 16) alias CustomFloatParams = CustomFloatParams!(10, 5, flags);
91 static if (bits == 32) alias CustomFloatParams = CustomFloatParams!(23, 8, flags);
92 static if (bits == 64) alias CustomFloatParams = CustomFloatParams!(52, 11, flags);
93 static if (bits == 80) alias CustomFloatParams = CustomFloatParams!(64, 15, flags);
94 }
95
CustomFloatParams(uint precision,uint exponentWidth,CustomFloatFlags flags)96 private template CustomFloatParams(uint precision, uint exponentWidth, CustomFloatFlags flags)
97 {
98 import std.meta : AliasSeq;
99 alias CustomFloatParams =
100 AliasSeq!(
101 precision,
102 exponentWidth,
103 flags,
104 (1 << (exponentWidth - ((flags & flags.probability) == 0)))
105 - ((flags & (flags.nan | flags.infinity)) != 0) - ((flags & flags.probability) != 0)
106 ); // ((flags & CustomFloatFlags.probability) == 0)
107 }
108
109 /**
110 * Allows user code to define custom floating-point formats. These formats are
111 * for storage only; all operations on them are performed by first implicitly
112 * extracting them to $(D real) first. After the operation is completed the
113 * result can be stored in a custom floating-point value via assignment.
114 */
115 template CustomFloat(uint bits)
116 if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 80)
117 {
118 alias CustomFloat = CustomFloat!(CustomFloatParams!(bits));
119 }
120
121 /// ditto
122 template CustomFloat(uint precision, uint exponentWidth, CustomFloatFlags flags = CustomFloatFlags.ieee)
123 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 && precision + exponentWidth > 0)
124 {
125 alias CustomFloat = CustomFloat!(CustomFloatParams!(precision, exponentWidth, flags));
126 }
127
128 ///
129 @safe unittest
130 {
131 import std.math : sin, cos;
132
133 // Define a 16-bit floating point values
134 CustomFloat!16 x; // Using the number of bits
135 CustomFloat!(10, 5) y; // Using the precision and exponent width
136 CustomFloat!(10, 5,CustomFloatFlags.ieee) z; // Using the precision, exponent width and format flags
137 CustomFloat!(10, 5,CustomFloatFlags.ieee, 15) w; // Using the precision, exponent width, format flags and exponent offset bias
138
139 // Use the 16-bit floats mostly like normal numbers
140 w = x*y - 1;
141
142 // Functions calls require conversion
143 z = sin(+x) + cos(+y); // Use unary plus to concisely convert to a real
144 z = sin(x.get!float) + cos(y.get!float); // Or use get!T
145 z = sin(cast(float) x) + cos(cast(float) y); // Or use cast(T) to explicitly convert
146
147 // Define a 8-bit custom float for storing probabilities
148 alias Probability = CustomFloat!(4, 4, CustomFloatFlags.ieee^CustomFloatFlags.probability^CustomFloatFlags.signed );
149 auto p = Probability(0.5);
150 }
151
152 /// ditto
153 struct CustomFloat(uint precision, // fraction bits (23 for float)
154 uint exponentWidth, // exponent bits (8 for float) Exponent width
155 CustomFloatFlags flags,
156 uint bias)
157 if (((flags & flags.signed) + precision + exponentWidth) % 8 == 0 &&
158 precision + exponentWidth > 0)
159 {
160 import std.bitmanip : bitfields;
161 import std.meta : staticIndexOf;
162 private:
163 // get the correct unsigned bitfield type to support > 32 bits
uType(uint bits)164 template uType(uint bits)
165 {
166 static if (bits <= size_t.sizeof*8) alias uType = size_t;
167 else alias uType = ulong ;
168 }
169
170 // get the correct signed bitfield type to support > 32 bits
sType(uint bits)171 template sType(uint bits)
172 {
173 static if (bits <= ptrdiff_t.sizeof*8-1) alias sType = ptrdiff_t;
174 else alias sType = long;
175 }
176
177 alias T_sig = uType!precision;
178 alias T_exp = uType!exponentWidth;
179 alias T_signed_exp = sType!exponentWidth;
180
181 alias Flags = CustomFloatFlags;
182
183 // Facilitate converting numeric types to custom float
184 union ToBinary(F)
185 if (is(typeof(CustomFloatParams!(F.sizeof*8))) || is(F == real))
186 {
187 F set;
188
189 // If on Linux or Mac, where 80-bit reals are padded, ignore the
190 // padding.
191 import std.algorithm.comparison : min;
192 CustomFloat!(CustomFloatParams!(min(F.sizeof*8, 80))) get;
193
194 // Convert F to the correct binary type.
opCall(F value)195 static typeof(get) opCall(F value)
196 {
197 ToBinary r;
198 r.set = value;
199 return r.get;
200 }
201 alias get this;
202 }
203
204 // Perform IEEE rounding with round to nearest detection
roundedShift(T,U)205 void roundedShift(T,U)(ref T sig, U shift)
206 {
207 if (sig << (T.sizeof*8 - shift) == cast(T) 1uL << (T.sizeof*8 - 1))
208 {
209 // round to even
210 sig >>= shift;
211 sig += sig & 1;
212 }
213 else
214 {
215 sig >>= shift - 1;
216 sig += sig & 1;
217 // Perform standard rounding
218 sig >>= 1;
219 }
220 }
221
222 // Convert the current value to signed exponent, normalized form
toNormalized(T,U)223 void toNormalized(T,U)(ref T sig, ref U exp)
224 {
225 sig = significand;
226 auto shift = (T.sizeof*8) - precision;
227 exp = exponent;
228 static if (flags&(Flags.infinity|Flags.nan))
229 {
230 // Handle inf or nan
231 if (exp == exponent_max)
232 {
233 exp = exp.max;
234 sig <<= shift;
235 static if (flags&Flags.storeNormalized)
236 {
237 // Save inf/nan in denormalized format
238 sig >>= 1;
239 sig += cast(T) 1uL << (T.sizeof*8 - 1);
240 }
241 return;
242 }
243 }
244 if ((~flags&Flags.storeNormalized) ||
245 // Convert denormalized form to normalized form
246 ((flags&Flags.allowDenorm) && exp == 0))
247 {
248 if (sig > 0)
249 {
250 import core.bitop : bsr;
251 auto shift2 = precision - bsr(sig);
252 exp -= shift2-1;
253 shift += shift2;
254 }
255 else // value = 0.0
256 {
257 exp = exp.min;
258 return;
259 }
260 }
261 sig <<= shift;
262 exp -= bias;
263 }
264
265 // Set the current value from signed exponent, normalized form
fromNormalized(T,U)266 void fromNormalized(T,U)(ref T sig, ref U exp)
267 {
268 auto shift = (T.sizeof*8) - precision;
269 if (exp == exp.max)
270 {
271 // infinity or nan
272 exp = exponent_max;
273 static if (flags & Flags.storeNormalized)
274 sig <<= 1;
275
276 // convert back to normalized form
277 static if (~flags & Flags.infinity)
278 // No infinity support?
279 assert(sig != 0, "Infinity floating point value assigned to a "
280 ~ typeof(this).stringof ~ " (no infinity support).");
281
282 static if (~flags & Flags.nan) // No NaN support?
283 assert(sig == 0, "NaN floating point value assigned to a " ~
284 typeof(this).stringof ~ " (no nan support).");
285 sig >>= shift;
286 return;
287 }
288 if (exp == exp.min) // 0.0
289 {
290 exp = 0;
291 sig = 0;
292 return;
293 }
294
295 exp += bias;
296 if (exp <= 0)
297 {
298 static if ((flags&Flags.allowDenorm) ||
299 // Convert from normalized form to denormalized
300 (~flags&Flags.storeNormalized))
301 {
302 shift += -exp;
303 roundedShift(sig,1);
304 sig += cast(T) 1uL << (T.sizeof*8 - 1);
305 // Add the leading 1
306 exp = 0;
307 }
308 else
309 assert((flags&Flags.storeNormalized) && exp == 0,
310 "Underflow occured assigning to a " ~
311 typeof(this).stringof ~ " (no denormal support).");
312 }
313 else
314 {
315 static if (~flags&Flags.storeNormalized)
316 {
317 // Convert from normalized form to denormalized
318 roundedShift(sig,1);
319 sig += cast(T) 1uL << (T.sizeof*8 - 1);
320 // Add the leading 1
321 }
322 }
323
324 if (shift > 0)
325 roundedShift(sig,shift);
326 if (sig > significand_max)
327 {
328 // handle significand overflow (should only be 1 bit)
329 static if (~flags&Flags.storeNormalized)
330 {
331 sig >>= 1;
332 }
333 else
334 sig &= significand_max;
335 exp++;
336 }
337 static if ((flags&Flags.allowDenormZeroOnly)==Flags.allowDenormZeroOnly)
338 {
339 // disallow non-zero denormals
340 if (exp == 0)
341 {
342 sig <<= 1;
343 if (sig > significand_max && (sig&significand_max) > 0)
344 // Check and round to even
345 exp++;
346 sig = 0;
347 }
348 }
349
350 if (exp >= exponent_max)
351 {
352 static if (flags&(Flags.infinity|Flags.nan))
353 {
354 sig = 0;
355 exp = exponent_max;
356 static if (~flags&(Flags.infinity))
357 assert(0, "Overflow occured assigning to a " ~
358 typeof(this).stringof ~ " (no infinity support).");
359 }
360 else
361 assert(exp == exponent_max, "Overflow occured assigning to a "
362 ~ typeof(this).stringof ~ " (no infinity support).");
363 }
364 }
365
366 public:
367 static if (precision == 64) // CustomFloat!80 support hack
368 {
369 ulong significand;
370 enum ulong significand_max = ulong.max;
371 mixin(bitfields!(
372 T_exp , "exponent", exponentWidth,
373 bool , "sign" , flags & flags.signed ));
374 }
375 else
376 {
377 mixin(bitfields!(
378 T_sig, "significand", precision,
379 T_exp, "exponent" , exponentWidth,
380 bool , "sign" , flags & flags.signed ));
381 }
382
383 /// Returns: infinity value
384 static if (flags & Flags.infinity)
infinity()385 static @property CustomFloat infinity()
386 {
387 CustomFloat value;
388 static if (flags & Flags.signed)
389 value.sign = 0;
390 value.significand = 0;
391 value.exponent = exponent_max;
392 return value;
393 }
394
395 /// Returns: NaN value
396 static if (flags & Flags.nan)
nan()397 static @property CustomFloat nan()
398 {
399 CustomFloat value;
400 static if (flags & Flags.signed)
401 value.sign = 0;
402 value.significand = cast(typeof(significand_max)) 1L << (precision-1);
403 value.exponent = exponent_max;
404 return value;
405 }
406
407 /// Returns: number of decimal digits of precision
dig()408 static @property size_t dig()
409 {
410 auto shiftcnt = precision - ((flags&Flags.storeNormalized) != 0);
411 immutable x = (shiftcnt == 64) ? 0 : 1uL << shiftcnt;
412 return cast(size_t) log10(x);
413 }
414
415 /// Returns: smallest increment to the value 1
epsilon()416 static @property CustomFloat epsilon()
417 {
418 CustomFloat value;
419 static if (flags & Flags.signed)
420 value.sign = 0;
421 T_signed_exp exp = -precision;
422 T_sig sig = 0;
423
424 value.fromNormalized(sig,exp);
425 if (exp == 0 && sig == 0) // underflowed to zero
426 {
427 static if ((flags&Flags.allowDenorm) ||
428 (~flags&Flags.storeNormalized))
429 sig = 1;
430 else
431 sig = cast(T) 1uL << (precision - 1);
432 }
433 value.exponent = cast(value.T_exp) exp;
434 value.significand = cast(value.T_sig) sig;
435 return value;
436 }
437
438 /// the number of bits in mantissa
439 enum mant_dig = precision + ((flags&Flags.storeNormalized) != 0);
440
441 /// Returns: maximum int value such that 10<sup>max_10_exp</sup> is representable
max_10_exp()442 static @property int max_10_exp(){ return cast(int) log10( +max ); }
443
444 /// maximum int value such that 2<sup>max_exp-1</sup> is representable
445 enum max_exp = exponent_max-bias+((~flags&(Flags.infinity|flags.nan))!=0);
446
447 /// Returns: minimum int value such that 10<sup>min_10_exp</sup> is representable
min_10_exp()448 static @property int min_10_exp(){ return cast(int) log10( +min_normal ); }
449
450 /// minimum int value such that 2<sup>min_exp-1</sup> is representable as a normalized value
451 enum min_exp = cast(T_signed_exp)-bias +1+ ((flags&Flags.allowDenorm)!=0);
452
453 /// Returns: largest representable value that's not infinity
max()454 static @property CustomFloat max()
455 {
456 CustomFloat value;
457 static if (flags & Flags.signed)
458 value.sign = 0;
459 value.exponent = exponent_max - ((flags&(flags.infinity|flags.nan)) != 0);
460 value.significand = significand_max;
461 return value;
462 }
463
464 /// Returns: smallest representable normalized value that's not 0
min_normal()465 static @property CustomFloat min_normal() {
466 CustomFloat value;
467 static if (flags & Flags.signed)
468 value.sign = 0;
469 value.exponent = 1;
470 static if (flags&Flags.storeNormalized)
471 value.significand = 0;
472 else
473 value.significand = cast(T_sig) 1uL << (precision - 1);
474 return value;
475 }
476
477 /// Returns: real part
re()478 @property CustomFloat re() { return this; }
479
480 /// Returns: imaginary part
im()481 static @property CustomFloat im() { return CustomFloat(0.0f); }
482
483 /// Initialize from any $(D real) compatible type.
484 this(F)(F input) if (__traits(compiles, cast(real) input ))
485 {
486 this = input;
487 }
488
489 /// Self assignment
490 void opAssign(F:CustomFloat)(F input)
491 {
492 static if (flags & Flags.signed)
493 sign = input.sign;
494 exponent = input.exponent;
495 significand = input.significand;
496 }
497
498 /// Assigns from any $(D real) compatible type.
499 void opAssign(F)(F input)
500 if (__traits(compiles, cast(real) input))
501 {
502 import std.conv : text;
503
504 static if (staticIndexOf!(Unqual!F, float, double, real) >= 0)
505 auto value = ToBinary!(Unqual!F)(input);
506 else
507 auto value = ToBinary!(real )(input);
508
509 // Assign the sign bit
510 static if (~flags & Flags.signed)
511 assert((!value.sign) ^ ((flags&flags.negativeUnsigned) > 0),
512 "Incorrectly signed floating point value assigned to a " ~
513 typeof(this).stringof ~ " (no sign support).");
514 else
515 sign = value.sign;
516
517 CommonType!(T_signed_exp ,value.T_signed_exp) exp = value.exponent;
518 CommonType!(T_sig, value.T_sig ) sig = value.significand;
519
520 value.toNormalized(sig,exp);
521 fromNormalized(sig,exp);
522
523 assert(exp <= exponent_max, text(typeof(this).stringof ~
524 " exponent too large: " ,exp," > ",exponent_max, "\t",input,"\t",sig));
525 assert(sig <= significand_max, text(typeof(this).stringof ~
526 " significand too large: ",sig," > ",significand_max,
527 "\t",input,"\t",exp," ",exponent_max));
528 exponent = cast(T_exp) exp;
529 significand = cast(T_sig) sig;
530 }
531
532 /// Fetches the stored value either as a $(D float), $(D double) or $(D real).
533 @property F get(F)()
534 if (staticIndexOf!(Unqual!F, float, double, real) >= 0)
535 {
536 import std.conv : text;
537
538 ToBinary!F result;
539
540 static if (flags&Flags.signed)
541 result.sign = sign;
542 else
543 result.sign = (flags&flags.negativeUnsigned) > 0;
544
545 CommonType!(T_signed_exp ,result.get.T_signed_exp ) exp = exponent; // Assign the exponent and fraction
546 CommonType!(T_sig, result.get.T_sig ) sig = significand;
547
548 toNormalized(sig,exp);
549 result.fromNormalized(sig,exp);
550 assert(exp <= result.exponent_max, text("get exponent too large: " ,exp," > ",result.exponent_max) );
551 assert(sig <= result.significand_max, text("get significand too large: ",sig," > ",result.significand_max) );
552 result.exponent = cast(result.get.T_exp) exp;
553 result.significand = cast(result.get.T_sig) sig;
554 return result.set;
555 }
556
557 ///ditto
558 T opCast(T)() if (__traits(compiles, get!T )) { return get!T; }
559
560 /// Convert the CustomFloat to a real and perform the relavent operator on the result
561 real opUnary(string op)()
562 if (__traits(compiles, mixin(op~`(get!real)`)) || op=="++" || op=="--")
563 {
564 static if (op=="++" || op=="--")
565 {
566 auto result = get!real;
567 this = mixin(op~`result`);
568 return result;
569 }
570 else
571 return mixin(op~`get!real`);
572 }
573
574 /// ditto
575 real opBinary(string op,T)(T b)
576 if (__traits(compiles, mixin(`get!real`~op~`b`)))
577 {
578 return mixin(`get!real`~op~`b`);
579 }
580
581 /// ditto
582 real opBinaryRight(string op,T)(T a)
583 if ( __traits(compiles, mixin(`a`~op~`get!real`)) &&
584 !__traits(compiles, mixin(`get!real`~op~`b`)))
585 {
586 return mixin(`a`~op~`get!real`);
587 }
588
589 /// ditto
590 int opCmp(T)(auto ref T b)
591 if (__traits(compiles, cast(real) b))
592 {
593 auto x = get!real;
594 auto y = cast(real) b;
595 return (x >= y)-(x <= y);
596 }
597
598 /// ditto
599 void opOpAssign(string op, T)(auto ref T b)
600 if (__traits(compiles, mixin(`get!real`~op~`cast(real) b`)))
601 {
602 return mixin(`this = this `~op~` cast(real) b`);
603 }
604
605 /// ditto
toString()606 template toString()
607 {
608 import std.format : FormatSpec, formatValue;
609 // Needs to be a template because of DMD @@BUG@@ 13737.
610 void toString()(scope void delegate(const(char)[]) sink, FormatSpec!char fmt)
611 {
612 sink.formatValue(get!real, fmt);
613 }
614 }
615 }
616
617 @safe unittest
618 {
619 import std.meta;
620 alias FPTypes =
621 AliasSeq!(
622 CustomFloat!(5, 10),
623 CustomFloat!(5, 11, CustomFloatFlags.ieee ^ CustomFloatFlags.signed),
624 CustomFloat!(1, 15, CustomFloatFlags.ieee ^ CustomFloatFlags.signed),
625 CustomFloat!(4, 3, CustomFloatFlags.ieee | CustomFloatFlags.probability ^ CustomFloatFlags.signed)
626 );
627
foreach(F;FPTypes)628 foreach (F; FPTypes)
629 {
630 auto x = F(0.125);
631 assert(x.get!float == 0.125F);
632 assert(x.get!double == 0.125);
633
634 x -= 0.0625;
635 assert(x.get!float == 0.0625F);
636 assert(x.get!double == 0.0625);
637
638 x *= 2;
639 assert(x.get!float == 0.125F);
640 assert(x.get!double == 0.125);
641
642 x /= 4;
643 assert(x.get!float == 0.03125);
644 assert(x.get!double == 0.03125);
645
646 x = 0.5;
647 x ^^= 4;
648 assert(x.get!float == 1 / 16.0F);
649 assert(x.get!double == 1 / 16.0);
650 }
651 }
652
653 @system unittest
654 {
655 // @system due to to!string(CustomFloat)
656 import std.conv;
657 CustomFloat!(5, 10) y = CustomFloat!(5, 10)(0.125);
658 assert(y.to!string == "0.125");
659 }
660
661 /**
662 Defines the fastest type to use when storing temporaries of a
663 calculation intended to ultimately yield a result of type $(D F)
664 (where $(D F) must be one of $(D float), $(D double), or $(D
665 real)). When doing a multi-step computation, you may want to store
666 intermediate results as $(D FPTemporary!F).
667
668 The necessity of $(D FPTemporary) stems from the optimized
669 floating-point operations and registers present in virtually all
670 processors. When adding numbers in the example above, the addition may
671 in fact be done in $(D real) precision internally. In that case,
672 storing the intermediate $(D result) in $(D double format) is not only
673 less precise, it is also (surprisingly) slower, because a conversion
674 from $(D real) to $(D double) is performed every pass through the
675 loop. This being a lose-lose situation, $(D FPTemporary!F) has been
676 defined as the $(I fastest) type to use for calculations at precision
677 $(D F). There is no need to define a type for the $(I most accurate)
678 calculations, as that is always $(D real).
679
680 Finally, there is no guarantee that using $(D FPTemporary!F) will
681 always be fastest, as the speed of floating-point calculations depends
682 on very many factors.
683 */
684 template FPTemporary(F)
685 if (isFloatingPoint!F)
686 {
687 version (X86)
688 alias FPTemporary = real;
689 else
690 alias FPTemporary = Unqual!F;
691 }
692
693 ///
694 @safe unittest
695 {
696 import std.math : approxEqual;
697
698 // Average numbers in an array
avg(in double[]a)699 double avg(in double[] a)
700 {
701 if (a.length == 0) return 0;
702 FPTemporary!double result = 0;
703 foreach (e; a) result += e;
704 return result / a.length;
705 }
706
707 auto a = [1.0, 2.0, 3.0];
708 assert(approxEqual(avg(a), 2));
709 }
710
711 /**
712 Implements the $(HTTP tinyurl.com/2zb9yr, secant method) for finding a
713 root of the function $(D fun) starting from points $(D [xn_1, x_n])
714 (ideally close to the root). $(D Num) may be $(D float), $(D double),
715 or $(D real).
716 */
secantMethod(alias fun)717 template secantMethod(alias fun)
718 {
719 import std.functional : unaryFun;
720 Num secantMethod(Num)(Num xn_1, Num xn)
721 {
722 auto fxn = unaryFun!(fun)(xn_1), d = xn_1 - xn;
723 typeof(fxn) fxn_1;
724
725 xn = xn_1;
726 while (!approxEqual(d, 0) && isFinite(d))
727 {
728 xn_1 = xn;
729 xn -= d;
730 fxn_1 = fxn;
731 fxn = unaryFun!(fun)(xn);
732 d *= -fxn / (fxn - fxn_1);
733 }
734 return xn;
735 }
736 }
737
738 ///
739 @safe unittest
740 {
741 import std.math : approxEqual, cos;
742
f(float x)743 float f(float x)
744 {
745 return cos(x) - x*x*x;
746 }
747 auto x = secantMethod!(f)(0f, 1f);
748 assert(approxEqual(x, 0.865474));
749 }
750
751 @system unittest
752 {
753 // @system because of __gshared stderr
754 scope(failure) stderr.writeln("Failure testing secantMethod");
f(float x)755 float f(float x)
756 {
757 return cos(x) - x*x*x;
758 }
759 immutable x = secantMethod!(f)(0f, 1f);
760 assert(approxEqual(x, 0.865474));
761 auto d = &f;
762 immutable y = secantMethod!(d)(0f, 1f);
763 assert(approxEqual(y, 0.865474));
764 }
765
766
767 /**
768 * Return true if a and b have opposite sign.
769 */
oppositeSigns(T1,T2)770 private bool oppositeSigns(T1, T2)(T1 a, T2 b)
771 {
772 return signbit(a) != signbit(b);
773 }
774
775 public:
776
777 /** Find a real root of a real function f(x) via bracketing.
778 *
779 * Given a function `f` and a range `[a .. b]` such that `f(a)`
780 * and `f(b)` have opposite signs or at least one of them equals ±0,
781 * returns the value of `x` in
782 * the range which is closest to a root of `f(x)`. If `f(x)`
783 * has more than one root in the range, one will be chosen
784 * arbitrarily. If `f(x)` returns NaN, NaN will be returned;
785 * otherwise, this algorithm is guaranteed to succeed.
786 *
787 * Uses an algorithm based on TOMS748, which uses inverse cubic
788 * interpolation whenever possible, otherwise reverting to parabolic
789 * or secant interpolation. Compared to TOMS748, this implementation
790 * improves worst-case performance by a factor of more than 100, and
791 * typical performance by a factor of 2. For 80-bit reals, most
792 * problems require 8 to 15 calls to `f(x)` to achieve full machine
793 * precision. The worst-case performance (pathological cases) is
794 * approximately twice the number of bits.
795 *
796 * References: "On Enclosing Simple Roots of Nonlinear Equations",
797 * G. Alefeld, F.A. Potra, Yixun Shi, Mathematics of Computation 61,
798 * pp733-744 (1993). Fortran code available from $(HTTP
799 * www.netlib.org,www.netlib.org) as algorithm TOMS478.
800 *
801 */
802 T findRoot(T, DF, DT)(scope DF f, in T a, in T b,
803 scope DT tolerance) //= (T a, T b) => false)
804 if (
805 isFloatingPoint!T &&
806 is(typeof(tolerance(T.init, T.init)) : bool) &&
807 is(typeof(f(T.init)) == R, R) && isFloatingPoint!R
808 )
809 {
810 immutable fa = f(a);
811 if (fa == 0)
812 return a;
813 immutable fb = f(b);
814 if (fb == 0)
815 return b;
816 immutable r = findRoot(f, a, b, fa, fb, tolerance);
817 // Return the first value if it is smaller or NaN
818 return !(fabs(r[2]) > fabs(r[3])) ? r[0] : r[1];
819 }
820
821 ///ditto
findRoot(T,DF)822 T findRoot(T, DF)(scope DF f, in T a, in T b)
823 {
824 return findRoot(f, a, b, (T a, T b) => false);
825 }
826
827 /** Find root of a real function f(x) by bracketing, allowing the
828 * termination condition to be specified.
829 *
830 * Params:
831 *
832 * f = Function to be analyzed
833 *
834 * ax = Left bound of initial range of `f` known to contain the
835 * root.
836 *
837 * bx = Right bound of initial range of `f` known to contain the
838 * root.
839 *
840 * fax = Value of $(D f(ax)).
841 *
842 * fbx = Value of $(D f(bx)). $(D fax) and $(D fbx) should have opposite signs.
843 * ($(D f(ax)) and $(D f(bx)) are commonly known in advance.)
844 *
845 *
846 * tolerance = Defines an early termination condition. Receives the
847 * current upper and lower bounds on the root. The
848 * delegate must return $(D true) when these bounds are
849 * acceptable. If this function always returns $(D false),
850 * full machine precision will be achieved.
851 *
852 * Returns:
853 *
854 * A tuple consisting of two ranges. The first two elements are the
855 * range (in `x`) of the root, while the second pair of elements
856 * are the corresponding function values at those points. If an exact
857 * root was found, both of the first two elements will contain the
858 * root, and the second pair of elements will be 0.
859 */
860 Tuple!(T, T, R, R) findRoot(T, R, DF, DT)(scope DF f, in T ax, in T bx, in R fax, in R fbx,
861 scope DT tolerance) // = (T a, T b) => false)
862 if (
863 isFloatingPoint!T &&
864 is(typeof(tolerance(T.init, T.init)) : bool) &&
865 is(typeof(f(T.init)) == R) && isFloatingPoint!R
866 )
867 in
868 {
869 assert(!ax.isNaN() && !bx.isNaN(), "Limits must not be NaN");
870 assert(signbit(fax) != signbit(fbx), "Parameters must bracket the root.");
871 }
872 body
873 {
874 // Author: Don Clugston. This code is (heavily) modified from TOMS748
875 // (www.netlib.org). The changes to improve the worst-cast performance are
876 // entirely original.
877
878 T a, b, d; // [a .. b] is our current bracket. d is the third best guess.
879 R fa, fb, fd; // Values of f at a, b, d.
880 bool done = false; // Has a root been found?
881
882 // Allow ax and bx to be provided in reverse order
883 if (ax <= bx)
884 {
885 a = ax; fa = fax;
886 b = bx; fb = fbx;
887 }
888 else
889 {
890 a = bx; fa = fbx;
891 b = ax; fb = fax;
892 }
893
894 // Test the function at point c; update brackets accordingly
bracket(T c)895 void bracket(T c)
896 {
897 R fc = f(c);
898 if (fc == 0 || fc.isNaN()) // Exact solution, or NaN
899 {
900 a = c;
901 fa = fc;
902 d = c;
903 fd = fc;
904 done = true;
905 return;
906 }
907
908 // Determine new enclosing interval
909 if (signbit(fa) != signbit(fc))
910 {
911 d = b;
912 fd = fb;
913 b = c;
914 fb = fc;
915 }
916 else
917 {
918 d = a;
919 fd = fa;
920 a = c;
921 fa = fc;
922 }
923 }
924
925 /* Perform a secant interpolation. If the result would lie on a or b, or if
926 a and b differ so wildly in magnitude that the result would be meaningless,
927 perform a bisection instead.
928 */
secant_interpolate(T a,T b,R fa,R fb)929 static T secant_interpolate(T a, T b, R fa, R fb)
930 {
931 if (( ((a - b) == a) && b != 0) || (a != 0 && ((b - a) == b)))
932 {
933 // Catastrophic cancellation
934 if (a == 0)
935 a = copysign(T(0), b);
936 else if (b == 0)
937 b = copysign(T(0), a);
938 else if (signbit(a) != signbit(b))
939 return 0;
940 T c = ieeeMean(a, b);
941 return c;
942 }
943 // avoid overflow
944 if (b - a > T.max)
945 return b / 2 + a / 2;
946 if (fb - fa > R.max)
947 return a - (b - a) / 2;
948 T c = a - (fa / (fb - fa)) * (b - a);
949 if (c == a || c == b)
950 return (a + b) / 2;
951 return c;
952 }
953
954 /* Uses 'numsteps' newton steps to approximate the zero in [a .. b] of the
955 quadratic polynomial interpolating f(x) at a, b, and d.
956 Returns:
957 The approximate zero in [a .. b] of the quadratic polynomial.
958 */
newtonQuadratic(int numsteps)959 T newtonQuadratic(int numsteps)
960 {
961 // Find the coefficients of the quadratic polynomial.
962 immutable T a0 = fa;
963 immutable T a1 = (fb - fa)/(b - a);
964 immutable T a2 = ((fd - fb)/(d - b) - a1)/(d - a);
965
966 // Determine the starting point of newton steps.
967 T c = oppositeSigns(a2, fa) ? a : b;
968
969 // start the safeguarded newton steps.
970 foreach (int i; 0 .. numsteps)
971 {
972 immutable T pc = a0 + (a1 + a2 * (c - b))*(c - a);
973 immutable T pdc = a1 + a2*((2 * c) - (a + b));
974 if (pdc == 0)
975 return a - a0 / a1;
976 else
977 c = c - pc / pdc;
978 }
979 return c;
980 }
981
982 // On the first iteration we take a secant step:
983 if (fa == 0 || fa.isNaN())
984 {
985 done = true;
986 b = a;
987 fb = fa;
988 }
989 else if (fb == 0 || fb.isNaN())
990 {
991 done = true;
992 a = b;
993 fa = fb;
994 }
995 else
996 {
997 bracket(secant_interpolate(a, b, fa, fb));
998 }
999
1000 // Starting with the second iteration, higher-order interpolation can
1001 // be used.
1002 int itnum = 1; // Iteration number
1003 int baditer = 1; // Num bisections to take if an iteration is bad.
1004 T c, e; // e is our fourth best guess
1005 R fe;
1006
1007 whileloop:
1008 while (!done && (b != nextUp(a)) && !tolerance(a, b))
1009 {
1010 T a0 = a, b0 = b; // record the brackets
1011
1012 // Do two higher-order (cubic or parabolic) interpolation steps.
1013 foreach (int QQ; 0 .. 2)
1014 {
1015 // Cubic inverse interpolation requires that
1016 // all four function values fa, fb, fd, and fe are distinct;
1017 // otherwise use quadratic interpolation.
1018 bool distinct = (fa != fb) && (fa != fd) && (fa != fe)
1019 && (fb != fd) && (fb != fe) && (fd != fe);
1020 // The first time, cubic interpolation is impossible.
1021 if (itnum<2) distinct = false;
1022 bool ok = distinct;
1023 if (distinct)
1024 {
1025 // Cubic inverse interpolation of f(x) at a, b, d, and e
1026 immutable q11 = (d - e) * fd / (fe - fd);
1027 immutable q21 = (b - d) * fb / (fd - fb);
1028 immutable q31 = (a - b) * fa / (fb - fa);
1029 immutable d21 = (b - d) * fd / (fd - fb);
1030 immutable d31 = (a - b) * fb / (fb - fa);
1031
1032 immutable q22 = (d21 - q11) * fb / (fe - fb);
1033 immutable q32 = (d31 - q21) * fa / (fd - fa);
1034 immutable d32 = (d31 - q21) * fd / (fd - fa);
1035 immutable q33 = (d32 - q22) * fa / (fe - fa);
1036 c = a + (q31 + q32 + q33);
1037 if (c.isNaN() || (c <= a) || (c >= b))
1038 {
1039 // DAC: If the interpolation predicts a or b, it's
1040 // probable that it's the actual root. Only allow this if
1041 // we're already close to the root.
1042 if (c == a && a - b != a)
1043 {
1044 c = nextUp(a);
1045 }
1046 else if (c == b && a - b != -b)
1047 {
1048 c = nextDown(b);
1049 }
1050 else
1051 {
1052 ok = false;
1053 }
1054 }
1055 }
1056 if (!ok)
1057 {
1058 // DAC: Alefeld doesn't explain why the number of newton steps
1059 // should vary.
1060 c = newtonQuadratic(distinct ? 3 : 2);
1061 if (c.isNaN() || (c <= a) || (c >= b))
1062 {
1063 // Failure, try a secant step:
1064 c = secant_interpolate(a, b, fa, fb);
1065 }
1066 }
1067 ++itnum;
1068 e = d;
1069 fe = fd;
1070 bracket(c);
1071 if (done || ( b == nextUp(a)) || tolerance(a, b))
1072 break whileloop;
1073 if (itnum == 2)
1074 continue whileloop;
1075 }
1076
1077 // Now we take a double-length secant step:
1078 T u;
1079 R fu;
1080 if (fabs(fa) < fabs(fb))
1081 {
1082 u = a;
1083 fu = fa;
1084 }
1085 else
1086 {
1087 u = b;
1088 fu = fb;
1089 }
1090 c = u - 2 * (fu / (fb - fa)) * (b - a);
1091
1092 // DAC: If the secant predicts a value equal to an endpoint, it's
1093 // probably false.
1094 if (c == a || c == b || c.isNaN() || fabs(c - u) > (b - a) / 2)
1095 {
1096 if ((a-b) == a || (b-a) == b)
1097 {
1098 if ((a>0 && b<0) || (a<0 && b>0))
1099 c = 0;
1100 else
1101 {
1102 if (a == 0)
1103 c = ieeeMean(copysign(T(0), b), b);
1104 else if (b == 0)
1105 c = ieeeMean(copysign(T(0), a), a);
1106 else
1107 c = ieeeMean(a, b);
1108 }
1109 }
1110 else
1111 {
1112 c = a + (b - a) / 2;
1113 }
1114 }
1115 e = d;
1116 fe = fd;
1117 bracket(c);
1118 if (done || (b == nextUp(a)) || tolerance(a, b))
1119 break;
1120
1121 // IMPROVE THE WORST-CASE PERFORMANCE
1122 // We must ensure that the bounds reduce by a factor of 2
1123 // in binary space! every iteration. If we haven't achieved this
1124 // yet, or if we don't yet know what the exponent is,
1125 // perform a binary chop.
1126
1127 if ((a == 0 || b == 0 ||
1128 (fabs(a) >= T(0.5) * fabs(b) && fabs(b) >= T(0.5) * fabs(a)))
1129 && (b - a) < T(0.25) * (b0 - a0))
1130 {
1131 baditer = 1;
1132 continue;
1133 }
1134
1135 // DAC: If this happens on consecutive iterations, we probably have a
1136 // pathological function. Perform a number of bisections equal to the
1137 // total number of consecutive bad iterations.
1138
1139 if ((b - a) < T(0.25) * (b0 - a0))
1140 baditer = 1;
1141 foreach (int QQ; 0 .. baditer)
1142 {
1143 e = d;
1144 fe = fd;
1145
1146 T w;
1147 if ((a>0 && b<0) || (a<0 && b>0))
1148 w = 0;
1149 else
1150 {
1151 T usea = a;
1152 T useb = b;
1153 if (a == 0)
1154 usea = copysign(T(0), b);
1155 else if (b == 0)
1156 useb = copysign(T(0), a);
1157 w = ieeeMean(usea, useb);
1158 }
1159 bracket(w);
1160 }
1161 ++baditer;
1162 }
1163 return Tuple!(T, T, R, R)(a, b, fa, fb);
1164 }
1165
1166 ///ditto
1167 Tuple!(T, T, R, R) findRoot(T, R, DF)(scope DF f, in T ax, in T bx, in R fax, in R fbx)
1168 {
1169 return findRoot(f, ax, bx, fax, fbx, (T a, T b) => false);
1170 }
1171
1172 ///ditto
findRoot(T,R)1173 T findRoot(T, R)(scope R delegate(T) f, in T a, in T b,
1174 scope bool delegate(T lo, T hi) tolerance = (T a, T b) => false)
1175 {
1176 return findRoot!(T, R delegate(T), bool delegate(T lo, T hi))(f, a, b, tolerance);
1177 }
1178
1179 @safe nothrow unittest
1180 {
1181 int numProblems = 0;
1182 int numCalls;
1183
testFindRoot(real delegate (real)@nogc@safe nothrow pure f,real x1,real x2)1184 void testFindRoot(real delegate(real) @nogc @safe nothrow pure f , real x1, real x2) @nogc @safe nothrow pure
1185 {
1186 //numCalls=0;
1187 //++numProblems;
1188 assert(!x1.isNaN() && !x2.isNaN());
1189 assert(signbit(x1) != signbit(x2));
1190 auto result = findRoot(f, x1, x2, f(x1), f(x2),
1191 (real lo, real hi) { return false; });
1192
1193 auto flo = f(result[0]);
1194 auto fhi = f(result[1]);
1195 if (flo != 0)
1196 {
1197 assert(oppositeSigns(flo, fhi));
1198 }
1199 }
1200
1201 // Test functions
cubicfn(real x)1202 real cubicfn(real x) @nogc @safe nothrow pure
1203 {
1204 //++numCalls;
1205 if (x>float.max)
1206 x = float.max;
1207 if (x<-double.max)
1208 x = -double.max;
1209 // This has a single real root at -59.286543284815
1210 return 0.386*x*x*x + 23*x*x + 15.7*x + 525.2;
1211 }
1212 // Test a function with more than one root.
multisine(real x)1213 real multisine(real x) { ++numCalls; return sin(x); }
1214 //testFindRoot( &multisine, 6, 90);
1215 //testFindRoot(&cubicfn, -100, 100);
1216 //testFindRoot( &cubicfn, -double.max, real.max);
1217
1218
1219 /* Tests from the paper:
1220 * "On Enclosing Simple Roots of Nonlinear Equations", G. Alefeld, F.A. Potra,
1221 * Yixun Shi, Mathematics of Computation 61, pp733-744 (1993).
1222 */
1223 // Parameters common to many alefeld tests.
1224 int n;
1225 real ale_a, ale_b;
1226
1227 int powercalls = 0;
1228
power(real x)1229 real power(real x)
1230 {
1231 ++powercalls;
1232 ++numCalls;
1233 return pow(x, n) + double.min_normal;
1234 }
1235 int [] power_nvals = [3, 5, 7, 9, 19, 25];
1236 // Alefeld paper states that pow(x,n) is a very poor case, where bisection
1237 // outperforms his method, and gives total numcalls =
1238 // 921 for bisection (2.4 calls per bit), 1830 for Alefeld (4.76/bit),
1239 // 2624 for brent (6.8/bit)
1240 // ... but that is for double, not real80.
1241 // This poor performance seems mainly due to catastrophic cancellation,
1242 // which is avoided here by the use of ieeeMean().
1243 // I get: 231 (0.48/bit).
1244 // IE this is 10X faster in Alefeld's worst case
1245 numProblems=0;
foreach(k;power_nvals)1246 foreach (k; power_nvals)
1247 {
1248 n = k;
1249 //testFindRoot(&power, -1, 10);
1250 }
1251
1252 int powerProblems = numProblems;
1253
1254 // Tests from Alefeld paper
1255
1256 int [9] alefeldSums;
alefeld0(real x)1257 real alefeld0(real x)
1258 {
1259 ++alefeldSums[0];
1260 ++numCalls;
1261 real q = sin(x) - x/2;
1262 for (int i=1; i<20; ++i)
1263 q+=(2*i-5.0)*(2*i-5.0)/((x-i*i)*(x-i*i)*(x-i*i));
1264 return q;
1265 }
alefeld1(real x)1266 real alefeld1(real x)
1267 {
1268 ++numCalls;
1269 ++alefeldSums[1];
1270 return ale_a*x + exp(ale_b * x);
1271 }
alefeld2(real x)1272 real alefeld2(real x)
1273 {
1274 ++numCalls;
1275 ++alefeldSums[2];
1276 return pow(x, n) - ale_a;
1277 }
alefeld3(real x)1278 real alefeld3(real x)
1279 {
1280 ++numCalls;
1281 ++alefeldSums[3];
1282 return (1.0 +pow(1.0L-n, 2))*x - pow(1.0L-n*x, 2);
1283 }
alefeld4(real x)1284 real alefeld4(real x)
1285 {
1286 ++numCalls;
1287 ++alefeldSums[4];
1288 return x*x - pow(1-x, n);
1289 }
alefeld5(real x)1290 real alefeld5(real x)
1291 {
1292 ++numCalls;
1293 ++alefeldSums[5];
1294 return (1+pow(1.0L-n, 4))*x - pow(1.0L-n*x, 4);
1295 }
alefeld6(real x)1296 real alefeld6(real x)
1297 {
1298 ++numCalls;
1299 ++alefeldSums[6];
1300 return exp(-n*x)*(x-1.01L) + pow(x, n);
1301 }
alefeld7(real x)1302 real alefeld7(real x)
1303 {
1304 ++numCalls;
1305 ++alefeldSums[7];
1306 return (n*x-1)/((n-1)*x);
1307 }
1308
1309 numProblems=0;
1310 //testFindRoot(&alefeld0, PI_2, PI);
1311 for (n=1; n <= 10; ++n)
1312 {
1313 //testFindRoot(&alefeld0, n*n+1e-9L, (n+1)*(n+1)-1e-9L);
1314 }
1315 ale_a = -40; ale_b = -1;
1316 //testFindRoot(&alefeld1, -9, 31);
1317 ale_a = -100; ale_b = -2;
1318 //testFindRoot(&alefeld1, -9, 31);
1319 ale_a = -200; ale_b = -3;
1320 //testFindRoot(&alefeld1, -9, 31);
1321 int [] nvals_3 = [1, 2, 5, 10, 15, 20];
1322 int [] nvals_5 = [1, 2, 4, 5, 8, 15, 20];
1323 int [] nvals_6 = [1, 5, 10, 15, 20];
1324 int [] nvals_7 = [2, 5, 15, 20];
1325
1326 for (int i=4; i<12; i+=2)
1327 {
1328 n = i;
1329 ale_a = 0.2;
1330 //testFindRoot(&alefeld2, 0, 5);
1331 ale_a=1;
1332 //testFindRoot(&alefeld2, 0.95, 4.05);
1333 //testFindRoot(&alefeld2, 0, 1.5);
1334 }
foreach(i;nvals_3)1335 foreach (i; nvals_3)
1336 {
1337 n=i;
1338 //testFindRoot(&alefeld3, 0, 1);
1339 }
foreach(i;nvals_3)1340 foreach (i; nvals_3)
1341 {
1342 n=i;
1343 //testFindRoot(&alefeld4, 0, 1);
1344 }
foreach(i;nvals_5)1345 foreach (i; nvals_5)
1346 {
1347 n=i;
1348 //testFindRoot(&alefeld5, 0, 1);
1349 }
foreach(i;nvals_6)1350 foreach (i; nvals_6)
1351 {
1352 n=i;
1353 //testFindRoot(&alefeld6, 0, 1);
1354 }
foreach(i;nvals_7)1355 foreach (i; nvals_7)
1356 {
1357 n=i;
1358 //testFindRoot(&alefeld7, 0.01L, 1);
1359 }
worstcase(real x)1360 real worstcase(real x)
1361 {
1362 ++numCalls;
1363 return x<0.3*real.max? -0.999e-3 : 1.0;
1364 }
1365 //testFindRoot(&worstcase, -real.max, real.max);
1366
1367 // just check that the double + float cases compile
1368 //findRoot((double x){ return 0.0; }, -double.max, double.max);
1369 //findRoot((float x){ return 0.0f; }, -float.max, float.max);
1370
1371 /*
1372 int grandtotal=0;
1373 foreach (calls; alefeldSums)
1374 {
1375 grandtotal+=calls;
1376 }
1377 grandtotal-=2*numProblems;
1378 printf("\nALEFELD TOTAL = %d avg = %f (alefeld avg=19.3 for double)\n",
1379 grandtotal, (1.0*grandtotal)/numProblems);
1380 powercalls -= 2*powerProblems;
1381 printf("POWER TOTAL = %d avg = %f ", powercalls,
1382 (1.0*powercalls)/powerProblems);
1383 */
1384 //Issue 14231
1385 auto xp = findRoot((float x) => x, 0f, 1f);
1386 auto xn = findRoot((float x) => x, -1f, -0f);
1387 }
1388
1389 //regression control
1390 @system unittest
1391 {
1392 // @system due to the case in the 2nd line
1393 static assert(__traits(compiles, findRoot((float x)=>cast(real) x, float.init, float.init)));
1394 static assert(__traits(compiles, findRoot!real((x)=>cast(double) x, real.init, real.init)));
1395 static assert(__traits(compiles, findRoot((real x)=>cast(double) x, real.init, real.init)));
1396 }
1397
1398 /++
1399 Find a real minimum of a real function `f(x)` via bracketing.
1400 Given a function `f` and a range `(ax .. bx)`,
1401 returns the value of `x` in the range which is closest to a minimum of `f(x)`.
1402 `f` is never evaluted at the endpoints of `ax` and `bx`.
1403 If `f(x)` has more than one minimum in the range, one will be chosen arbitrarily.
1404 If `f(x)` returns NaN or -Infinity, `(x, f(x), NaN)` will be returned;
1405 otherwise, this algorithm is guaranteed to succeed.
1406
1407 Params:
1408 f = Function to be analyzed
1409 ax = Left bound of initial range of f known to contain the minimum.
1410 bx = Right bound of initial range of f known to contain the minimum.
1411 relTolerance = Relative tolerance.
1412 absTolerance = Absolute tolerance.
1413
1414 Preconditions:
1415 `ax` and `bx` shall be finite reals. $(BR)
1416 $(D relTolerance) shall be normal positive real. $(BR)
1417 $(D absTolerance) shall be normal positive real no less then $(D T.epsilon*2).
1418
1419 Returns:
1420 A tuple consisting of `x`, `y = f(x)` and `error = 3 * (absTolerance * fabs(x) + relTolerance)`.
1421
1422 The method used is a combination of golden section search and
1423 successive parabolic interpolation. Convergence is never much slower
1424 than that for a Fibonacci search.
1425
1426 References:
1427 "Algorithms for Minimization without Derivatives", Richard Brent, Prentice-Hall, Inc. (1973)
1428
1429 See_Also: $(LREF findRoot), $(REF isNormal, std,math)
1430 +/
1431 Tuple!(T, "x", Unqual!(ReturnType!DF), "y", T, "error")
1432 findLocalMin(T, DF)(
1433 scope DF f,
1434 in T ax,
1435 in T bx,
1436 in T relTolerance = sqrt(T.epsilon),
1437 in T absTolerance = sqrt(T.epsilon),
1438 )
1439 if (isFloatingPoint!T
1440 && __traits(compiles, {T _ = DF.init(T.init);}))
1441 in
1442 {
1443 assert(isFinite(ax), "ax is not finite");
1444 assert(isFinite(bx), "bx is not finite");
1445 assert(isNormal(relTolerance), "relTolerance is not normal floating point number");
1446 assert(isNormal(absTolerance), "absTolerance is not normal floating point number");
1447 assert(relTolerance >= 0, "absTolerance is not positive");
1448 assert(absTolerance >= T.epsilon*2, "absTolerance is not greater then `2*T.epsilon`");
1449 }
out(result)1450 out (result)
1451 {
1452 assert(isFinite(result.x));
1453 }
1454 body
1455 {
1456 alias R = Unqual!(CommonType!(ReturnType!DF, T));
1457 // c is the squared inverse of the golden ratio
1458 // (3 - sqrt(5))/2
1459 // Value obtained from Wolfram Alpha.
1460 enum T c = 0x0.61c8864680b583ea0c633f9fa31237p+0L;
1461 enum T cm1 = 0x0.9e3779b97f4a7c15f39cc0605cedc8p+0L;
1462 R tolerance;
1463 T a = ax > bx ? bx : ax;
1464 T b = ax > bx ? ax : bx;
1465 // sequence of declarations suitable for SIMD instructions
1466 T v = a * cm1 + b * c;
1467 assert(isFinite(v));
1468 R fv = f(v);
1469 if (isNaN(fv) || fv == -T.infinity)
1470 {
1471 return typeof(return)(v, fv, T.init);
1472 }
1473 T w = v;
1474 R fw = fv;
1475 T x = v;
1476 R fx = fv;
1477 size_t i;
1478 for (R d = 0, e = 0;;)
1479 {
1480 i++;
1481 T m = (a + b) / 2;
1482 // This fix is not part of the original algorithm
1483 if (!isFinite(m)) // fix infinity loop. Issue can be reproduced in R.
1484 {
1485 m = a / 2 + b / 2;
1486 if (!isFinite(m)) // fast-math compiler switch is enabled
1487 {
1488 //SIMD instructions can be used by compiler, do not reduce declarations
1489 int a_exp = void;
1490 int b_exp = void;
1491 immutable an = frexp(a, a_exp);
1492 immutable bn = frexp(b, b_exp);
1493 immutable am = ldexp(an, a_exp-1);
1494 immutable bm = ldexp(bn, b_exp-1);
1495 m = am + bm;
1496 if (!isFinite(m)) // wrong input: constraints are disabled in release mode
1497 {
1498 return typeof(return).init;
1499 }
1500 }
1501 }
1502 tolerance = absTolerance * fabs(x) + relTolerance;
1503 immutable t2 = tolerance * 2;
1504 // check stopping criterion
1505 if (!(fabs(x - m) > t2 - (b - a) / 2))
1506 {
1507 break;
1508 }
1509 R p = 0;
1510 R q = 0;
1511 R r = 0;
1512 // fit parabola
1513 if (fabs(e) > tolerance)
1514 {
1515 immutable xw = x - w;
1516 immutable fxw = fx - fw;
1517 immutable xv = x - v;
1518 immutable fxv = fx - fv;
1519 immutable xwfxv = xw * fxv;
1520 immutable xvfxw = xv * fxw;
1521 p = xv * xvfxw - xw * xwfxv;
1522 q = (xvfxw - xwfxv) * 2;
1523 if (q > 0)
1524 p = -p;
1525 else
1526 q = -q;
1527 r = e;
1528 e = d;
1529 }
1530 T u;
1531 // a parabolic-interpolation step
1532 if (fabs(p) < fabs(q * r / 2) && p > q * (a - x) && p < q * (b - x))
1533 {
1534 d = p / q;
1535 u = x + d;
1536 // f must not be evaluated too close to a or b
1537 if (u - a < t2 || b - u < t2)
1538 d = x < m ? tolerance : -tolerance;
1539 }
1540 // a golden-section step
1541 else
1542 {
1543 e = (x < m ? b : a) - x;
1544 d = c * e;
1545 }
1546 // f must not be evaluated too close to x
1547 u = x + (fabs(d) >= tolerance ? d : d > 0 ? tolerance : -tolerance);
1548 immutable fu = f(u);
1549 if (isNaN(fu) || fu == -T.infinity)
1550 {
1551 return typeof(return)(u, fu, T.init);
1552 }
1553 // update a, b, v, w, and x
1554 if (fu <= fx)
1555 {
1556 u < x ? b : a = x;
1557 v = w; fv = fw;
1558 w = x; fw = fx;
1559 x = u; fx = fu;
1560 }
1561 else
1562 {
1563 u < x ? a : b = u;
1564 if (fu <= fw || w == x)
1565 {
1566 v = w; fv = fw;
1567 w = u; fw = fu;
1568 }
1569 else if (fu <= fv || v == x || v == w)
1570 { // do not remove this braces
1571 v = u; fv = fu;
1572 }
1573 }
1574 }
1575 return typeof(return)(x, fx, tolerance * 3);
1576 }
1577
1578 ///
1579 @safe unittest
1580 {
1581 import std.math : approxEqual;
1582
1583 auto ret = findLocalMin((double x) => (x-4)^^2, -1e7, 1e7);
1584 assert(ret.x.approxEqual(4.0));
1585 assert(ret.y.approxEqual(0.0));
1586 }
1587
1588 @safe unittest
1589 {
1590 import std.meta : AliasSeq;
1591 foreach (T; AliasSeq!(double, float, real))
1592 {
1593 {
1594 auto ret = findLocalMin!T((T x) => (x-4)^^2, T.min_normal, 1e7);
1595 assert(ret.x.approxEqual(T(4)));
1596 assert(ret.y.approxEqual(T(0)));
1597 }
1598 {
1599 auto ret = findLocalMin!T((T x) => fabs(x-1), -T.max/4, T.max/4, T.min_normal, 2*T.epsilon);
1600 assert(approxEqual(ret.x, T(1)));
1601 assert(approxEqual(ret.y, T(0)));
1602 assert(ret.error <= 10 * T.epsilon);
1603 }
1604 {
1605 auto ret = findLocalMin!T((T x) => T.init, 0, 1, T.min_normal, 2*T.epsilon);
1606 assert(!ret.x.isNaN);
1607 assert(ret.y.isNaN);
1608 assert(ret.error.isNaN);
1609 }
1610 {
1611 auto ret = findLocalMin!T((T x) => log(x), 0, 1, T.min_normal, 2*T.epsilon);
1612 assert(ret.error < 3.00001 * ((2*T.epsilon)*fabs(ret.x)+ T.min_normal));
1613 assert(ret.x >= 0 && ret.x <= ret.error);
1614 }
1615 {
1616 auto ret = findLocalMin!T((T x) => log(x), 0, T.max, T.min_normal, 2*T.epsilon);
1617 assert(ret.y < -18);
1618 assert(ret.error < 5e-08);
1619 assert(ret.x >= 0 && ret.x <= ret.error);
1620 }
1621 {
1622 auto ret = findLocalMin!T((T x) => -fabs(x), -1, 1, T.min_normal, 2*T.epsilon);
1623 assert(ret.x.fabs.approxEqual(T(1)));
1624 assert(ret.y.fabs.approxEqual(T(1)));
1625 assert(ret.error.approxEqual(T(0)));
1626 }
1627 }
1628 }
1629
1630 /**
1631 Computes $(LINK2 https://en.wikipedia.org/wiki/Euclidean_distance,
1632 Euclidean distance) between input ranges $(D a) and
1633 $(D b). The two ranges must have the same length. The three-parameter
1634 version stops computation as soon as the distance is greater than or
1635 equal to $(D limit) (this is useful to save computation if a small
1636 distance is sought).
1637 */
1638 CommonType!(ElementType!(Range1), ElementType!(Range2))
1639 euclideanDistance(Range1, Range2)(Range1 a, Range2 b)
1640 if (isInputRange!(Range1) && isInputRange!(Range2))
1641 {
1642 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1643 static if (haveLen) assert(a.length == b.length);
1644 Unqual!(typeof(return)) result = 0;
1645 for (; !a.empty; a.popFront(), b.popFront())
1646 {
1647 immutable t = a.front - b.front;
1648 result += t * t;
1649 }
1650 static if (!haveLen) assert(b.empty);
1651 return sqrt(result);
1652 }
1653
1654 /// Ditto
1655 CommonType!(ElementType!(Range1), ElementType!(Range2))
1656 euclideanDistance(Range1, Range2, F)(Range1 a, Range2 b, F limit)
1657 if (isInputRange!(Range1) && isInputRange!(Range2))
1658 {
1659 limit *= limit;
1660 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1661 static if (haveLen) assert(a.length == b.length);
1662 Unqual!(typeof(return)) result = 0;
1663 for (; ; a.popFront(), b.popFront())
1664 {
1665 if (a.empty)
1666 {
1667 static if (!haveLen) assert(b.empty);
1668 break;
1669 }
1670 immutable t = a.front - b.front;
1671 result += t * t;
1672 if (result >= limit) break;
1673 }
1674 return sqrt(result);
1675 }
1676
1677 @safe unittest
1678 {
1679 import std.meta : AliasSeq;
1680 foreach (T; AliasSeq!(double, const double, immutable double))
1681 {
1682 T[] a = [ 1.0, 2.0, ];
1683 T[] b = [ 4.0, 6.0, ];
1684 assert(euclideanDistance(a, b) == 5);
1685 assert(euclideanDistance(a, b, 5) == 5);
1686 assert(euclideanDistance(a, b, 4) == 5);
1687 assert(euclideanDistance(a, b, 2) == 3);
1688 }
1689 }
1690
1691 /**
1692 Computes the $(LINK2 https://en.wikipedia.org/wiki/Dot_product,
1693 dot product) of input ranges $(D a) and $(D
1694 b). The two ranges must have the same length. If both ranges define
1695 length, the check is done once; otherwise, it is done at each
1696 iteration.
1697 */
1698 CommonType!(ElementType!(Range1), ElementType!(Range2))
1699 dotProduct(Range1, Range2)(Range1 a, Range2 b)
1700 if (isInputRange!(Range1) && isInputRange!(Range2) &&
1701 !(isArray!(Range1) && isArray!(Range2)))
1702 {
1703 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1704 static if (haveLen) assert(a.length == b.length);
1705 Unqual!(typeof(return)) result = 0;
1706 for (; !a.empty; a.popFront(), b.popFront())
1707 {
1708 result += a.front * b.front;
1709 }
1710 static if (!haveLen) assert(b.empty);
1711 return result;
1712 }
1713
1714 /// Ditto
1715 CommonType!(F1, F2)
1716 dotProduct(F1, F2)(in F1[] avector, in F2[] bvector)
1717 {
1718 immutable n = avector.length;
1719 assert(n == bvector.length);
1720 auto avec = avector.ptr, bvec = bvector.ptr;
1721 Unqual!(typeof(return)) sum0 = 0, sum1 = 0;
1722
1723 const all_endp = avec + n;
1724 const smallblock_endp = avec + (n & ~3);
1725 const bigblock_endp = avec + (n & ~15);
1726
1727 for (; avec != bigblock_endp; avec += 16, bvec += 16)
1728 {
1729 sum0 += avec[0] * bvec[0];
1730 sum1 += avec[1] * bvec[1];
1731 sum0 += avec[2] * bvec[2];
1732 sum1 += avec[3] * bvec[3];
1733 sum0 += avec[4] * bvec[4];
1734 sum1 += avec[5] * bvec[5];
1735 sum0 += avec[6] * bvec[6];
1736 sum1 += avec[7] * bvec[7];
1737 sum0 += avec[8] * bvec[8];
1738 sum1 += avec[9] * bvec[9];
1739 sum0 += avec[10] * bvec[10];
1740 sum1 += avec[11] * bvec[11];
1741 sum0 += avec[12] * bvec[12];
1742 sum1 += avec[13] * bvec[13];
1743 sum0 += avec[14] * bvec[14];
1744 sum1 += avec[15] * bvec[15];
1745 }
1746
1747 for (; avec != smallblock_endp; avec += 4, bvec += 4)
1748 {
1749 sum0 += avec[0] * bvec[0];
1750 sum1 += avec[1] * bvec[1];
1751 sum0 += avec[2] * bvec[2];
1752 sum1 += avec[3] * bvec[3];
1753 }
1754
1755 sum0 += sum1;
1756
1757 /* Do trailing portion in naive loop. */
1758 while (avec != all_endp)
1759 {
1760 sum0 += *avec * *bvec;
1761 ++avec;
1762 ++bvec;
1763 }
1764
1765 return sum0;
1766 }
1767
1768 @system unittest
1769 {
1770 // @system due to dotProduct and assertCTFEable
1771 import std.exception : assertCTFEable;
1772 import std.meta : AliasSeq;
1773 foreach (T; AliasSeq!(double, const double, immutable double))
1774 {
1775 T[] a = [ 1.0, 2.0, ];
1776 T[] b = [ 4.0, 6.0, ];
1777 assert(dotProduct(a, b) == 16);
1778 assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3);
1779 }
1780
1781 // Make sure the unrolled loop codepath gets tested.
1782 static const x =
1783 [1.0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18];
1784 static const y =
1785 [2.0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19];
1786 assertCTFEable!({ assert(dotProduct(x, y) == 2280); });
1787 }
1788
1789 /**
1790 Computes the $(LINK2 https://en.wikipedia.org/wiki/Cosine_similarity,
1791 cosine similarity) of input ranges $(D a) and $(D
1792 b). The two ranges must have the same length. If both ranges define
1793 length, the check is done once; otherwise, it is done at each
1794 iteration. If either range has all-zero elements, return 0.
1795 */
1796 CommonType!(ElementType!(Range1), ElementType!(Range2))
1797 cosineSimilarity(Range1, Range2)(Range1 a, Range2 b)
1798 if (isInputRange!(Range1) && isInputRange!(Range2))
1799 {
1800 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
1801 static if (haveLen) assert(a.length == b.length);
1802 Unqual!(typeof(return)) norma = 0, normb = 0, dotprod = 0;
1803 for (; !a.empty; a.popFront(), b.popFront())
1804 {
1805 immutable t1 = a.front, t2 = b.front;
1806 norma += t1 * t1;
1807 normb += t2 * t2;
1808 dotprod += t1 * t2;
1809 }
1810 static if (!haveLen) assert(b.empty);
1811 if (norma == 0 || normb == 0) return 0;
1812 return dotprod / sqrt(norma * normb);
1813 }
1814
1815 @safe unittest
1816 {
1817 import std.meta : AliasSeq;
1818 foreach (T; AliasSeq!(double, const double, immutable double))
1819 {
1820 T[] a = [ 1.0, 2.0, ];
1821 T[] b = [ 4.0, 3.0, ];
1822 assert(approxEqual(
1823 cosineSimilarity(a, b), 10.0 / sqrt(5.0 * 25),
1824 0.01));
1825 }
1826 }
1827
1828 /**
1829 Normalizes values in $(D range) by multiplying each element with a
1830 number chosen such that values sum up to $(D sum). If elements in $(D
1831 range) sum to zero, assigns $(D sum / range.length) to
1832 all. Normalization makes sense only if all elements in $(D range) are
1833 positive. $(D normalize) assumes that is the case without checking it.
1834
1835 Returns: $(D true) if normalization completed normally, $(D false) if
1836 all elements in $(D range) were zero or if $(D range) is empty.
1837 */
1838 bool normalize(R)(R range, ElementType!(R) sum = 1)
1839 if (isForwardRange!(R))
1840 {
1841 ElementType!(R) s = 0;
1842 // Step 1: Compute sum and length of the range
1843 static if (hasLength!(R))
1844 {
1845 const length = range.length;
foreach(e;range)1846 foreach (e; range)
1847 {
1848 s += e;
1849 }
1850 }
1851 else
1852 {
1853 uint length = 0;
foreach(e;range)1854 foreach (e; range)
1855 {
1856 s += e;
1857 ++length;
1858 }
1859 }
1860 // Step 2: perform normalization
1861 if (s == 0)
1862 {
1863 if (length)
1864 {
1865 immutable f = sum / range.length;
1866 foreach (ref e; range) e = f;
1867 }
1868 return false;
1869 }
1870 // The path most traveled
1871 assert(s >= 0);
1872 immutable f = sum / s;
1873 foreach (ref e; range)
1874 e *= f;
1875 return true;
1876 }
1877
1878 ///
1879 @safe unittest
1880 {
1881 double[] a = [];
1882 assert(!normalize(a));
1883 a = [ 1.0, 3.0 ];
1884 assert(normalize(a));
1885 assert(a == [ 0.25, 0.75 ]);
1886 a = [ 0.0, 0.0 ];
1887 assert(!normalize(a));
1888 assert(a == [ 0.5, 0.5 ]);
1889 }
1890
1891 /**
1892 Compute the sum of binary logarithms of the input range $(D r).
1893 The error of this method is much smaller than with a naive sum of log2.
1894 */
1895 ElementType!Range sumOfLog2s(Range)(Range r)
1896 if (isInputRange!Range && isFloatingPoint!(ElementType!Range))
1897 {
1898 long exp = 0;
1899 Unqual!(typeof(return)) x = 1;
foreach(e;r)1900 foreach (e; r)
1901 {
1902 if (e < 0)
1903 return typeof(return).nan;
1904 int lexp = void;
1905 x *= frexp(e, lexp);
1906 exp += lexp;
1907 if (x < 0.5)
1908 {
1909 x *= 2;
1910 exp--;
1911 }
1912 }
1913 return exp + log2(x);
1914 }
1915
1916 ///
1917 @safe unittest
1918 {
1919 import std.math : isNaN;
1920
1921 assert(sumOfLog2s(new double[0]) == 0);
1922 assert(sumOfLog2s([0.0L]) == -real.infinity);
1923 assert(sumOfLog2s([-0.0L]) == -real.infinity);
1924 assert(sumOfLog2s([2.0L]) == 1);
1925 assert(sumOfLog2s([-2.0L]).isNaN());
1926 assert(sumOfLog2s([real.nan]).isNaN());
1927 assert(sumOfLog2s([-real.nan]).isNaN());
1928 assert(sumOfLog2s([real.infinity]) == real.infinity);
1929 assert(sumOfLog2s([-real.infinity]).isNaN());
1930 assert(sumOfLog2s([ 0.25, 0.25, 0.25, 0.125 ]) == -9);
1931 }
1932
1933 /**
1934 Computes $(LINK2 https://en.wikipedia.org/wiki/Entropy_(information_theory),
1935 _entropy) of input range $(D r) in bits. This
1936 function assumes (without checking) that the values in $(D r) are all
1937 in $(D [0, 1]). For the entropy to be meaningful, often $(D r) should
1938 be normalized too (i.e., its values should sum to 1). The
1939 two-parameter version stops evaluating as soon as the intermediate
1940 result is greater than or equal to $(D max).
1941 */
1942 ElementType!Range entropy(Range)(Range r)
1943 if (isInputRange!Range)
1944 {
1945 Unqual!(typeof(return)) result = 0.0;
1946 for (;!r.empty; r.popFront)
1947 {
1948 if (!r.front) continue;
1949 result -= r.front * log2(r.front);
1950 }
1951 return result;
1952 }
1953
1954 /// Ditto
1955 ElementType!Range entropy(Range, F)(Range r, F max)
1956 if (isInputRange!Range &&
1957 !is(CommonType!(ElementType!Range, F) == void))
1958 {
1959 Unqual!(typeof(return)) result = 0.0;
1960 for (;!r.empty; r.popFront)
1961 {
1962 if (!r.front) continue;
1963 result -= r.front * log2(r.front);
1964 if (result >= max) break;
1965 }
1966 return result;
1967 }
1968
1969 @safe unittest
1970 {
1971 import std.meta : AliasSeq;
1972 foreach (T; AliasSeq!(double, const double, immutable double))
1973 {
1974 T[] p = [ 0.0, 0, 0, 1 ];
1975 assert(entropy(p) == 0);
1976 p = [ 0.25, 0.25, 0.25, 0.25 ];
1977 assert(entropy(p) == 2);
1978 assert(entropy(p, 1) == 1);
1979 }
1980 }
1981
1982 /**
1983 Computes the $(LINK2 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence,
1984 Kullback-Leibler divergence) between input ranges
1985 $(D a) and $(D b), which is the sum $(D ai * log(ai / bi)). The base
1986 of logarithm is 2. The ranges are assumed to contain elements in $(D
1987 [0, 1]). Usually the ranges are normalized probability distributions,
1988 but this is not required or checked by $(D
1989 kullbackLeiblerDivergence). If any element $(D bi) is zero and the
1990 corresponding element $(D ai) nonzero, returns infinity. (Otherwise,
1991 if $(D ai == 0 && bi == 0), the term $(D ai * log(ai / bi)) is
1992 considered zero.) If the inputs are normalized, the result is
1993 positive.
1994 */
1995 CommonType!(ElementType!Range1, ElementType!Range2)
1996 kullbackLeiblerDivergence(Range1, Range2)(Range1 a, Range2 b)
1997 if (isInputRange!(Range1) && isInputRange!(Range2))
1998 {
1999 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2000 static if (haveLen) assert(a.length == b.length);
2001 Unqual!(typeof(return)) result = 0;
2002 for (; !a.empty; a.popFront(), b.popFront())
2003 {
2004 immutable t1 = a.front;
2005 if (t1 == 0) continue;
2006 immutable t2 = b.front;
2007 if (t2 == 0) return result.infinity;
2008 assert(t1 > 0 && t2 > 0);
2009 result += t1 * log2(t1 / t2);
2010 }
2011 static if (!haveLen) assert(b.empty);
2012 return result;
2013 }
2014
2015 ///
2016 @safe unittest
2017 {
2018 import std.math : approxEqual;
2019
2020 double[] p = [ 0.0, 0, 0, 1 ];
2021 assert(kullbackLeiblerDivergence(p, p) == 0);
2022 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ];
2023 assert(kullbackLeiblerDivergence(p1, p1) == 0);
2024 assert(kullbackLeiblerDivergence(p, p1) == 2);
2025 assert(kullbackLeiblerDivergence(p1, p) == double.infinity);
2026 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ];
2027 assert(approxEqual(kullbackLeiblerDivergence(p1, p2), 0.0719281));
2028 assert(approxEqual(kullbackLeiblerDivergence(p2, p1), 0.0780719));
2029 }
2030
2031 /**
2032 Computes the $(LINK2 https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence,
2033 Jensen-Shannon divergence) between $(D a) and $(D
2034 b), which is the sum $(D (ai * log(2 * ai / (ai + bi)) + bi * log(2 *
2035 bi / (ai + bi))) / 2). The base of logarithm is 2. The ranges are
2036 assumed to contain elements in $(D [0, 1]). Usually the ranges are
2037 normalized probability distributions, but this is not required or
2038 checked by $(D jensenShannonDivergence). If the inputs are normalized,
2039 the result is bounded within $(D [0, 1]). The three-parameter version
2040 stops evaluations as soon as the intermediate result is greater than
2041 or equal to $(D limit).
2042 */
2043 CommonType!(ElementType!Range1, ElementType!Range2)
2044 jensenShannonDivergence(Range1, Range2)(Range1 a, Range2 b)
2045 if (isInputRange!Range1 && isInputRange!Range2 &&
2046 is(CommonType!(ElementType!Range1, ElementType!Range2)))
2047 {
2048 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2049 static if (haveLen) assert(a.length == b.length);
2050 Unqual!(typeof(return)) result = 0;
2051 for (; !a.empty; a.popFront(), b.popFront())
2052 {
2053 immutable t1 = a.front;
2054 immutable t2 = b.front;
2055 immutable avg = (t1 + t2) / 2;
2056 if (t1 != 0)
2057 {
2058 result += t1 * log2(t1 / avg);
2059 }
2060 if (t2 != 0)
2061 {
2062 result += t2 * log2(t2 / avg);
2063 }
2064 }
2065 static if (!haveLen) assert(b.empty);
2066 return result / 2;
2067 }
2068
2069 /// Ditto
2070 CommonType!(ElementType!Range1, ElementType!Range2)
2071 jensenShannonDivergence(Range1, Range2, F)(Range1 a, Range2 b, F limit)
2072 if (isInputRange!Range1 && isInputRange!Range2 &&
2073 is(typeof(CommonType!(ElementType!Range1, ElementType!Range2).init
2074 >= F.init) : bool))
2075 {
2076 enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
2077 static if (haveLen) assert(a.length == b.length);
2078 Unqual!(typeof(return)) result = 0;
2079 limit *= 2;
2080 for (; !a.empty; a.popFront(), b.popFront())
2081 {
2082 immutable t1 = a.front;
2083 immutable t2 = b.front;
2084 immutable avg = (t1 + t2) / 2;
2085 if (t1 != 0)
2086 {
2087 result += t1 * log2(t1 / avg);
2088 }
2089 if (t2 != 0)
2090 {
2091 result += t2 * log2(t2 / avg);
2092 }
2093 if (result >= limit) break;
2094 }
2095 static if (!haveLen) assert(b.empty);
2096 return result / 2;
2097 }
2098
2099 ///
2100 @safe unittest
2101 {
2102 import std.math : approxEqual;
2103
2104 double[] p = [ 0.0, 0, 0, 1 ];
2105 assert(jensenShannonDivergence(p, p) == 0);
2106 double[] p1 = [ 0.25, 0.25, 0.25, 0.25 ];
2107 assert(jensenShannonDivergence(p1, p1) == 0);
2108 assert(approxEqual(jensenShannonDivergence(p1, p), 0.548795));
2109 double[] p2 = [ 0.2, 0.2, 0.2, 0.4 ];
2110 assert(approxEqual(jensenShannonDivergence(p1, p2), 0.0186218));
2111 assert(approxEqual(jensenShannonDivergence(p2, p1), 0.0186218));
2112 assert(approxEqual(jensenShannonDivergence(p2, p1, 0.005), 0.00602366));
2113 }
2114
2115 /**
2116 The so-called "all-lengths gap-weighted string kernel" computes a
2117 similarity measure between $(D s) and $(D t) based on all of their
2118 common subsequences of all lengths. Gapped subsequences are also
2119 included.
2120
2121 To understand what $(D gapWeightedSimilarity(s, t, lambda)) computes,
2122 consider first the case $(D lambda = 1) and the strings $(D s =
2123 ["Hello", "brave", "new", "world"]) and $(D t = ["Hello", "new",
2124 "world"]). In that case, $(D gapWeightedSimilarity) counts the
2125 following matches:
2126
2127 $(OL $(LI three matches of length 1, namely $(D "Hello"), $(D "new"),
2128 and $(D "world");) $(LI three matches of length 2, namely ($(D
2129 "Hello", "new")), ($(D "Hello", "world")), and ($(D "new", "world"));)
2130 $(LI one match of length 3, namely ($(D "Hello", "new", "world")).))
2131
2132 The call $(D gapWeightedSimilarity(s, t, 1)) simply counts all of
2133 these matches and adds them up, returning 7.
2134
2135 ----
2136 string[] s = ["Hello", "brave", "new", "world"];
2137 string[] t = ["Hello", "new", "world"];
2138 assert(gapWeightedSimilarity(s, t, 1) == 7);
2139 ----
2140
2141 Note how the gaps in matching are simply ignored, for example ($(D
2142 "Hello", "new")) is deemed as good a match as ($(D "new",
2143 "world")). This may be too permissive for some applications. To
2144 eliminate gapped matches entirely, use $(D lambda = 0):
2145
2146 ----
2147 string[] s = ["Hello", "brave", "new", "world"];
2148 string[] t = ["Hello", "new", "world"];
2149 assert(gapWeightedSimilarity(s, t, 0) == 4);
2150 ----
2151
2152 The call above eliminated the gapped matches ($(D "Hello", "new")),
2153 ($(D "Hello", "world")), and ($(D "Hello", "new", "world")) from the
2154 tally. That leaves only 4 matches.
2155
2156 The most interesting case is when gapped matches still participate in
2157 the result, but not as strongly as ungapped matches. The result will
2158 be a smooth, fine-grained similarity measure between the input
2159 strings. This is where values of $(D lambda) between 0 and 1 enter
2160 into play: gapped matches are $(I exponentially penalized with the
2161 number of gaps) with base $(D lambda). This means that an ungapped
2162 match adds 1 to the return value; a match with one gap in either
2163 string adds $(D lambda) to the return value; ...; a match with a total
2164 of $(D n) gaps in both strings adds $(D pow(lambda, n)) to the return
2165 value. In the example above, we have 4 matches without gaps, 2 matches
2166 with one gap, and 1 match with three gaps. The latter match is ($(D
2167 "Hello", "world")), which has two gaps in the first string and one gap
2168 in the second string, totaling to three gaps. Summing these up we get
2169 $(D 4 + 2 * lambda + pow(lambda, 3)).
2170
2171 ----
2172 string[] s = ["Hello", "brave", "new", "world"];
2173 string[] t = ["Hello", "new", "world"];
2174 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 0.5 * 2 + 0.125);
2175 ----
2176
2177 $(D gapWeightedSimilarity) is useful wherever a smooth similarity
2178 measure between sequences allowing for approximate matches is
2179 needed. The examples above are given with words, but any sequences
2180 with elements comparable for equality are allowed, e.g. characters or
2181 numbers. $(D gapWeightedSimilarity) uses a highly optimized dynamic
2182 programming implementation that needs $(D 16 * min(s.length,
2183 t.length)) extra bytes of memory and $(BIGOH s.length * t.length) time
2184 to complete.
2185 */
2186 F gapWeightedSimilarity(alias comp = "a == b", R1, R2, F)(R1 s, R2 t, F lambda)
2187 if (isRandomAccessRange!(R1) && hasLength!(R1) &&
2188 isRandomAccessRange!(R2) && hasLength!(R2))
2189 {
2190 import core.exception : onOutOfMemoryError;
2191 import core.stdc.stdlib : malloc, free;
2192 import std.algorithm.mutation : swap;
2193 import std.functional : binaryFun;
2194
2195 if (s.length < t.length) return gapWeightedSimilarity(t, s, lambda);
2196 if (!t.length) return 0;
2197
2198 auto dpvi = cast(F*) malloc(F.sizeof * 2 * t.length);
2199 if (!dpvi)
2200 onOutOfMemoryError();
2201
2202 auto dpvi1 = dpvi + t.length;
2203 scope(exit) free(dpvi < dpvi1 ? dpvi : dpvi1);
2204 dpvi[0 .. t.length] = 0;
2205 dpvi1[0] = 0;
2206 immutable lambda2 = lambda * lambda;
2207
2208 F result = 0;
2209 foreach (i; 0 .. s.length)
2210 {
2211 const si = s[i];
2212 for (size_t j = 0;;)
2213 {
2214 F dpsij = void;
2215 if (binaryFun!(comp)(si, t[j]))
2216 {
2217 dpsij = 1 + dpvi[j];
2218 result += dpsij;
2219 }
2220 else
2221 {
2222 dpsij = 0;
2223 }
2224 immutable j1 = j + 1;
2225 if (j1 == t.length) break;
2226 dpvi1[j1] = dpsij + lambda * (dpvi1[j] + dpvi[j1]) -
2227 lambda2 * dpvi[j];
2228 j = j1;
2229 }
2230 swap(dpvi, dpvi1);
2231 }
2232 return result;
2233 }
2234
2235 @system unittest
2236 {
2237 string[] s = ["Hello", "brave", "new", "world"];
2238 string[] t = ["Hello", "new", "world"];
2239 assert(gapWeightedSimilarity(s, t, 1) == 7);
2240 assert(gapWeightedSimilarity(s, t, 0) == 4);
2241 assert(gapWeightedSimilarity(s, t, 0.5) == 4 + 2 * 0.5 + 0.125);
2242 }
2243
2244 /**
2245 The similarity per $(D gapWeightedSimilarity) has an issue in that it
2246 grows with the lengths of the two strings, even though the strings are
2247 not actually very similar. For example, the range $(D ["Hello",
2248 "world"]) is increasingly similar with the range $(D ["Hello",
2249 "world", "world", "world",...]) as more instances of $(D "world") are
2250 appended. To prevent that, $(D gapWeightedSimilarityNormalized)
2251 computes a normalized version of the similarity that is computed as
2252 $(D gapWeightedSimilarity(s, t, lambda) /
2253 sqrt(gapWeightedSimilarity(s, t, lambda) * gapWeightedSimilarity(s, t,
2254 lambda))). The function $(D gapWeightedSimilarityNormalized) (a
2255 so-called normalized kernel) is bounded in $(D [0, 1]), reaches $(D 0)
2256 only for ranges that don't match in any position, and $(D 1) only for
2257 identical ranges.
2258
2259 The optional parameters $(D sSelfSim) and $(D tSelfSim) are meant for
2260 avoiding duplicate computation. Many applications may have already
2261 computed $(D gapWeightedSimilarity(s, s, lambda)) and/or $(D
2262 gapWeightedSimilarity(t, t, lambda)). In that case, they can be passed
2263 as $(D sSelfSim) and $(D tSelfSim), respectively.
2264 */
2265 Select!(isFloatingPoint!(F), F, double)
2266 gapWeightedSimilarityNormalized(alias comp = "a == b", R1, R2, F)
2267 (R1 s, R2 t, F lambda, F sSelfSim = F.init, F tSelfSim = F.init)
2268 if (isRandomAccessRange!(R1) && hasLength!(R1) &&
2269 isRandomAccessRange!(R2) && hasLength!(R2))
2270 {
uncomputed(F n)2271 static bool uncomputed(F n)
2272 {
2273 static if (isFloatingPoint!(F))
2274 return isNaN(n);
2275 else
2276 return n == n.init;
2277 }
2278 if (uncomputed(sSelfSim))
2279 sSelfSim = gapWeightedSimilarity!(comp)(s, s, lambda);
2280 if (sSelfSim == 0) return 0;
2281 if (uncomputed(tSelfSim))
2282 tSelfSim = gapWeightedSimilarity!(comp)(t, t, lambda);
2283 if (tSelfSim == 0) return 0;
2284
2285 return gapWeightedSimilarity!(comp)(s, t, lambda) /
2286 sqrt(cast(typeof(return)) sSelfSim * tSelfSim);
2287 }
2288
2289 ///
2290 @system unittest
2291 {
2292 import std.math : approxEqual, sqrt;
2293
2294 string[] s = ["Hello", "brave", "new", "world"];
2295 string[] t = ["Hello", "new", "world"];
2296 assert(gapWeightedSimilarity(s, s, 1) == 15);
2297 assert(gapWeightedSimilarity(t, t, 1) == 7);
2298 assert(gapWeightedSimilarity(s, t, 1) == 7);
2299 assert(approxEqual(gapWeightedSimilarityNormalized(s, t, 1),
2300 7.0 / sqrt(15.0 * 7), 0.01));
2301 }
2302
2303 /**
2304 Similar to $(D gapWeightedSimilarity), just works in an incremental
2305 manner by first revealing the matches of length 1, then gapped matches
2306 of length 2, and so on. The memory requirement is $(BIGOH s.length *
2307 t.length). The time complexity is $(BIGOH s.length * t.length) time
2308 for computing each step. Continuing on the previous example:
2309
2310 The implementation is based on the pseudocode in Fig. 4 of the paper
2311 $(HTTP jmlr.csail.mit.edu/papers/volume6/rousu05a/rousu05a.pdf,
2312 "Efficient Computation of Gapped Substring Kernels on Large Alphabets")
2313 by Rousu et al., with additional algorithmic and systems-level
2314 optimizations.
2315 */
2316 struct GapWeightedSimilarityIncremental(Range, F = double)
2317 if (isRandomAccessRange!(Range) && hasLength!(Range))
2318 {
2319 import core.stdc.stdlib : malloc, realloc, alloca, free;
2320
2321 private:
2322 Range s, t;
2323 F currentValue = 0;
2324 F* kl;
2325 size_t gram = void;
2326 F lambda = void, lambda2 = void;
2327
2328 public:
2329 /**
2330 Constructs an object given two ranges $(D s) and $(D t) and a penalty
2331 $(D lambda). Constructor completes in $(BIGOH s.length * t.length)
2332 time and computes all matches of length 1.
2333 */
thisGapWeightedSimilarityIncremental2334 this(Range s, Range t, F lambda)
2335 {
2336 import core.exception : onOutOfMemoryError;
2337
2338 assert(lambda > 0);
2339 this.gram = 0;
2340 this.lambda = lambda;
2341 this.lambda2 = lambda * lambda; // for efficiency only
2342
2343 size_t iMin = size_t.max, jMin = size_t.max,
2344 iMax = 0, jMax = 0;
2345 /* initialize */
2346 Tuple!(size_t, size_t) * k0;
2347 size_t k0len;
2348 scope(exit) free(k0);
2349 currentValue = 0;
2350 foreach (i, si; s)
2351 {
2352 foreach (j; 0 .. t.length)
2353 {
2354 if (si != t[j]) continue;
2355 k0 = cast(typeof(k0)) realloc(k0, ++k0len * (*k0).sizeof);
2356 with (k0[k0len - 1])
2357 {
2358 field[0] = i;
2359 field[1] = j;
2360 }
2361 // Maintain the minimum and maximum i and j
2362 if (iMin > i) iMin = i;
2363 if (iMax < i) iMax = i;
2364 if (jMin > j) jMin = j;
2365 if (jMax < j) jMax = j;
2366 }
2367 }
2368
2369 if (iMin > iMax) return;
2370 assert(k0len);
2371
2372 currentValue = k0len;
2373 // Chop strings down to the useful sizes
2374 s = s[iMin .. iMax + 1];
2375 t = t[jMin .. jMax + 1];
2376 this.s = s;
2377 this.t = t;
2378
2379 kl = cast(F*) malloc(s.length * t.length * F.sizeof);
2380 if (!kl)
2381 onOutOfMemoryError();
2382
2383 kl[0 .. s.length * t.length] = 0;
2384 foreach (pos; 0 .. k0len)
2385 {
2386 with (k0[pos])
2387 {
2388 kl[(field[0] - iMin) * t.length + field[1] -jMin] = lambda2;
2389 }
2390 }
2391 }
2392
2393 /**
2394 Returns: $(D this).
2395 */
opSliceGapWeightedSimilarityIncremental2396 ref GapWeightedSimilarityIncremental opSlice()
2397 {
2398 return this;
2399 }
2400
2401 /**
2402 Computes the match of the popFront length. Completes in $(BIGOH s.length *
2403 t.length) time.
2404 */
popFrontGapWeightedSimilarityIncremental2405 void popFront()
2406 {
2407 import std.algorithm.mutation : swap;
2408
2409 // This is a large source of optimization: if similarity at
2410 // the gram-1 level was 0, then we can safely assume
2411 // similarity at the gram level is 0 as well.
2412 if (empty) return;
2413
2414 // Now attempt to match gapped substrings of length `gram'
2415 ++gram;
2416 currentValue = 0;
2417
2418 auto Si = cast(F*) alloca(t.length * F.sizeof);
2419 Si[0 .. t.length] = 0;
2420 foreach (i; 0 .. s.length)
2421 {
2422 const si = s[i];
2423 F Sij_1 = 0;
2424 F Si_1j_1 = 0;
2425 auto kli = kl + i * t.length;
2426 for (size_t j = 0;;)
2427 {
2428 const klij = kli[j];
2429 const Si_1j = Si[j];
2430 const tmp = klij + lambda * (Si_1j + Sij_1) - lambda2 * Si_1j_1;
2431 // now update kl and currentValue
2432 if (si == t[j])
2433 currentValue += kli[j] = lambda2 * Si_1j_1;
2434 else
2435 kli[j] = 0;
2436 // commit to Si
2437 Si[j] = tmp;
2438 if (++j == t.length) break;
2439 // get ready for the popFront step; virtually increment j,
2440 // so essentially stuffj_1 <-- stuffj
2441 Si_1j_1 = Si_1j;
2442 Sij_1 = tmp;
2443 }
2444 }
2445 currentValue /= pow(lambda, 2 * (gram + 1));
2446
2447 version (none)
2448 {
2449 Si_1[0 .. t.length] = 0;
2450 kl[0 .. min(t.length, maxPerimeter + 1)] = 0;
2451 foreach (i; 1 .. min(s.length, maxPerimeter + 1))
2452 {
2453 auto kli = kl + i * t.length;
2454 assert(s.length > i);
2455 const si = s[i];
2456 auto kl_1i_1 = kl_1 + (i - 1) * t.length;
2457 kli[0] = 0;
2458 F lastS = 0;
2459 foreach (j; 1 .. min(maxPerimeter - i + 1, t.length))
2460 {
2461 immutable j_1 = j - 1;
2462 immutable tmp = kl_1i_1[j_1]
2463 + lambda * (Si_1[j] + lastS)
2464 - lambda2 * Si_1[j_1];
2465 kl_1i_1[j_1] = float.nan;
2466 Si_1[j_1] = lastS;
2467 lastS = tmp;
2468 if (si == t[j])
2469 {
2470 currentValue += kli[j] = lambda2 * lastS;
2471 }
2472 else
2473 {
2474 kli[j] = 0;
2475 }
2476 }
2477 Si_1[t.length - 1] = lastS;
2478 }
2479 currentValue /= pow(lambda, 2 * (gram + 1));
2480 // get ready for the popFront computation
2481 swap(kl, kl_1);
2482 }
2483 }
2484
2485 /**
2486 Returns: The gapped similarity at the current match length (initially
2487 1, grows with each call to $(D popFront)).
2488 */
frontGapWeightedSimilarityIncremental2489 @property F front() { return currentValue; }
2490
2491 /**
2492 Returns: Whether there are more matches.
2493 */
emptyGapWeightedSimilarityIncremental2494 @property bool empty()
2495 {
2496 if (currentValue) return false;
2497 if (kl)
2498 {
2499 free(kl);
2500 kl = null;
2501 }
2502 return true;
2503 }
2504 }
2505
2506 /**
2507 Ditto
2508 */
2509 GapWeightedSimilarityIncremental!(R, F) gapWeightedSimilarityIncremental(R, F)
2510 (R r1, R r2, F penalty)
2511 {
2512 return typeof(return)(r1, r2, penalty);
2513 }
2514
2515 ///
2516 @system unittest
2517 {
2518 string[] s = ["Hello", "brave", "new", "world"];
2519 string[] t = ["Hello", "new", "world"];
2520 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0);
2521 assert(simIter.front == 3); // three 1-length matches
2522 simIter.popFront();
2523 assert(simIter.front == 3); // three 2-length matches
2524 simIter.popFront();
2525 assert(simIter.front == 1); // one 3-length match
2526 simIter.popFront();
2527 assert(simIter.empty); // no more match
2528 }
2529
2530 @system unittest
2531 {
2532 import std.conv : text;
2533 string[] s = ["Hello", "brave", "new", "world"];
2534 string[] t = ["Hello", "new", "world"];
2535 auto simIter = gapWeightedSimilarityIncremental(s, t, 1.0);
2536 //foreach (e; simIter) writeln(e);
2537 assert(simIter.front == 3); // three 1-length matches
2538 simIter.popFront();
2539 assert(simIter.front == 3, text(simIter.front)); // three 2-length matches
2540 simIter.popFront();
2541 assert(simIter.front == 1); // one 3-length matches
2542 simIter.popFront();
2543 assert(simIter.empty); // no more match
2544
2545 s = ["Hello"];
2546 t = ["bye"];
2547 simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2548 assert(simIter.empty);
2549
2550 s = ["Hello"];
2551 t = ["Hello"];
2552 simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2553 assert(simIter.front == 1); // one match
2554 simIter.popFront();
2555 assert(simIter.empty);
2556
2557 s = ["Hello", "world"];
2558 t = ["Hello"];
2559 simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2560 assert(simIter.front == 1); // one match
2561 simIter.popFront();
2562 assert(simIter.empty);
2563
2564 s = ["Hello", "world"];
2565 t = ["Hello", "yah", "world"];
2566 simIter = gapWeightedSimilarityIncremental(s, t, 0.5);
2567 assert(simIter.front == 2); // two 1-gram matches
2568 simIter.popFront();
2569 assert(simIter.front == 0.5, text(simIter.front)); // one 2-gram match, 1 gap
2570 }
2571
2572 @system unittest
2573 {
2574 GapWeightedSimilarityIncremental!(string[]) sim =
2575 GapWeightedSimilarityIncremental!(string[])(
2576 ["nyuk", "I", "have", "no", "chocolate", "giba"],
2577 ["wyda", "I", "have", "I", "have", "have", "I", "have", "hehe"],
2578 0.5);
2579 double[] witness = [ 7.0, 4.03125, 0, 0 ];
foreach(e;sim)2580 foreach (e; sim)
2581 {
2582 //writeln(e);
2583 assert(e == witness.front);
2584 witness.popFront();
2585 }
2586 witness = [ 3.0, 1.3125, 0.25 ];
2587 sim = GapWeightedSimilarityIncremental!(string[])(
2588 ["I", "have", "no", "chocolate"],
2589 ["I", "have", "some", "chocolate"],
2590 0.5);
foreach(e;sim)2591 foreach (e; sim)
2592 {
2593 //writeln(e);
2594 assert(e == witness.front);
2595 witness.popFront();
2596 }
2597 assert(witness.empty);
2598 }
2599
2600 /**
2601 Computes the greatest common divisor of $(D a) and $(D b) by using
2602 an efficient algorithm such as $(HTTPS en.wikipedia.org/wiki/Euclidean_algorithm, Euclid's)
2603 or $(HTTPS en.wikipedia.org/wiki/Binary_GCD_algorithm, Stein's) algorithm.
2604
2605 Params:
2606 T = Any numerical type that supports the modulo operator `%`. If
2607 bit-shifting `<<` and `>>` are also supported, Stein's algorithm will
2608 be used; otherwise, Euclid's algorithm is used as _a fallback.
2609 Returns:
2610 The greatest common divisor of the given arguments.
2611 */
2612 T gcd(T)(T a, T b)
2613 if (isIntegral!T)
2614 {
2615 static if (is(T == const) || is(T == immutable))
2616 {
2617 return gcd!(Unqual!T)(a, b);
2618 }
version(DigitalMars)2619 else version (DigitalMars)
2620 {
2621 static if (T.min < 0)
2622 {
2623 assert(a >= 0 && b >= 0);
2624 }
2625 while (b)
2626 {
2627 immutable t = b;
2628 b = a % b;
2629 a = t;
2630 }
2631 return a;
2632 }
2633 else
2634 {
2635 if (a == 0)
2636 return b;
2637 if (b == 0)
2638 return a;
2639
2640 import core.bitop : bsf;
2641 import std.algorithm.mutation : swap;
2642
2643 immutable uint shift = bsf(a | b);
2644 a >>= a.bsf;
2645
2646 do
2647 {
2648 b >>= b.bsf;
2649 if (a > b)
2650 swap(a, b);
2651 b -= a;
2652 } while (b);
2653
2654 return a << shift;
2655 }
2656 }
2657
2658 ///
2659 @safe unittest
2660 {
2661 assert(gcd(2 * 5 * 7 * 7, 5 * 7 * 11) == 5 * 7);
2662 const int a = 5 * 13 * 23 * 23, b = 13 * 59;
2663 assert(gcd(a, b) == 13);
2664 }
2665
2666 // This overload is for non-builtin numerical types like BigInt or
2667 // user-defined types.
2668 /// ditto
2669 T gcd(T)(T a, T b)
2670 if (!isIntegral!T &&
2671 is(typeof(T.init % T.init)) &&
2672 is(typeof(T.init == 0 || T.init > 0)))
2673 {
2674 import std.algorithm.mutation : swap;
2675
2676 enum canUseBinaryGcd = is(typeof(() {
2677 T t, u;
2678 t <<= 1;
2679 t >>= 1;
2680 t -= u;
2681 bool b = (t & 1) == 0;
2682 swap(t, u);
2683 }));
2684
2685 assert(a >= 0 && b >= 0);
2686
2687 static if (canUseBinaryGcd)
2688 {
2689 uint shift = 0;
2690 while ((a & 1) == 0 && (b & 1) == 0)
2691 {
2692 a >>= 1;
2693 b >>= 1;
2694 shift++;
2695 }
2696
2697 do
2698 {
2699 assert((a & 1) != 0);
2700 while ((b & 1) == 0)
2701 b >>= 1;
2702 if (a > b)
2703 swap(a, b);
2704 b -= a;
2705 } while (b);
2706
2707 return a << shift;
2708 }
2709 else
2710 {
2711 // The only thing we have is %; fallback to Euclidean algorithm.
2712 while (b != 0)
2713 {
2714 auto t = b;
2715 b = a % b;
2716 a = t;
2717 }
2718 return a;
2719 }
2720 }
2721
2722 // Issue 7102
2723 @system pure unittest
2724 {
2725 import std.bigint : BigInt;
2726 assert(gcd(BigInt("71_000_000_000_000_000_000"),
2727 BigInt("31_000_000_000_000_000_000")) ==
2728 BigInt("1_000_000_000_000_000_000"));
2729 }
2730
2731 @safe pure nothrow unittest
2732 {
2733 // A numerical type that only supports % and - (to force gcd implementation
2734 // to use Euclidean algorithm).
2735 struct CrippledInt
2736 {
2737 int impl;
2738 CrippledInt opBinary(string op : "%")(CrippledInt i)
2739 {
2740 return CrippledInt(impl % i.impl);
2741 }
opEqualsCrippledInt2742 int opEquals(CrippledInt i) { return impl == i.impl; }
opEqualsCrippledInt2743 int opEquals(int i) { return impl == i; }
opCmpCrippledInt2744 int opCmp(int i) { return (impl < i) ? -1 : (impl > i) ? 1 : 0; }
2745 }
2746 assert(gcd(CrippledInt(2310), CrippledInt(1309)) == CrippledInt(77));
2747 }
2748
2749 // This is to make tweaking the speed/size vs. accuracy tradeoff easy,
2750 // though floats seem accurate enough for all practical purposes, since
2751 // they pass the "approxEqual(inverseFft(fft(arr)), arr)" test even for
2752 // size 2 ^^ 22.
2753 private alias lookup_t = float;
2754
2755 /**A class for performing fast Fourier transforms of power of two sizes.
2756 * This class encapsulates a large amount of state that is reusable when
2757 * performing multiple FFTs of sizes smaller than or equal to that specified
2758 * in the constructor. This results in substantial speedups when performing
2759 * multiple FFTs with a known maximum size. However,
2760 * a free function API is provided for convenience if you need to perform a
2761 * one-off FFT.
2762 *
2763 * References:
2764 * $(HTTP en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm)
2765 */
2766 final class Fft
2767 {
2768 import core.bitop : bsf;
2769 import std.algorithm.iteration : map;
2770 import std.array : uninitializedArray;
2771
2772 private:
2773 immutable lookup_t[][] negSinLookup;
2774
enforceSize(R)2775 void enforceSize(R)(R range) const
2776 {
2777 import std.conv : text;
2778 assert(range.length <= size, text(
2779 "FFT size mismatch. Expected ", size, ", got ", range.length));
2780 }
2781
fftImpl(Ret,R)2782 void fftImpl(Ret, R)(Stride!R range, Ret buf) const
2783 in
2784 {
2785 assert(range.length >= 4);
2786 assert(isPowerOf2(range.length));
2787 }
2788 body
2789 {
2790 auto recurseRange = range;
2791 recurseRange.doubleSteps();
2792
2793 if (buf.length > 4)
2794 {
2795 fftImpl(recurseRange, buf[0..$ / 2]);
2796 recurseRange.popHalf();
2797 fftImpl(recurseRange, buf[$ / 2..$]);
2798 }
2799 else
2800 {
2801 // Do this here instead of in another recursion to save on
2802 // recursion overhead.
2803 slowFourier2(recurseRange, buf[0..$ / 2]);
2804 recurseRange.popHalf();
2805 slowFourier2(recurseRange, buf[$ / 2..$]);
2806 }
2807
2808 butterfly(buf);
2809 }
2810
2811 // This algorithm works by performing the even and odd parts of our FFT
2812 // using the "two for the price of one" method mentioned at
2813 // http://www.engineeringproductivitytools.com/stuff/T0001/PT10.HTM#Head521
2814 // by making the odd terms into the imaginary components of our new FFT,
2815 // and then using symmetry to recombine them.
fftImplPureReal(Ret,R)2816 void fftImplPureReal(Ret, R)(R range, Ret buf) const
2817 in
2818 {
2819 assert(range.length >= 4);
2820 assert(isPowerOf2(range.length));
2821 }
2822 body
2823 {
2824 alias E = ElementType!R;
2825
2826 // Converts odd indices of range to the imaginary components of
2827 // a range half the size. The even indices become the real components.
2828 static if (isArray!R && isFloatingPoint!E)
2829 {
2830 // Then the memory layout of complex numbers provides a dirt
2831 // cheap way to convert. This is a common case, so take advantage.
2832 auto oddsImag = cast(Complex!E[]) range;
2833 }
2834 else
2835 {
2836 // General case: Use a higher order range. We can assume
2837 // source.length is even because it has to be a power of 2.
2838 static struct OddToImaginary
2839 {
2840 R source;
2841 alias C = Complex!(CommonType!(E, typeof(buf[0].re)));
2842
2843 @property
2844 {
frontOddToImaginary2845 C front()
2846 {
2847 return C(source[0], source[1]);
2848 }
2849
backOddToImaginary2850 C back()
2851 {
2852 immutable n = source.length;
2853 return C(source[n - 2], source[n - 1]);
2854 }
2855
saveOddToImaginary2856 typeof(this) save()
2857 {
2858 return typeof(this)(source.save);
2859 }
2860
emptyOddToImaginary2861 bool empty()
2862 {
2863 return source.empty;
2864 }
2865
lengthOddToImaginary2866 size_t length()
2867 {
2868 return source.length / 2;
2869 }
2870 }
2871
popFrontOddToImaginary2872 void popFront()
2873 {
2874 source.popFront();
2875 source.popFront();
2876 }
2877
popBackOddToImaginary2878 void popBack()
2879 {
2880 source.popBack();
2881 source.popBack();
2882 }
2883
opIndexOddToImaginary2884 C opIndex(size_t index)
2885 {
2886 return C(source[index * 2], source[index * 2 + 1]);
2887 }
2888
opSliceOddToImaginary2889 typeof(this) opSlice(size_t lower, size_t upper)
2890 {
2891 return typeof(this)(source[lower * 2 .. upper * 2]);
2892 }
2893 }
2894
2895 auto oddsImag = OddToImaginary(range);
2896 }
2897
2898 fft(oddsImag, buf[0..$ / 2]);
2899 auto evenFft = buf[0..$ / 2];
2900 auto oddFft = buf[$ / 2..$];
2901 immutable halfN = evenFft.length;
2902 oddFft[0].re = buf[0].im;
2903 oddFft[0].im = 0;
2904 evenFft[0].im = 0;
2905 // evenFft[0].re is already right b/c it's aliased with buf[0].re.
2906
2907 foreach (k; 1 .. halfN / 2 + 1)
2908 {
2909 immutable bufk = buf[k];
2910 immutable bufnk = buf[buf.length / 2 - k];
2911 evenFft[k].re = 0.5 * (bufk.re + bufnk.re);
2912 evenFft[halfN - k].re = evenFft[k].re;
2913 evenFft[k].im = 0.5 * (bufk.im - bufnk.im);
2914 evenFft[halfN - k].im = -evenFft[k].im;
2915
2916 oddFft[k].re = 0.5 * (bufk.im + bufnk.im);
2917 oddFft[halfN - k].re = oddFft[k].re;
2918 oddFft[k].im = 0.5 * (bufnk.re - bufk.re);
2919 oddFft[halfN - k].im = -oddFft[k].im;
2920 }
2921
2922 butterfly(buf);
2923 }
2924
butterfly(R)2925 void butterfly(R)(R buf) const
2926 in
2927 {
2928 assert(isPowerOf2(buf.length));
2929 }
2930 body
2931 {
2932 immutable n = buf.length;
2933 immutable localLookup = negSinLookup[bsf(n)];
2934 assert(localLookup.length == n);
2935
2936 immutable cosMask = n - 1;
2937 immutable cosAdd = n / 4 * 3;
2938
negSinFromLookup(size_t index)2939 lookup_t negSinFromLookup(size_t index) pure nothrow
2940 {
2941 return localLookup[index];
2942 }
2943
cosFromLookup(size_t index)2944 lookup_t cosFromLookup(size_t index) pure nothrow
2945 {
2946 // cos is just -sin shifted by PI * 3 / 2.
2947 return localLookup[(index + cosAdd) & cosMask];
2948 }
2949
2950 immutable halfLen = n / 2;
2951
2952 // This loop is unrolled and the two iterations are interleaved
2953 // relative to the textbook FFT to increase ILP. This gives roughly 5%
2954 // speedups on DMD.
2955 for (size_t k = 0; k < halfLen; k += 2)
2956 {
2957 immutable cosTwiddle1 = cosFromLookup(k);
2958 immutable sinTwiddle1 = negSinFromLookup(k);
2959 immutable cosTwiddle2 = cosFromLookup(k + 1);
2960 immutable sinTwiddle2 = negSinFromLookup(k + 1);
2961
2962 immutable realLower1 = buf[k].re;
2963 immutable imagLower1 = buf[k].im;
2964 immutable realLower2 = buf[k + 1].re;
2965 immutable imagLower2 = buf[k + 1].im;
2966
2967 immutable upperIndex1 = k + halfLen;
2968 immutable upperIndex2 = upperIndex1 + 1;
2969 immutable realUpper1 = buf[upperIndex1].re;
2970 immutable imagUpper1 = buf[upperIndex1].im;
2971 immutable realUpper2 = buf[upperIndex2].re;
2972 immutable imagUpper2 = buf[upperIndex2].im;
2973
2974 immutable realAdd1 = cosTwiddle1 * realUpper1
2975 - sinTwiddle1 * imagUpper1;
2976 immutable imagAdd1 = sinTwiddle1 * realUpper1
2977 + cosTwiddle1 * imagUpper1;
2978 immutable realAdd2 = cosTwiddle2 * realUpper2
2979 - sinTwiddle2 * imagUpper2;
2980 immutable imagAdd2 = sinTwiddle2 * realUpper2
2981 + cosTwiddle2 * imagUpper2;
2982
2983 buf[k].re += realAdd1;
2984 buf[k].im += imagAdd1;
2985 buf[k + 1].re += realAdd2;
2986 buf[k + 1].im += imagAdd2;
2987
2988 buf[upperIndex1].re = realLower1 - realAdd1;
2989 buf[upperIndex1].im = imagLower1 - imagAdd1;
2990 buf[upperIndex2].re = realLower2 - realAdd2;
2991 buf[upperIndex2].im = imagLower2 - imagAdd2;
2992 }
2993 }
2994
2995 // This constructor is used within this module for allocating the
2996 // buffer space elsewhere besides the GC heap. It's definitely **NOT**
2997 // part of the public API and definitely **IS** subject to change.
2998 //
2999 // Also, this is unsafe because the memSpace buffer will be cast
3000 // to immutable.
this(lookup_t[]memSpace)3001 public this(lookup_t[] memSpace) // Public b/c of bug 4636.
3002 {
3003 immutable size = memSpace.length / 2;
3004
3005 /* Create a lookup table of all negative sine values at a resolution of
3006 * size and all smaller power of two resolutions. This may seem
3007 * inefficient, but having all the lookups be next to each other in
3008 * memory at every level of iteration is a huge win performance-wise.
3009 */
3010 if (size == 0)
3011 {
3012 return;
3013 }
3014
3015 assert(isPowerOf2(size),
3016 "Can only do FFTs on ranges with a size that is a power of two.");
3017
3018 auto table = new lookup_t[][bsf(size) + 1];
3019
3020 table[$ - 1] = memSpace[$ - size..$];
3021 memSpace = memSpace[0 .. size];
3022
3023 auto lastRow = table[$ - 1];
3024 lastRow[0] = 0; // -sin(0) == 0.
3025 foreach (ptrdiff_t i; 1 .. size)
3026 {
3027 // The hard coded cases are for improved accuracy and to prevent
3028 // annoying non-zeroness when stuff should be zero.
3029
3030 if (i == size / 4)
3031 lastRow[i] = -1; // -sin(pi / 2) == -1.
3032 else if (i == size / 2)
3033 lastRow[i] = 0; // -sin(pi) == 0.
3034 else if (i == size * 3 / 4)
3035 lastRow[i] = 1; // -sin(pi * 3 / 2) == 1
3036 else
3037 lastRow[i] = -sin(i * 2.0L * PI / size);
3038 }
3039
3040 // Fill in all the other rows with strided versions.
3041 foreach (i; 1 .. table.length - 1)
3042 {
3043 immutable strideLength = size / (2 ^^ i);
3044 auto strided = Stride!(lookup_t[])(lastRow, strideLength);
3045 table[i] = memSpace[$ - strided.length..$];
3046 memSpace = memSpace[0..$ - strided.length];
3047
3048 size_t copyIndex;
3049 foreach (elem; strided)
3050 {
3051 table[i][copyIndex++] = elem;
3052 }
3053 }
3054
3055 negSinLookup = cast(immutable) table;
3056 }
3057
3058 public:
3059 /**Create an $(D Fft) object for computing fast Fourier transforms of
3060 * power of two sizes of $(D size) or smaller. $(D size) must be a
3061 * power of two.
3062 */
this(size_t size)3063 this(size_t size)
3064 {
3065 // Allocate all twiddle factor buffers in one contiguous block so that,
3066 // when one is done being used, the next one is next in cache.
3067 auto memSpace = uninitializedArray!(lookup_t[])(2 * size);
3068 this(memSpace);
3069 }
3070
size()3071 @property size_t size() const
3072 {
3073 return (negSinLookup is null) ? 0 : negSinLookup[$ - 1].length;
3074 }
3075
3076 /**Compute the Fourier transform of range using the $(BIGOH N log N)
3077 * Cooley-Tukey Algorithm. $(D range) must be a random-access range with
3078 * slicing and a length equal to $(D size) as provided at the construction of
3079 * this object. The contents of range can be either numeric types,
3080 * which will be interpreted as pure real values, or complex types with
3081 * properties or members $(D .re) and $(D .im) that can be read.
3082 *
3083 * Note: Pure real FFTs are automatically detected and the relevant
3084 * optimizations are performed.
3085 *
3086 * Returns: An array of complex numbers representing the transformed data in
3087 * the frequency domain.
3088 *
3089 * Conventions: The exponent is negative and the factor is one,
3090 * i.e., output[j] := sum[ exp(-2 PI i j k / N) input[k] ].
3091 */
3092 Complex!F[] fft(F = double, R)(R range) const
3093 if (isFloatingPoint!F && isRandomAccessRange!R)
3094 {
3095 enforceSize(range);
3096 Complex!F[] ret;
3097 if (range.length == 0)
3098 {
3099 return ret;
3100 }
3101
3102 // Don't waste time initializing the memory for ret.
3103 ret = uninitializedArray!(Complex!F[])(range.length);
3104
3105 fft(range, ret);
3106 return ret;
3107 }
3108
3109 /**Same as the overload, but allows for the results to be stored in a user-
3110 * provided buffer. The buffer must be of the same length as range, must be
3111 * a random-access range, must have slicing, and must contain elements that are
3112 * complex-like. This means that they must have a .re and a .im member or
3113 * property that can be both read and written and are floating point numbers.
3114 */
3115 void fft(Ret, R)(R range, Ret buf) const
3116 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret)
3117 {
3118 assert(buf.length == range.length);
3119 enforceSize(range);
3120
3121 if (range.length == 0)
3122 {
3123 return;
3124 }
3125 else if (range.length == 1)
3126 {
3127 buf[0] = range[0];
3128 return;
3129 }
3130 else if (range.length == 2)
3131 {
3132 slowFourier2(range, buf);
3133 return;
3134 }
3135 else
3136 {
3137 alias E = ElementType!R;
3138 static if (is(E : real))
3139 {
3140 return fftImplPureReal(range, buf);
3141 }
3142 else
3143 {
3144 static if (is(R : Stride!R))
3145 return fftImpl(range, buf);
3146 else
3147 return fftImpl(Stride!R(range, 1), buf);
3148 }
3149 }
3150 }
3151
3152 /**
3153 * Computes the inverse Fourier transform of a range. The range must be a
3154 * random access range with slicing, have a length equal to the size
3155 * provided at construction of this object, and contain elements that are
3156 * either of type std.complex.Complex or have essentially
3157 * the same compile-time interface.
3158 *
3159 * Returns: The time-domain signal.
3160 *
3161 * Conventions: The exponent is positive and the factor is 1/N, i.e.,
3162 * output[j] := (1 / N) sum[ exp(+2 PI i j k / N) input[k] ].
3163 */
3164 Complex!F[] inverseFft(F = double, R)(R range) const
3165 if (isRandomAccessRange!R && isComplexLike!(ElementType!R) && isFloatingPoint!F)
3166 {
3167 enforceSize(range);
3168 Complex!F[] ret;
3169 if (range.length == 0)
3170 {
3171 return ret;
3172 }
3173
3174 // Don't waste time initializing the memory for ret.
3175 ret = uninitializedArray!(Complex!F[])(range.length);
3176
3177 inverseFft(range, ret);
3178 return ret;
3179 }
3180
3181 /**
3182 * Inverse FFT that allows a user-supplied buffer to be provided. The buffer
3183 * must be a random access range with slicing, and its elements
3184 * must be some complex-like type.
3185 */
3186 void inverseFft(Ret, R)(R range, Ret buf) const
3187 if (isRandomAccessRange!Ret && isComplexLike!(ElementType!Ret) && hasSlicing!Ret)
3188 {
3189 enforceSize(range);
3190
3191 auto swapped = map!swapRealImag(range);
3192 fft(swapped, buf);
3193
3194 immutable lenNeg1 = 1.0 / buf.length;
foreach(ref elem;buf)3195 foreach (ref elem; buf)
3196 {
3197 immutable temp = elem.re * lenNeg1;
3198 elem.re = elem.im * lenNeg1;
3199 elem.im = temp;
3200 }
3201 }
3202 }
3203
3204 // This mixin creates an Fft object in the scope it's mixed into such that all
3205 // memory owned by the object is deterministically destroyed at the end of that
3206 // scope.
3207 private enum string MakeLocalFft = q{
3208 import core.stdc.stdlib;
3209 import core.exception : onOutOfMemoryError;
3210
3211 auto lookupBuf = (cast(lookup_t*) malloc(range.length * 2 * lookup_t.sizeof))
3212 [0 .. 2 * range.length];
3213 if (!lookupBuf.ptr)
3214 onOutOfMemoryError();
3215
3216 scope(exit) free(cast(void*) lookupBuf.ptr);
3217 auto fftObj = scoped!Fft(lookupBuf);
3218 };
3219
3220 /**Convenience functions that create an $(D Fft) object, run the FFT or inverse
3221 * FFT and return the result. Useful for one-off FFTs.
3222 *
3223 * Note: In addition to convenience, these functions are slightly more
3224 * efficient than manually creating an Fft object for a single use,
3225 * as the Fft object is deterministically destroyed before these
3226 * functions return.
3227 */
3228 Complex!F[] fft(F = double, R)(R range)
3229 {
3230 mixin(MakeLocalFft);
3231 return fftObj.fft!(F, R)(range);
3232 }
3233
3234 /// ditto
fft(Ret,R)3235 void fft(Ret, R)(R range, Ret buf)
3236 {
3237 mixin(MakeLocalFft);
3238 return fftObj.fft!(Ret, R)(range, buf);
3239 }
3240
3241 /// ditto
3242 Complex!F[] inverseFft(F = double, R)(R range)
3243 {
3244 mixin(MakeLocalFft);
3245 return fftObj.inverseFft!(F, R)(range);
3246 }
3247
3248 /// ditto
inverseFft(Ret,R)3249 void inverseFft(Ret, R)(R range, Ret buf)
3250 {
3251 mixin(MakeLocalFft);
3252 return fftObj.inverseFft!(Ret, R)(range, buf);
3253 }
3254
3255 @system unittest
3256 {
3257 import std.algorithm;
3258 import std.conv;
3259 import std.range;
3260 // Test values from R and Octave.
3261 auto arr = [1,2,3,4,5,6,7,8];
3262 auto fft1 = fft(arr);
3263 assert(approxEqual(map!"a.re"(fft1),
3264 [36.0, -4, -4, -4, -4, -4, -4, -4]));
3265 assert(approxEqual(map!"a.im"(fft1),
3266 [0, 9.6568, 4, 1.6568, 0, -1.6568, -4, -9.6568]));
3267
3268 auto fft1Retro = fft(retro(arr));
3269 assert(approxEqual(map!"a.re"(fft1Retro),
3270 [36.0, 4, 4, 4, 4, 4, 4, 4]));
3271 assert(approxEqual(map!"a.im"(fft1Retro),
3272 [0, -9.6568, -4, -1.6568, 0, 1.6568, 4, 9.6568]));
3273
3274 auto fft1Float = fft(to!(float[])(arr));
3275 assert(approxEqual(map!"a.re"(fft1), map!"a.re"(fft1Float)));
3276 assert(approxEqual(map!"a.im"(fft1), map!"a.im"(fft1Float)));
3277
3278 alias C = Complex!float;
3279 auto arr2 = [C(1,2), C(3,4), C(5,6), C(7,8), C(9,10),
3280 C(11,12), C(13,14), C(15,16)];
3281 auto fft2 = fft(arr2);
3282 assert(approxEqual(map!"a.re"(fft2),
3283 [64.0, -27.3137, -16, -11.3137, -8, -4.6862, 0, 11.3137]));
3284 assert(approxEqual(map!"a.im"(fft2),
3285 [72, 11.3137, 0, -4.686, -8, -11.3137, -16, -27.3137]));
3286
3287 auto inv1 = inverseFft(fft1);
3288 assert(approxEqual(map!"a.re"(inv1), arr));
3289 assert(reduce!max(map!"a.im"(inv1)) < 1e-10);
3290
3291 auto inv2 = inverseFft(fft2);
3292 assert(approxEqual(map!"a.re"(inv2), map!"a.re"(arr2)));
3293 assert(approxEqual(map!"a.im"(inv2), map!"a.im"(arr2)));
3294
3295 // FFTs of size 0, 1 and 2 are handled as special cases. Test them here.
3296 ushort[] empty;
3297 assert(fft(empty) == null);
3298 assert(inverseFft(fft(empty)) == null);
3299
3300 real[] oneElem = [4.5L];
3301 auto oneFft = fft(oneElem);
3302 assert(oneFft.length == 1);
3303 assert(oneFft[0].re == 4.5L);
3304 assert(oneFft[0].im == 0);
3305
3306 auto oneInv = inverseFft(oneFft);
3307 assert(oneInv.length == 1);
3308 assert(approxEqual(oneInv[0].re, 4.5));
3309 assert(approxEqual(oneInv[0].im, 0));
3310
3311 long[2] twoElems = [8, 4];
3312 auto twoFft = fft(twoElems[]);
3313 assert(twoFft.length == 2);
3314 assert(approxEqual(twoFft[0].re, 12));
3315 assert(approxEqual(twoFft[0].im, 0));
3316 assert(approxEqual(twoFft[1].re, 4));
3317 assert(approxEqual(twoFft[1].im, 0));
3318 auto twoInv = inverseFft(twoFft);
3319 assert(approxEqual(twoInv[0].re, 8));
3320 assert(approxEqual(twoInv[0].im, 0));
3321 assert(approxEqual(twoInv[1].re, 4));
3322 assert(approxEqual(twoInv[1].im, 0));
3323 }
3324
3325 // Swaps the real and imaginary parts of a complex number. This is useful
3326 // for inverse FFTs.
swapRealImag(C)3327 C swapRealImag(C)(C input)
3328 {
3329 return C(input.im, input.re);
3330 }
3331
3332 private:
3333 // The reasons I couldn't use std.algorithm were b/c its stride length isn't
3334 // modifiable on the fly and because range has grown some performance hacks
3335 // for powers of 2.
Stride(R)3336 struct Stride(R)
3337 {
3338 import core.bitop : bsf;
3339 Unqual!R range;
3340 size_t _nSteps;
3341 size_t _length;
3342 alias E = ElementType!(R);
3343
3344 this(R range, size_t nStepsIn)
3345 {
3346 this.range = range;
3347 _nSteps = nStepsIn;
3348 _length = (range.length + _nSteps - 1) / nSteps;
3349 }
3350
3351 size_t length() const @property
3352 {
3353 return _length;
3354 }
3355
3356 typeof(this) save() @property
3357 {
3358 auto ret = this;
3359 ret.range = ret.range.save;
3360 return ret;
3361 }
3362
3363 E opIndex(size_t index)
3364 {
3365 return range[index * _nSteps];
3366 }
3367
3368 E front() @property
3369 {
3370 return range[0];
3371 }
3372
3373 void popFront()
3374 {
3375 if (range.length >= _nSteps)
3376 {
3377 range = range[_nSteps .. range.length];
3378 _length--;
3379 }
3380 else
3381 {
3382 range = range[0 .. 0];
3383 _length = 0;
3384 }
3385 }
3386
3387 // Pops half the range's stride.
3388 void popHalf()
3389 {
3390 range = range[_nSteps / 2 .. range.length];
3391 }
3392
3393 bool empty() const @property
3394 {
3395 return length == 0;
3396 }
3397
3398 size_t nSteps() const @property
3399 {
3400 return _nSteps;
3401 }
3402
3403 void doubleSteps()
3404 {
3405 _nSteps *= 2;
3406 _length /= 2;
3407 }
3408
3409 size_t nSteps(size_t newVal) @property
3410 {
3411 _nSteps = newVal;
3412
3413 // Using >> bsf(nSteps) is a few cycles faster than / nSteps.
3414 _length = (range.length + _nSteps - 1) >> bsf(nSteps);
3415 return newVal;
3416 }
3417 }
3418
3419 // Hard-coded base case for FFT of size 2. This is actually a TON faster than
3420 // using a generic slow DFT. This seems to be the best base case. (Size 1
3421 // can be coded inline as buf[0] = range[0]).
slowFourier2(Ret,R)3422 void slowFourier2(Ret, R)(R range, Ret buf)
3423 {
3424 assert(range.length == 2);
3425 assert(buf.length == 2);
3426 buf[0] = range[0] + range[1];
3427 buf[1] = range[0] - range[1];
3428 }
3429
3430 // Hard-coded base case for FFT of size 4. Doesn't work as well as the size
3431 // 2 case.
slowFourier4(Ret,R)3432 void slowFourier4(Ret, R)(R range, Ret buf)
3433 {
3434 alias C = ElementType!Ret;
3435
3436 assert(range.length == 4);
3437 assert(buf.length == 4);
3438 buf[0] = range[0] + range[1] + range[2] + range[3];
3439 buf[1] = range[0] - range[1] * C(0, 1) - range[2] + range[3] * C(0, 1);
3440 buf[2] = range[0] - range[1] + range[2] - range[3];
3441 buf[3] = range[0] + range[1] * C(0, 1) - range[2] - range[3] * C(0, 1);
3442 }
3443
3444 N roundDownToPowerOf2(N)(N num)
3445 if (isScalarType!N && !isFloatingPoint!N)
3446 {
3447 import core.bitop : bsr;
3448 return num & (cast(N) 1 << bsr(num));
3449 }
3450
3451 @safe unittest
3452 {
3453 assert(roundDownToPowerOf2(7) == 4);
3454 assert(roundDownToPowerOf2(4) == 4);
3455 }
3456
isComplexLike(T)3457 template isComplexLike(T)
3458 {
3459 enum bool isComplexLike = is(typeof(T.init.re)) &&
3460 is(typeof(T.init.im));
3461 }
3462
3463 @safe unittest
3464 {
3465 static assert(isComplexLike!(Complex!double));
3466 static assert(!isComplexLike!(uint));
3467 }
3468