1from sympy.core.add import Add
2from sympy.core.assumptions import check_assumptions
3from sympy.core.containers import Tuple
4from sympy.core.compatibility import as_int, is_sequence, ordered
5from sympy.core.exprtools import factor_terms
6from sympy.core.function import _mexpand
7from sympy.core.mul import Mul
8from sympy.core.numbers import Rational
9from sympy.core.numbers import igcdex, ilcm, igcd
10from sympy.core.power import integer_nthroot, isqrt
11from sympy.core.relational import Eq
12from sympy.core.singleton import S
13from sympy.core.symbol import Symbol, symbols
14from sympy.core.sympify import _sympify
15from sympy.functions.elementary.complexes import sign
16from sympy.functions.elementary.integers import floor
17from sympy.functions.elementary.miscellaneous import sqrt
18from sympy.matrices.dense import MutableDenseMatrix as Matrix
19from sympy.ntheory.factor_ import (
20    divisors, factorint, multiplicity, perfect_power)
21from sympy.ntheory.generate import nextprime
22from sympy.ntheory.primetest import is_square, isprime
23from sympy.ntheory.residue_ntheory import sqrt_mod
24from sympy.polys.polyerrors import GeneratorsNeeded
25from sympy.polys.polytools import Poly, factor_list
26from sympy.simplify.simplify import signsimp
27from sympy.solvers.solveset import solveset_real
28from sympy.utilities import default_sort_key, numbered_symbols
29from sympy.utilities.misc import filldedent
30
31
32# these are imported with 'from sympy.solvers.diophantine import *
33__all__ = ['diophantine', 'classify_diop']
34
35
36class DiophantineSolutionSet(set):
37    """
38    Container for a set of solutions to a particular diophantine equation.
39
40    The base representation is a set of tuples representing each of the solutions.
41
42    Parameters
43    ==========
44
45    symbols : list
46        List of free symbols in the original equation.
47    parameters: list
48        List of parameters to be used in the solution.
49
50    Examples
51    ========
52
53    Adding solutions:
54
55        >>> from sympy.solvers.diophantine.diophantine import DiophantineSolutionSet
56        >>> from sympy.abc import x, y, t, u
57        >>> s1 = DiophantineSolutionSet([x, y], [t, u])
58        >>> s1
59        set()
60        >>> s1.add((2, 3))
61        >>> s1.add((-1, u))
62        >>> s1
63        {(-1, u), (2, 3)}
64        >>> s2 = DiophantineSolutionSet([x, y], [t, u])
65        >>> s2.add((3, 4))
66        >>> s1.update(*s2)
67        >>> s1
68        {(-1, u), (2, 3), (3, 4)}
69
70    Conversion of solutions into dicts:
71
72        >>> list(s1.dict_iterator())
73        [{x: -1, y: u}, {x: 2, y: 3}, {x: 3, y: 4}]
74
75    Substituting values:
76
77        >>> s3 = DiophantineSolutionSet([x, y], [t, u])
78        >>> s3.add((t**2, t + u))
79        >>> s3
80        {(t**2, t + u)}
81        >>> s3.subs({t: 2, u: 3})
82        {(4, 5)}
83        >>> s3.subs(t, -1)
84        {(1, u - 1)}
85        >>> s3.subs(t, 3)
86        {(9, u + 3)}
87
88    Evaluation at specific values. Positional arguments are given in the same order as the parameters:
89
90        >>> s3(-2, 3)
91        {(4, 1)}
92        >>> s3(5)
93        {(25, u + 5)}
94        >>> s3(None, 2)
95        {(t**2, t + 2)}
96    """
97
98    def __init__(self, symbols_seq, parameters):
99        super().__init__()
100
101        if not is_sequence(symbols_seq):
102            raise ValueError("Symbols must be given as a sequence.")
103
104        if not is_sequence(parameters):
105            raise ValueError("Parameters must be given as a sequence.")
106
107        self.symbols = tuple(symbols_seq)
108        self.parameters = tuple(parameters)
109
110    def add(self, solution):
111        if len(solution) != len(self.symbols):
112            raise ValueError("Solution should have a length of %s, not %s" % (len(self.symbols), len(solution)))
113        super().add(Tuple(*solution))
114
115    def update(self, *solutions):
116        for solution in solutions:
117            self.add(solution)
118
119    def dict_iterator(self):
120        for solution in ordered(self):
121            yield dict(zip(self.symbols, solution))
122
123    def subs(self, *args, **kwargs):
124        result = DiophantineSolutionSet(self.symbols, self.parameters)
125        for solution in self:
126            result.add(solution.subs(*args, **kwargs))
127        return result
128
129    def __call__(self, *args):
130        if len(args) > len(self.parameters):
131            raise ValueError("Evaluation should have at most %s values, not %s" % (len(self.parameters), len(args)))
132        return self.subs(list(zip(self.parameters, args)))
133
134
135class DiophantineEquationType:
136    """
137    Internal representation of a particular diophantine equation type.
138
139    Parameters
140    ==========
141
142    equation :
143        The diophantine equation that is being solved.
144    free_symbols : list (optional)
145        The symbols being solved for.
146
147    Attributes
148    ==========
149
150    total_degree :
151        The maximum of the degrees of all terms in the equation
152    homogeneous :
153        Does the equation contain a term of degree 0
154    homogeneous_order :
155        Does the equation contain any coefficient that is in the symbols being solved for
156    dimension :
157        The number of symbols being solved for
158    """
159    name = None  # type: str
160
161    def __init__(self, equation, free_symbols=None):
162        self.equation = _sympify(equation).expand(force=True)
163
164        if free_symbols is not None:
165            self.free_symbols = free_symbols
166        else:
167            self.free_symbols = list(self.equation.free_symbols)
168            self.free_symbols.sort(key=default_sort_key)
169
170        if not self.free_symbols:
171            raise ValueError('equation should have 1 or more free symbols')
172
173        self.coeff = self.equation.as_coefficients_dict()
174        if not all(_is_int(c) for c in self.coeff.values()):
175            raise TypeError("Coefficients should be Integers")
176
177        self.total_degree = Poly(self.equation).total_degree()
178        self.homogeneous = 1 not in self.coeff
179        self.homogeneous_order = not (set(self.coeff) & set(self.free_symbols))
180        self.dimension = len(self.free_symbols)
181        self._parameters = None
182
183    def matches(self):
184        """
185        Determine whether the given equation can be matched to the particular equation type.
186        """
187        return False
188
189    @property
190    def n_parameters(self):
191        return self.dimension
192
193    @property
194    def parameters(self):
195        if self._parameters is None:
196            self._parameters = symbols('t_:%i' % (self.n_parameters,), integer=True)
197        return self._parameters
198
199    def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet:
200        raise NotImplementedError('No solver has been written for %s.' % self.name)
201
202    def pre_solve(self, parameters=None):
203        if not self.matches():
204            raise ValueError("This equation does not match the %s equation type." % self.name)
205
206        if parameters is not None:
207            if len(parameters) != self.n_parameters:
208                raise ValueError("Expected %s parameter(s) but got %s" % (self.n_parameters, len(parameters)))
209
210        self._parameters = parameters
211
212
213class Univariate(DiophantineEquationType):
214    """
215    Representation of a univariate diophantine equation.
216
217    A univariate diophantine equation is an equation of the form
218    `a_{0} + a_{1}x + a_{2}x^2 + .. + a_{n}x^n = 0` where `a_{1}, a_{2}, ..a_{n}` are
219    integer constants and `x` is an integer variable.
220
221    Examples
222    ========
223
224    >>> from sympy.solvers.diophantine.diophantine import Univariate
225    >>> from sympy.abc import x
226    >>> Univariate((x - 2)*(x - 3)**2).solve() # solves equation (x - 2)*(x - 3)**2 == 0
227    {(2,), (3,)}
228
229    """
230
231    name = 'univariate'
232
233    def matches(self):
234        return self.dimension == 1
235
236    def solve(self, parameters=None, limit=None):
237        self.pre_solve(parameters)
238
239        result = DiophantineSolutionSet(self.free_symbols, parameters=self.parameters)
240        for i in solveset_real(self.equation, self.free_symbols[0]).intersect(S.Integers):
241            result.add((i,))
242        return result
243
244
245class Linear(DiophantineEquationType):
246    """
247    Representation of a linear diophantine equation.
248
249    A linear diophantine equation is an equation of the form `a_{1}x_{1} +
250    a_{2}x_{2} + .. + a_{n}x_{n} = 0` where `a_{1}, a_{2}, ..a_{n}` are
251    integer constants and `x_{1}, x_{2}, ..x_{n}` are integer variables.
252
253    Examples
254    ========
255
256    >>> from sympy.solvers.diophantine.diophantine import Linear
257    >>> from sympy.abc import x, y, z
258    >>> l1 = Linear(2*x - 3*y - 5)
259    >>> l1.matches() # is this equation linear
260    True
261    >>> l1.solve() # solves equation 2*x - 3*y - 5 == 0
262    {(3*t_0 - 5, 2*t_0 - 5)}
263
264    Here x = -3*t_0 - 5 and y = -2*t_0 - 5
265
266    >>> Linear(2*x - 3*y - 4*z -3).solve()
267    {(t_0, 2*t_0 + 4*t_1 + 3, -t_0 - 3*t_1 - 3)}
268
269    """
270
271    name = 'linear'
272
273    def matches(self):
274        return self.total_degree == 1
275
276    def solve(self, parameters=None, limit=None):
277        self.pre_solve(parameters)
278
279        coeff = self.coeff
280        var = self.free_symbols
281
282        if 1 in coeff:
283            # negate coeff[] because input is of the form: ax + by + c ==  0
284            #                              but is used as: ax + by     == -c
285            c = -coeff[1]
286        else:
287            c = 0
288
289        result = DiophantineSolutionSet(var, parameters=self.parameters)
290        params = result.parameters
291
292        if len(var) == 1:
293            q, r = divmod(c, coeff[var[0]])
294            if not r:
295                result.add((q,))
296                return result
297            else:
298                return result
299
300        '''
301        base_solution_linear() can solve diophantine equations of the form:
302
303        a*x + b*y == c
304
305        We break down multivariate linear diophantine equations into a
306        series of bivariate linear diophantine equations which can then
307        be solved individually by base_solution_linear().
308
309        Consider the following:
310
311        a_0*x_0 + a_1*x_1 + a_2*x_2 == c
312
313        which can be re-written as:
314
315        a_0*x_0 + g_0*y_0 == c
316
317        where
318
319        g_0 == gcd(a_1, a_2)
320
321        and
322
323        y == (a_1*x_1)/g_0 + (a_2*x_2)/g_0
324
325        This leaves us with two binary linear diophantine equations.
326        For the first equation:
327
328        a == a_0
329        b == g_0
330        c == c
331
332        For the second:
333
334        a == a_1/g_0
335        b == a_2/g_0
336        c == the solution we find for y_0 in the first equation.
337
338        The arrays A and B are the arrays of integers used for
339        'a' and 'b' in each of the n-1 bivariate equations we solve.
340        '''
341
342        A = [coeff[v] for v in var]
343        B = []
344        if len(var) > 2:
345            B.append(igcd(A[-2], A[-1]))
346            A[-2] = A[-2] // B[0]
347            A[-1] = A[-1] // B[0]
348            for i in range(len(A) - 3, 0, -1):
349                gcd = igcd(B[0], A[i])
350                B[0] = B[0] // gcd
351                A[i] = A[i] // gcd
352                B.insert(0, gcd)
353        B.append(A[-1])
354
355        '''
356        Consider the trivariate linear equation:
357
358        4*x_0 + 6*x_1 + 3*x_2 == 2
359
360        This can be re-written as:
361
362        4*x_0 + 3*y_0 == 2
363
364        where
365
366        y_0 == 2*x_1 + x_2
367        (Note that gcd(3, 6) == 3)
368
369        The complete integral solution to this equation is:
370
371        x_0 ==  2 + 3*t_0
372        y_0 == -2 - 4*t_0
373
374        where 't_0' is any integer.
375
376        Now that we have a solution for 'x_0', find 'x_1' and 'x_2':
377
378        2*x_1 + x_2 == -2 - 4*t_0
379
380        We can then solve for '-2' and '-4' independently,
381        and combine the results:
382
383        2*x_1a + x_2a == -2
384        x_1a == 0 + t_0
385        x_2a == -2 - 2*t_0
386
387        2*x_1b + x_2b == -4*t_0
388        x_1b == 0*t_0 + t_1
389        x_2b == -4*t_0 - 2*t_1
390
391        ==>
392
393        x_1 == t_0 + t_1
394        x_2 == -2 - 6*t_0 - 2*t_1
395
396        where 't_0' and 't_1' are any integers.
397
398        Note that:
399
400        4*(2 + 3*t_0) + 6*(t_0 + t_1) + 3*(-2 - 6*t_0 - 2*t_1) == 2
401
402        for any integral values of 't_0', 't_1'; as required.
403
404        This method is generalised for many variables, below.
405
406        '''
407        solutions = []
408        for i in range(len(B)):
409            tot_x, tot_y = [], []
410
411            for j, arg in enumerate(Add.make_args(c)):
412                if arg.is_Integer:
413                    # example: 5 -> k = 5
414                    k, p = arg, S.One
415                    pnew = params[0]
416                else:  # arg is a Mul or Symbol
417                    # example: 3*t_1 -> k = 3
418                    # example: t_0 -> k = 1
419                    k, p = arg.as_coeff_Mul()
420                    pnew = params[params.index(p) + 1]
421
422                sol = sol_x, sol_y = base_solution_linear(k, A[i], B[i], pnew)
423
424                if p is S.One:
425                    if None in sol:
426                        return result
427                else:
428                    # convert a + b*pnew -> a*p + b*pnew
429                    if isinstance(sol_x, Add):
430                        sol_x = sol_x.args[0]*p + sol_x.args[1]
431                    if isinstance(sol_y, Add):
432                        sol_y = sol_y.args[0]*p + sol_y.args[1]
433
434                tot_x.append(sol_x)
435                tot_y.append(sol_y)
436
437            solutions.append(Add(*tot_x))
438            c = Add(*tot_y)
439
440        solutions.append(c)
441        result.add(solutions)
442        return result
443
444
445class BinaryQuadratic(DiophantineEquationType):
446    """
447    Representation of a binary quadratic diophantine equation.
448
449    A binary quadratic diophantine equation is an equation of the
450    form `Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0`, where `A, B, C, D, E,
451    F` are integer constants and `x` and `y` are integer variables.
452
453    Examples
454    ========
455
456    >>> from sympy.abc import x, y
457    >>> from sympy.solvers.diophantine.diophantine import BinaryQuadratic
458    >>> b1 = BinaryQuadratic(x**3 + y**2 + 1)
459    >>> b1.matches()
460    False
461    >>> b2 = BinaryQuadratic(x**2 + y**2 + 2*x + 2*y + 2)
462    >>> b2.matches()
463    True
464    >>> b2.solve()
465    {(-1, -1)}
466
467    References
468    ==========
469
470    .. [1] Methods to solve Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0, [online],
471          Available: http://www.alpertron.com.ar/METHODS.HTM
472    .. [2] Solving the equation ax^2+ bxy + cy^2 + dx + ey + f= 0, [online],
473          Available: https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf
474
475    """
476
477    name = 'binary_quadratic'
478
479    def matches(self):
480        return self.total_degree == 2 and self.dimension == 2
481
482    def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet:
483        self.pre_solve(parameters)
484
485        var = self.free_symbols
486        coeff = self.coeff
487
488        x, y = var
489
490        A = coeff[x**2]
491        B = coeff[x*y]
492        C = coeff[y**2]
493        D = coeff[x]
494        E = coeff[y]
495        F = coeff[S.One]
496
497        A, B, C, D, E, F = [as_int(i) for i in _remove_gcd(A, B, C, D, E, F)]
498
499        # (1) Simple-Hyperbolic case: A = C = 0, B != 0
500        # In this case equation can be converted to (Bx + E)(By + D) = DE - BF
501        # We consider two cases; DE - BF = 0 and DE - BF != 0
502        # More details, http://www.alpertron.com.ar/METHODS.HTM#SHyperb
503
504        result = DiophantineSolutionSet(var, self.parameters)
505        t, u = result.parameters
506
507        discr = B**2 - 4*A*C
508        if A == 0 and C == 0 and B != 0:
509
510            if D*E - B*F == 0:
511                q, r = divmod(E, B)
512                if not r:
513                    result.add((-q, t))
514                q, r = divmod(D, B)
515                if not r:
516                    result.add((t, -q))
517            else:
518                div = divisors(D*E - B*F)
519                div = div + [-term for term in div]
520                for d in div:
521                    x0, r = divmod(d - E, B)
522                    if not r:
523                        q, r = divmod(D*E - B*F, d)
524                        if not r:
525                            y0, r = divmod(q - D, B)
526                            if not r:
527                                result.add((x0, y0))
528
529        # (2) Parabolic case: B**2 - 4*A*C = 0
530        # There are two subcases to be considered in this case.
531        # sqrt(c)D - sqrt(a)E = 0 and sqrt(c)D - sqrt(a)E != 0
532        # More Details, http://www.alpertron.com.ar/METHODS.HTM#Parabol
533
534        elif discr == 0:
535
536            if A == 0:
537                s = BinaryQuadratic(self.equation, free_symbols=[y, x]).solve(parameters=[t, u])
538                for soln in s:
539                    result.add((soln[1], soln[0]))
540
541            else:
542                g = sign(A)*igcd(A, C)
543                a = A // g
544                c = C // g
545                e = sign(B / A)
546
547                sqa = isqrt(a)
548                sqc = isqrt(c)
549                _c = e*sqc*D - sqa*E
550                if not _c:
551                    z = symbols("z", real=True)
552                    eq = sqa*g*z**2 + D*z + sqa*F
553                    roots = solveset_real(eq, z).intersect(S.Integers)
554                    for root in roots:
555                        ans = diop_solve(sqa*x + e*sqc*y - root)
556                        result.add((ans[0], ans[1]))
557
558                elif _is_int(c):
559                    solve_x = lambda u: -e*sqc*g*_c*t**2 - (E + 2*e*sqc*g*u)*t \
560                                        - (e*sqc*g*u**2 + E*u + e*sqc*F) // _c
561
562                    solve_y = lambda u: sqa*g*_c*t**2 + (D + 2*sqa*g*u)*t \
563                                        + (sqa*g*u**2 + D*u + sqa*F) // _c
564
565                    for z0 in range(0, abs(_c)):
566                        # Check if the coefficients of y and x obtained are integers or not
567                        if (divisible(sqa*g*z0**2 + D*z0 + sqa*F, _c) and
568                            divisible(e*sqc*g*z0**2 + E*z0 + e*sqc*F, _c)):
569                            result.add((solve_x(z0), solve_y(z0)))
570
571        # (3) Method used when B**2 - 4*A*C is a square, is described in p. 6 of the below paper
572        # by John P. Robertson.
573        # https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf
574
575        elif is_square(discr):
576            if A != 0:
577                r = sqrt(discr)
578                u, v = symbols("u, v", integer=True)
579                eq = _mexpand(
580                    4*A*r*u*v + 4*A*D*(B*v + r*u + r*v - B*u) +
581                    2*A*4*A*E*(u - v) + 4*A*r*4*A*F)
582
583                solution = diop_solve(eq, t)
584
585                for s0, t0 in solution:
586
587                    num = B*t0 + r*s0 + r*t0 - B*s0
588                    x_0 = S(num) / (4*A*r)
589                    y_0 = S(s0 - t0) / (2*r)
590                    if isinstance(s0, Symbol) or isinstance(t0, Symbol):
591                        if len(check_param(x_0, y_0, 4*A*r, parameters)) > 0:
592                            ans = check_param(x_0, y_0, 4*A*r, parameters)
593                            result.update(*ans)
594                    elif x_0.is_Integer and y_0.is_Integer:
595                        if is_solution_quad(var, coeff, x_0, y_0):
596                            result.add((x_0, y_0))
597
598            else:
599                s = BinaryQuadratic(self.equation, free_symbols=var[::-1]).solve(parameters=[t, u])  # Interchange x and y
600                while s:
601                    result.add(s.pop()[::-1])  # and solution <--------+
602
603        # (4) B**2 - 4*A*C > 0 and B**2 - 4*A*C not a square or B**2 - 4*A*C < 0
604
605        else:
606
607            P, Q = _transformation_to_DN(var, coeff)
608            D, N = _find_DN(var, coeff)
609            solns_pell = diop_DN(D, N)
610
611            if D < 0:
612                for x0, y0 in solns_pell:
613                    for x in [-x0, x0]:
614                        for y in [-y0, y0]:
615                            s = P*Matrix([x, y]) + Q
616                            try:
617                                result.add([as_int(_) for _ in s])
618                            except ValueError:
619                                pass
620            else:
621                # In this case equation can be transformed into a Pell equation
622
623                solns_pell = set(solns_pell)
624                for X, Y in list(solns_pell):
625                    solns_pell.add((-X, -Y))
626
627                a = diop_DN(D, 1)
628                T = a[0][0]
629                U = a[0][1]
630
631                if all(_is_int(_) for _ in P[:4] + Q[:2]):
632                    for r, s in solns_pell:
633                        _a = (r + s*sqrt(D))*(T + U*sqrt(D))**t
634                        _b = (r - s*sqrt(D))*(T - U*sqrt(D))**t
635                        x_n = _mexpand(S(_a + _b) / 2)
636                        y_n = _mexpand(S(_a - _b) / (2*sqrt(D)))
637                        s = P*Matrix([x_n, y_n]) + Q
638                        result.add(s)
639
640                else:
641                    L = ilcm(*[_.q for _ in P[:4] + Q[:2]])
642
643                    k = 1
644
645                    T_k = T
646                    U_k = U
647
648                    while (T_k - 1) % L != 0 or U_k % L != 0:
649                        T_k, U_k = T_k*T + D*U_k*U, T_k*U + U_k*T
650                        k += 1
651
652                    for X, Y in solns_pell:
653
654                        for i in range(k):
655                            if all(_is_int(_) for _ in P*Matrix([X, Y]) + Q):
656                                _a = (X + sqrt(D)*Y)*(T_k + sqrt(D)*U_k)**t
657                                _b = (X - sqrt(D)*Y)*(T_k - sqrt(D)*U_k)**t
658                                Xt = S(_a + _b) / 2
659                                Yt = S(_a - _b) / (2*sqrt(D))
660                                s = P*Matrix([Xt, Yt]) + Q
661                                result.add(s)
662
663                            X, Y = X*T + D*U*Y, X*U + Y*T
664
665        return result
666
667
668class InhomogeneousTernaryQuadratic(DiophantineEquationType):
669    """
670
671    Representation of an inhomogeneous ternary quadratic.
672
673    No solver is currently implemented for this equation type.
674
675    """
676
677    name = 'inhomogeneous_ternary_quadratic'
678
679    def matches(self):
680        if not (self.total_degree == 2 and self.dimension == 3):
681            return False
682        if not self.homogeneous:
683            return False
684        return not self.homogeneous_order
685
686
687class HomogeneousTernaryQuadraticNormal(DiophantineEquationType):
688    """
689    Representation of a homogeneous ternary quadratic normal diophantine equation.
690
691    Examples
692    ========
693
694    >>> from sympy.abc import x, y, z
695    >>> from sympy.solvers.diophantine.diophantine import HomogeneousTernaryQuadraticNormal
696    >>> HomogeneousTernaryQuadraticNormal(4*x**2 - 5*y**2 + z**2).solve()
697    {(1, 2, 4)}
698
699    """
700
701    name = 'homogeneous_ternary_quadratic_normal'
702
703    def matches(self):
704        if not (self.total_degree == 2 and self.dimension == 3):
705            return False
706        if not self.homogeneous:
707            return False
708        if not self.homogeneous_order:
709            return False
710
711        nonzero = [k for k in self.coeff if self.coeff[k]]
712        return len(nonzero) == 3 and all(i**2 in nonzero for i in self.free_symbols)
713
714    def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet:
715        self.pre_solve(parameters)
716
717        var = self.free_symbols
718        coeff = self.coeff
719
720        x, y, z = var
721
722        a = coeff[x**2]
723        b = coeff[y**2]
724        c = coeff[z**2]
725
726        (sqf_of_a, sqf_of_b, sqf_of_c), (a_1, b_1, c_1), (a_2, b_2, c_2) = \
727            sqf_normal(a, b, c, steps=True)
728
729        A = -a_2*c_2
730        B = -b_2*c_2
731
732        result = DiophantineSolutionSet(var, parameters=self.parameters)
733
734        # If following two conditions are satisfied then there are no solutions
735        if A < 0 and B < 0:
736            return result
737
738        if (
739            sqrt_mod(-b_2*c_2, a_2) is None or
740            sqrt_mod(-c_2*a_2, b_2) is None or
741            sqrt_mod(-a_2*b_2, c_2) is None):
742            return result
743
744        z_0, x_0, y_0 = descent(A, B)
745
746        z_0, q = _rational_pq(z_0, abs(c_2))
747        x_0 *= q
748        y_0 *= q
749
750        x_0, y_0, z_0 = _remove_gcd(x_0, y_0, z_0)
751
752        # Holzer reduction
753        if sign(a) == sign(b):
754            x_0, y_0, z_0 = holzer(x_0, y_0, z_0, abs(a_2), abs(b_2), abs(c_2))
755        elif sign(a) == sign(c):
756            x_0, z_0, y_0 = holzer(x_0, z_0, y_0, abs(a_2), abs(c_2), abs(b_2))
757        else:
758            y_0, z_0, x_0 = holzer(y_0, z_0, x_0, abs(b_2), abs(c_2), abs(a_2))
759
760        x_0 = reconstruct(b_1, c_1, x_0)
761        y_0 = reconstruct(a_1, c_1, y_0)
762        z_0 = reconstruct(a_1, b_1, z_0)
763
764        sq_lcm = ilcm(sqf_of_a, sqf_of_b, sqf_of_c)
765
766        x_0 = abs(x_0*sq_lcm // sqf_of_a)
767        y_0 = abs(y_0*sq_lcm // sqf_of_b)
768        z_0 = abs(z_0*sq_lcm // sqf_of_c)
769
770        result.add(_remove_gcd(x_0, y_0, z_0))
771        return result
772
773
774class HomogeneousTernaryQuadratic(DiophantineEquationType):
775    """
776    Representation of a homogeneous ternary quadratic diophantine equation.
777
778    Examples
779    ========
780
781    >>> from sympy.abc import x, y, z
782    >>> from sympy.solvers.diophantine.diophantine import HomogeneousTernaryQuadratic
783    >>> HomogeneousTernaryQuadratic(x**2 + y**2 - 3*z**2 + x*y).solve()
784    {(-1, 2, 1)}
785    >>> HomogeneousTernaryQuadratic(3*x**2 + y**2 - 3*z**2 + 5*x*y + y*z).solve()
786    {(3, 12, 13)}
787
788    """
789
790    name = 'homogeneous_ternary_quadratic'
791
792    def matches(self):
793        if not (self.total_degree == 2 and self.dimension == 3):
794            return False
795        if not self.homogeneous:
796            return False
797        if not self.homogeneous_order:
798            return False
799
800        nonzero = [k for k in self.coeff if self.coeff[k]]
801        return not (len(nonzero) == 3 and all(i**2 in nonzero for i in self.free_symbols))
802
803    def solve(self, parameters=None, limit=None):
804        self.pre_solve(parameters)
805
806        _var = self.free_symbols
807        coeff = self.coeff
808
809        x, y, z = _var
810        var = [x, y, z]
811
812        # Equations of the form B*x*y + C*z*x + E*y*z = 0 and At least two of the
813        # coefficients A, B, C are non-zero.
814        # There are infinitely many solutions for the equation.
815        # Ex: (0, 0, t), (0, t, 0), (t, 0, 0)
816        # Equation can be re-written as y*(B*x + E*z) = -C*x*z and we can find rather
817        # unobvious solutions. Set y = -C and B*x + E*z = x*z. The latter can be solved by
818        # using methods for binary quadratic diophantine equations. Let's select the
819        # solution which minimizes |x| + |z|
820
821        result = DiophantineSolutionSet(var, parameters=self.parameters)
822
823        def unpack_sol(sol):
824            if len(sol) > 0:
825                return list(sol)[0]
826            return None, None, None
827
828        if not any(coeff[i**2] for i in var):
829            if coeff[x*z]:
830                sols = diophantine(coeff[x*y]*x + coeff[y*z]*z - x*z)
831                s = sols.pop()
832                min_sum = abs(s[0]) + abs(s[1])
833
834                for r in sols:
835                    m = abs(r[0]) + abs(r[1])
836                    if m < min_sum:
837                        s = r
838                        min_sum = m
839
840                result.add(_remove_gcd(s[0], -coeff[x*z], s[1]))
841                return result
842
843            else:
844                var[0], var[1] = _var[1], _var[0]
845                y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff))
846                if x_0 is not None:
847                    result.add((x_0, y_0, z_0))
848                return result
849
850        if coeff[x**2] == 0:
851            # If the coefficient of x is zero change the variables
852            if coeff[y**2] == 0:
853                var[0], var[2] = _var[2], _var[0]
854                z_0, y_0, x_0 = unpack_sol(_diop_ternary_quadratic(var, coeff))
855
856            else:
857                var[0], var[1] = _var[1], _var[0]
858                y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff))
859
860        else:
861            if coeff[x*y] or coeff[x*z]:
862                # Apply the transformation x --> X - (B*y + C*z)/(2*A)
863                A = coeff[x**2]
864                B = coeff[x*y]
865                C = coeff[x*z]
866                D = coeff[y**2]
867                E = coeff[y*z]
868                F = coeff[z**2]
869
870                _coeff = dict()
871
872                _coeff[x**2] = 4*A**2
873                _coeff[y**2] = 4*A*D - B**2
874                _coeff[z**2] = 4*A*F - C**2
875                _coeff[y*z] = 4*A*E - 2*B*C
876                _coeff[x*y] = 0
877                _coeff[x*z] = 0
878
879                x_0, y_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, _coeff))
880
881                if x_0 is None:
882                    return result
883
884                p, q = _rational_pq(B*y_0 + C*z_0, 2*A)
885                x_0, y_0, z_0 = x_0*q - p, y_0*q, z_0*q
886
887            elif coeff[z*y] != 0:
888                if coeff[y**2] == 0:
889                    if coeff[z**2] == 0:
890                        # Equations of the form A*x**2 + E*yz = 0.
891                        A = coeff[x**2]
892                        E = coeff[y*z]
893
894                        b, a = _rational_pq(-E, A)
895
896                        x_0, y_0, z_0 = b, a, b
897
898                    else:
899                        # Ax**2 + E*y*z + F*z**2  = 0
900                        var[0], var[2] = _var[2], _var[0]
901                        z_0, y_0, x_0 = unpack_sol(_diop_ternary_quadratic(var, coeff))
902
903                else:
904                    # A*x**2 + D*y**2 + E*y*z + F*z**2 = 0, C may be zero
905                    var[0], var[1] = _var[1], _var[0]
906                    y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff))
907
908            else:
909                # Ax**2 + D*y**2 + F*z**2 = 0, C may be zero
910                x_0, y_0, z_0 = unpack_sol(_diop_ternary_quadratic_normal(var, coeff))
911
912        if x_0 is None:
913            return result
914
915        result.add(_remove_gcd(x_0, y_0, z_0))
916        return result
917
918
919class InhomogeneousGeneralQuadratic(DiophantineEquationType):
920    """
921
922    Representation of an inhomogeneous general quadratic.
923
924    No solver is currently implemented for this equation type.
925
926    """
927
928    name = 'inhomogeneous_general_quadratic'
929
930    def matches(self):
931        if not (self.total_degree == 2 and self.dimension >= 3):
932            return False
933        if not self.homogeneous_order:
934            return True
935        else:
936            # there may be Pow keys like x**2 or Mul keys like x*y
937            if any(k.is_Mul for k in self.coeff): # cross terms
938                return not self.homogeneous
939        return False
940
941
942class HomogeneousGeneralQuadratic(DiophantineEquationType):
943    """
944
945    Representation of a homogeneous general quadratic.
946
947    No solver is currently implemented for this equation type.
948
949    """
950
951    name = 'homogeneous_general_quadratic'
952
953    def matches(self):
954        if not (self.total_degree == 2 and self.dimension >= 3):
955            return False
956        if not self.homogeneous_order:
957            return False
958        else:
959            # there may be Pow keys like x**2 or Mul keys like x*y
960            if any(k.is_Mul for k in self.coeff): # cross terms
961                return self.homogeneous
962        return False
963
964
965class GeneralSumOfSquares(DiophantineEquationType):
966    r"""
967    Representation of the diophantine equation
968
969    `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`.
970
971    Details
972    =======
973
974    When `n = 3` if `k = 4^a(8m + 7)` for some `a, m \in Z` then there will be
975    no solutions. Refer [1]_ for more details.
976
977    Examples
978    ========
979
980    >>> from sympy.solvers.diophantine.diophantine import GeneralSumOfSquares
981    >>> from sympy.abc import a, b, c, d, e
982    >>> GeneralSumOfSquares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345).solve()
983    {(15, 22, 22, 24, 24)}
984
985    By default only 1 solution is returned. Use the `limit` keyword for more:
986
987    >>> sorted(GeneralSumOfSquares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345).solve(limit=3))
988    [(15, 22, 22, 24, 24), (16, 19, 24, 24, 24), (16, 20, 22, 23, 26)]
989
990    References
991    ==========
992
993    .. [1] Representing an integer as a sum of three squares, [online],
994        Available:
995        http://www.proofwiki.org/wiki/Integer_as_Sum_of_Three_Squares
996    """
997
998    name = 'general_sum_of_squares'
999
1000    def matches(self):
1001        if not (self.total_degree == 2 and self.dimension >= 3):
1002            return False
1003        if not self.homogeneous_order:
1004            return False
1005        if any(k.is_Mul for k in self.coeff):
1006            return False
1007        return all(self.coeff[k] == 1 for k in self.coeff if k != 1)
1008
1009    def solve(self, parameters=None, limit=1):
1010        self.pre_solve(parameters)
1011
1012        var = self.free_symbols
1013        k = -int(self.coeff[1])
1014        n = self.dimension
1015
1016        result = DiophantineSolutionSet(var, parameters=self.parameters)
1017
1018        if k < 0 or limit < 1:
1019            return result
1020
1021        signs = [-1 if x.is_nonpositive else 1 for x in var]
1022        negs = signs.count(-1) != 0
1023
1024        took = 0
1025        for t in sum_of_squares(k, n, zeros=True):
1026            if negs:
1027                result.add([signs[i]*j for i, j in enumerate(t)])
1028            else:
1029                result.add(t)
1030            took += 1
1031            if took == limit:
1032                break
1033        return result
1034
1035
1036class GeneralPythagorean(DiophantineEquationType):
1037    """
1038    Representation of the general pythagorean equation,
1039    `a_{1}^2x_{1}^2 + a_{2}^2x_{2}^2 + . . . + a_{n}^2x_{n}^2 - a_{n + 1}^2x_{n + 1}^2 = 0`.
1040
1041    Examples
1042    ========
1043
1044    >>> from sympy.solvers.diophantine.diophantine import GeneralPythagorean
1045    >>> from sympy.abc import a, b, c, d, e, x, y, z, t
1046    >>> GeneralPythagorean(a**2 + b**2 + c**2 - d**2).solve()
1047    {(t_0**2 + t_1**2 - t_2**2, 2*t_0*t_2, 2*t_1*t_2, t_0**2 + t_1**2 + t_2**2)}
1048    >>> GeneralPythagorean(9*a**2 - 4*b**2 + 16*c**2 + 25*d**2 + e**2).solve(parameters=[x, y, z, t])
1049    {(-10*t**2 + 10*x**2 + 10*y**2 + 10*z**2, 15*t**2 + 15*x**2 + 15*y**2 + 15*z**2, 15*t*x, 12*t*y, 60*t*z)}
1050    """
1051
1052    name = 'general_pythagorean'
1053
1054    def matches(self):
1055        if not (self.total_degree == 2 and self.dimension >= 3):
1056            return False
1057        if not self.homogeneous_order:
1058            return False
1059        if any(k.is_Mul for k in self.coeff):
1060            return False
1061        if all(self.coeff[k] == 1 for k in self.coeff if k != 1):
1062            return False
1063        if not all(is_square(abs(self.coeff[k])) for k in self.coeff):
1064            return False
1065        # all but one has the same sign
1066        # e.g. 4*x**2 + y**2 - 4*z**2
1067        return abs(sum(sign(self.coeff[k]) for k in self.coeff)) == self.dimension - 2
1068
1069    @property
1070    def n_parameters(self):
1071        return self.dimension - 1
1072
1073    def solve(self, parameters=None, limit=1):
1074        self.pre_solve(parameters)
1075
1076        coeff = self.coeff
1077        var = self.free_symbols
1078        n = self.dimension
1079
1080        if sign(coeff[var[0] ** 2]) + sign(coeff[var[1] ** 2]) + sign(coeff[var[2] ** 2]) < 0:
1081            for key in coeff.keys():
1082                coeff[key] = -coeff[key]
1083
1084        result = DiophantineSolutionSet(var, parameters=self.parameters)
1085
1086        index = 0
1087
1088        for i, v in enumerate(var):
1089            if sign(coeff[v ** 2]) == -1:
1090                index = i
1091
1092        m = result.parameters
1093
1094        ith = sum(m_i ** 2 for m_i in m)
1095        L = [ith - 2 * m[n - 2] ** 2]
1096        L.extend([2 * m[i] * m[n - 2] for i in range(n - 2)])
1097        sol = L[:index] + [ith] + L[index:]
1098
1099        lcm = 1
1100        for i, v in enumerate(var):
1101            if i == index or (index > 0 and i == 0) or (index == 0 and i == 1):
1102                lcm = ilcm(lcm, sqrt(abs(coeff[v ** 2])))
1103            else:
1104                s = sqrt(coeff[v ** 2])
1105                lcm = ilcm(lcm, s if _odd(s) else s // 2)
1106
1107        for i, v in enumerate(var):
1108            sol[i] = (lcm * sol[i]) / sqrt(abs(coeff[v ** 2]))
1109
1110        result.add(sol)
1111        return result
1112
1113
1114class CubicThue(DiophantineEquationType):
1115    """
1116    Representation of a cubic Thue diophantine equation.
1117
1118    A cubic Thue diophantine equation is a polynomial of the form
1119    `f(x, y) = r` of degree 3, where `x` and `y` are integers
1120    and `r` is a rational number.
1121
1122    No solver is currently implemented for this equation type.
1123
1124    Examples
1125    ========
1126
1127    >>> from sympy.abc import x, y
1128    >>> from sympy.solvers.diophantine.diophantine import CubicThue
1129    >>> c1 = CubicThue(x**3 + y**2 + 1)
1130    >>> c1.matches()
1131    True
1132
1133    """
1134
1135    name = 'cubic_thue'
1136
1137    def matches(self):
1138        return self.total_degree == 3 and self.dimension == 2
1139
1140
1141class GeneralSumOfEvenPowers(DiophantineEquationType):
1142    """
1143    Representation of the diophantine equation
1144
1145    `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0`
1146
1147    where `e` is an even, integer power.
1148
1149    Examples
1150    ========
1151
1152    >>> from sympy.solvers.diophantine.diophantine import GeneralSumOfEvenPowers
1153    >>> from sympy.abc import a, b
1154    >>> GeneralSumOfEvenPowers(a**4 + b**4 - (2**4 + 3**4)).solve()
1155    {(2, 3)}
1156
1157    """
1158
1159    name = 'general_sum_of_even_powers'
1160
1161    def matches(self):
1162        if not self.total_degree > 3:
1163            return False
1164        if self.total_degree % 2 != 0:
1165            return False
1166        if not all(k.is_Pow and k.exp == self.total_degree for k in self.coeff if k != 1):
1167            return False
1168        return all(self.coeff[k] == 1 for k in self.coeff if k != 1)
1169
1170    def solve(self, parameters=None, limit=1):
1171        self.pre_solve(parameters)
1172
1173        var = self.free_symbols
1174        coeff = self.coeff
1175
1176        p = None
1177        for q in coeff.keys():
1178            if q.is_Pow and coeff[q]:
1179                p = q.exp
1180
1181        k = len(var)
1182        n = -coeff[1]
1183
1184        result = DiophantineSolutionSet(var, parameters=self.parameters)
1185
1186        if n < 0 or limit < 1:
1187            return result
1188
1189        sign = [-1 if x.is_nonpositive else 1 for x in var]
1190        negs = sign.count(-1) != 0
1191
1192        took = 0
1193        for t in power_representation(n, p, k):
1194            if negs:
1195                result.add([sign[i]*j for i, j in enumerate(t)])
1196            else:
1197                result.add(t)
1198            took += 1
1199            if took == limit:
1200                break
1201        return result
1202
1203# these types are known (but not necessarily handled)
1204# note that order is important here (in the current solver state)
1205all_diop_classes = [
1206    Linear,
1207    Univariate,
1208    BinaryQuadratic,
1209    InhomogeneousTernaryQuadratic,
1210    HomogeneousTernaryQuadraticNormal,
1211    HomogeneousTernaryQuadratic,
1212    InhomogeneousGeneralQuadratic,
1213    HomogeneousGeneralQuadratic,
1214    GeneralSumOfSquares,
1215    GeneralPythagorean,
1216    CubicThue,
1217    GeneralSumOfEvenPowers,
1218]
1219
1220diop_known = {diop_class.name for diop_class in all_diop_classes}
1221
1222
1223def _is_int(i):
1224    try:
1225        as_int(i)
1226        return True
1227    except ValueError:
1228        pass
1229
1230
1231def _sorted_tuple(*i):
1232    return tuple(sorted(i))
1233
1234
1235def _remove_gcd(*x):
1236    try:
1237        g = igcd(*x)
1238    except ValueError:
1239        fx = list(filter(None, x))
1240        if len(fx) < 2:
1241            return x
1242        g = igcd(*[i.as_content_primitive()[0] for i in fx])
1243    except TypeError:
1244        raise TypeError('_remove_gcd(a,b,c) or _remove_gcd(*container)')
1245    if g == 1:
1246        return x
1247    return tuple([i//g for i in x])
1248
1249
1250def _rational_pq(a, b):
1251    # return `(numer, denom)` for a/b; sign in numer and gcd removed
1252    return _remove_gcd(sign(b)*a, abs(b))
1253
1254
1255def _nint_or_floor(p, q):
1256    # return nearest int to p/q; in case of tie return floor(p/q)
1257    w, r = divmod(p, q)
1258    if abs(r) <= abs(q)//2:
1259        return w
1260    return w + 1
1261
1262
1263def _odd(i):
1264    return i % 2 != 0
1265
1266
1267def _even(i):
1268    return i % 2 == 0
1269
1270
1271def diophantine(eq, param=symbols("t", integer=True), syms=None,
1272                permute=False):
1273    """
1274    Simplify the solution procedure of diophantine equation ``eq`` by
1275    converting it into a product of terms which should equal zero.
1276
1277    Explanation
1278    ===========
1279
1280    For example, when solving, `x^2 - y^2 = 0` this is treated as
1281    `(x + y)(x - y) = 0` and `x + y = 0` and `x - y = 0` are solved
1282    independently and combined. Each term is solved by calling
1283    ``diop_solve()``. (Although it is possible to call ``diop_solve()``
1284    directly, one must be careful to pass an equation in the correct
1285    form and to interpret the output correctly; ``diophantine()`` is
1286    the public-facing function to use in general.)
1287
1288    Output of ``diophantine()`` is a set of tuples. The elements of the
1289    tuple are the solutions for each variable in the equation and
1290    are arranged according to the alphabetic ordering of the variables.
1291    e.g. For an equation with two variables, `a` and `b`, the first
1292    element of the tuple is the solution for `a` and the second for `b`.
1293
1294    Usage
1295    =====
1296
1297    ``diophantine(eq, t, syms)``: Solve the diophantine
1298    equation ``eq``.
1299    ``t`` is the optional parameter to be used by ``diop_solve()``.
1300    ``syms`` is an optional list of symbols which determines the
1301    order of the elements in the returned tuple.
1302
1303    By default, only the base solution is returned. If ``permute`` is set to
1304    True then permutations of the base solution and/or permutations of the
1305    signs of the values will be returned when applicable.
1306
1307    Examples
1308    ========
1309
1310    >>> from sympy.solvers.diophantine import diophantine
1311    >>> from sympy.abc import a, b
1312    >>> eq = a**4 + b**4 - (2**4 + 3**4)
1313    >>> diophantine(eq)
1314    {(2, 3)}
1315    >>> diophantine(eq, permute=True)
1316    {(-3, -2), (-3, 2), (-2, -3), (-2, 3), (2, -3), (2, 3), (3, -2), (3, 2)}
1317
1318    Details
1319    =======
1320
1321    ``eq`` should be an expression which is assumed to be zero.
1322    ``t`` is the parameter to be used in the solution.
1323
1324    Examples
1325    ========
1326
1327    >>> from sympy.abc import x, y, z
1328    >>> diophantine(x**2 - y**2)
1329    {(t_0, -t_0), (t_0, t_0)}
1330
1331    >>> diophantine(x*(2*x + 3*y - z))
1332    {(0, n1, n2), (t_0, t_1, 2*t_0 + 3*t_1)}
1333    >>> diophantine(x**2 + 3*x*y + 4*x)
1334    {(0, n1), (3*t_0 - 4, -t_0)}
1335
1336    See Also
1337    ========
1338
1339    diop_solve()
1340    sympy.utilities.iterables.permute_signs
1341    sympy.utilities.iterables.signed_permutations
1342    """
1343
1344    from sympy.utilities.iterables import (
1345        subsets, permute_signs, signed_permutations)
1346
1347    eq = _sympify(eq)
1348
1349    if isinstance(eq, Eq):
1350        eq = eq.lhs - eq.rhs
1351
1352    try:
1353        var = list(eq.expand(force=True).free_symbols)
1354        var.sort(key=default_sort_key)
1355        if syms:
1356            if not is_sequence(syms):
1357                raise TypeError(
1358                    'syms should be given as a sequence, e.g. a list')
1359            syms = [i for i in syms if i in var]
1360            if syms != var:
1361                dict_sym_index = dict(zip(syms, range(len(syms))))
1362                return {tuple([t[dict_sym_index[i]] for i in var])
1363                            for t in diophantine(eq, param, permute=permute)}
1364        n, d = eq.as_numer_denom()
1365        if n.is_number:
1366            return set()
1367        if not d.is_number:
1368            dsol = diophantine(d)
1369            good = diophantine(n) - dsol
1370            return {s for s in good if _mexpand(d.subs(zip(var, s)))}
1371        else:
1372            eq = n
1373        eq = factor_terms(eq)
1374        assert not eq.is_number
1375        eq = eq.as_independent(*var, as_Add=False)[1]
1376        p = Poly(eq)
1377        assert not any(g.is_number for g in p.gens)
1378        eq = p.as_expr()
1379        assert eq.is_polynomial()
1380    except (GeneratorsNeeded, AssertionError):
1381        raise TypeError(filldedent('''
1382    Equation should be a polynomial with Rational coefficients.'''))
1383
1384    # permute only sign
1385    do_permute_signs = False
1386    # permute sign and values
1387    do_permute_signs_var = False
1388    # permute few signs
1389    permute_few_signs = False
1390    try:
1391        # if we know that factoring should not be attempted, skip
1392        # the factoring step
1393        v, c, t = classify_diop(eq)
1394
1395        # check for permute sign
1396        if permute:
1397            len_var = len(v)
1398            permute_signs_for = [
1399                GeneralSumOfSquares.name,
1400                GeneralSumOfEvenPowers.name]
1401            permute_signs_check = [
1402                HomogeneousTernaryQuadratic.name,
1403                HomogeneousTernaryQuadraticNormal.name,
1404                BinaryQuadratic.name]
1405            if t in permute_signs_for:
1406                do_permute_signs_var = True
1407            elif t in permute_signs_check:
1408                # if all the variables in eq have even powers
1409                # then do_permute_sign = True
1410                if len_var == 3:
1411                    var_mul = list(subsets(v, 2))
1412                    # here var_mul is like [(x, y), (x, z), (y, z)]
1413                    xy_coeff = True
1414                    x_coeff = True
1415                    var1_mul_var2 = map(lambda a: a[0]*a[1], var_mul)
1416                    # if coeff(y*z), coeff(y*x), coeff(x*z) is not 0 then
1417                    # `xy_coeff` => True and do_permute_sign => False.
1418                    # Means no permuted solution.
1419                    for v1_mul_v2 in var1_mul_var2:
1420                        try:
1421                            coeff = c[v1_mul_v2]
1422                        except KeyError:
1423                            coeff = 0
1424                        xy_coeff = bool(xy_coeff) and bool(coeff)
1425                    var_mul = list(subsets(v, 1))
1426                    # here var_mul is like [(x,), (y, )]
1427                    for v1 in var_mul:
1428                        try:
1429                            coeff = c[v1[0]]
1430                        except KeyError:
1431                            coeff = 0
1432                        x_coeff = bool(x_coeff) and bool(coeff)
1433                    if not any([xy_coeff, x_coeff]):
1434                        # means only x**2, y**2, z**2, const is present
1435                        do_permute_signs = True
1436                    elif not x_coeff:
1437                        permute_few_signs = True
1438                elif len_var == 2:
1439                    var_mul = list(subsets(v, 2))
1440                    # here var_mul is like [(x, y)]
1441                    xy_coeff = True
1442                    x_coeff = True
1443                    var1_mul_var2 = map(lambda x: x[0]*x[1], var_mul)
1444                    for v1_mul_v2 in var1_mul_var2:
1445                        try:
1446                            coeff = c[v1_mul_v2]
1447                        except KeyError:
1448                            coeff = 0
1449                        xy_coeff = bool(xy_coeff) and bool(coeff)
1450                    var_mul = list(subsets(v, 1))
1451                    # here var_mul is like [(x,), (y, )]
1452                    for v1 in var_mul:
1453                        try:
1454                            coeff = c[v1[0]]
1455                        except KeyError:
1456                            coeff = 0
1457                        x_coeff = bool(x_coeff) and bool(coeff)
1458                    if not any([xy_coeff, x_coeff]):
1459                        # means only x**2, y**2 and const is present
1460                        # so we can get more soln by permuting this soln.
1461                        do_permute_signs = True
1462                    elif not x_coeff:
1463                        # when coeff(x), coeff(y) is not present then signs of
1464                        #  x, y can be permuted such that their sign are same
1465                        # as sign of x*y.
1466                        # e.g 1. (x_val,y_val)=> (x_val,y_val), (-x_val,-y_val)
1467                        # 2. (-x_vall, y_val)=> (-x_val,y_val), (x_val,-y_val)
1468                        permute_few_signs = True
1469        if t == 'general_sum_of_squares':
1470            # trying to factor such expressions will sometimes hang
1471            terms = [(eq, 1)]
1472        else:
1473            raise TypeError
1474    except (TypeError, NotImplementedError):
1475        fl = factor_list(eq)
1476        if fl[0].is_Rational and fl[0] != 1:
1477            return diophantine(eq/fl[0], param=param, syms=syms, permute=permute)
1478        terms = fl[1]
1479
1480    sols = set()
1481
1482    for term in terms:
1483
1484        base, _ = term
1485        var_t, _, eq_type = classify_diop(base, _dict=False)
1486        _, base = signsimp(base, evaluate=False).as_coeff_Mul()
1487        solution = diop_solve(base, param)
1488
1489        if eq_type in [
1490                Linear.name,
1491                HomogeneousTernaryQuadratic.name,
1492                HomogeneousTernaryQuadraticNormal.name,
1493                GeneralPythagorean.name]:
1494            sols.add(merge_solution(var, var_t, solution))
1495
1496        elif eq_type in [
1497                BinaryQuadratic.name,
1498                GeneralSumOfSquares.name,
1499                GeneralSumOfEvenPowers.name,
1500                Univariate.name]:
1501            for sol in solution:
1502                sols.add(merge_solution(var, var_t, sol))
1503
1504        else:
1505            raise NotImplementedError('unhandled type: %s' % eq_type)
1506
1507    # remove null merge results
1508    if () in sols:
1509        sols.remove(())
1510    null = tuple([0]*len(var))
1511    # if there is no solution, return trivial solution
1512    if not sols and eq.subs(zip(var, null)).is_zero:
1513        sols.add(null)
1514    final_soln = set()
1515    for sol in sols:
1516        if all(_is_int(s) for s in sol):
1517            if do_permute_signs:
1518                permuted_sign = set(permute_signs(sol))
1519                final_soln.update(permuted_sign)
1520            elif permute_few_signs:
1521                lst = list(permute_signs(sol))
1522                lst = list(filter(lambda x: x[0]*x[1] == sol[1]*sol[0], lst))
1523                permuted_sign = set(lst)
1524                final_soln.update(permuted_sign)
1525            elif do_permute_signs_var:
1526                permuted_sign_var = set(signed_permutations(sol))
1527                final_soln.update(permuted_sign_var)
1528            else:
1529                final_soln.add(sol)
1530        else:
1531                final_soln.add(sol)
1532    return final_soln
1533
1534
1535def merge_solution(var, var_t, solution):
1536    """
1537    This is used to construct the full solution from the solutions of sub
1538    equations.
1539
1540    Explanation
1541    ===========
1542
1543    For example when solving the equation `(x - y)(x^2 + y^2 - z^2) = 0`,
1544    solutions for each of the equations `x - y = 0` and `x^2 + y^2 - z^2` are
1545    found independently. Solutions for `x - y = 0` are `(x, y) = (t, t)`. But
1546    we should introduce a value for z when we output the solution for the
1547    original equation. This function converts `(t, t)` into `(t, t, n_{1})`
1548    where `n_{1}` is an integer parameter.
1549    """
1550    sol = []
1551
1552    if None in solution:
1553        return ()
1554
1555    solution = iter(solution)
1556    params = numbered_symbols("n", integer=True, start=1)
1557    for v in var:
1558        if v in var_t:
1559            sol.append(next(solution))
1560        else:
1561            sol.append(next(params))
1562
1563    for val, symb in zip(sol, var):
1564        if check_assumptions(val, **symb.assumptions0) is False:
1565            return tuple()
1566
1567    return tuple(sol)
1568
1569
1570def _diop_solve(eq, params=None):
1571    for diop_type in all_diop_classes:
1572        if diop_type(eq).matches():
1573            return diop_type(eq).solve(parameters=params)
1574
1575
1576def diop_solve(eq, param=symbols("t", integer=True)):
1577    """
1578    Solves the diophantine equation ``eq``.
1579
1580    Explanation
1581    ===========
1582
1583    Unlike ``diophantine()``, factoring of ``eq`` is not attempted. Uses
1584    ``classify_diop()`` to determine the type of the equation and calls
1585    the appropriate solver function.
1586
1587    Use of ``diophantine()`` is recommended over other helper functions.
1588    ``diop_solve()`` can return either a set or a tuple depending on the
1589    nature of the equation.
1590
1591    Usage
1592    =====
1593
1594    ``diop_solve(eq, t)``: Solve diophantine equation, ``eq`` using ``t``
1595    as a parameter if needed.
1596
1597    Details
1598    =======
1599
1600    ``eq`` should be an expression which is assumed to be zero.
1601    ``t`` is a parameter to be used in the solution.
1602
1603    Examples
1604    ========
1605
1606    >>> from sympy.solvers.diophantine import diop_solve
1607    >>> from sympy.abc import x, y, z, w
1608    >>> diop_solve(2*x + 3*y - 5)
1609    (3*t_0 - 5, 5 - 2*t_0)
1610    >>> diop_solve(4*x + 3*y - 4*z + 5)
1611    (t_0, 8*t_0 + 4*t_1 + 5, 7*t_0 + 3*t_1 + 5)
1612    >>> diop_solve(x + 3*y - 4*z + w - 6)
1613    (t_0, t_0 + t_1, 6*t_0 + 5*t_1 + 4*t_2 - 6, 5*t_0 + 4*t_1 + 3*t_2 - 6)
1614    >>> diop_solve(x**2 + y**2 - 5)
1615    {(-2, -1), (-2, 1), (-1, -2), (-1, 2), (1, -2), (1, 2), (2, -1), (2, 1)}
1616
1617
1618    See Also
1619    ========
1620
1621    diophantine()
1622    """
1623    var, coeff, eq_type = classify_diop(eq, _dict=False)
1624
1625    if eq_type == Linear.name:
1626        return diop_linear(eq, param)
1627
1628    elif eq_type == BinaryQuadratic.name:
1629        return diop_quadratic(eq, param)
1630
1631    elif eq_type == HomogeneousTernaryQuadratic.name:
1632        return diop_ternary_quadratic(eq, parameterize=True)
1633
1634    elif eq_type == HomogeneousTernaryQuadraticNormal.name:
1635        return diop_ternary_quadratic_normal(eq, parameterize=True)
1636
1637    elif eq_type == GeneralPythagorean.name:
1638        return diop_general_pythagorean(eq, param)
1639
1640    elif eq_type == Univariate.name:
1641        return diop_univariate(eq)
1642
1643    elif eq_type == GeneralSumOfSquares.name:
1644        return diop_general_sum_of_squares(eq, limit=S.Infinity)
1645
1646    elif eq_type == GeneralSumOfEvenPowers.name:
1647        return diop_general_sum_of_even_powers(eq, limit=S.Infinity)
1648
1649    if eq_type is not None and eq_type not in diop_known:
1650            raise ValueError(filldedent('''
1651    Alhough this type of equation was identified, it is not yet
1652    handled. It should, however, be listed in `diop_known` at the
1653    top of this file. Developers should see comments at the end of
1654    `classify_diop`.
1655            '''))  # pragma: no cover
1656    else:
1657        raise NotImplementedError(
1658            'No solver has been written for %s.' % eq_type)
1659
1660
1661def classify_diop(eq, _dict=True):
1662    # docstring supplied externally
1663
1664    matched = False
1665    diop_type = None
1666    for diop_class in all_diop_classes:
1667        diop_type = diop_class(eq)
1668        if diop_type.matches():
1669            matched = True
1670            break
1671
1672    if matched:
1673        return diop_type.free_symbols, dict(diop_type.coeff) if _dict else diop_type.coeff, diop_type.name
1674
1675    # new diop type instructions
1676    # --------------------------
1677    # if this error raises and the equation *can* be classified,
1678    #  * it should be identified in the if-block above
1679    #  * the type should be added to the diop_known
1680    # if a solver can be written for it,
1681    #  * a dedicated handler should be written (e.g. diop_linear)
1682    #  * it should be passed to that handler in diop_solve
1683    raise NotImplementedError(filldedent('''
1684        This equation is not yet recognized or else has not been
1685        simplified sufficiently to put it in a form recognized by
1686        diop_classify().'''))
1687
1688
1689classify_diop.func_doc = (  # type: ignore
1690    '''
1691    Helper routine used by diop_solve() to find information about ``eq``.
1692
1693    Explanation
1694    ===========
1695
1696    Returns a tuple containing the type of the diophantine equation
1697    along with the variables (free symbols) and their coefficients.
1698    Variables are returned as a list and coefficients are returned
1699    as a dict with the key being the respective term and the constant
1700    term is keyed to 1. The type is one of the following:
1701
1702    * %s
1703
1704    Usage
1705    =====
1706
1707    ``classify_diop(eq)``: Return variables, coefficients and type of the
1708    ``eq``.
1709
1710    Details
1711    =======
1712
1713    ``eq`` should be an expression which is assumed to be zero.
1714    ``_dict`` is for internal use: when True (default) a dict is returned,
1715    otherwise a defaultdict which supplies 0 for missing keys is returned.
1716
1717    Examples
1718    ========
1719
1720    >>> from sympy.solvers.diophantine import classify_diop
1721    >>> from sympy.abc import x, y, z, w, t
1722    >>> classify_diop(4*x + 6*y - 4)
1723    ([x, y], {1: -4, x: 4, y: 6}, 'linear')
1724    >>> classify_diop(x + 3*y -4*z + 5)
1725    ([x, y, z], {1: 5, x: 1, y: 3, z: -4}, 'linear')
1726    >>> classify_diop(x**2 + y**2 - x*y + x + 5)
1727    ([x, y], {1: 5, x: 1, x**2: 1, y**2: 1, x*y: -1}, 'binary_quadratic')
1728    ''' % ('\n    * '.join(sorted(diop_known))))
1729
1730
1731def diop_linear(eq, param=symbols("t", integer=True)):
1732    """
1733    Solves linear diophantine equations.
1734
1735    A linear diophantine equation is an equation of the form `a_{1}x_{1} +
1736    a_{2}x_{2} + .. + a_{n}x_{n} = 0` where `a_{1}, a_{2}, ..a_{n}` are
1737    integer constants and `x_{1}, x_{2}, ..x_{n}` are integer variables.
1738
1739    Usage
1740    =====
1741
1742    ``diop_linear(eq)``: Returns a tuple containing solutions to the
1743    diophantine equation ``eq``. Values in the tuple is arranged in the same
1744    order as the sorted variables.
1745
1746    Details
1747    =======
1748
1749    ``eq`` is a linear diophantine equation which is assumed to be zero.
1750    ``param`` is the parameter to be used in the solution.
1751
1752    Examples
1753    ========
1754
1755    >>> from sympy.solvers.diophantine.diophantine import diop_linear
1756    >>> from sympy.abc import x, y, z
1757    >>> diop_linear(2*x - 3*y - 5) # solves equation 2*x - 3*y - 5 == 0
1758    (3*t_0 - 5, 2*t_0 - 5)
1759
1760    Here x = -3*t_0 - 5 and y = -2*t_0 - 5
1761
1762    >>> diop_linear(2*x - 3*y - 4*z -3)
1763    (t_0, 2*t_0 + 4*t_1 + 3, -t_0 - 3*t_1 - 3)
1764
1765    See Also
1766    ========
1767
1768    diop_quadratic(), diop_ternary_quadratic(), diop_general_pythagorean(),
1769    diop_general_sum_of_squares()
1770    """
1771    var, coeff, diop_type = classify_diop(eq, _dict=False)
1772
1773    if diop_type == Linear.name:
1774        parameters = None
1775        if param is not None:
1776            parameters = symbols('%s_0:%i' % (param, len(var)), integer=True)
1777
1778        result = Linear(eq).solve(parameters=parameters)
1779
1780        if param is None:
1781            result = result(*[0]*len(result.parameters))
1782
1783        if len(result) > 0:
1784            return list(result)[0]
1785        else:
1786            return tuple([None]*len(result.parameters))
1787
1788
1789def base_solution_linear(c, a, b, t=None):
1790    """
1791    Return the base solution for the linear equation, `ax + by = c`.
1792
1793    Explanation
1794    ===========
1795
1796    Used by ``diop_linear()`` to find the base solution of a linear
1797    Diophantine equation. If ``t`` is given then the parametrized solution is
1798    returned.
1799
1800    Usage
1801    =====
1802
1803    ``base_solution_linear(c, a, b, t)``: ``a``, ``b``, ``c`` are coefficients
1804    in `ax + by = c` and ``t`` is the parameter to be used in the solution.
1805
1806    Examples
1807    ========
1808
1809    >>> from sympy.solvers.diophantine.diophantine import base_solution_linear
1810    >>> from sympy.abc import t
1811    >>> base_solution_linear(5, 2, 3) # equation 2*x + 3*y = 5
1812    (-5, 5)
1813    >>> base_solution_linear(0, 5, 7) # equation 5*x + 7*y = 0
1814    (0, 0)
1815    >>> base_solution_linear(5, 2, 3, t) # equation 2*x + 3*y = 5
1816    (3*t - 5, 5 - 2*t)
1817    >>> base_solution_linear(0, 5, 7, t) # equation 5*x + 7*y = 0
1818    (7*t, -5*t)
1819    """
1820    a, b, c = _remove_gcd(a, b, c)
1821
1822    if c == 0:
1823        if t is not None:
1824            if b < 0:
1825                t = -t
1826            return (b*t , -a*t)
1827        else:
1828            return (0, 0)
1829    else:
1830        x0, y0, d = igcdex(abs(a), abs(b))
1831
1832        x0 *= sign(a)
1833        y0 *= sign(b)
1834
1835        if divisible(c, d):
1836            if t is not None:
1837                if b < 0:
1838                    t = -t
1839                return (c*x0 + b*t, c*y0 - a*t)
1840            else:
1841                return (c*x0, c*y0)
1842        else:
1843            return (None, None)
1844
1845
1846def diop_univariate(eq):
1847    """
1848    Solves a univariate diophantine equations.
1849
1850    Explanation
1851    ===========
1852
1853    A univariate diophantine equation is an equation of the form
1854    `a_{0} + a_{1}x + a_{2}x^2 + .. + a_{n}x^n = 0` where `a_{1}, a_{2}, ..a_{n}` are
1855    integer constants and `x` is an integer variable.
1856
1857    Usage
1858    =====
1859
1860    ``diop_univariate(eq)``: Returns a set containing solutions to the
1861    diophantine equation ``eq``.
1862
1863    Details
1864    =======
1865
1866    ``eq`` is a univariate diophantine equation which is assumed to be zero.
1867
1868    Examples
1869    ========
1870
1871    >>> from sympy.solvers.diophantine.diophantine import diop_univariate
1872    >>> from sympy.abc import x
1873    >>> diop_univariate((x - 2)*(x - 3)**2) # solves equation (x - 2)*(x - 3)**2 == 0
1874    {(2,), (3,)}
1875
1876    """
1877    var, coeff, diop_type = classify_diop(eq, _dict=False)
1878
1879    if diop_type == Univariate.name:
1880        return {(int(i),) for i in solveset_real(
1881            eq, var[0]).intersect(S.Integers)}
1882
1883
1884def divisible(a, b):
1885    """
1886    Returns `True` if ``a`` is divisible by ``b`` and `False` otherwise.
1887    """
1888    return not a % b
1889
1890
1891def diop_quadratic(eq, param=symbols("t", integer=True)):
1892    """
1893    Solves quadratic diophantine equations.
1894
1895    i.e. equations of the form `Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0`. Returns a
1896    set containing the tuples `(x, y)` which contains the solutions. If there
1897    are no solutions then `(None, None)` is returned.
1898
1899    Usage
1900    =====
1901
1902    ``diop_quadratic(eq, param)``: ``eq`` is a quadratic binary diophantine
1903    equation. ``param`` is used to indicate the parameter to be used in the
1904    solution.
1905
1906    Details
1907    =======
1908
1909    ``eq`` should be an expression which is assumed to be zero.
1910    ``param`` is a parameter to be used in the solution.
1911
1912    Examples
1913    ========
1914
1915    >>> from sympy.abc import x, y, t
1916    >>> from sympy.solvers.diophantine.diophantine import diop_quadratic
1917    >>> diop_quadratic(x**2 + y**2 + 2*x + 2*y + 2, t)
1918    {(-1, -1)}
1919
1920    References
1921    ==========
1922
1923    .. [1] Methods to solve Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0, [online],
1924          Available: http://www.alpertron.com.ar/METHODS.HTM
1925    .. [2] Solving the equation ax^2+ bxy + cy^2 + dx + ey + f= 0, [online],
1926          Available: https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf
1927
1928    See Also
1929    ========
1930
1931    diop_linear(), diop_ternary_quadratic(), diop_general_sum_of_squares(),
1932    diop_general_pythagorean()
1933    """
1934    var, coeff, diop_type = classify_diop(eq, _dict=False)
1935
1936    if diop_type == BinaryQuadratic.name:
1937        if param is not None:
1938            parameters = [param, Symbol("u", integer=True)]
1939        else:
1940            parameters = None
1941        return set(BinaryQuadratic(eq).solve(parameters=parameters))
1942
1943
1944def is_solution_quad(var, coeff, u, v):
1945    """
1946    Check whether `(u, v)` is solution to the quadratic binary diophantine
1947    equation with the variable list ``var`` and coefficient dictionary
1948    ``coeff``.
1949
1950    Not intended for use by normal users.
1951    """
1952    reps = dict(zip(var, (u, v)))
1953    eq = Add(*[j*i.xreplace(reps) for i, j in coeff.items()])
1954    return _mexpand(eq) == 0
1955
1956
1957def diop_DN(D, N, t=symbols("t", integer=True)):
1958    """
1959    Solves the equation `x^2 - Dy^2 = N`.
1960
1961    Explanation
1962    ===========
1963
1964    Mainly concerned with the case `D > 0, D` is not a perfect square,
1965    which is the same as the generalized Pell equation. The LMM
1966    algorithm [1]_ is used to solve this equation.
1967
1968    Returns one solution tuple, (`x, y)` for each class of the solutions.
1969    Other solutions of the class can be constructed according to the
1970    values of ``D`` and ``N``.
1971
1972    Usage
1973    =====
1974
1975    ``diop_DN(D, N, t)``: D and N are integers as in `x^2 - Dy^2 = N` and
1976    ``t`` is the parameter to be used in the solutions.
1977
1978    Details
1979    =======
1980
1981    ``D`` and ``N`` correspond to D and N in the equation.
1982    ``t`` is the parameter to be used in the solutions.
1983
1984    Examples
1985    ========
1986
1987    >>> from sympy.solvers.diophantine.diophantine import diop_DN
1988    >>> diop_DN(13, -4) # Solves equation x**2 - 13*y**2 = -4
1989    [(3, 1), (393, 109), (36, 10)]
1990
1991    The output can be interpreted as follows: There are three fundamental
1992    solutions to the equation `x^2 - 13y^2 = -4` given by (3, 1), (393, 109)
1993    and (36, 10). Each tuple is in the form (x, y), i.e. solution (3, 1) means
1994    that `x = 3` and `y = 1`.
1995
1996    >>> diop_DN(986, 1) # Solves equation x**2 - 986*y**2 = 1
1997    [(49299, 1570)]
1998
1999    See Also
2000    ========
2001
2002    find_DN(), diop_bf_DN()
2003
2004    References
2005    ==========
2006
2007    .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P.
2008        Robertson, July 31, 2004, Pages 16 - 17. [online], Available:
2009        https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf
2010    """
2011    if D < 0:
2012        if N == 0:
2013            return [(0, 0)]
2014        elif N < 0:
2015            return []
2016        elif N > 0:
2017            sol = []
2018            for d in divisors(square_factor(N)):
2019                sols = cornacchia(1, -D, N // d**2)
2020                if sols:
2021                    for x, y in sols:
2022                        sol.append((d*x, d*y))
2023                        if D == -1:
2024                            sol.append((d*y, d*x))
2025            return sol
2026
2027    elif D == 0:
2028        if N < 0:
2029            return []
2030        if N == 0:
2031            return [(0, t)]
2032        sN, _exact = integer_nthroot(N, 2)
2033        if _exact:
2034            return [(sN, t)]
2035        else:
2036            return []
2037
2038    else:  # D > 0
2039        sD, _exact = integer_nthroot(D, 2)
2040        if _exact:
2041            if N == 0:
2042                return [(sD*t, t)]
2043            else:
2044                sol = []
2045
2046                for y in range(floor(sign(N)*(N - 1)/(2*sD)) + 1):
2047                    try:
2048                        sq, _exact = integer_nthroot(D*y**2 + N, 2)
2049                    except ValueError:
2050                        _exact = False
2051                    if _exact:
2052                        sol.append((sq, y))
2053
2054                return sol
2055
2056        elif 1 < N**2 < D:
2057            # It is much faster to call `_special_diop_DN`.
2058            return _special_diop_DN(D, N)
2059
2060        else:
2061            if N == 0:
2062                return [(0, 0)]
2063
2064            elif abs(N) == 1:
2065
2066                pqa = PQa(0, 1, D)
2067                j = 0
2068                G = []
2069                B = []
2070
2071                for i in pqa:
2072
2073                    a = i[2]
2074                    G.append(i[5])
2075                    B.append(i[4])
2076
2077                    if j != 0 and a == 2*sD:
2078                        break
2079                    j = j + 1
2080
2081                if _odd(j):
2082
2083                    if N == -1:
2084                        x = G[j - 1]
2085                        y = B[j - 1]
2086                    else:
2087                        count = j
2088                        while count < 2*j - 1:
2089                            i = next(pqa)
2090                            G.append(i[5])
2091                            B.append(i[4])
2092                            count += 1
2093
2094                        x = G[count]
2095                        y = B[count]
2096                else:
2097                    if N == 1:
2098                        x = G[j - 1]
2099                        y = B[j - 1]
2100                    else:
2101                        return []
2102
2103                return [(x, y)]
2104
2105            else:
2106
2107                fs = []
2108                sol = []
2109                div = divisors(N)
2110
2111                for d in div:
2112                    if divisible(N, d**2):
2113                        fs.append(d)
2114
2115                for f in fs:
2116                    m = N // f**2
2117
2118                    zs = sqrt_mod(D, abs(m), all_roots=True)
2119                    zs = [i for i in zs if i <= abs(m) // 2 ]
2120
2121                    if abs(m) != 2:
2122                        zs = zs + [-i for i in zs if i]  # omit dupl 0
2123
2124                    for z in zs:
2125
2126                        pqa = PQa(z, abs(m), D)
2127                        j = 0
2128                        G = []
2129                        B = []
2130
2131                        for i in pqa:
2132
2133                            G.append(i[5])
2134                            B.append(i[4])
2135
2136                            if j != 0 and abs(i[1]) == 1:
2137                                r = G[j-1]
2138                                s = B[j-1]
2139
2140                                if r**2 - D*s**2 == m:
2141                                    sol.append((f*r, f*s))
2142
2143                                elif diop_DN(D, -1) != []:
2144                                    a = diop_DN(D, -1)
2145                                    sol.append((f*(r*a[0][0] + a[0][1]*s*D), f*(r*a[0][1] + s*a[0][0])))
2146
2147                                break
2148
2149                            j = j + 1
2150                            if j == length(z, abs(m), D):
2151                                break
2152
2153                return sol
2154
2155
2156def _special_diop_DN(D, N):
2157    """
2158    Solves the equation `x^2 - Dy^2 = N` for the special case where
2159    `1 < N**2 < D` and `D` is not a perfect square.
2160    It is better to call `diop_DN` rather than this function, as
2161    the former checks the condition `1 < N**2 < D`, and calls the latter only
2162    if appropriate.
2163
2164    Usage
2165    =====
2166
2167    WARNING: Internal method. Do not call directly!
2168
2169    ``_special_diop_DN(D, N)``: D and N are integers as in `x^2 - Dy^2 = N`.
2170
2171    Details
2172    =======
2173
2174    ``D`` and ``N`` correspond to D and N in the equation.
2175
2176    Examples
2177    ========
2178
2179    >>> from sympy.solvers.diophantine.diophantine import _special_diop_DN
2180    >>> _special_diop_DN(13, -3) # Solves equation x**2 - 13*y**2 = -3
2181    [(7, 2), (137, 38)]
2182
2183    The output can be interpreted as follows: There are two fundamental
2184    solutions to the equation `x^2 - 13y^2 = -3` given by (7, 2) and
2185    (137, 38). Each tuple is in the form (x, y), i.e. solution (7, 2) means
2186    that `x = 7` and `y = 2`.
2187
2188    >>> _special_diop_DN(2445, -20) # Solves equation x**2 - 2445*y**2 = -20
2189    [(445, 9), (17625560, 356454), (698095554475, 14118073569)]
2190
2191    See Also
2192    ========
2193
2194    diop_DN()
2195
2196    References
2197    ==========
2198
2199    .. [1] Section 4.4.4 of the following book:
2200        Quadratic Diophantine Equations, T. Andreescu and D. Andrica,
2201        Springer, 2015.
2202    """
2203
2204    # The following assertion was removed for efficiency, with the understanding
2205    #     that this method is not called directly. The parent method, `diop_DN`
2206    #     is responsible for performing the appropriate checks.
2207    #
2208    # assert (1 < N**2 < D) and (not integer_nthroot(D, 2)[1])
2209
2210    sqrt_D = sqrt(D)
2211    F = [(N, 1)]
2212    f = 2
2213    while True:
2214        f2 = f**2
2215        if f2 > abs(N):
2216            break
2217        n, r = divmod(N, f2)
2218        if r == 0:
2219            F.append((n, f))
2220        f += 1
2221
2222    P = 0
2223    Q = 1
2224    G0, G1 = 0, 1
2225    B0, B1 = 1, 0
2226
2227    solutions = []
2228
2229    i = 0
2230    while True:
2231        a = floor((P + sqrt_D) / Q)
2232        P = a*Q - P
2233        Q = (D - P**2) // Q
2234        G2 = a*G1 + G0
2235        B2 = a*B1 + B0
2236
2237        for n, f in F:
2238            if G2**2 - D*B2**2 == n:
2239                solutions.append((f*G2, f*B2))
2240
2241        i += 1
2242        if Q == 1 and i % 2 == 0:
2243            break
2244
2245        G0, G1 = G1, G2
2246        B0, B1 = B1, B2
2247
2248    return solutions
2249
2250
2251def cornacchia(a, b, m):
2252    r"""
2253    Solves `ax^2 + by^2 = m` where `\gcd(a, b) = 1 = gcd(a, m)` and `a, b > 0`.
2254
2255    Explanation
2256    ===========
2257
2258    Uses the algorithm due to Cornacchia. The method only finds primitive
2259    solutions, i.e. ones with `\gcd(x, y) = 1`. So this method can't be used to
2260    find the solutions of `x^2 + y^2 = 20` since the only solution to former is
2261    `(x, y) = (4, 2)` and it is not primitive. When `a = b`, only the
2262    solutions with `x \leq y` are found. For more details, see the References.
2263
2264    Examples
2265    ========
2266
2267    >>> from sympy.solvers.diophantine.diophantine import cornacchia
2268    >>> cornacchia(2, 3, 35) # equation 2x**2 + 3y**2 = 35
2269    {(2, 3), (4, 1)}
2270    >>> cornacchia(1, 1, 25) # equation x**2 + y**2 = 25
2271    {(4, 3)}
2272
2273    References
2274    ===========
2275
2276    .. [1] A. Nitaj, "L'algorithme de Cornacchia"
2277    .. [2] Solving the diophantine equation ax**2 + by**2 = m by Cornacchia's
2278        method, [online], Available:
2279        http://www.numbertheory.org/php/cornacchia.html
2280
2281    See Also
2282    ========
2283
2284    sympy.utilities.iterables.signed_permutations
2285    """
2286    sols = set()
2287
2288    a1 = igcdex(a, m)[0]
2289    v = sqrt_mod(-b*a1, m, all_roots=True)
2290    if not v:
2291        return None
2292
2293    for t in v:
2294        if t < m // 2:
2295            continue
2296
2297        u, r = t, m
2298
2299        while True:
2300            u, r = r, u % r
2301            if a*r**2 < m:
2302                break
2303
2304        m1 = m - a*r**2
2305
2306        if m1 % b == 0:
2307            m1 = m1 // b
2308            s, _exact = integer_nthroot(m1, 2)
2309            if _exact:
2310                if a == b and r < s:
2311                    r, s = s, r
2312                sols.add((int(r), int(s)))
2313
2314    return sols
2315
2316
2317def PQa(P_0, Q_0, D):
2318    r"""
2319    Returns useful information needed to solve the Pell equation.
2320
2321    Explanation
2322    ===========
2323
2324    There are six sequences of integers defined related to the continued
2325    fraction representation of `\\frac{P + \sqrt{D}}{Q}`, namely {`P_{i}`},
2326    {`Q_{i}`}, {`a_{i}`},{`A_{i}`}, {`B_{i}`}, {`G_{i}`}. ``PQa()`` Returns
2327    these values as a 6-tuple in the same order as mentioned above. Refer [1]_
2328    for more detailed information.
2329
2330    Usage
2331    =====
2332
2333    ``PQa(P_0, Q_0, D)``: ``P_0``, ``Q_0`` and ``D`` are integers corresponding
2334    to `P_{0}`, `Q_{0}` and `D` in the continued fraction
2335    `\\frac{P_{0} + \sqrt{D}}{Q_{0}}`.
2336    Also it's assumed that `P_{0}^2 == D mod(|Q_{0}|)` and `D` is square free.
2337
2338    Examples
2339    ========
2340
2341    >>> from sympy.solvers.diophantine.diophantine import PQa
2342    >>> pqa = PQa(13, 4, 5) # (13 + sqrt(5))/4
2343    >>> next(pqa) # (P_0, Q_0, a_0, A_0, B_0, G_0)
2344    (13, 4, 3, 3, 1, -1)
2345    >>> next(pqa) # (P_1, Q_1, a_1, A_1, B_1, G_1)
2346    (-1, 1, 1, 4, 1, 3)
2347
2348    References
2349    ==========
2350
2351    .. [1] Solving the generalized Pell equation x^2 - Dy^2 = N, John P.
2352        Robertson, July 31, 2004, Pages 4 - 8. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf
2353    """
2354    A_i_2 = B_i_1 = 0
2355    A_i_1 = B_i_2 = 1
2356
2357    G_i_2 = -P_0
2358    G_i_1 = Q_0
2359
2360    P_i = P_0
2361    Q_i = Q_0
2362
2363    while True:
2364
2365        a_i = floor((P_i + sqrt(D))/Q_i)
2366        A_i = a_i*A_i_1 + A_i_2
2367        B_i = a_i*B_i_1 + B_i_2
2368        G_i = a_i*G_i_1 + G_i_2
2369
2370        yield P_i, Q_i, a_i, A_i, B_i, G_i
2371
2372        A_i_1, A_i_2 = A_i, A_i_1
2373        B_i_1, B_i_2 = B_i, B_i_1
2374        G_i_1, G_i_2 = G_i, G_i_1
2375
2376        P_i = a_i*Q_i - P_i
2377        Q_i = (D - P_i**2)/Q_i
2378
2379
2380def diop_bf_DN(D, N, t=symbols("t", integer=True)):
2381    r"""
2382    Uses brute force to solve the equation, `x^2 - Dy^2 = N`.
2383
2384    Explanation
2385    ===========
2386
2387    Mainly concerned with the generalized Pell equation which is the case when
2388    `D > 0, D` is not a perfect square. For more information on the case refer
2389    [1]_. Let `(t, u)` be the minimal positive solution of the equation
2390    `x^2 - Dy^2 = 1`. Then this method requires
2391    `\sqrt{\\frac{\mid N \mid (t \pm 1)}{2D}}` to be small.
2392
2393    Usage
2394    =====
2395
2396    ``diop_bf_DN(D, N, t)``: ``D`` and ``N`` are coefficients in
2397    `x^2 - Dy^2 = N` and ``t`` is the parameter to be used in the solutions.
2398
2399    Details
2400    =======
2401
2402    ``D`` and ``N`` correspond to D and N in the equation.
2403    ``t`` is the parameter to be used in the solutions.
2404
2405    Examples
2406    ========
2407
2408    >>> from sympy.solvers.diophantine.diophantine import diop_bf_DN
2409    >>> diop_bf_DN(13, -4)
2410    [(3, 1), (-3, 1), (36, 10)]
2411    >>> diop_bf_DN(986, 1)
2412    [(49299, 1570)]
2413
2414    See Also
2415    ========
2416
2417    diop_DN()
2418
2419    References
2420    ==========
2421
2422    .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P.
2423        Robertson, July 31, 2004, Page 15. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf
2424    """
2425    D = as_int(D)
2426    N = as_int(N)
2427
2428    sol = []
2429    a = diop_DN(D, 1)
2430    u = a[0][0]
2431
2432    if abs(N) == 1:
2433        return diop_DN(D, N)
2434
2435    elif N > 1:
2436        L1 = 0
2437        L2 = integer_nthroot(int(N*(u - 1)/(2*D)), 2)[0] + 1
2438
2439    elif N < -1:
2440        L1, _exact = integer_nthroot(-int(N/D), 2)
2441        if not _exact:
2442            L1 += 1
2443        L2 = integer_nthroot(-int(N*(u + 1)/(2*D)), 2)[0] + 1
2444
2445    else:  # N = 0
2446        if D < 0:
2447            return [(0, 0)]
2448        elif D == 0:
2449            return [(0, t)]
2450        else:
2451            sD, _exact = integer_nthroot(D, 2)
2452            if _exact:
2453                return [(sD*t, t), (-sD*t, t)]
2454            else:
2455                return [(0, 0)]
2456
2457
2458    for y in range(L1, L2):
2459        try:
2460            x, _exact = integer_nthroot(N + D*y**2, 2)
2461        except ValueError:
2462            _exact = False
2463        if _exact:
2464            sol.append((x, y))
2465            if not equivalent(x, y, -x, y, D, N):
2466                sol.append((-x, y))
2467
2468    return sol
2469
2470
2471def equivalent(u, v, r, s, D, N):
2472    """
2473    Returns True if two solutions `(u, v)` and `(r, s)` of `x^2 - Dy^2 = N`
2474    belongs to the same equivalence class and False otherwise.
2475
2476    Explanation
2477    ===========
2478
2479    Two solutions `(u, v)` and `(r, s)` to the above equation fall to the same
2480    equivalence class iff both `(ur - Dvs)` and `(us - vr)` are divisible by
2481    `N`. See reference [1]_. No test is performed to test whether `(u, v)` and
2482    `(r, s)` are actually solutions to the equation. User should take care of
2483    this.
2484
2485    Usage
2486    =====
2487
2488    ``equivalent(u, v, r, s, D, N)``: `(u, v)` and `(r, s)` are two solutions
2489    of the equation `x^2 - Dy^2 = N` and all parameters involved are integers.
2490
2491    Examples
2492    ========
2493
2494    >>> from sympy.solvers.diophantine.diophantine import equivalent
2495    >>> equivalent(18, 5, -18, -5, 13, -1)
2496    True
2497    >>> equivalent(3, 1, -18, 393, 109, -4)
2498    False
2499
2500    References
2501    ==========
2502
2503    .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P.
2504        Robertson, July 31, 2004, Page 12. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf
2505
2506    """
2507    return divisible(u*r - D*v*s, N) and divisible(u*s - v*r, N)
2508
2509
2510def length(P, Q, D):
2511    r"""
2512    Returns the (length of aperiodic part + length of periodic part) of
2513    continued fraction representation of `\\frac{P + \sqrt{D}}{Q}`.
2514
2515    It is important to remember that this does NOT return the length of the
2516    periodic part but the sum of the lengths of the two parts as mentioned
2517    above.
2518
2519    Usage
2520    =====
2521
2522    ``length(P, Q, D)``: ``P``, ``Q`` and ``D`` are integers corresponding to
2523    the continued fraction `\\frac{P + \sqrt{D}}{Q}`.
2524
2525    Details
2526    =======
2527
2528    ``P``, ``D`` and ``Q`` corresponds to P, D and Q in the continued fraction,
2529    `\\frac{P + \sqrt{D}}{Q}`.
2530
2531    Examples
2532    ========
2533
2534    >>> from sympy.solvers.diophantine.diophantine import length
2535    >>> length(-2 , 4, 5) # (-2 + sqrt(5))/4
2536    3
2537    >>> length(-5, 4, 17) # (-5 + sqrt(17))/4
2538    4
2539
2540    See Also
2541    ========
2542    sympy.ntheory.continued_fraction.continued_fraction_periodic
2543    """
2544    from sympy.ntheory.continued_fraction import continued_fraction_periodic
2545    v = continued_fraction_periodic(P, Q, D)
2546    if type(v[-1]) is list:
2547        rpt = len(v[-1])
2548        nonrpt = len(v) - 1
2549    else:
2550        rpt = 0
2551        nonrpt = len(v)
2552    return rpt + nonrpt
2553
2554
2555def transformation_to_DN(eq):
2556    """
2557    This function transforms general quadratic,
2558    `ax^2 + bxy + cy^2 + dx + ey + f = 0`
2559    to more easy to deal with `X^2 - DY^2 = N` form.
2560
2561    Explanation
2562    ===========
2563
2564    This is used to solve the general quadratic equation by transforming it to
2565    the latter form. Refer [1]_ for more detailed information on the
2566    transformation. This function returns a tuple (A, B) where A is a 2 X 2
2567    matrix and B is a 2 X 1 matrix such that,
2568
2569    Transpose([x y]) =  A * Transpose([X Y]) + B
2570
2571    Usage
2572    =====
2573
2574    ``transformation_to_DN(eq)``: where ``eq`` is the quadratic to be
2575    transformed.
2576
2577    Examples
2578    ========
2579
2580    >>> from sympy.abc import x, y
2581    >>> from sympy.solvers.diophantine.diophantine import transformation_to_DN
2582    >>> A, B = transformation_to_DN(x**2 - 3*x*y - y**2 - 2*y + 1)
2583    >>> A
2584    Matrix([
2585    [1/26, 3/26],
2586    [   0, 1/13]])
2587    >>> B
2588    Matrix([
2589    [-6/13],
2590    [-4/13]])
2591
2592    A, B  returned are such that Transpose((x y)) =  A * Transpose((X Y)) + B.
2593    Substituting these values for `x` and `y` and a bit of simplifying work
2594    will give an equation of the form `x^2 - Dy^2 = N`.
2595
2596    >>> from sympy.abc import X, Y
2597    >>> from sympy import Matrix, simplify
2598    >>> u = (A*Matrix([X, Y]) + B)[0] # Transformation for x
2599    >>> u
2600    X/26 + 3*Y/26 - 6/13
2601    >>> v = (A*Matrix([X, Y]) + B)[1] # Transformation for y
2602    >>> v
2603    Y/13 - 4/13
2604
2605    Next we will substitute these formulas for `x` and `y` and do
2606    ``simplify()``.
2607
2608    >>> eq = simplify((x**2 - 3*x*y - y**2 - 2*y + 1).subs(zip((x, y), (u, v))))
2609    >>> eq
2610    X**2/676 - Y**2/52 + 17/13
2611
2612    By multiplying the denominator appropriately, we can get a Pell equation
2613    in the standard form.
2614
2615    >>> eq * 676
2616    X**2 - 13*Y**2 + 884
2617
2618    If only the final equation is needed, ``find_DN()`` can be used.
2619
2620    See Also
2621    ========
2622
2623    find_DN()
2624
2625    References
2626    ==========
2627
2628    .. [1] Solving the equation ax^2 + bxy + cy^2 + dx + ey + f = 0,
2629           John P.Robertson, May 8, 2003, Page 7 - 11.
2630           https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf
2631    """
2632
2633    var, coeff, diop_type = classify_diop(eq, _dict=False)
2634    if diop_type == BinaryQuadratic.name:
2635        return _transformation_to_DN(var, coeff)
2636
2637
2638def _transformation_to_DN(var, coeff):
2639
2640    x, y = var
2641
2642    a = coeff[x**2]
2643    b = coeff[x*y]
2644    c = coeff[y**2]
2645    d = coeff[x]
2646    e = coeff[y]
2647    f = coeff[1]
2648
2649    a, b, c, d, e, f = [as_int(i) for i in _remove_gcd(a, b, c, d, e, f)]
2650
2651    X, Y = symbols("X, Y", integer=True)
2652
2653    if b:
2654        B, C = _rational_pq(2*a, b)
2655        A, T = _rational_pq(a, B**2)
2656
2657        # eq_1 = A*B*X**2 + B*(c*T - A*C**2)*Y**2 + d*T*X + (B*e*T - d*T*C)*Y + f*T*B
2658        coeff = {X**2: A*B, X*Y: 0, Y**2: B*(c*T - A*C**2), X: d*T, Y: B*e*T - d*T*C, 1: f*T*B}
2659        A_0, B_0 = _transformation_to_DN([X, Y], coeff)
2660        return Matrix(2, 2, [S.One/B, -S(C)/B, 0, 1])*A_0, Matrix(2, 2, [S.One/B, -S(C)/B, 0, 1])*B_0
2661
2662    else:
2663        if d:
2664            B, C = _rational_pq(2*a, d)
2665            A, T = _rational_pq(a, B**2)
2666
2667            # eq_2 = A*X**2 + c*T*Y**2 + e*T*Y + f*T - A*C**2
2668            coeff = {X**2: A, X*Y: 0, Y**2: c*T, X: 0, Y: e*T, 1: f*T - A*C**2}
2669            A_0, B_0 = _transformation_to_DN([X, Y], coeff)
2670            return Matrix(2, 2, [S.One/B, 0, 0, 1])*A_0, Matrix(2, 2, [S.One/B, 0, 0, 1])*B_0 + Matrix([-S(C)/B, 0])
2671
2672        else:
2673            if e:
2674                B, C = _rational_pq(2*c, e)
2675                A, T = _rational_pq(c, B**2)
2676
2677                # eq_3 = a*T*X**2 + A*Y**2 + f*T - A*C**2
2678                coeff = {X**2: a*T, X*Y: 0, Y**2: A, X: 0, Y: 0, 1: f*T - A*C**2}
2679                A_0, B_0 = _transformation_to_DN([X, Y], coeff)
2680                return Matrix(2, 2, [1, 0, 0, S.One/B])*A_0, Matrix(2, 2, [1, 0, 0, S.One/B])*B_0 + Matrix([0, -S(C)/B])
2681
2682            else:
2683                # TODO: pre-simplification: Not necessary but may simplify
2684                # the equation.
2685
2686                return Matrix(2, 2, [S.One/a, 0, 0, 1]), Matrix([0, 0])
2687
2688
2689def find_DN(eq):
2690    """
2691    This function returns a tuple, `(D, N)` of the simplified form,
2692    `x^2 - Dy^2 = N`, corresponding to the general quadratic,
2693    `ax^2 + bxy + cy^2 + dx + ey + f = 0`.
2694
2695    Solving the general quadratic is then equivalent to solving the equation
2696    `X^2 - DY^2 = N` and transforming the solutions by using the transformation
2697    matrices returned by ``transformation_to_DN()``.
2698
2699    Usage
2700    =====
2701
2702    ``find_DN(eq)``: where ``eq`` is the quadratic to be transformed.
2703
2704    Examples
2705    ========
2706
2707    >>> from sympy.abc import x, y
2708    >>> from sympy.solvers.diophantine.diophantine import find_DN
2709    >>> find_DN(x**2 - 3*x*y - y**2 - 2*y + 1)
2710    (13, -884)
2711
2712    Interpretation of the output is that we get `X^2 -13Y^2 = -884` after
2713    transforming `x^2 - 3xy - y^2 - 2y + 1` using the transformation returned
2714    by ``transformation_to_DN()``.
2715
2716    See Also
2717    ========
2718
2719    transformation_to_DN()
2720
2721    References
2722    ==========
2723
2724    .. [1] Solving the equation ax^2 + bxy + cy^2 + dx + ey + f = 0,
2725           John P.Robertson, May 8, 2003, Page 7 - 11.
2726           https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf
2727    """
2728    var, coeff, diop_type = classify_diop(eq, _dict=False)
2729    if diop_type == BinaryQuadratic.name:
2730        return _find_DN(var, coeff)
2731
2732
2733def _find_DN(var, coeff):
2734
2735    x, y = var
2736    X, Y = symbols("X, Y", integer=True)
2737    A, B = _transformation_to_DN(var, coeff)
2738
2739    u = (A*Matrix([X, Y]) + B)[0]
2740    v = (A*Matrix([X, Y]) + B)[1]
2741    eq = x**2*coeff[x**2] + x*y*coeff[x*y] + y**2*coeff[y**2] + x*coeff[x] + y*coeff[y] + coeff[1]
2742
2743    simplified = _mexpand(eq.subs(zip((x, y), (u, v))))
2744
2745    coeff = simplified.as_coefficients_dict()
2746
2747    return -coeff[Y**2]/coeff[X**2], -coeff[1]/coeff[X**2]
2748
2749
2750def check_param(x, y, a, params):
2751    """
2752    If there is a number modulo ``a`` such that ``x`` and ``y`` are both
2753    integers, then return a parametric representation for ``x`` and ``y``
2754    else return (None, None).
2755
2756    Here ``x`` and ``y`` are functions of ``t``.
2757    """
2758    from sympy.simplify.simplify import clear_coefficients
2759
2760    if x.is_number and not x.is_Integer:
2761        return DiophantineSolutionSet([x, y], parameters=params)
2762
2763    if y.is_number and not y.is_Integer:
2764        return DiophantineSolutionSet([x, y], parameters=params)
2765
2766    m, n = symbols("m, n", integer=True)
2767    c, p = (m*x + n*y).as_content_primitive()
2768    if a % c.q:
2769        return DiophantineSolutionSet([x, y], parameters=params)
2770
2771    # clear_coefficients(mx + b, R)[1] -> (R - b)/m
2772    eq = clear_coefficients(x, m)[1] - clear_coefficients(y, n)[1]
2773    junk, eq = eq.as_content_primitive()
2774
2775    return _diop_solve(eq, params=params)
2776
2777
2778def diop_ternary_quadratic(eq, parameterize=False):
2779    """
2780    Solves the general quadratic ternary form,
2781    `ax^2 + by^2 + cz^2 + fxy + gyz + hxz = 0`.
2782
2783    Returns a tuple `(x, y, z)` which is a base solution for the above
2784    equation. If there are no solutions, `(None, None, None)` is returned.
2785
2786    Usage
2787    =====
2788
2789    ``diop_ternary_quadratic(eq)``: Return a tuple containing a basic solution
2790    to ``eq``.
2791
2792    Details
2793    =======
2794
2795    ``eq`` should be an homogeneous expression of degree two in three variables
2796    and it is assumed to be zero.
2797
2798    Examples
2799    ========
2800
2801    >>> from sympy.abc import x, y, z
2802    >>> from sympy.solvers.diophantine.diophantine import diop_ternary_quadratic
2803    >>> diop_ternary_quadratic(x**2 + 3*y**2 - z**2)
2804    (1, 0, 1)
2805    >>> diop_ternary_quadratic(4*x**2 + 5*y**2 - z**2)
2806    (1, 0, 2)
2807    >>> diop_ternary_quadratic(45*x**2 - 7*y**2 - 8*x*y - z**2)
2808    (28, 45, 105)
2809    >>> diop_ternary_quadratic(x**2 - 49*y**2 - z**2 + 13*z*y -8*x*y)
2810    (9, 1, 5)
2811    """
2812    var, coeff, diop_type = classify_diop(eq, _dict=False)
2813
2814    if diop_type in (
2815            HomogeneousTernaryQuadratic.name,
2816            HomogeneousTernaryQuadraticNormal.name):
2817        sol = _diop_ternary_quadratic(var, coeff)
2818        if len(sol) > 0:
2819            x_0, y_0, z_0 = list(sol)[0]
2820        else:
2821            x_0, y_0, z_0 = None, None, None
2822
2823        if parameterize:
2824            return _parametrize_ternary_quadratic(
2825                (x_0, y_0, z_0), var, coeff)
2826        return x_0, y_0, z_0
2827
2828
2829def _diop_ternary_quadratic(_var, coeff):
2830    eq = sum([i*coeff[i] for i in coeff])
2831    if HomogeneousTernaryQuadratic(eq).matches():
2832        return HomogeneousTernaryQuadratic(eq, free_symbols=_var).solve()
2833    elif HomogeneousTernaryQuadraticNormal(eq).matches():
2834        return HomogeneousTernaryQuadraticNormal(eq, free_symbols=_var).solve()
2835
2836
2837def transformation_to_normal(eq):
2838    """
2839    Returns the transformation Matrix that converts a general ternary
2840    quadratic equation ``eq`` (`ax^2 + by^2 + cz^2 + dxy + eyz + fxz`)
2841    to a form without cross terms: `ax^2 + by^2 + cz^2 = 0`. This is
2842    not used in solving ternary quadratics; it is only implemented for
2843    the sake of completeness.
2844    """
2845    var, coeff, diop_type = classify_diop(eq, _dict=False)
2846
2847    if diop_type in (
2848            "homogeneous_ternary_quadratic",
2849            "homogeneous_ternary_quadratic_normal"):
2850        return _transformation_to_normal(var, coeff)
2851
2852
2853def _transformation_to_normal(var, coeff):
2854
2855    _var = list(var)  # copy
2856    x, y, z = var
2857
2858    if not any(coeff[i**2] for i in var):
2859        # https://math.stackexchange.com/questions/448051/transform-quadratic-ternary-form-to-normal-form/448065#448065
2860        a = coeff[x*y]
2861        b = coeff[y*z]
2862        c = coeff[x*z]
2863        swap = False
2864        if not a:  # b can't be 0 or else there aren't 3 vars
2865            swap = True
2866            a, b = b, a
2867        T = Matrix(((1, 1, -b/a), (1, -1, -c/a), (0, 0, 1)))
2868        if swap:
2869            T.row_swap(0, 1)
2870            T.col_swap(0, 1)
2871        return T
2872
2873    if coeff[x**2] == 0:
2874        # If the coefficient of x is zero change the variables
2875        if coeff[y**2] == 0:
2876            _var[0], _var[2] = var[2], var[0]
2877            T = _transformation_to_normal(_var, coeff)
2878            T.row_swap(0, 2)
2879            T.col_swap(0, 2)
2880            return T
2881
2882        else:
2883            _var[0], _var[1] = var[1], var[0]
2884            T = _transformation_to_normal(_var, coeff)
2885            T.row_swap(0, 1)
2886            T.col_swap(0, 1)
2887            return T
2888
2889    # Apply the transformation x --> X - (B*Y + C*Z)/(2*A)
2890    if coeff[x*y] != 0 or coeff[x*z] != 0:
2891        A = coeff[x**2]
2892        B = coeff[x*y]
2893        C = coeff[x*z]
2894        D = coeff[y**2]
2895        E = coeff[y*z]
2896        F = coeff[z**2]
2897
2898        _coeff = dict()
2899
2900        _coeff[x**2] = 4*A**2
2901        _coeff[y**2] = 4*A*D - B**2
2902        _coeff[z**2] = 4*A*F - C**2
2903        _coeff[y*z] = 4*A*E - 2*B*C
2904        _coeff[x*y] = 0
2905        _coeff[x*z] = 0
2906
2907        T_0 = _transformation_to_normal(_var, _coeff)
2908        return Matrix(3, 3, [1, S(-B)/(2*A), S(-C)/(2*A), 0, 1, 0, 0, 0, 1])*T_0
2909
2910    elif coeff[y*z] != 0:
2911        if coeff[y**2] == 0:
2912            if coeff[z**2] == 0:
2913                # Equations of the form A*x**2 + E*yz = 0.
2914                # Apply transformation y -> Y + Z ans z -> Y - Z
2915                return Matrix(3, 3, [1, 0, 0, 0, 1, 1, 0, 1, -1])
2916
2917            else:
2918                # Ax**2 + E*y*z + F*z**2  = 0
2919                _var[0], _var[2] = var[2], var[0]
2920                T = _transformation_to_normal(_var, coeff)
2921                T.row_swap(0, 2)
2922                T.col_swap(0, 2)
2923                return T
2924
2925        else:
2926            # A*x**2 + D*y**2 + E*y*z + F*z**2 = 0, F may be zero
2927            _var[0], _var[1] = var[1], var[0]
2928            T = _transformation_to_normal(_var, coeff)
2929            T.row_swap(0, 1)
2930            T.col_swap(0, 1)
2931            return T
2932
2933    else:
2934        return Matrix.eye(3)
2935
2936
2937def parametrize_ternary_quadratic(eq):
2938    """
2939    Returns the parametrized general solution for the ternary quadratic
2940    equation ``eq`` which has the form
2941    `ax^2 + by^2 + cz^2 + fxy + gyz + hxz = 0`.
2942
2943    Examples
2944    ========
2945
2946    >>> from sympy import Tuple, ordered
2947    >>> from sympy.abc import x, y, z
2948    >>> from sympy.solvers.diophantine.diophantine import parametrize_ternary_quadratic
2949
2950    The parametrized solution may be returned with three parameters:
2951
2952    >>> parametrize_ternary_quadratic(2*x**2 + y**2 - 2*z**2)
2953    (p**2 - 2*q**2, -2*p**2 + 4*p*q - 4*p*r - 4*q**2, p**2 - 4*p*q + 2*q**2 - 4*q*r)
2954
2955    There might also be only two parameters:
2956
2957    >>> parametrize_ternary_quadratic(4*x**2 + 2*y**2 - 3*z**2)
2958    (2*p**2 - 3*q**2, -4*p**2 + 12*p*q - 6*q**2, 4*p**2 - 8*p*q + 6*q**2)
2959
2960    Notes
2961    =====
2962
2963    Consider ``p`` and ``q`` in the previous 2-parameter
2964    solution and observe that more than one solution can be represented
2965    by a given pair of parameters. If `p` and ``q`` are not coprime, this is
2966    trivially true since the common factor will also be a common factor of the
2967    solution values. But it may also be true even when ``p`` and
2968    ``q`` are coprime:
2969
2970    >>> sol = Tuple(*_)
2971    >>> p, q = ordered(sol.free_symbols)
2972    >>> sol.subs([(p, 3), (q, 2)])
2973    (6, 12, 12)
2974    >>> sol.subs([(q, 1), (p, 1)])
2975    (-1, 2, 2)
2976    >>> sol.subs([(q, 0), (p, 1)])
2977    (2, -4, 4)
2978    >>> sol.subs([(q, 1), (p, 0)])
2979    (-3, -6, 6)
2980
2981    Except for sign and a common factor, these are equivalent to
2982    the solution of (1, 2, 2).
2983
2984    References
2985    ==========
2986
2987    .. [1] The algorithmic resolution of Diophantine equations, Nigel P. Smart,
2988           London Mathematical Society Student Texts 41, Cambridge University
2989           Press, Cambridge, 1998.
2990
2991    """
2992    var, coeff, diop_type = classify_diop(eq, _dict=False)
2993
2994    if diop_type in (
2995            "homogeneous_ternary_quadratic",
2996            "homogeneous_ternary_quadratic_normal"):
2997        x_0, y_0, z_0 = list(_diop_ternary_quadratic(var, coeff))[0]
2998        return _parametrize_ternary_quadratic(
2999            (x_0, y_0, z_0), var, coeff)
3000
3001
3002def _parametrize_ternary_quadratic(solution, _var, coeff):
3003    # called for a*x**2 + b*y**2 + c*z**2 + d*x*y + e*y*z + f*x*z = 0
3004    assert 1 not in coeff
3005
3006    x_0, y_0, z_0 = solution
3007
3008    v = list(_var)  # copy
3009
3010    if x_0 is None:
3011        return (None, None, None)
3012
3013    if solution.count(0) >= 2:
3014        # if there are 2 zeros the equation reduces
3015        # to k*X**2 == 0 where X is x, y, or z so X must
3016        # be zero, too. So there is only the trivial
3017        # solution.
3018        return (None, None, None)
3019
3020    if x_0 == 0:
3021        v[0], v[1] = v[1], v[0]
3022        y_p, x_p, z_p = _parametrize_ternary_quadratic(
3023            (y_0, x_0, z_0), v, coeff)
3024        return x_p, y_p, z_p
3025
3026    x, y, z = v
3027    r, p, q = symbols("r, p, q", integer=True)
3028
3029    eq = sum(k*v for k, v in coeff.items())
3030    eq_1 = _mexpand(eq.subs(zip(
3031        (x, y, z), (r*x_0, r*y_0 + p, r*z_0 + q))))
3032    A, B = eq_1.as_independent(r, as_Add=True)
3033
3034
3035    x = A*x_0
3036    y = (A*y_0 - _mexpand(B/r*p))
3037    z = (A*z_0 - _mexpand(B/r*q))
3038
3039    return _remove_gcd(x, y, z)
3040
3041
3042def diop_ternary_quadratic_normal(eq, parameterize=False):
3043    """
3044    Solves the quadratic ternary diophantine equation,
3045    `ax^2 + by^2 + cz^2 = 0`.
3046
3047    Explanation
3048    ===========
3049
3050    Here the coefficients `a`, `b`, and `c` should be non zero. Otherwise the
3051    equation will be a quadratic binary or univariate equation. If solvable,
3052    returns a tuple `(x, y, z)` that satisfies the given equation. If the
3053    equation does not have integer solutions, `(None, None, None)` is returned.
3054
3055    Usage
3056    =====
3057
3058    ``diop_ternary_quadratic_normal(eq)``: where ``eq`` is an equation of the form
3059    `ax^2 + by^2 + cz^2 = 0`.
3060
3061    Examples
3062    ========
3063
3064    >>> from sympy.abc import x, y, z
3065    >>> from sympy.solvers.diophantine.diophantine import diop_ternary_quadratic_normal
3066    >>> diop_ternary_quadratic_normal(x**2 + 3*y**2 - z**2)
3067    (1, 0, 1)
3068    >>> diop_ternary_quadratic_normal(4*x**2 + 5*y**2 - z**2)
3069    (1, 0, 2)
3070    >>> diop_ternary_quadratic_normal(34*x**2 - 3*y**2 - 301*z**2)
3071    (4, 9, 1)
3072    """
3073    var, coeff, diop_type = classify_diop(eq, _dict=False)
3074    if diop_type == HomogeneousTernaryQuadraticNormal.name:
3075        sol = _diop_ternary_quadratic_normal(var, coeff)
3076        if len(sol) > 0:
3077            x_0, y_0, z_0 = list(sol)[0]
3078        else:
3079            x_0, y_0, z_0 = None, None, None
3080        if parameterize:
3081            return _parametrize_ternary_quadratic(
3082                (x_0, y_0, z_0), var, coeff)
3083        return x_0, y_0, z_0
3084
3085
3086def _diop_ternary_quadratic_normal(var, coeff):
3087    eq = sum([i * coeff[i] for i in coeff])
3088    return HomogeneousTernaryQuadraticNormal(eq, free_symbols=var).solve()
3089
3090
3091def sqf_normal(a, b, c, steps=False):
3092    """
3093    Return `a', b', c'`, the coefficients of the square-free normal
3094    form of `ax^2 + by^2 + cz^2 = 0`, where `a', b', c'` are pairwise
3095    prime.  If `steps` is True then also return three tuples:
3096    `sq`, `sqf`, and `(a', b', c')` where `sq` contains the square
3097    factors of `a`, `b` and `c` after removing the `gcd(a, b, c)`;
3098    `sqf` contains the values of `a`, `b` and `c` after removing
3099    both the `gcd(a, b, c)` and the square factors.
3100
3101    The solutions for `ax^2 + by^2 + cz^2 = 0` can be
3102    recovered from the solutions of `a'x^2 + b'y^2 + c'z^2 = 0`.
3103
3104    Examples
3105    ========
3106
3107    >>> from sympy.solvers.diophantine.diophantine import sqf_normal
3108    >>> sqf_normal(2 * 3**2 * 5, 2 * 5 * 11, 2 * 7**2 * 11)
3109    (11, 1, 5)
3110    >>> sqf_normal(2 * 3**2 * 5, 2 * 5 * 11, 2 * 7**2 * 11, True)
3111    ((3, 1, 7), (5, 55, 11), (11, 1, 5))
3112
3113    References
3114    ==========
3115
3116    .. [1] Legendre's Theorem, Legrange's Descent,
3117           http://public.csusm.edu/aitken_html/notes/legendre.pdf
3118
3119
3120    See Also
3121    ========
3122
3123    reconstruct()
3124    """
3125    ABC = _remove_gcd(a, b, c)
3126    sq = tuple(square_factor(i) for i in ABC)
3127    sqf = A, B, C = tuple([i//j**2 for i,j in zip(ABC, sq)])
3128    pc = igcd(A, B)
3129    A /= pc
3130    B /= pc
3131    pa = igcd(B, C)
3132    B /= pa
3133    C /= pa
3134    pb = igcd(A, C)
3135    A /= pb
3136    B /= pb
3137
3138    A *= pa
3139    B *= pb
3140    C *= pc
3141
3142    if steps:
3143        return (sq, sqf, (A, B, C))
3144    else:
3145        return A, B, C
3146
3147
3148def square_factor(a):
3149    r"""
3150    Returns an integer `c` s.t. `a = c^2k, \ c,k \in Z`. Here `k` is square
3151    free. `a` can be given as an integer or a dictionary of factors.
3152
3153    Examples
3154    ========
3155
3156    >>> from sympy.solvers.diophantine.diophantine import square_factor
3157    >>> square_factor(24)
3158    2
3159    >>> square_factor(-36*3)
3160    6
3161    >>> square_factor(1)
3162    1
3163    >>> square_factor({3: 2, 2: 1, -1: 1})  # -18
3164    3
3165
3166    See Also
3167    ========
3168    sympy.ntheory.factor_.core
3169    """
3170    f = a if isinstance(a, dict) else factorint(a)
3171    return Mul(*[p**(e//2) for p, e in f.items()])
3172
3173
3174def reconstruct(A, B, z):
3175    """
3176    Reconstruct the `z` value of an equivalent solution of `ax^2 + by^2 + cz^2`
3177    from the `z` value of a solution of the square-free normal form of the
3178    equation, `a'*x^2 + b'*y^2 + c'*z^2`, where `a'`, `b'` and `c'` are square
3179    free and `gcd(a', b', c') == 1`.
3180    """
3181    f = factorint(igcd(A, B))
3182    for p, e in f.items():
3183        if e != 1:
3184            raise ValueError('a and b should be square-free')
3185        z *= p
3186    return z
3187
3188
3189def ldescent(A, B):
3190    """
3191    Return a non-trivial solution to `w^2 = Ax^2 + By^2` using
3192    Lagrange's method; return None if there is no such solution.
3193    .
3194
3195    Here, `A \\neq 0` and `B \\neq 0` and `A` and `B` are square free. Output a
3196    tuple `(w_0, x_0, y_0)` which is a solution to the above equation.
3197
3198    Examples
3199    ========
3200
3201    >>> from sympy.solvers.diophantine.diophantine import ldescent
3202    >>> ldescent(1, 1) # w^2 = x^2 + y^2
3203    (1, 1, 0)
3204    >>> ldescent(4, -7) # w^2 = 4x^2 - 7y^2
3205    (2, -1, 0)
3206
3207    This means that `x = -1, y = 0` and `w = 2` is a solution to the equation
3208    `w^2 = 4x^2 - 7y^2`
3209
3210    >>> ldescent(5, -1) # w^2 = 5x^2 - y^2
3211    (2, 1, -1)
3212
3213    References
3214    ==========
3215
3216    .. [1] The algorithmic resolution of Diophantine equations, Nigel P. Smart,
3217           London Mathematical Society Student Texts 41, Cambridge University
3218           Press, Cambridge, 1998.
3219    .. [2] Efficient Solution of Rational Conices, J. E. Cremona and D. Rusin,
3220           [online], Available:
3221           http://eprints.nottingham.ac.uk/60/1/kvxefz87.pdf
3222    """
3223    if abs(A) > abs(B):
3224        w, y, x = ldescent(B, A)
3225        return w, x, y
3226
3227    if A == 1:
3228        return (1, 1, 0)
3229
3230    if B == 1:
3231        return (1, 0, 1)
3232
3233    if B == -1:  # and A == -1
3234        return
3235
3236    r = sqrt_mod(A, B)
3237
3238    Q = (r**2 - A) // B
3239
3240    if Q == 0:
3241        B_0 = 1
3242        d = 0
3243    else:
3244        div = divisors(Q)
3245        B_0 = None
3246
3247        for i in div:
3248            sQ, _exact = integer_nthroot(abs(Q) // i, 2)
3249            if _exact:
3250                B_0, d = sign(Q)*i, sQ
3251                break
3252
3253    if B_0 is not None:
3254        W, X, Y = ldescent(A, B_0)
3255        return _remove_gcd((-A*X + r*W), (r*X - W), Y*(B_0*d))
3256
3257
3258def descent(A, B):
3259    """
3260    Returns a non-trivial solution, (x, y, z), to `x^2 = Ay^2 + Bz^2`
3261    using Lagrange's descent method with lattice-reduction. `A` and `B`
3262    are assumed to be valid for such a solution to exist.
3263
3264    This is faster than the normal Lagrange's descent algorithm because
3265    the Gaussian reduction is used.
3266
3267    Examples
3268    ========
3269
3270    >>> from sympy.solvers.diophantine.diophantine import descent
3271    >>> descent(3, 1) # x**2 = 3*y**2 + z**2
3272    (1, 0, 1)
3273
3274    `(x, y, z) = (1, 0, 1)` is a solution to the above equation.
3275
3276    >>> descent(41, -113)
3277    (-16, -3, 1)
3278
3279    References
3280    ==========
3281
3282    .. [1] Efficient Solution of Rational Conices, J. E. Cremona and D. Rusin,
3283           Mathematics of Computation, Volume 00, Number 0.
3284    """
3285    if abs(A) > abs(B):
3286        x, y, z = descent(B, A)
3287        return x, z, y
3288
3289    if B == 1:
3290        return (1, 0, 1)
3291    if A == 1:
3292        return (1, 1, 0)
3293    if B == -A:
3294        return (0, 1, 1)
3295    if B == A:
3296        x, z, y = descent(-1, A)
3297        return (A*y, z, x)
3298
3299    w = sqrt_mod(A, B)
3300    x_0, z_0 = gaussian_reduce(w, A, B)
3301
3302    t = (x_0**2 - A*z_0**2) // B
3303    t_2 = square_factor(t)
3304    t_1 = t // t_2**2
3305
3306    x_1, z_1, y_1 = descent(A, t_1)
3307
3308    return _remove_gcd(x_0*x_1 + A*z_0*z_1, z_0*x_1 + x_0*z_1, t_1*t_2*y_1)
3309
3310
3311def gaussian_reduce(w, a, b):
3312    r"""
3313    Returns a reduced solution `(x, z)` to the congruence
3314    `X^2 - aZ^2 \equiv 0 \ (mod \ b)` so that `x^2 + |a|z^2` is minimal.
3315
3316    Details
3317    =======
3318
3319    Here ``w`` is a solution of the congruence `x^2 \equiv a \ (mod \ b)`
3320
3321    References
3322    ==========
3323
3324    .. [1] Gaussian lattice Reduction [online]. Available:
3325           http://home.ie.cuhk.edu.hk/~wkshum/wordpress/?p=404
3326    .. [2] Efficient Solution of Rational Conices, J. E. Cremona and D. Rusin,
3327           Mathematics of Computation, Volume 00, Number 0.
3328    """
3329    u = (0, 1)
3330    v = (1, 0)
3331
3332    if dot(u, v, w, a, b) < 0:
3333        v = (-v[0], -v[1])
3334
3335    if norm(u, w, a, b) < norm(v, w, a, b):
3336        u, v = v, u
3337
3338    while norm(u, w, a, b) > norm(v, w, a, b):
3339        k = dot(u, v, w, a, b) // dot(v, v, w, a, b)
3340        u, v = v, (u[0]- k*v[0], u[1]- k*v[1])
3341
3342    u, v = v, u
3343
3344    if dot(u, v, w, a, b) < dot(v, v, w, a, b)/2 or norm((u[0]-v[0], u[1]-v[1]), w, a, b) > norm(v, w, a, b):
3345        c = v
3346    else:
3347        c = (u[0] - v[0], u[1] - v[1])
3348
3349    return c[0]*w + b*c[1], c[0]
3350
3351
3352def dot(u, v, w, a, b):
3353    r"""
3354    Returns a special dot product of the vectors `u = (u_{1}, u_{2})` and
3355    `v = (v_{1}, v_{2})` which is defined in order to reduce solution of
3356    the congruence equation `X^2 - aZ^2 \equiv 0 \ (mod \ b)`.
3357    """
3358    u_1, u_2 = u
3359    v_1, v_2 = v
3360    return (w*u_1 + b*u_2)*(w*v_1 + b*v_2) + abs(a)*u_1*v_1
3361
3362
3363def norm(u, w, a, b):
3364    r"""
3365    Returns the norm of the vector `u = (u_{1}, u_{2})` under the dot product
3366    defined by `u \cdot v = (wu_{1} + bu_{2})(w*v_{1} + bv_{2}) + |a|*u_{1}*v_{1}`
3367    where `u = (u_{1}, u_{2})` and `v = (v_{1}, v_{2})`.
3368    """
3369    u_1, u_2 = u
3370    return sqrt(dot((u_1, u_2), (u_1, u_2), w, a, b))
3371
3372
3373def holzer(x, y, z, a, b, c):
3374    r"""
3375    Simplify the solution `(x, y, z)` of the equation
3376    `ax^2 + by^2 = cz^2` with `a, b, c > 0` and `z^2 \geq \mid ab \mid` to
3377    a new reduced solution `(x', y', z')` such that `z'^2 \leq \mid ab \mid`.
3378
3379    The algorithm is an interpretation of Mordell's reduction as described
3380    on page 8 of Cremona and Rusin's paper [1]_ and the work of Mordell in
3381    reference [2]_.
3382
3383    References
3384    ==========
3385
3386    .. [1] Efficient Solution of Rational Conices, J. E. Cremona and D. Rusin,
3387           Mathematics of Computation, Volume 00, Number 0.
3388    .. [2] Diophantine Equations, L. J. Mordell, page 48.
3389
3390    """
3391
3392    if _odd(c):
3393        k = 2*c
3394    else:
3395        k = c//2
3396
3397    small = a*b*c
3398    step = 0
3399    while True:
3400        t1, t2, t3 = a*x**2, b*y**2, c*z**2
3401        # check that it's a solution
3402        if t1 + t2 != t3:
3403            if step == 0:
3404                raise ValueError('bad starting solution')
3405            break
3406        x_0, y_0, z_0 = x, y, z
3407        if max(t1, t2, t3) <= small:
3408            # Holzer condition
3409            break
3410
3411        uv = u, v = base_solution_linear(k, y_0, -x_0)
3412        if None in uv:
3413            break
3414
3415        p, q = -(a*u*x_0 + b*v*y_0), c*z_0
3416        r = Rational(p, q)
3417        if _even(c):
3418            w = _nint_or_floor(p, q)
3419            assert abs(w - r) <= S.Half
3420        else:
3421            w = p//q  # floor
3422            if _odd(a*u + b*v + c*w):
3423                w += 1
3424            assert abs(w - r) <= S.One
3425
3426        A = (a*u**2 + b*v**2 + c*w**2)
3427        B = (a*u*x_0 + b*v*y_0 + c*w*z_0)
3428        x = Rational(x_0*A - 2*u*B, k)
3429        y = Rational(y_0*A - 2*v*B, k)
3430        z = Rational(z_0*A - 2*w*B, k)
3431        assert all(i.is_Integer for i in (x, y, z))
3432        step += 1
3433
3434    return tuple([int(i) for i in (x_0, y_0, z_0)])
3435
3436
3437def diop_general_pythagorean(eq, param=symbols("m", integer=True)):
3438    """
3439    Solves the general pythagorean equation,
3440    `a_{1}^2x_{1}^2 + a_{2}^2x_{2}^2 + . . . + a_{n}^2x_{n}^2 - a_{n + 1}^2x_{n + 1}^2 = 0`.
3441
3442    Returns a tuple which contains a parametrized solution to the equation,
3443    sorted in the same order as the input variables.
3444
3445    Usage
3446    =====
3447
3448    ``diop_general_pythagorean(eq, param)``: where ``eq`` is a general
3449    pythagorean equation which is assumed to be zero and ``param`` is the base
3450    parameter used to construct other parameters by subscripting.
3451
3452    Examples
3453    ========
3454
3455    >>> from sympy.solvers.diophantine.diophantine import diop_general_pythagorean
3456    >>> from sympy.abc import a, b, c, d, e
3457    >>> diop_general_pythagorean(a**2 + b**2 + c**2 - d**2)
3458    (m1**2 + m2**2 - m3**2, 2*m1*m3, 2*m2*m3, m1**2 + m2**2 + m3**2)
3459    >>> diop_general_pythagorean(9*a**2 - 4*b**2 + 16*c**2 + 25*d**2 + e**2)
3460    (10*m1**2  + 10*m2**2  + 10*m3**2 - 10*m4**2, 15*m1**2  + 15*m2**2  + 15*m3**2  + 15*m4**2, 15*m1*m4, 12*m2*m4, 60*m3*m4)
3461    """
3462    var, coeff, diop_type  = classify_diop(eq, _dict=False)
3463
3464    if diop_type == GeneralPythagorean.name:
3465        if param is None:
3466            params = None
3467        else:
3468            params = symbols('%s1:%i' % (param, len(var)), integer=True)
3469        return list(GeneralPythagorean(eq).solve(parameters=params))[0]
3470
3471
3472def diop_general_sum_of_squares(eq, limit=1):
3473    r"""
3474    Solves the equation `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`.
3475
3476    Returns at most ``limit`` number of solutions.
3477
3478    Usage
3479    =====
3480
3481    ``general_sum_of_squares(eq, limit)`` : Here ``eq`` is an expression which
3482    is assumed to be zero. Also, ``eq`` should be in the form,
3483    `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`.
3484
3485    Details
3486    =======
3487
3488    When `n = 3` if `k = 4^a(8m + 7)` for some `a, m \in Z` then there will be
3489    no solutions. Refer [1]_ for more details.
3490
3491    Examples
3492    ========
3493
3494    >>> from sympy.solvers.diophantine.diophantine import diop_general_sum_of_squares
3495    >>> from sympy.abc import a, b, c, d, e
3496    >>> diop_general_sum_of_squares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345)
3497    {(15, 22, 22, 24, 24)}
3498
3499    Reference
3500    =========
3501
3502    .. [1] Representing an integer as a sum of three squares, [online],
3503        Available:
3504        http://www.proofwiki.org/wiki/Integer_as_Sum_of_Three_Squares
3505    """
3506    var, coeff, diop_type = classify_diop(eq, _dict=False)
3507
3508    if diop_type == GeneralSumOfSquares.name:
3509        return set(GeneralSumOfSquares(eq).solve(limit=limit))
3510
3511
3512def diop_general_sum_of_even_powers(eq, limit=1):
3513    """
3514    Solves the equation `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0`
3515    where `e` is an even, integer power.
3516
3517    Returns at most ``limit`` number of solutions.
3518
3519    Usage
3520    =====
3521
3522    ``general_sum_of_even_powers(eq, limit)`` : Here ``eq`` is an expression which
3523    is assumed to be zero. Also, ``eq`` should be in the form,
3524    `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0`.
3525
3526    Examples
3527    ========
3528
3529    >>> from sympy.solvers.diophantine.diophantine import diop_general_sum_of_even_powers
3530    >>> from sympy.abc import a, b
3531    >>> diop_general_sum_of_even_powers(a**4 + b**4 - (2**4 + 3**4))
3532    {(2, 3)}
3533
3534    See Also
3535    ========
3536
3537    power_representation
3538    """
3539    var, coeff, diop_type = classify_diop(eq, _dict=False)
3540
3541    if diop_type == GeneralSumOfEvenPowers.name:
3542        return set(GeneralSumOfEvenPowers(eq).solve(limit=limit))
3543
3544
3545## Functions below this comment can be more suitably grouped under
3546## an Additive number theory module rather than the Diophantine
3547## equation module.
3548
3549
3550def partition(n, k=None, zeros=False):
3551    """
3552    Returns a generator that can be used to generate partitions of an integer
3553    `n`.
3554
3555    Explanation
3556    ===========
3557
3558    A partition of `n` is a set of positive integers which add up to `n`. For
3559    example, partitions of 3 are 3, 1 + 2, 1 + 1 + 1. A partition is returned
3560    as a tuple. If ``k`` equals None, then all possible partitions are returned
3561    irrespective of their size, otherwise only the partitions of size ``k`` are
3562    returned. If the ``zero`` parameter is set to True then a suitable
3563    number of zeros are added at the end of every partition of size less than
3564    ``k``.
3565
3566    ``zero`` parameter is considered only if ``k`` is not None. When the
3567    partitions are over, the last `next()` call throws the ``StopIteration``
3568    exception, so this function should always be used inside a try - except
3569    block.
3570
3571    Details
3572    =======
3573
3574    ``partition(n, k)``: Here ``n`` is a positive integer and ``k`` is the size
3575    of the partition which is also positive integer.
3576
3577    Examples
3578    ========
3579
3580    >>> from sympy.solvers.diophantine.diophantine import partition
3581    >>> f = partition(5)
3582    >>> next(f)
3583    (1, 1, 1, 1, 1)
3584    >>> next(f)
3585    (1, 1, 1, 2)
3586    >>> g = partition(5, 3)
3587    >>> next(g)
3588    (1, 1, 3)
3589    >>> next(g)
3590    (1, 2, 2)
3591    >>> g = partition(5, 3, zeros=True)
3592    >>> next(g)
3593    (0, 0, 5)
3594
3595    """
3596    from sympy.utilities.iterables import ordered_partitions
3597    if not zeros or k is None:
3598        for i in ordered_partitions(n, k):
3599            yield tuple(i)
3600    else:
3601        for m in range(1, k + 1):
3602            for i in ordered_partitions(n, m):
3603                i = tuple(i)
3604                yield (0,)*(k - len(i)) + i
3605
3606
3607def prime_as_sum_of_two_squares(p):
3608    """
3609    Represent a prime `p` as a unique sum of two squares; this can
3610    only be done if the prime is congruent to 1 mod 4.
3611
3612    Examples
3613    ========
3614
3615    >>> from sympy.solvers.diophantine.diophantine import prime_as_sum_of_two_squares
3616    >>> prime_as_sum_of_two_squares(7)  # can't be done
3617    >>> prime_as_sum_of_two_squares(5)
3618    (1, 2)
3619
3620    Reference
3621    =========
3622
3623    .. [1] Representing a number as a sum of four squares, [online],
3624        Available: http://schorn.ch/lagrange.html
3625
3626    See Also
3627    ========
3628    sum_of_squares()
3629    """
3630    if not p % 4 == 1:
3631        return
3632
3633    if p % 8 == 5:
3634        b = 2
3635    else:
3636        b = 3
3637
3638        while pow(b, (p - 1) // 2, p) == 1:
3639            b = nextprime(b)
3640
3641    b = pow(b, (p - 1) // 4, p)
3642    a = p
3643
3644    while b**2 > p:
3645        a, b = b, a % b
3646
3647    return (int(a % b), int(b))  # convert from long
3648
3649
3650def sum_of_three_squares(n):
3651    r"""
3652    Returns a 3-tuple `(a, b, c)` such that `a^2 + b^2 + c^2 = n` and
3653    `a, b, c \geq 0`.
3654
3655    Returns None if `n = 4^a(8m + 7)` for some `a, m \in Z`. See
3656    [1]_ for more details.
3657
3658    Usage
3659    =====
3660
3661    ``sum_of_three_squares(n)``: Here ``n`` is a non-negative integer.
3662
3663    Examples
3664    ========
3665
3666    >>> from sympy.solvers.diophantine.diophantine import sum_of_three_squares
3667    >>> sum_of_three_squares(44542)
3668    (18, 37, 207)
3669
3670    References
3671    ==========
3672
3673    .. [1] Representing a number as a sum of three squares, [online],
3674        Available: http://schorn.ch/lagrange.html
3675
3676    See Also
3677    ========
3678
3679    sum_of_squares()
3680    """
3681    special = {1:(1, 0, 0), 2:(1, 1, 0), 3:(1, 1, 1), 10: (1, 3, 0), 34: (3, 3, 4), 58:(3, 7, 0),
3682        85:(6, 7, 0), 130:(3, 11, 0), 214:(3, 6, 13), 226:(8, 9, 9), 370:(8, 9, 15),
3683        526:(6, 7, 21), 706:(15, 15, 16), 730:(1, 27, 0), 1414:(6, 17, 33), 1906:(13, 21, 36),
3684        2986: (21, 32, 39), 9634: (56, 57, 57)}
3685
3686    v = 0
3687
3688    if n == 0:
3689        return (0, 0, 0)
3690
3691    v = multiplicity(4, n)
3692    n //= 4**v
3693
3694    if n % 8 == 7:
3695        return
3696
3697    if n in special.keys():
3698        x, y, z = special[n]
3699        return _sorted_tuple(2**v*x, 2**v*y, 2**v*z)
3700
3701    s, _exact = integer_nthroot(n, 2)
3702
3703    if _exact:
3704        return (2**v*s, 0, 0)
3705
3706    x = None
3707
3708    if n % 8 == 3:
3709        s = s if _odd(s) else s - 1
3710
3711        for x in range(s, -1, -2):
3712            N = (n - x**2) // 2
3713            if isprime(N):
3714                y, z = prime_as_sum_of_two_squares(N)
3715                return _sorted_tuple(2**v*x, 2**v*(y + z), 2**v*abs(y - z))
3716        return
3717
3718    if n % 8 == 2 or n % 8 == 6:
3719        s = s if _odd(s) else s - 1
3720    else:
3721        s = s - 1 if _odd(s) else s
3722
3723    for x in range(s, -1, -2):
3724        N = n - x**2
3725        if isprime(N):
3726            y, z = prime_as_sum_of_two_squares(N)
3727            return _sorted_tuple(2**v*x, 2**v*y, 2**v*z)
3728
3729
3730def sum_of_four_squares(n):
3731    r"""
3732    Returns a 4-tuple `(a, b, c, d)` such that `a^2 + b^2 + c^2 + d^2 = n`.
3733
3734    Here `a, b, c, d \geq 0`.
3735
3736    Usage
3737    =====
3738
3739    ``sum_of_four_squares(n)``: Here ``n`` is a non-negative integer.
3740
3741    Examples
3742    ========
3743
3744    >>> from sympy.solvers.diophantine.diophantine import sum_of_four_squares
3745    >>> sum_of_four_squares(3456)
3746    (8, 8, 32, 48)
3747    >>> sum_of_four_squares(1294585930293)
3748    (0, 1234, 2161, 1137796)
3749
3750    References
3751    ==========
3752
3753    .. [1] Representing a number as a sum of four squares, [online],
3754        Available: http://schorn.ch/lagrange.html
3755
3756    See Also
3757    ========
3758
3759    sum_of_squares()
3760    """
3761    if n == 0:
3762        return (0, 0, 0, 0)
3763
3764    v = multiplicity(4, n)
3765    n //= 4**v
3766
3767    if n % 8 == 7:
3768        d = 2
3769        n = n - 4
3770    elif n % 8 == 6 or n % 8 == 2:
3771        d = 1
3772        n = n - 1
3773    else:
3774        d = 0
3775
3776    x, y, z = sum_of_three_squares(n)
3777
3778    return _sorted_tuple(2**v*d, 2**v*x, 2**v*y, 2**v*z)
3779
3780
3781def power_representation(n, p, k, zeros=False):
3782    r"""
3783    Returns a generator for finding k-tuples of integers,
3784    `(n_{1}, n_{2}, . . . n_{k})`, such that
3785    `n = n_{1}^p + n_{2}^p + . . . n_{k}^p`.
3786
3787    Usage
3788    =====
3789
3790    ``power_representation(n, p, k, zeros)``: Represent non-negative number
3791    ``n`` as a sum of ``k`` ``p``\ th powers. If ``zeros`` is true, then the
3792    solutions is allowed to contain zeros.
3793
3794    Examples
3795    ========
3796
3797    >>> from sympy.solvers.diophantine.diophantine import power_representation
3798
3799    Represent 1729 as a sum of two cubes:
3800
3801    >>> f = power_representation(1729, 3, 2)
3802    >>> next(f)
3803    (9, 10)
3804    >>> next(f)
3805    (1, 12)
3806
3807    If the flag `zeros` is True, the solution may contain tuples with
3808    zeros; any such solutions will be generated after the solutions
3809    without zeros:
3810
3811    >>> list(power_representation(125, 2, 3, zeros=True))
3812    [(5, 6, 8), (3, 4, 10), (0, 5, 10), (0, 2, 11)]
3813
3814    For even `p` the `permute_sign` function can be used to get all
3815    signed values:
3816
3817    >>> from sympy.utilities.iterables import permute_signs
3818    >>> list(permute_signs((1, 12)))
3819    [(1, 12), (-1, 12), (1, -12), (-1, -12)]
3820
3821    All possible signed permutations can also be obtained:
3822
3823    >>> from sympy.utilities.iterables import signed_permutations
3824    >>> list(signed_permutations((1, 12)))
3825    [(1, 12), (-1, 12), (1, -12), (-1, -12), (12, 1), (-12, 1), (12, -1), (-12, -1)]
3826    """
3827    n, p, k = [as_int(i) for i in (n, p, k)]
3828
3829    if n < 0:
3830        if p % 2:
3831            for t in power_representation(-n, p, k, zeros):
3832                yield tuple(-i for i in t)
3833        return
3834
3835    if p < 1 or k < 1:
3836        raise ValueError(filldedent('''
3837    Expecting positive integers for `(p, k)`, but got `(%s, %s)`'''
3838    % (p, k)))
3839
3840    if n == 0:
3841        if zeros:
3842            yield (0,)*k
3843        return
3844
3845    if k == 1:
3846        if p == 1:
3847            yield (n,)
3848        else:
3849            be = perfect_power(n)
3850            if be:
3851                b, e = be
3852                d, r = divmod(e, p)
3853                if not r:
3854                    yield (b**d,)
3855        return
3856
3857    if p == 1:
3858        for t in partition(n, k, zeros=zeros):
3859            yield t
3860        return
3861
3862    if p == 2:
3863        feasible = _can_do_sum_of_squares(n, k)
3864        if not feasible:
3865            return
3866        if not zeros and n > 33 and k >= 5 and k <= n and n - k in (
3867                13, 10, 7, 5, 4, 2, 1):
3868            '''Todd G. Will, "When Is n^2 a Sum of k Squares?", [online].
3869                Available: https://www.maa.org/sites/default/files/Will-MMz-201037918.pdf'''
3870            return
3871        if feasible is not True:  # it's prime and k == 2
3872            yield prime_as_sum_of_two_squares(n)
3873            return
3874
3875    if k == 2 and p > 2:
3876        be = perfect_power(n)
3877        if be and be[1] % p == 0:
3878            return  # Fermat: a**n + b**n = c**n has no solution for n > 2
3879
3880    if n >= k:
3881        a = integer_nthroot(n - (k - 1), p)[0]
3882        for t in pow_rep_recursive(a, k, n, [], p):
3883            yield tuple(reversed(t))
3884
3885    if zeros:
3886        a = integer_nthroot(n, p)[0]
3887        for i in range(1, k):
3888            for t in pow_rep_recursive(a, i, n, [], p):
3889                yield tuple(reversed(t + (0,)*(k - i)))
3890
3891
3892sum_of_powers = power_representation
3893
3894
3895def pow_rep_recursive(n_i, k, n_remaining, terms, p):
3896
3897    if k == 0 and n_remaining == 0:
3898        yield tuple(terms)
3899    else:
3900        if n_i >= 1 and k > 0:
3901            yield from pow_rep_recursive(n_i - 1, k, n_remaining, terms, p)
3902            residual = n_remaining - pow(n_i, p)
3903            if residual >= 0:
3904                yield from pow_rep_recursive(n_i, k - 1, residual, terms + [n_i], p)
3905
3906
3907def sum_of_squares(n, k, zeros=False):
3908    """Return a generator that yields the k-tuples of nonnegative
3909    values, the squares of which sum to n. If zeros is False (default)
3910    then the solution will not contain zeros. The nonnegative
3911    elements of a tuple are sorted.
3912
3913    * If k == 1 and n is square, (n,) is returned.
3914
3915    * If k == 2 then n can only be written as a sum of squares if
3916      every prime in the factorization of n that has the form
3917      4*k + 3 has an even multiplicity. If n is prime then
3918      it can only be written as a sum of two squares if it is
3919      in the form 4*k + 1.
3920
3921    * if k == 3 then n can be written as a sum of squares if it does
3922      not have the form 4**m*(8*k + 7).
3923
3924    * all integers can be written as the sum of 4 squares.
3925
3926    * if k > 4 then n can be partitioned and each partition can
3927      be written as a sum of 4 squares; if n is not evenly divisible
3928      by 4 then n can be written as a sum of squares only if the
3929      an additional partition can be written as sum of squares.
3930      For example, if k = 6 then n is partitioned into two parts,
3931      the first being written as a sum of 4 squares and the second
3932      being written as a sum of 2 squares -- which can only be
3933      done if the condition above for k = 2 can be met, so this will
3934      automatically reject certain partitions of n.
3935
3936    Examples
3937    ========
3938
3939    >>> from sympy.solvers.diophantine.diophantine import sum_of_squares
3940    >>> list(sum_of_squares(25, 2))
3941    [(3, 4)]
3942    >>> list(sum_of_squares(25, 2, True))
3943    [(3, 4), (0, 5)]
3944    >>> list(sum_of_squares(25, 4))
3945    [(1, 2, 2, 4)]
3946
3947    See Also
3948    ========
3949
3950    sympy.utilities.iterables.signed_permutations
3951    """
3952    yield from power_representation(n, 2, k, zeros)
3953
3954
3955def _can_do_sum_of_squares(n, k):
3956    """Return True if n can be written as the sum of k squares,
3957    False if it can't, or 1 if ``k == 2`` and ``n`` is prime (in which
3958    case it *can* be written as a sum of two squares). A False
3959    is returned only if it can't be written as ``k``-squares, even
3960    if 0s are allowed.
3961    """
3962    if k < 1:
3963        return False
3964    if n < 0:
3965        return False
3966    if n == 0:
3967        return True
3968    if k == 1:
3969        return is_square(n)
3970    if k == 2:
3971        if n in (1, 2):
3972            return True
3973        if isprime(n):
3974            if n % 4 == 1:
3975                return 1  # signal that it was prime
3976            return False
3977        else:
3978            f = factorint(n)
3979            for p, m in f.items():
3980                # we can proceed iff no prime factor in the form 4*k + 3
3981                # has an odd multiplicity
3982                if (p % 4 == 3) and m % 2:
3983                    return False
3984            return True
3985    if k == 3:
3986        if (n//4**multiplicity(4, n)) % 8 == 7:
3987            return False
3988    # every number can be written as a sum of 4 squares; for k > 4 partitions
3989    # can be 0
3990    return True
3991