1# poly.rb -- polynomial-related stuff; poly.scm --> poly.rb
2
3# Translator: Michael Scholz <mi-scholz@users.sourceforge.net>
4# Created: 05/04/09 23:55:07
5# Changed: 20/11/08 00:06:07
6
7# class Complex
8#  to_f
9#  to_f_or_c
10#
11# class Poly < Vec
12#  inspect
13#  to_poly
14#  reduce
15#  +(other)
16#  *(other)
17#  /(other)
18#  derivative
19#  resultant(other)
20#  discriminant
21#  gcd(other)
22#  roots
23#  eval(x)
24#
25# class Float
26#  +(other)
27#  *(other)
28#  /(other)
29#
30# class String
31#  to_poly
32#
33# class Array
34#  to_poly
35#
36# class Vct
37#  to_poly
38#
39# Poly(obj)
40# make_poly(len, init, &body)
41# poly?(obj)
42# poly(*vals)
43# poly_reduce(obj)
44# poly_add(obj1, obj2)
45# poly_multiply(obj1, obj2)
46# poly_div(obj1, obj2)
47# poly_derivative(obj)
48# poly_gcd(obj1, obj2)
49# poly_roots(obj)
50
51require "clm"
52require "mix"
53include Math
54
55class Complex
56  # XXX: attr_writer :real, :imag
57  #      Doesn't work any longer.
58  #      Complex objects are now frozen objects.
59  #      (Thu Nov 30 21:29:10 CET 2017)
60  with_silence do
61    def to_f
62      self.real.to_f
63    end
64  end
65
66  def to_f_or_c
67    self.imag.zero? ? self.to_f : self
68  end
69end
70
71class Poly < Vec
72  Poly_roots_epsilon = 1.0e-6
73
74  def inspect
75    @name = "poly"
76    super
77  end
78
79  def to_poly
80    self
81  end
82
83  def reduce
84    if self.last.zero?
85      i = self.length - 1
86      while self[i].zero? and i > 0
87        i -= 1
88      end
89      # FIXME: ruby3 requires to_poly
90      self[0, i + 1].to_poly
91    else
92      self
93    end
94  end
95  # [1, 2, 3].to_poly.reduce             ==> poly(1.0, 2.0, 3.0)
96  # poly(1, 2, 3, 0, 0, 0).reduce        ==> poly(1.0, 2.0, 3.0)
97  # vct(0, 0, 0, 0, 1, 0).to_poly.reduce ==> poly(0.0, 0.0, 0.0, 0.0, 1.0)
98
99  def poly_add(other)
100    assert_type((array?(other) or vct?(other) or number?(other)),
101                other, 0, "a poly, a vct an array, or a number")
102    if number?(other)
103      v = self.dup
104      v[0] += other
105      v
106    else
107      if self.length > other.length
108        self.add(other)
109      else
110        Poly(other).add(self)
111      end
112    end
113  end
114  alias + poly_add
115  # poly(0.1, 0.2, 0.3) + poly(0, 1, 2, 3, 4) ==> poly(0.1, 1.2, 2.3, 3.0, 4.0)
116  # poly(0.1, 0.2, 0.3) + 0.5                 ==> poly(0.6, 0.2, 0.3)
117  # 0.5 + poly(0.1, 0.2, 0.3)                 ==> poly(0.6, 0.2, 0.3)
118
119  def poly_multiply(other)
120    assert_type((array?(other) or vct?(other) or number?(other)),
121                other, 0, "a poly, a vct, an array, or a number")
122    if number?(other)
123      Poly(self.scale(Float(other)))
124    else
125      len = self.length + other.length
126      m = Poly.new(len, 0.0)
127      self.each_with_index do |val1, i|
128        other.each_with_index do |val2, j|
129          m[i + j] = m[i + j] + val1 * val2
130        end
131      end
132      m
133    end
134  end
135  alias * poly_multiply
136  # poly(1, 1) * poly(-1, 1)        ==> poly(-1.0, 0.0, 1.0, 0.0)
137  # poly(-5, 1) * poly(3, 7, 2)     ==> poly(-15.0, -32.0, -3.0, 2.0, 0.0)
138  # poly(-30, -4, 2) * poly(0.5, 1) ==> poly(-15.0, -32.0, -3.0, 2.0, 0.0)
139  # poly(-30, -4, 2) * 0.5          ==> poly(-15.0, -2.0, 1.0)
140  # 2.0 * poly(-30, -4, 2)          ==> poly(-60.0, -8.0, 4.0)
141
142  def poly_div(other)
143    assert_type((array?(other) or vct?(other) or number?(other)),
144                other, 0, "a poly, a vct, an array, or a number")
145    if number?(other)
146      [self * (1.0 / other), poly(0.0)]
147    else
148      if other.length > self.length
149        [poly(0.0), other.to_poly]
150      else
151        r = self.dup
152        q = Poly.new(self.length, 0.0)
153        n = self.length - 1
154        nv = other.length - 1
155        (n - nv).downto(0) do |i|
156          q[i] = r[nv + i] / other[nv]
157          (nv + i - 1).downto(i) do |j|
158            r[j] = r[j] - q[i] * other[j - i]
159          end
160        end
161        nv.upto(n) do |i|
162          r[i] = 0.0
163        end
164        [q, r]
165      end
166    end
167  end
168  alias / poly_div
169  # poly(-1.0, 0.0, 1.0) / poly(1.0, 1.0)
170  #   ==> [poly(-1.0, 1.0, 0.0),       poly(0.0, 0.0, 0.0)]
171  # poly(-15, -32, -3, 2) / poly(-5, 1)
172  #   ==> [poly(3.0, 7.0, 2.0, 0.0),   poly(0.0, 0.0, 0.0, 0.0)]
173  # poly(-15, -32, -3, 2) / poly(3, 1)
174  #   ==> [poly(-5.0, -9.0, 2.0, 0.0), poly(0.0, 0.0, 0.0, 0.0)]
175  # poly(-15, -32, -3, 2) / poly(0.5, 1)
176  #   ==> [poly(-30.0, -4.0, 2.0, 0.0), poly(0.0, 0.0, 0.0, 0.0)]
177  # poly(-15, -32, -3, 2) / poly(3, 7, 2)
178  #   ==> [poly(-5.0, 1.0, 0.0, 0.0),  poly(0.0, 0.0, 0.0, 0.0)]
179  # poly(-15, -32, -3, 2) / 2.0
180  #   ==> [poly(-7.5, -16.0, -1.5, 1.0), poly(0.0)]
181
182  def derivative
183    len = self.length - 1
184    pl = Poly.new(len, 0.0)
185    j = len
186    (len - 1).downto(0) do |i|
187      pl[i] = self[j] * j
188      j -= 1
189    end
190    pl
191  end
192  # poly(0.5, 1.0, 2.0, 4.0).derivative ==> poly(1.0, 4.0, 12.0)
193
194  def resultant(other)
195    m = self.length
196    m1 = m - 1
197    n = other.length
198    n1 = n - 1
199    d = n1 + m1
200    mat = Array.new(d) do
201      Vct.new(d, 0.0)
202    end
203    n1.times do |i|
204      m.times do |j|
205        mat[i][i + j] = self[m1 - j]
206      end
207    end
208    m1.times do |i|
209      n.times do |j|
210        mat[i + n1][i + j] = other[n1 - j]
211      end
212    end
213    determinant(mat)
214  end
215  # poly(-1, 0, 1).resultant([1, -2, 1]) ==> 0.0
216  # poly(-1, 0, 2).resultant([1, -2, 1]) ==> 1.0
217  # poly(-1, 0, 1).resultant([1, 1])     ==> 0.0
218  # poly(-1, 0, 1).resultant([2, 1])     ==> 3.0
219
220  def discriminant
221    self.resultant(self.derivative)
222  end
223  # poly(-1, 0, 1).discriminant ==> -4.0
224  # poly(1, -2, 1).discriminant ==>  0.0
225  # (poly(-1, 1) * poly(-1, 1) * poly(3, 1)).reduce.discriminant
226  #   ==> 0.0
227  # (poly(-1, 1) * poly(-1, 1) * poly(3, 1) * poly(2, 1)).reduce.discriminant
228  #   ==> 0.0
229  # (poly(1, 1) * poly(-1, 1) * poly(3, 1) * poly(2, 1)).reduce.discriminant
230  #   ==> 2304.0
231  # (poly(1, 1) * poly(-1, 1) * poly(3, 1) * poly(3, 1)).reduce.discriminant
232  #   ==> 0.0
233
234  def gcd(other)
235    assert_type((array?(other) or vct?(other)), other, 0,
236                "a poly, a vct or an array")
237    if self.length < other.length
238      poly(0.0)
239    else
240      qr = self.poly_div(other).map do |m|
241        m.reduce
242      end
243      if qr[1].length == 1
244        if qr[1][0].zero?
245          Poly(other)
246        else
247          poly(0.0)
248        end
249      else
250        qr[0].gcd(qr[1])
251      end
252    end
253  end
254  # (poly(2, 1) * poly(-3, 1)).reduce.gcd(poly(2, 1))
255  #   ==> poly(2.0, 1.0)
256  # (poly(2, 1) * poly(-3, 1)).reduce.gcd(poly(3, 1))
257  #   ==> poly(0.0)
258  # (poly(2, 1) * poly(-3, 1)).reduce.gcd(poly(-3, 1))
259  #   ==> poly(-3.0, 1.0)
260  # (poly(8, 1) * poly(2, 1) * poly(-3, 1)).reduce.gcd(poly(-3, 1))
261  #   ==> poly(-3.0, 1.0)
262  # (poly(8, 1) * poly(2, 1) *
263  #  poly(-3, 1)).reduce.gcd((poly(8, 1) * poly(-3, 1)).reduce)
264  #   ==> poly(-24.0, 5.0, 1.0)
265  # poly(-1, 0, 1).gcd(poly(2, -2, -1, 1))
266  #   ==> poly(0.0)
267  # poly(2, -2, -1, 1).gcd(poly(-1, 0, 1))
268  #   ==> poly(1.0, -1.0)
269  # poly(2, -2, -1, 1).gcd(poly(-2.5, 1))
270  #   ==> poly(0.0)
271
272  def roots
273    rts = poly()
274    deg = self.length - 1
275    if deg.zero?
276      rts
277    else
278      if self[0].zero?
279        if deg == 1
280          poly(0.0)
281        else
282          Poly.new(deg) do |i|
283            self[i + 1]
284          end.roots.unshift(0.0)
285        end
286      else
287        if deg == 1
288          linear_root(self[1], self[0])
289        else
290          if deg == 2
291            quadratic_root(self[2], self[1], self[0])
292          else
293            if deg == 3 and
294               (rts = cubic_root(self[3], self[2], self[1], self[0]))
295              rts
296            else
297              if deg == 4 and
298                 (rts = quartic_root(self[4], self[3],
299                                     self[2], self[1], self[0]))
300                rts
301              else
302                ones = 0
303                1.upto(deg) do |i|
304                  if self[i].nonzero?
305                    ones += 1
306                  end
307                end
308                if ones == 1
309                  nth_root(self[deg], self[0], deg)
310                else
311                  if ones == 2 and deg.even? and self[deg / 2].nonzero?
312                    n = deg / 2
313                    poly(self[0], self[deg / 2], self[deg]).roots.each do |qr|
314                      rts.push(*nth_root(1.0, -qr, n.to_f))
315                    end
316                    rts
317                  else
318                    if deg > 3 and
319                        ones == 3 and
320                        (deg % 3).zero? and
321                        self[deg / 3].nonzero? and
322                        self[(deg * 2) / 3].nonzero?
323                      n = deg / 3
324                      poly(self[0],
325                           self[deg / 3],
326                           self[(deg * 2) / 3],
327                           self[deg]).roots.each do |qr|
328                        rts.push(*nth_root(1.0, -qr, n.to_f))
329                      end
330                      rts
331                    else
332                      q = self.dup
333                      pp = self.derivative
334                      qp = pp.dup
335                      n = deg
336                      x = Complex(1.3, 0.314159)
337                      v = q.eval(x)
338                      m = v.abs * v.abs
339                      20.times do # until c_g?
340                        if (dx = v / qp.eval(x)).abs <= Poly_roots_epsilon
341                          break
342                        end
343                        20.times do
344                          if dx.abs <= Poly_roots_epsilon
345                            break
346                          end
347                          y = x - dx
348                          v1 = q.eval(y)
349                          if (m1 = v1.abs * v1.abs) < m
350                            x = y
351                            v = v1
352                            m = m1
353                            break
354                          else
355                            dx /= 4.0
356                          end
357                        end
358                      end
359                      x = x - self.eval(x) / pp.eval(x)
360                      x = x - self.eval(x) / pp.eval(x)
361                      if x.imag < Poly_roots_epsilon
362                        q = q.poly_div(poly(-x.real, 1.0))
363                        n -= 1
364                      else
365                        q = q.poly_div(poly(x.abs, 0.0, 1.0))
366                        n -= 2
367                      end
368                      rts = if n > 0
369                              q.car.reduce.roots
370                            else
371                              poly()
372                            end
373                      rts << x.to_f_or_c
374                      rts
375                    end
376                  end
377                end
378              end
379            end
380          end
381        end
382      end
383    end
384  end
385
386  def eval(x)
387    sum = self.last
388    self.reverse[1..-1].each do |val|
389      sum = sum * x + val
390    end
391    sum
392  end
393
394  private
395  def submatrix(mx, row, col)
396    nmx = Array.new(mx.length - 1) do
397      Vct.new(mx.length - 1, 0.0)
398    end
399    ni = 0
400    mx.length.times do |i|
401      if i != row
402        nj = 0
403        mx.length.times do |j|
404          if j != col
405            nmx[ni][nj] = mx[i][j]
406            nj += 1
407          end
408        end
409        ni += 1
410      end
411    end
412    nmx
413  end
414
415  def determinant(mx)
416    if mx.length == 1
417      mx[0][0]
418    else
419      if mx.length == 2
420        mx[0][0] * mx[1][1] - mx[0][1] * mx[1][0]
421      else
422        if mx.length == 3
423          ((mx[0][0] * mx[1][1] * mx[2][2] +
424            mx[0][1] * mx[1][2] * mx[2][0] +
425            mx[0][2] * mx[1][0] * mx[2][1]) -
426           (mx[0][0] * mx[1][2] * mx[2][1] +
427            mx[0][1] * mx[1][0] * mx[2][2] +
428            mx[0][2] * mx[1][1] * mx[2][0]))
429        else
430          sum = 0.0
431          sign = 1
432          mx.length.times do |i|
433            mult = mx[0][i]
434            if mult != 0.0
435              sum = sum + sign * mult * determinant(submatrix(mx, 0, i))
436            end
437            sign = -sign
438          end
439          sum
440        end
441      end
442    end
443  end
444
445  # ax + b
446  def linear_root(a, b)
447    poly(-b / a)
448  end
449
450  # ax^2 + bx + c
451  def quadratic_root(a, b, c)
452    d = sqrt(b * b - 4.0 * a * c)
453    poly((-b + d) / (2.0 * a), (-b - d) / (2.0 * a))
454  end
455
456  # ax^3 + bx^2 + cx + d
457  def cubic_root(a, b, c, d)
458    # Abramowitz & Stegun 3.8.2
459    a0 = d / a
460    a1 = c / a
461    a2 = b / a
462    q = (a1 / 3) - ((a2 * a2) / 9)
463    r = ((a1 * a2 - 3 * a0) / 6) - ((a2 * a2 * a2) / 27)
464    sq3r2 = sqrt(q * q * q + r * r)
465    r1 = (r + sq3r2) ** (1 / 3.0)
466    r2 = (r - sq3r2) ** (1 / 3.0)
467    incr = (TWO_PI * Complex::I) / 3
468    pl = poly(a0, a1, a2, 1)
469    sqrt3 = sqrt(-3)
470    3.times do |i|
471      3.times do |j|
472        s1 = r1 * exp(i * incr)
473        s2 = r2 * exp(j * incr)
474        z1 = simplify_complex((s1 + s2) - (a2 / 3))
475        if pl.eval(z1).abs < Poly_roots_epsilon
476          z2 = simplify_complex((-0.5 * (s1 + s2)) +
477                                (a2 / -3) +
478                                ((s1 - s2) * 0.5 * sqrt3))
479          if pl.eval(z2).abs < Poly_roots_epsilon
480            z3 = simplify_complex((-0.5 * (s1 + s2)) +
481                                  (a2 / -3) +
482                                  ((s1 - s2) * -0.5 * sqrt3))
483            if pl.eval(z3).abs < Poly_roots_epsilon
484              return poly(z1, z2, z3)
485            end
486          end
487        end
488      end
489    end
490    false
491  end
492
493  # ax^4 + bx^3 + cx^2 + dx + e
494  def quartic_root(a, b, c, d, e)
495    # Weisstein, "Encyclopedia of Mathematics"
496    a0 = e / a
497    a1 = d / a
498    a2 = c / a
499    a3 = b / a
500    if yroot = poly((4 * a2 * a0) + -(a1 * a1) + -(a3 * a3 * a0),
501                    (a1 * a3) - (4 * a0),
502                    -a2,
503                    1).roots
504      yroot.each do |y1|
505        r = sqrt((0.25 * a3 * a3) + (-a2 + y1))
506        dd = if r.zero?
507              sqrt((0.75 * a3 * a3) +
508                   (-2 * a2) +
509                   (2 * sqrt(y1 * y1 - 4 * a0)))
510            else
511              sqrt((0.75 * a3 * a3) + (-2 * a2) + (-(r * r)) +
512                   (0.25 * ((4 * a3 * a2) + (-8 * a1) + (-(a3 * a3 * a3)))) / r)
513            end
514        ee = if r.zero?
515              sqrt((0.75 * a3 * a3) +
516                   (-2 * a2) +
517                   (-2 * sqrt((y1 * y1) - (4 * a0))))
518            else
519              sqrt((0.75 * a3 * a3) + (-2 * a2) + (-(r * r)) +
520                   (-0.25 *
521                    ((4 * a3 * a2) + (-8 * a1) + (-(a3 * a3 * a3)))) / r)
522            end
523        z1 = (-0.25 * a3) + ( 0.5 * r) + ( 0.5 * dd)
524        z2 = (-0.25 * a3) + ( 0.5 * r) + (-0.5 * dd)
525        z3 = (-0.25 * a3) + (-0.5 * r) + ( 0.5 * ee)
526        z4 = (-0.25 * a3) + (-0.5 * r) + (-0.5 * ee)
527        if poly(e, d, c, b, a).eval(z1).abs < Poly_roots_epsilon
528          return poly(z1, z2, z3, z4)
529        end
530      end
531    end
532    false
533  end
534
535  # ax^n + b
536  def nth_root(a, b, deg)
537    n = (-b / a) ** (1.0 / deg)
538    incr = (TWO_PI * Complex::I) / deg
539    rts = poly()
540    deg.to_i.times do |i|
541      rts.unshift(simplify_complex(exp(i * incr) * n))
542    end
543    rts
544  end
545
546  Poly_roots_epsilon2 = 1.0e-6
547  def simplify_complex(a)
548    if a.imag.abs < Poly_roots_epsilon2
549      (a.real.abs < Poly_roots_epsilon2) ? 0.0 : a.real.to_f
550    else
551      if a.real.abs < Poly_roots_epsilon2
552        # XXX: a.real = 0.0
553        #      Doesn't work any longer (see above, class Complex).
554        a = Complex(0.0, a.imag)
555      end
556      a
557    end
558  end
559end
560
561class Float
562  unless defined? 0.0.poly_plus
563    alias fp_plus +
564    def poly_plus(other)
565      case other
566      when Poly
567        other[0] += self
568        other
569      else
570        self.fp_plus(other)
571      end
572    end
573    alias + poly_plus
574  end
575
576  unless defined? 0.0.poly_times
577    alias fp_times *
578    def poly_times(other)
579      case other
580      when Poly
581        Poly(other.scale(self))
582      else
583        self.fp_times(other)
584      end
585    end
586    alias * poly_times
587  end
588
589  unless defined? 0.0.poly_div
590    alias fp_div /
591    def poly_div(other)
592      case other
593      when Poly
594        [poly(0.0), other]
595      else
596        self.fp_div(other)
597      end
598    end
599    alias / poly_div
600  end
601end
602
603class String
604  def to_poly
605    if self.scan(/^poly\([-+,.)\d\s]+/).null?
606      poly()
607    else
608      eval(self)
609    end
610  end
611end
612
613class Array
614  def to_poly
615    poly(*self)
616  end
617end
618
619class Vct
620  def to_poly
621    poly(*self.to_a)
622  end
623end
624
625def Poly(obj)
626  if obj.nil?
627    obj = []
628  end
629  assert_type(obj.respond_to?(:to_poly), obj, 0,
630              "an object containing method 'to_poly'")
631  obj.to_poly
632end
633
634def make_poly(len, init = 0.0, &body)
635  Poly.new(len, init, &body)
636end
637
638def poly?(obj)
639  obj.instance_of?(Poly)
640end
641
642def poly(*vals)
643  Poly.new(vals.length) do |i|
644    if integer?(val = vals[i])
645      Float(val)
646    else
647      val
648    end
649  end
650end
651
652def poly_reduce(obj)
653  assert_type(obj.respond_to?(:to_poly), obj, 0,
654              "an object containing method 'to_poly'")
655  Poly(obj).reduce
656end
657
658def poly_add(obj1, obj2)
659  if number?(obj1)
660    assert_type(obj2.respond_to?(:to_poly), obj2, 1,
661                "an object containing method 'to_poly'")
662    Float(obj1) + Poly(obj2)
663  else
664    assert_type(obj1.respond_to?(:to_poly), obj1, 0,
665                "an object containing method 'to_poly'")
666    Poly(obj1) + obj2
667  end
668end
669
670def poly_multiply(obj1, obj2)
671  if number?(obj1)
672    assert_type(obj2.respond_to?(:to_poly), obj2, 1,
673                "an object containing method 'to_poly'")
674    Float(obj1) * Poly(obj2)
675  else
676    assert_type(obj1.respond_to?(:to_poly), obj1, 0,
677                "an object containing method 'to_poly'")
678    Poly(obj1) * obj2
679  end
680end
681
682def poly_div(obj1, obj2)
683  if number?(obj1)
684    assert_type(obj2.respond_to?(:to_poly), obj2, 1,
685                "an object containing method 'to_poly'")
686    Float(obj1) / Poly(obj2)
687  else
688    assert_type(obj1.respond_to?(:to_poly), obj1, 0,
689                "an object containing method 'to_poly'")
690    Poly(obj1) / obj2
691  end
692end
693
694def poly_derivative(obj)
695  assert_type(obj.respond_to?(:to_poly), obj, 0,
696              "an object containing method 'to_poly'")
697  Poly(obj).derivative
698end
699
700def poly_gcd(obj1, obj2)
701  assert_type(obj.respond_to?(:to_poly), obj, 0,
702              "an object containing method 'to_poly'")
703  Poly(obj1).gcd(obj2)
704end
705
706def poly_roots(obj)
707  assert_type(obj.respond_to?(:to_poly), obj, 0,
708              "an object containing method 'to_poly'")
709  Poly(obj).roots
710end
711
712# poly.rb ends here
713