1"""Sparse rational function fields."""
2
3from __future__ import annotations
4
5import functools
6import operator
7
8from ..core import Expr, Symbol
9from ..core.sympify import CantSympify, sympify
10from ..domains.compositedomain import CompositeDomain
11from ..domains.domainelement import DomainElement
12from ..domains.field import Field
13from .orderings import lex
14from .polyerrors import CoercionFailed, GeneratorsError
15from .rings import PolyElement, PolynomialRing
16
17
18def field(symbols, domain, order=lex):
19    """Construct new rational function field returning (field, x1, ..., xn)."""
20    _field = FractionField(domain, symbols, order)
21    return (_field,) + _field.gens
22
23
24class FractionField(Field, CompositeDomain):
25    """A class for representing multivariate rational function fields."""
26
27    is_FractionField = True
28
29    has_assoc_Ring = True
30
31    def __new__(cls, domain, symbols, order=lex):
32        ring = PolynomialRing(domain, symbols, order)
33        symbols = ring.symbols
34        ngens = ring.ngens
35        domain = ring.domain
36        order = ring.order
37
38        key = cls.__name__, symbols, ngens, domain, order
39        obj = _field_cache.get(key)
40
41        if obj is None:
42            obj = object.__new__(cls)
43            obj._hash = hash(key)
44            obj.dtype = type('FracElement', (FracElement,), {'field': obj})
45            obj.symbols = symbols
46            obj.ngens = ngens
47            obj.domain = domain
48            obj.order = order
49
50            obj.zero = obj.dtype(ring.zero)
51            obj.one = obj.dtype(ring.one)
52
53            obj.gens = obj._gens()
54
55            obj.rep = str(domain) + '(' + ','.join(map(str, symbols)) + ')'
56
57            for symbol, generator in zip(obj.symbols, obj.gens):
58                if isinstance(symbol, Symbol):
59                    name = symbol.name
60
61                    if not hasattr(obj, name):
62                        setattr(obj, name, generator)
63
64            _field_cache[key] = obj
65
66        return obj
67
68    def __getnewargs_ex__(self):
69        return (self.domain, self.symbols), {'order': self.order}
70
71    @property
72    def characteristic(self):
73        return self.domain.characteristic
74
75    def _gens(self):
76        """Return a list of polynomial generators."""
77        return tuple(self.dtype(gen) for gen in self.ring.gens)
78
79    def __hash__(self):
80        return self._hash
81
82    def __eq__(self, other):
83        return self is other
84
85    def clone(self, symbols=None, domain=None, order=None):
86        return self.__class__(domain or self.domain, symbols or self.symbols, order or self.order)
87
88    def __ne__(self, other):
89        return self is not other
90
91    def raw_new(self, numer, denom=None):
92        return self.dtype(numer, denom)
93
94    def domain_new(self, element):
95        return self.domain.convert(element)
96
97    def ground_new(self, element):
98        try:
99            return self(self.ring.ground_new(element))
100        except CoercionFailed:
101            domain = self.domain
102
103            if not domain.is_Field and hasattr(domain, 'field'):
104                ring = self.ring
105                ground_field = domain.field
106                element = ground_field.convert(element)
107                numer = ring.ground_new(element.numerator)
108                denom = ring.ground_new(element.denominator)
109                return self.raw_new(numer, denom)
110            else:
111                raise NotImplementedError
112
113    def __call__(self, element):
114        if isinstance(element, FracElement):
115            if self == element.field:
116                return element
117            else:
118                raise NotImplementedError('conversion')
119        elif isinstance(element, PolyElement):
120            denom, numer = element.clear_denoms()
121            numer = numer.set_ring(self.ring)
122            denom = self.ring.ground_new(denom)
123            return self.raw_new(numer, denom)
124        elif isinstance(element, tuple) and len(element) == 2:
125            numer, denom = list(map(self.ring.__call__, element))
126            numer, denom = numer.cancel(denom)
127            return self.raw_new(numer, denom)
128        elif isinstance(element, str):
129            raise NotImplementedError('parsing')
130        elif isinstance(element, Expr):
131            return self.convert(element)
132        else:
133            return self.ground_new(element)
134
135    def from_expr(self, expr):
136        expr = sympify(expr)
137        domain = self.domain
138        mapping = dict(zip(self.symbols, self.gens))
139
140        def _rebuild(expr):
141            if (generator := mapping.get(expr)) is not None:
142                return generator
143            elif expr.is_Add:
144                return functools.reduce(operator.add, list(map(_rebuild, expr.args)))
145            elif expr.is_Mul:
146                return functools.reduce(operator.mul, list(map(_rebuild, expr.args)))
147            elif expr.is_Pow:
148                c, a = expr.exp.as_coeff_Mul(rational=True)
149                if c.is_Integer and c != 1:
150                    return _rebuild(expr.base**a)**int(c)
151
152            if not domain.is_Field and hasattr(domain, 'field'):
153                frac = domain.field.convert(expr)
154            else:
155                frac = domain.convert(expr)
156
157            return self(frac)
158
159        try:
160            return _rebuild(expr)
161        except CoercionFailed:
162            raise ValueError('expected an expression convertible to a '
163                             f'rational function in {self}, got {expr}')
164
165    def to_ring(self):
166        return self.domain.poly_ring(*self.symbols, order=self.order)
167
168    def to_expr(self, element):
169        ring = self.ring
170        return ring.to_expr(element.numerator)/ring.to_expr(element.denominator)
171
172    def _from_PythonIntegerRing(self, a, K0):
173        return self(self.domain.convert(a, K0))
174    _from_GMPYIntegerRing = _from_PythonIntegerRing
175    _from_PythonRationalField = _from_PythonIntegerRing
176    _from_GMPYRationalField = _from_PythonIntegerRing
177    _from_RealField = _from_PythonIntegerRing
178    _from_ComplexField = _from_PythonIntegerRing
179
180    def _from_PolynomialRing(self, a, K0):
181        try:
182            return self(a)
183        except (CoercionFailed, GeneratorsError):
184            return
185
186    def _from_FractionField(self, a, K0):
187        try:
188            return a.set_field(self)
189        except (CoercionFailed, GeneratorsError):
190            return
191
192    @property
193    def ring(self):
194        return self.to_ring()
195
196    def is_normal(self, a):
197        return self.domain.is_normal(a.numerator.LC)
198
199
200_field_cache: dict[tuple, FractionField] = {}
201
202
203class FracElement(DomainElement, CantSympify):
204    """Element of multivariate distributed rational function field.
205
206    See Also
207    ========
208
209    FractionField
210
211    """
212
213    def __init__(self, numer, denom=None):
214        if denom is None:
215            denom = self.field.ring.one
216        elif not denom:
217            raise ZeroDivisionError('zero denominator')
218
219        self._numerator = numer
220        self._denominator = denom
221
222    def __reduce__(self):
223        return self.parent.__call__, ((self.numerator, self.denominator),)
224
225    def raw_new(self, numer, denom):
226        return self.__class__(numer, denom)
227
228    def new(self, numer, denom):
229        return self.raw_new(*numer.cancel(denom))
230
231    def to_poly(self):
232        if self.denominator != 1:
233            raise ValueError('self.denominator should be 1')
234        return self.numerator
235
236    @property
237    def numerator(self):
238        return self._numerator
239
240    @property
241    def denominator(self):
242        return self._denominator
243
244    @property
245    def parent(self):
246        return self.field
247
248    _hash = None
249
250    def __hash__(self):
251        _hash = self._hash
252        if _hash is None:
253            self._hash = _hash = hash((self.field, self.numerator, self.denominator))
254        return _hash
255
256    def copy(self):
257        return self.raw_new(self.numerator.copy(), self.denominator.copy())
258
259    def set_field(self, new_field):
260        if self.field == new_field:
261            return self
262        else:
263            new_ring = new_field.ring
264            numer = self.numerator.set_ring(new_ring)
265            denom = self.denominator.set_ring(new_ring)
266            return new_field((numer, denom))
267
268    def __eq__(self, other):
269        if isinstance(other, self.field.dtype):
270            return self.numerator == other.numerator and self.denominator == other.denominator
271        else:
272            return self.numerator == other and self.denominator == 1
273
274    def __bool__(self):
275        return bool(self.numerator)
276
277    def __pos__(self):
278        return self.raw_new(self.numerator, self.denominator)
279
280    def __neg__(self):
281        """Negate all coefficients in ``self``."""
282        return self.raw_new(-self.numerator, self.denominator)
283
284    def _extract_ground(self, element):
285        domain = self.field.domain
286
287        try:
288            element = domain.convert(element)
289        except CoercionFailed:
290            ground_field = domain.field
291
292            try:
293                element = ground_field.convert(element)
294            except CoercionFailed:
295                return 0, None, None
296            else:
297                return -1, element.numerator, element.denominator
298        else:
299            return 1, element, None
300
301    def __add__(self, other):
302        """Add rational functions ``self`` and ``other``."""
303        field = self.field
304
305        if not other:
306            return self
307        elif not self:
308            return other
309        elif isinstance(other, field.dtype):
310            if self.denominator == other.denominator:
311                return self.new(self.numerator + other.numerator, self.denominator)
312            else:
313                return self.new(self.numerator*other.denominator + self.denominator*other.numerator,
314                                self.denominator*other.denominator)
315        elif isinstance(other, field.ring.dtype):
316            return self.new(self.numerator + self.denominator*other, self.denominator)
317        else:
318            if isinstance(other, FracElement):
319                if isinstance(field.domain, FractionField) and field.domain.field == other.field:
320                    pass
321                elif isinstance(other.field.domain, FractionField) and other.field.domain.field == field:
322                    return other.__radd__(self)
323                else:
324                    return NotImplemented
325            elif isinstance(other, PolyElement):
326                if isinstance(field.domain, PolynomialRing) and field.domain.ring == other.ring:
327                    pass
328                else:
329                    return other.__radd__(self)
330
331        return self.__radd__(other)
332
333    def __radd__(self, other):
334        op, other_numer, other_denom = self._extract_ground(other)
335
336        if op == 1:
337            return self.new(self.numerator + self.denominator*other_numer, self.denominator)
338        elif not op:
339            return NotImplemented
340        else:
341            return self.new(self.numerator*other_denom + self.denominator*other_numer,
342                            self.denominator*other_denom)
343
344    def __sub__(self, other):
345        """Subtract rational functions ``self`` and ``other``."""
346        field = self.field
347
348        if not other:
349            return self
350        elif not self:
351            return -other
352        elif isinstance(other, field.dtype):
353            if self.denominator == other.denominator:
354                return self.new(self.numerator - other.numerator, self.denominator)
355            else:
356                return self.new(self.numerator*other.denominator - self.denominator*other.numerator,
357                                self.denominator*other.denominator)
358        elif isinstance(other, field.ring.dtype):
359            return self.new(self.numerator - self.denominator*other, self.denominator)
360        else:
361            if isinstance(other, FracElement):
362                if isinstance(field.domain, FractionField) and field.domain.field == other.field:
363                    pass
364                elif isinstance(other.field.domain, FractionField) and other.field.domain.field == field:
365                    return other.__rsub__(self)
366                else:
367                    return NotImplemented
368            elif isinstance(other, PolyElement):
369                if isinstance(field.domain, PolynomialRing) and field.domain.ring == other.ring:
370                    pass
371                else:
372                    return other.__rsub__(self)
373
374        op, other_numer, other_denom = self._extract_ground(other)
375
376        if op == 1:
377            return self.new(self.numerator - self.denominator*other_numer, self.denominator)
378        elif not op:
379            return NotImplemented
380        else:
381            return self.new(self.numerator*other_denom - self.denominator*other_numer,
382                            self.denominator*other_denom)
383
384    def __rsub__(self, other):
385        op, other_numer, other_denom = self._extract_ground(other)
386
387        if op == 1:
388            return self.new(-self.numerator + self.denominator*other_numer, self.denominator)
389        elif not op:
390            return NotImplemented
391        else:
392            return self.new(-self.numerator*other_denom + self.denominator*other_numer,
393                            self.denominator*other_denom)
394
395    def __mul__(self, other):
396        """Multiply rational functions ``self`` and ``other``."""
397        field = self.field
398
399        if not self or not other:
400            return field.zero
401        elif isinstance(other, field.dtype):
402            return self.new(self.numerator*other.numerator, self.denominator*other.denominator)
403        elif isinstance(other, field.ring.dtype):
404            return self.new(self.numerator*other, self.denominator)
405        else:
406            if isinstance(other, FracElement):
407                if isinstance(field.domain, FractionField) and field.domain.field == other.field:
408                    pass
409                elif isinstance(other.field.domain, FractionField) and other.field.domain.field == field:
410                    return other.__rmul__(self)
411                else:
412                    return NotImplemented
413            elif isinstance(other, PolyElement):
414                if isinstance(field.domain, PolynomialRing) and field.domain.ring == other.ring:
415                    pass
416                else:
417                    return other.__rmul__(self)
418
419        return self.__rmul__(other)
420
421    def __rmul__(self, other):
422        op, other_numer, other_denom = self._extract_ground(other)
423
424        if op == 1:
425            return self.new(self.numerator*other_numer, self.denominator)
426        elif not op:
427            return NotImplemented
428        else:
429            return self.new(self.numerator*other_numer, self.denominator*other_denom)
430
431    def __truediv__(self, other):
432        """Computes quotient of fractions ``self`` and ``other``."""
433        field = self.field
434
435        if not other:
436            raise ZeroDivisionError
437        elif isinstance(other, field.dtype):
438            return self.new(self.numerator*other.denominator, self.denominator*other.numerator)
439        elif isinstance(other, field.ring.dtype):
440            return self.new(self.numerator, self.denominator*other)
441        else:
442            if isinstance(other, FracElement):
443                if isinstance(field.domain, FractionField) and field.domain.field == other.field:
444                    pass
445                elif isinstance(other.field.domain, FractionField) and other.field.domain.field == field:
446                    return other.__rtruediv__(self)
447                else:
448                    return NotImplemented
449            elif isinstance(other, PolyElement):
450                if isinstance(field.domain, PolynomialRing) and field.domain.ring == other.ring:
451                    pass
452                else:
453                    return NotImplemented
454
455        op, other_numer, other_denom = self._extract_ground(other)
456
457        if op == 1:
458            return self.new(self.numerator, self.denominator*other_numer)
459        elif not op:
460            return NotImplemented
461        else:
462            return self.new(self.numerator*other_denom, self.denominator*other_numer)
463
464    def __rtruediv__(self, other):
465        if not self:
466            raise ZeroDivisionError
467        elif isinstance(other, self.field.ring.dtype):
468            return self.new(self.denominator*other, self.numerator)
469
470        op, other_numer, other_denom = self._extract_ground(other)
471
472        if op == 1:
473            return self.new(self.denominator*other_numer, self.numerator)
474        elif not op:
475            return NotImplemented
476        else:
477            return self.new(self.denominator*other_numer, self.numerator*other_denom)
478
479    def __pow__(self, n):
480        """Raise ``self`` to a non-negative power ``n``."""
481        if n >= 0:
482            return self.raw_new(self.numerator**n, self.denominator**n)
483        elif not self:
484            raise ZeroDivisionError
485        else:
486            return self.raw_new(self.denominator**-n, self.numerator**-n)
487
488    def diff(self, x):
489        """Computes partial derivative in ``x``.
490
491        Examples
492        ========
493
494        >>> _, x, y, z = field('x y z', ZZ)
495        >>> ((x**2 + y)/(z + 1)).diff(x)
496        2*x/(z + 1)
497
498        """
499        x = x.to_poly()
500        return self.new(self.numerator.diff(x)*self.denominator -
501                        self.numerator*self.denominator.diff(x), self.denominator**2)
502
503    def __call__(self, *values):
504        if 0 < len(values) <= self.field.ngens:
505            return self.eval(list(zip(self.field.gens, values)))
506        else:
507            raise ValueError(f'expected at least 1 and at most {self.field.ngens} values, got {len(values)}')
508
509    def eval(self, x, a=None):
510        if isinstance(x, list) and a is None:
511            x = [(X.to_poly(), a) for X, a in x]
512            numer, denom = self.numerator.eval(x), self.denominator.eval(x)
513        else:
514            x = x.to_poly()
515            numer, denom = self.numerator.eval(x, a), self.denominator.eval(x, a)
516
517        if self._extract_ground(denom) == (1, 1, None):
518            return numer
519        if isinstance(numer, PolyElement):
520            field = numer.ring.field
521        else:
522            field = self.field
523        return field((field.ring(numer), field.ring(denom)))
524
525    def compose(self, x, a=None):
526        """Computes the functional composition."""
527        field = self.field
528
529        if isinstance(x, list) and a is None:
530            x = [(X.to_poly(), a) for X, a in x]
531            numer = (self.numerator.compose([(X, a.numerator) for X, a in x]) *
532                     self.denominator.compose([(X, a.denominator) for X, a in x]))
533            denom = (self.numerator.compose([(X, a.denominator) for X, a in x]) *
534                     self.denominator.compose([(X, a.numerator) for X, a in x]))
535        else:
536            x = x.to_poly()
537            numer = (self.numerator.compose(x, a.numerator) *
538                     self.denominator.compose(x, a.denominator))
539            denom = (self.numerator.compose(x, a.denominator) *
540                     self.denominator.compose(x, a.numerator))
541
542        return field((field.ring(numer), field.ring(denom)))
543