1import sys
2import ctypes
3import ctypes.util
4# from gmpy2 import mpq
5
6# Load the libmps shared library. We should keep the .so version update
7# in case we bump it in the future.
8_mps = ctypes.CDLL("libmps.so.3")
9
10class Cplx(ctypes.Structure):
11    _mps.mps_chebyshev_poly_new.restype = ctypes.c_void_p
12    _mps.mps_context_new.restype = ctypes.c_void_p
13    _mps.mps_monomial_poly_new.restype = ctypes.c_void_p
14
15    """Wrapper around the cplx_t type of MPSolve, that is usually
16    a direct mapping onto the complex type of C99, but has a fallback
17    custom implementation for systems that do not provide the type"""
18
19    _fields_ = [("real", ctypes.c_double),
20                ("imag", ctypes.c_double)]
21
22    def __repr__(self):
23        return "%e + %ei" % (self.real, self.imag)
24
25    def __complex__(self):
26        return complex(self.real + 1j*self.imag)
27
28
29class Goal:
30    """ Goal to reach before returning the result.
31    """
32    MPS_OUTPUT_GOAL_ISOLATE = 0
33    MPS_OUTPUT_GOAL_APPROXIMATE = 1
34    MPS_OUTPUT_GOAL_COUNT = 2
35
36
37class Algorithm:
38    """Here you can find all the available algorithms in MPSolve.
39    You can use these contants to specify the algorithm of your choice
40    when you call Context.solve() or Context.mpsolve(). For example:
41
42     poly = MonomialPoly(ctx, n)
43     roots = ctx.solve(poly, Algorithm.SECULAR_GA)
44    """
45
46    STANDARD_MPSOLVE = 0
47    SECULAR_GA = 1
48
49
50class Context:
51    """The Context class is a wrapper around the mps_context type
52    in libmps. A Context instance must be instantiated before
53    allocating and/or solving polynomials and secular equations,
54    and can then be used to specify the desired property of the
55    solution. """
56
57    def __init__(self):
58        self._c_ctx = ctypes.c_void_p(_mps.mps_context_new())
59
60    def __del__(self):
61        if self._c_ctx is not None:
62            _mps.mps_context_free(self._c_ctx)
63
64    def set_input_poly(self, poly):
65        """Select the polynomial that should be solved when mpsolve() is
66        called. Note that each Context can only solve one polynomial
67        at a time."""
68        self._set_input_poly(poly)
69
70    def _set_input_poly(self, poly):
71        _mps.mps_context_set_input_poly(self._c_ctx, poly._c_polynomial)
72
73    def mpsolve(self, poly=None, algorithm=Algorithm.SECULAR_GA):
74        """Calling this method will trigger the solution of the polynomial
75        previously loaded by a call to set_input_poly, or to the one passed
76        as second argument to this function.
77
78        An optional third argument specify the desired algorithm. """
79        if poly is not None:
80            self.set_input_poly(poly)
81
82        # Select the proper algorithm for this polynomial
83        if isinstance(poly, ChebyshevPoly):
84            algorithm = Algorithm.SECULAR_GA
85
86        _mps.mps_context_select_algorithm(self._c_ctx, algorithm)
87        _mps.mps_context_set_output_prec(self._c_ctx, ctypes.c_long(100))
88        _mps.mps_context_set_output_goal(self._c_ctx,
89                                         Goal.MPS_OUTPUT_GOAL_APPROXIMATE)
90        _mps.mps_mpsolve(self._c_ctx)
91
92    def solve(self, poly = None, algorithm = Algorithm.SECULAR_GA):
93        """Simple shorthand for the combination of set_input_poly() and mpsolve().
94        This function directly returns the approximations that could otherwise be
95        obtained by a call to the get_roots() method. """
96        self.mpsolve(poly, algorithm)
97        return self.get_roots()
98
99    def get_roots(self):
100        """Returns  the approximations obtained by MPSolve after a call
101        to the mpsolve method. Consider using the convienience solve() method
102        instead."""
103        degree = _mps.mps_context_get_degree(self._c_ctx)
104        roots = (Cplx*degree)()
105
106        _mps.mps_context_get_roots_d(self._c_ctx,
107                                     ctypes.pointer(ctypes.pointer(roots)),
108                                     None)
109        return [complex(x) for x in roots]
110
111    def get_inclusion_radii(self):
112        """Return a set of guaranteed inclusion radii for the
113        approximations obtained through a call to get_roots()"""
114        degree = _mps.mps_context_get_degree(self._c_ctx)
115        roots = (Cplx*degree)()
116        radii = (ctypes.c_double*degree)()
117
118        _mps.mps_context_get_roots_d(self._c_ctx,
119                                     ctypes.pointer(ctypes.pointer(roots)),
120                                     ctypes.pointer(ctypes.pointer(radii)))
121
122        return list(radii)
123
124
125class Polynomial:
126    """This is a wrapper around mps_polynomial struct. """
127
128    def __init__(self, ctx, degree):
129        self._degree = int(degree)
130        self._ctx = ctx
131
132    def __del__(self):
133        _mps.mps_polynomial_free(self._ctx._c_ctx, self._c_polynomial)
134
135
136class MonomialPoly(Polynomial):
137    """A polynomial specified with its monomial coefficients. """
138
139    def __init__(self, ctx, degree):
140        Polynomial.__init__(self, ctx, degree)
141        self._c_polynomial = \
142            ctypes.c_void_p(_mps.mps_monomial_poly_new (ctx._c_ctx, degree))
143
144    def set_coefficient(self, n, coeff_re, coeff_im=None):
145        """Set coefficient of degree n of the polynomial
146        to the value of coeff. Please note that you should use
147        the same data type for all the coefficients, and you
148        should use integers when possible. """
149
150        if coeff_im is not None and type(coeff_re) != type(coeff_im):
151            raise ValueError("Coefficient's real and imaginary parts \
152have different types")
153
154        mp = self._c_polynomial
155        cntxt = self._ctx._c_ctx
156        if n < 0 or n > self._degree:
157            raise RuntimeError("Coefficient degree is out of bounds")
158
159        if isinstance(coeff_re, int):
160            if coeff_im is None:
161                coeff_im = 0
162            coeff_re = ctypes.c_longlong(coeff_re)
163            coeff_im = ctypes.c_longlong(coeff_im)
164            _mps.mps_monomial_poly_set_coefficient_int(cntxt, mp, n,
165                                                       coeff_re, coeff_im)
166        elif isinstance(coeff_re, float):
167            if coeff_im is None:
168                coeff_im = 0.0
169            coeff_re = ctypes.c_double(coeff_re)
170            coeff_im = ctypes.c_double(coeff_im)
171            _mps.mps_monomial_poly_set_coefficient_d(cntxt, mp, n,
172                                                     coeff_re, coeff_im)
173        elif isinstance(coeff_re, str):
174            if coeff_im is None:
175                coeff_im = "0.0"
176            if sys.version_info.major > 2:
177                coeff_re = bytes(coeff_re, "ASCII")
178                coeff_im = bytes(coeff_im, "ASCII")
179            _mps.mps_monomial_poly_set_coefficient_s(cntxt, mp, n,
180                                                     coeff_re, coeff_im)
181        # elif isinstance(coeff_re, type(mpq())):
182        #     if coeff_im is None:
183        #         coeff_im = mpq()
184        #     _mps.mps_monomial_poly_set_coefficient_q(cntxt, mp, n,
185        #                                              coeff_re, coeff_im)
186
187        else:
188            raise RuntimeError("Coefficient type not supported")
189
190    def get_coefficient(self, n):
191        """ Get a coefficient of the polynomial
192        """
193        mp = self._c_polynomial
194        if n < 0 or n > self._degree or not isinstance(n, int):
195            raise ValueError("Invalid coefficient degree")
196
197        cf = (Cplx)()
198        _mps.mps_monomial_poly_get_coefficient_d(self._ctx._c_ctx, mp, n,
199                                                 ctypes.pointer(cf))
200        return complex(cf)
201
202    def get_coefficients(self):
203        """ Get list of coefficients of the polynomial
204        """
205        mp = self._c_polynomial
206        cf = (Cplx)()
207        coeffs = []
208        for n in range(self._degree + 1):
209            _mps.mps_monomial_poly_get_coefficient_d(self._ctx._c_ctx, mp, n,
210                                                     ctypes.pointer(cf))
211            cf_n = complex(cf)
212            coeffs.append(cf_n)
213        return coeffs
214
215
216class ChebyshevPoly(Polynomial):
217    """A polynomial represented in the Chebyshev base."""
218
219    def __init__(self, ctx, degree):
220        Polynomial.__init__(self, ctx, degree)
221
222        # 5 here is the equivalent of MPS_STRUCTURE_COMPLEX_RATIONAL
223        self._c_polynomial = \
224            ctypes.c_void_p(_mps.mps_chebyshev_poly_new (ctx._c_ctx,
225                                                          degree, 5))
226
227    def set_coefficient(self, n, coeff_re, coeff_im=None):
228        """Set the coefficient of degree n of the polynomial"""
229        cb = self._c_polynomial
230
231        if coeff_im is not None and type(coeff_re) != type(coeff_im):
232            raise ValueError("Coefficient's real and imaginary parts \
233have different types")
234
235        mp = self._c_polynomial
236        cntxt = self._ctx._c_ctx
237
238        if n < 0 or n > self._degree:
239            raise RuntimeError("Coefficient degree is out of bounds")
240
241        if isinstance(coeff_re, int):
242            if coeff_im is None:
243                coeff_im = 0
244            _mps.mps_chebyshev_poly_set_coefficient_i(cntxt, cb, n,
245                                                      coeff_re, coeff_im)
246        elif isinstance(coeff, float):
247            if coeff_im is None:
248                coeff_im = 0.0
249            _mps.mpc_set_d(ccoeff, coeff_re, coeff_im)
250        # elif isinstance(coeff, str):
251        #     if coeff_im is None:
252        #         coeff_im = "0.0"
253        #     _mps.mps_chebyshev_poly_set_coefficient_s(cntxt, cb, n,
254        #                                               coeff_re, coeff_im)
255        else:
256            raise RuntimeError("Unsupported type for coefficient")
257