1 #include "PolynomialSolver.h"
2 #include <limits>
3 #include <cmath>
4 #include <iostream>
5 #include "ATC_Error.h"
6 
7 namespace ATC {
8   // Utility functions used by solvers, but not globally accessible.
9   static const double PI_OVER_3 = acos(-1.0)*(1.0/3.0);
is_zero(double x)10   static bool is_zero(double x)
11   {
12     static double GT_ZERO = 1.0e2*std::numeric_limits<double>::epsilon();
13     static double LT_ZERO = -GT_ZERO;
14     return x>LT_ZERO && x<GT_ZERO;
15   }
sign(double x)16   static double sign(double x)
17   {
18     static double s[] = {-1.0,1.0};
19     return s[x>0];
20   }
21 
22   // Linear solver
solve_linear(double c[2],double x0[1])23   int solve_linear(double c[2], double x0[1])
24   {
25     if (c[1] == 0) return 0;  // constant function
26     *x0 = -c[0] / c[1];
27     return 1;
28   }
29 
30   // Quadratic solver
solve_quadratic(double c[3],double x0[2])31   int solve_quadratic(double c[3], double x0[2])
32   {
33     if (is_zero(c[2])) return solve_linear(c, x0);
34     const double ainv = 1.0/c[2];       // ax^2 + bx + c = 0
35     const double p = 0.5 * c[1] * ainv; // -b/2a
36     const double q = c[0] * ainv;       // c/a
37     double D = p*p-q;
38 
39     if (is_zero(D))  { // quadratic has one repeated root
40       x0[0] = -p;
41       return 1;
42     }
43     if (D > 0) {       // quadratic has two real roots
44       D = sqrt(D);
45       x0[0] =  D - p;
46       x0[1] = -D - p;
47       return 2;
48     }
49     return 0;          // quadratic has no real roots
50   }
51 
52   // Cubic solver
solve_cubic(double c[4],double x0[3])53   int solve_cubic(double c[4], double x0[3])
54   {
55     int num_roots;
56     if (is_zero(c[3])) return solve_quadratic(c, x0);
57     // normalize to  x^3 + Ax^2 + Bx + C = 0
58     const double c3inv = 1.0/c[3];
59     const double A = c[2] * c3inv;
60     const double B = c[1] * c3inv;
61     const double C = c[0] * c3inv;
62 
63     // substitute x = t - A/3 so t^3 + pt + q = 0
64     const double A2 = A*A;
65     const double p = (1.0/3.0)*((-1.0/3.0)*A2 + B);
66     const double q = 0.5*((2.0/27.0)*A*A2 - (1.0/3.0)*A*B + C);
67 
68     // Cardano's fomula
69     const double p3 = p*p*p;
70     const double D  = q*q + p3;
71     if (is_zero(D)) {
72       if (is_zero(q)) { // one triple soln
73         x0[0] = 0.0;
74         num_roots = 1;
75       }
76       else {            // one single and one double soln
77         const double u  = pow(fabs(q), 1.0/3.0)*sign(q);
78         x0[0] = -2.0*u;
79         x0[1] = u;
80         num_roots = 2;
81       }
82     }
83     else {
84       if (D < 0.0) {    // three real roots
85         const double phi = 1.0/3.0 * acos(-q/sqrt(-p3));
86         const double t   = 2.0 * sqrt(-p);
87         x0[0] =  t * cos(phi);
88         x0[1] = -t * cos(phi + PI_OVER_3);
89         x0[2] = -t * cos(phi - PI_OVER_3);
90         num_roots = 3;
91       }
92       else {            // one real root
93         const double sqrt_D = sqrt(D);
94         const double u      = pow(sqrt_D + fabs(q), 1.0/3.0);
95         if (q > 0) x0[0] = -u + p / u;
96         else       x0[0] =  u - p / u;
97         num_roots = 1;
98       }
99     }
100     double sub = (1.0/3.0)*A;
101     for (int i=0; i<num_roots; i++) x0[i] -= sub;
102     return num_roots;
103   }
104 
105   // solve ode with polynomial source : y'n + a_n-1 y'n-1 + ... = b_n x^n +...
integrate_ode(double x,int na,double * a,double * y0,double * y,int nb,double *)106   void integrate_ode(double x,
107                      int na, double * a, double * y0, double * y, int nb, double * /* b */ )
108   {
109     if (na == 2) {
110       // particular
111       if ( a[1] == 0) {
112         if ( a[0] == 0) {
113           y[0] = y0[0]+y0[1]*x;
114           y[1] =       y0[1];
115         }
116         else {
117           double c = sqrt(a[0]);
118           y[0] =    y0[0]*cos(c*x)+y0[1]/c*sin(c*x);
119           y[1] = -c*y0[0]*cos(c*x)+y0[1]  *sin(c*x);
120         }
121       }
122       else {
123         // use solve_quadratic
124         throw ATC_Error("not yet supported");
125       }
126       // homogenous
127       double c = 1.;
128       double z = x;
129       int j = 2;
130       for (int i = 0; i < nb; i++,j++) {
131         y[1] += j*c*z;
132         c /= j;
133         z *= x;
134         y[0] += c*z;
135       }
136     }
137     else throw ATC_Error("can only integrate 2nd order ODEs currently");
138   }
139 }
140