1from sympy import (S, Dummy, Lambda, symbols, Interval, Intersection, Set,
2                   EmptySet, FiniteSet, Union, ComplexRegion, Mul)
3from sympy.multipledispatch import dispatch
4from sympy.sets.conditionset import ConditionSet
5from sympy.sets.fancysets import (Integers, Naturals, Reals, Range,
6    ImageSet, Rationals)
7from sympy.sets.sets import UniversalSet, imageset, ProductSet
8from sympy.simplify.radsimp import numer
9
10@dispatch(ConditionSet, ConditionSet)  # type: ignore # noqa:F811
11def intersection_sets(a, b): # noqa:F811
12    return None
13
14@dispatch(ConditionSet, Set)  # type: ignore # noqa:F811
15def intersection_sets(a, b): # noqa:F811
16    return ConditionSet(a.sym, a.condition, Intersection(a.base_set, b))
17
18@dispatch(Naturals, Integers)  # type: ignore # noqa:F811
19def intersection_sets(a, b): # noqa:F811
20    return a
21
22@dispatch(Naturals, Naturals)  # type: ignore # noqa:F811
23def intersection_sets(a, b): # noqa:F811
24    return a if a is S.Naturals else b
25
26@dispatch(Interval, Naturals)  # type: ignore # noqa:F811
27def intersection_sets(a, b): # noqa:F811
28    return intersection_sets(b, a)
29
30@dispatch(ComplexRegion, Set)  # type: ignore # noqa:F811
31def intersection_sets(self, other): # noqa:F811
32    if other.is_ComplexRegion:
33        # self in rectangular form
34        if (not self.polar) and (not other.polar):
35            return ComplexRegion(Intersection(self.sets, other.sets))
36
37        # self in polar form
38        elif self.polar and other.polar:
39            r1, theta1 = self.a_interval, self.b_interval
40            r2, theta2 = other.a_interval, other.b_interval
41            new_r_interval = Intersection(r1, r2)
42            new_theta_interval = Intersection(theta1, theta2)
43
44            # 0 and 2*Pi means the same
45            if ((2*S.Pi in theta1 and S.Zero in theta2) or
46               (2*S.Pi in theta2 and S.Zero in theta1)):
47                new_theta_interval = Union(new_theta_interval,
48                                           FiniteSet(0))
49            return ComplexRegion(new_r_interval*new_theta_interval,
50                                polar=True)
51
52
53    if other.is_subset(S.Reals):
54        new_interval = []
55        x = symbols("x", cls=Dummy, real=True)
56
57        # self in rectangular form
58        if not self.polar:
59            for element in self.psets:
60                if S.Zero in element.args[1]:
61                    new_interval.append(element.args[0])
62            new_interval = Union(*new_interval)
63            return Intersection(new_interval, other)
64
65        # self in polar form
66        elif self.polar:
67            for element in self.psets:
68                if S.Zero in element.args[1]:
69                    new_interval.append(element.args[0])
70                if S.Pi in element.args[1]:
71                    new_interval.append(ImageSet(Lambda(x, -x), element.args[0]))
72                if S.Zero in element.args[0]:
73                    new_interval.append(FiniteSet(0))
74            new_interval = Union(*new_interval)
75            return Intersection(new_interval, other)
76
77@dispatch(Integers, Reals)  # type: ignore # noqa:F811
78def intersection_sets(a, b): # noqa:F811
79    return a
80
81@dispatch(Range, Interval)  # type: ignore # noqa:F811
82def intersection_sets(a, b): # noqa:F811
83    from sympy.functions.elementary.integers import floor, ceiling
84    if not all(i.is_number for i in b.args[:2]):
85        return
86
87    # In case of null Range, return an EmptySet.
88    if a.size == 0:
89        return S.EmptySet
90
91    # trim down to self's size, and represent
92    # as a Range with step 1.
93    start = ceiling(max(b.inf, a.inf))
94    if start not in b:
95        start += 1
96    end = floor(min(b.sup, a.sup))
97    if end not in b:
98        end -= 1
99    return intersection_sets(a, Range(start, end + 1))
100
101@dispatch(Range, Naturals)  # type: ignore # noqa:F811
102def intersection_sets(a, b): # noqa:F811
103    return intersection_sets(a, Interval(b.inf, S.Infinity))
104
105@dispatch(Range, Range)  # type: ignore # noqa:F811
106def intersection_sets(a, b): # noqa:F811
107    from sympy.solvers.diophantine.diophantine import diop_linear
108    from sympy.core.numbers import ilcm
109    from sympy import sign
110
111    # non-overlap quick exits
112    if not b:
113        return S.EmptySet
114    if not a:
115        return S.EmptySet
116    if b.sup < a.inf:
117        return S.EmptySet
118    if b.inf > a.sup:
119        return S.EmptySet
120
121    # work with finite end at the start
122    r1 = a
123    if r1.start.is_infinite:
124        r1 = r1.reversed
125    r2 = b
126    if r2.start.is_infinite:
127        r2 = r2.reversed
128
129    # If both ends are infinite then it means that one Range is just the set
130    # of all integers (the step must be 1).
131    if r1.start.is_infinite:
132        return b
133    if r2.start.is_infinite:
134        return a
135
136    # this equation represents the values of the Range;
137    # it's a linear equation
138    eq = lambda r, i: r.start + i*r.step
139
140    # we want to know when the two equations might
141    # have integer solutions so we use the diophantine
142    # solver
143    va, vb = diop_linear(eq(r1, Dummy('a')) - eq(r2, Dummy('b')))
144
145    # check for no solution
146    no_solution = va is None and vb is None
147    if no_solution:
148        return S.EmptySet
149
150    # there is a solution
151    # -------------------
152
153    # find the coincident point, c
154    a0 = va.as_coeff_Add()[0]
155    c = eq(r1, a0)
156
157    # find the first point, if possible, in each range
158    # since c may not be that point
159    def _first_finite_point(r1, c):
160        if c == r1.start:
161            return c
162        # st is the signed step we need to take to
163        # get from c to r1.start
164        st = sign(r1.start - c)*step
165        # use Range to calculate the first point:
166        # we want to get as close as possible to
167        # r1.start; the Range will not be null since
168        # it will at least contain c
169        s1 = Range(c, r1.start + st, st)[-1]
170        if s1 == r1.start:
171            pass
172        else:
173            # if we didn't hit r1.start then, if the
174            # sign of st didn't match the sign of r1.step
175            # we are off by one and s1 is not in r1
176            if sign(r1.step) != sign(st):
177                s1 -= st
178        if s1 not in r1:
179            return
180        return s1
181
182    # calculate the step size of the new Range
183    step = abs(ilcm(r1.step, r2.step))
184    s1 = _first_finite_point(r1, c)
185    if s1 is None:
186        return S.EmptySet
187    s2 = _first_finite_point(r2, c)
188    if s2 is None:
189        return S.EmptySet
190
191    # replace the corresponding start or stop in
192    # the original Ranges with these points; the
193    # result must have at least one point since
194    # we know that s1 and s2 are in the Ranges
195    def _updated_range(r, first):
196        st = sign(r.step)*step
197        if r.start.is_finite:
198            rv = Range(first, r.stop, st)
199        else:
200            rv = Range(r.start, first + st, st)
201        return rv
202    r1 = _updated_range(a, s1)
203    r2 = _updated_range(b, s2)
204
205    # work with them both in the increasing direction
206    if sign(r1.step) < 0:
207        r1 = r1.reversed
208    if sign(r2.step) < 0:
209        r2 = r2.reversed
210
211    # return clipped Range with positive step; it
212    # can't be empty at this point
213    start = max(r1.start, r2.start)
214    stop = min(r1.stop, r2.stop)
215    return Range(start, stop, step)
216
217
218@dispatch(Range, Integers)  # type: ignore # noqa:F811
219def intersection_sets(a, b): # noqa:F811
220    return a
221
222
223@dispatch(ImageSet, Set)  # type: ignore # noqa:F811
224def intersection_sets(self, other): # noqa:F811
225    from sympy.solvers.diophantine import diophantine
226
227    # Only handle the straight-forward univariate case
228    if (len(self.lamda.variables) > 1
229            or self.lamda.signature != self.lamda.variables):
230        return None
231    base_set = self.base_sets[0]
232
233    # Intersection between ImageSets with Integers as base set
234    # For {f(n) : n in Integers} & {g(m) : m in Integers} we solve the
235    # diophantine equations f(n)=g(m).
236    # If the solutions for n are {h(t) : t in Integers} then we return
237    # {f(h(t)) : t in integers}.
238    # If the solutions for n are {n_1, n_2, ..., n_k} then we return
239    # {f(n_i) : 1 <= i <= k}.
240    if base_set is S.Integers:
241        gm = None
242        if isinstance(other, ImageSet) and other.base_sets == (S.Integers,):
243            gm = other.lamda.expr
244            var = other.lamda.variables[0]
245            # Symbol of second ImageSet lambda must be distinct from first
246            m = Dummy('m')
247            gm = gm.subs(var, m)
248        elif other is S.Integers:
249            m = gm = Dummy('m')
250        if gm is not None:
251            fn = self.lamda.expr
252            n = self.lamda.variables[0]
253            try:
254                solns = list(diophantine(fn - gm, syms=(n, m), permute=True))
255            except (TypeError, NotImplementedError):
256                # TypeError if equation not polynomial with rational coeff.
257                # NotImplementedError if correct format but no solver.
258                return
259            # 3 cases are possible for solns:
260            # - empty set,
261            # - one or more parametric (infinite) solutions,
262            # - a finite number of (non-parametric) solution couples.
263            # Among those, there is one type of solution set that is
264            # not helpful here: multiple parametric solutions.
265            if len(solns) == 0:
266                return EmptySet
267            elif any(not isinstance(s, int) and s.free_symbols
268                     for tupl in solns for s in tupl):
269                if len(solns) == 1:
270                    soln, solm = solns[0]
271                    (t,) = soln.free_symbols
272                    expr = fn.subs(n, soln.subs(t, n)).expand()
273                    return imageset(Lambda(n, expr), S.Integers)
274                else:
275                    return
276            else:
277                return FiniteSet(*(fn.subs(n, s[0]) for s in solns))
278
279    if other == S.Reals:
280        from sympy.core.function import expand_complex
281        from sympy.solvers.solvers import denoms, solve_linear
282        from sympy.core.relational import Eq
283
284        def _solution_union(exprs, sym):
285            # return a union of linear solutions to i in expr;
286            # if i cannot be solved, use a ConditionSet for solution
287            sols = []
288            for i in exprs:
289                x, xis = solve_linear(i, 0, [sym])
290                if x == sym:
291                    sols.append(FiniteSet(xis))
292                else:
293                    sols.append(ConditionSet(sym, Eq(i, 0)))
294            return Union(*sols)
295
296        f = self.lamda.expr
297        n = self.lamda.variables[0]
298
299        n_ = Dummy(n.name, real=True)
300        f_ = f.subs(n, n_)
301
302        re, im = f_.as_real_imag()
303        im = expand_complex(im)
304
305        re = re.subs(n_, n)
306        im = im.subs(n_, n)
307        ifree = im.free_symbols
308        lam = Lambda(n, re)
309        if im.is_zero:
310            # allow re-evaluation
311            # of self in this case to make
312            # the result canonical
313            pass
314        elif im.is_zero is False:
315            return S.EmptySet
316        elif ifree != {n}:
317            return None
318        else:
319            # univarite imaginary part in same variable;
320            # use numer instead of as_numer_denom to keep
321            # this as fast as possible while still handling
322            # simple cases
323            base_set &= _solution_union(
324                Mul.make_args(numer(im)), n)
325        # exclude values that make denominators 0
326        base_set -= _solution_union(denoms(f), n)
327        return imageset(lam, base_set)
328
329    elif isinstance(other, Interval):
330        from sympy.solvers.solveset import (invert_real, invert_complex,
331                                            solveset)
332
333        f = self.lamda.expr
334        n = self.lamda.variables[0]
335        new_inf, new_sup = None, None
336        new_lopen, new_ropen = other.left_open, other.right_open
337
338        if f.is_real:
339            inverter = invert_real
340        else:
341            inverter = invert_complex
342
343        g1, h1 = inverter(f, other.inf, n)
344        g2, h2 = inverter(f, other.sup, n)
345
346        if all(isinstance(i, FiniteSet) for i in (h1, h2)):
347            if g1 == n:
348                if len(h1) == 1:
349                    new_inf = h1.args[0]
350            if g2 == n:
351                if len(h2) == 1:
352                    new_sup = h2.args[0]
353            # TODO: Design a technique to handle multiple-inverse
354            # functions
355
356            # Any of the new boundary values cannot be determined
357            if any(i is None for i in (new_sup, new_inf)):
358                return
359
360
361            range_set = S.EmptySet
362
363            if all(i.is_real for i in (new_sup, new_inf)):
364                # this assumes continuity of underlying function
365                # however fixes the case when it is decreasing
366                if new_inf > new_sup:
367                    new_inf, new_sup = new_sup, new_inf
368                new_interval = Interval(new_inf, new_sup, new_lopen, new_ropen)
369                range_set = base_set.intersect(new_interval)
370            else:
371                if other.is_subset(S.Reals):
372                    solutions = solveset(f, n, S.Reals)
373                    if not isinstance(range_set, (ImageSet, ConditionSet)):
374                        range_set = solutions.intersect(other)
375                    else:
376                        return
377
378            if range_set is S.EmptySet:
379                return S.EmptySet
380            elif isinstance(range_set, Range) and range_set.size is not S.Infinity:
381                range_set = FiniteSet(*list(range_set))
382
383            if range_set is not None:
384                return imageset(Lambda(n, f), range_set)
385            return
386        else:
387            return
388
389
390@dispatch(ProductSet, ProductSet)  # type: ignore # noqa:F811
391def intersection_sets(a, b): # noqa:F811
392    if len(b.args) != len(a.args):
393        return S.EmptySet
394    return ProductSet(*(i.intersect(j) for i, j in zip(a.sets, b.sets)))
395
396
397@dispatch(Interval, Interval)  # type: ignore # noqa:F811
398def intersection_sets(a, b): # noqa:F811
399    # handle (-oo, oo)
400    infty = S.NegativeInfinity, S.Infinity
401    if a == Interval(*infty):
402        l, r = a.left, a.right
403        if l.is_real or l in infty or r.is_real or r in infty:
404            return b
405
406    # We can't intersect [0,3] with [x,6] -- we don't know if x>0 or x<0
407    if not a._is_comparable(b):
408        return None
409
410    empty = False
411
412    if a.start <= b.end and b.start <= a.end:
413        # Get topology right.
414        if a.start < b.start:
415            start = b.start
416            left_open = b.left_open
417        elif a.start > b.start:
418            start = a.start
419            left_open = a.left_open
420        else:
421            start = a.start
422            left_open = a.left_open or b.left_open
423
424        if a.end < b.end:
425            end = a.end
426            right_open = a.right_open
427        elif a.end > b.end:
428            end = b.end
429            right_open = b.right_open
430        else:
431            end = a.end
432            right_open = a.right_open or b.right_open
433
434        if end - start == 0 and (left_open or right_open):
435            empty = True
436    else:
437        empty = True
438
439    if empty:
440        return S.EmptySet
441
442    return Interval(start, end, left_open, right_open)
443
444@dispatch(type(EmptySet), Set)  # type: ignore # noqa:F811
445def intersection_sets(a, b): # noqa:F811
446    return S.EmptySet
447
448@dispatch(UniversalSet, Set)  # type: ignore # noqa:F811
449def intersection_sets(a, b): # noqa:F811
450    return b
451
452@dispatch(FiniteSet, FiniteSet)  # type: ignore # noqa:F811
453def intersection_sets(a, b): # noqa:F811
454    return FiniteSet(*(a._elements & b._elements))
455
456@dispatch(FiniteSet, Set)  # type: ignore # noqa:F811
457def intersection_sets(a, b): # noqa:F811
458    try:
459        return FiniteSet(*[el for el in a if el in b])
460    except TypeError:
461        return None  # could not evaluate `el in b` due to symbolic ranges.
462
463@dispatch(Set, Set)  # type: ignore # noqa:F811
464def intersection_sets(a, b): # noqa:F811
465    return None
466
467@dispatch(Integers, Rationals)  # type: ignore # noqa:F811
468def intersection_sets(a, b): # noqa:F811
469    return a
470
471@dispatch(Naturals, Rationals)  # type: ignore # noqa:F811
472def intersection_sets(a, b): # noqa:F811
473    return a
474
475@dispatch(Rationals, Reals)  # type: ignore # noqa:F811
476def intersection_sets(a, b): # noqa:F811
477    return a
478
479def _intlike_interval(a, b):
480    try:
481        from sympy.functions.elementary.integers import floor, ceiling
482        if b._inf is S.NegativeInfinity and b._sup is S.Infinity:
483            return a
484        s = Range(max(a.inf, ceiling(b.left)), floor(b.right) + 1)
485        return intersection_sets(s, b)  # take out endpoints if open interval
486    except ValueError:
487        return None
488
489@dispatch(Integers, Interval)  # type: ignore # noqa:F811
490def intersection_sets(a, b): # noqa:F811
491    return _intlike_interval(a, b)
492
493@dispatch(Naturals, Interval)  # type: ignore # noqa:F811
494def intersection_sets(a, b): # noqa:F811
495    return _intlike_interval(a, b)
496