1from sympy import (S, Dummy, Lambda, symbols, Interval, Intersection, Set, 2 EmptySet, FiniteSet, Union, ComplexRegion, Mul) 3from sympy.multipledispatch import dispatch 4from sympy.sets.conditionset import ConditionSet 5from sympy.sets.fancysets import (Integers, Naturals, Reals, Range, 6 ImageSet, Rationals) 7from sympy.sets.sets import UniversalSet, imageset, ProductSet 8from sympy.simplify.radsimp import numer 9 10@dispatch(ConditionSet, ConditionSet) # type: ignore # noqa:F811 11def intersection_sets(a, b): # noqa:F811 12 return None 13 14@dispatch(ConditionSet, Set) # type: ignore # noqa:F811 15def intersection_sets(a, b): # noqa:F811 16 return ConditionSet(a.sym, a.condition, Intersection(a.base_set, b)) 17 18@dispatch(Naturals, Integers) # type: ignore # noqa:F811 19def intersection_sets(a, b): # noqa:F811 20 return a 21 22@dispatch(Naturals, Naturals) # type: ignore # noqa:F811 23def intersection_sets(a, b): # noqa:F811 24 return a if a is S.Naturals else b 25 26@dispatch(Interval, Naturals) # type: ignore # noqa:F811 27def intersection_sets(a, b): # noqa:F811 28 return intersection_sets(b, a) 29 30@dispatch(ComplexRegion, Set) # type: ignore # noqa:F811 31def intersection_sets(self, other): # noqa:F811 32 if other.is_ComplexRegion: 33 # self in rectangular form 34 if (not self.polar) and (not other.polar): 35 return ComplexRegion(Intersection(self.sets, other.sets)) 36 37 # self in polar form 38 elif self.polar and other.polar: 39 r1, theta1 = self.a_interval, self.b_interval 40 r2, theta2 = other.a_interval, other.b_interval 41 new_r_interval = Intersection(r1, r2) 42 new_theta_interval = Intersection(theta1, theta2) 43 44 # 0 and 2*Pi means the same 45 if ((2*S.Pi in theta1 and S.Zero in theta2) or 46 (2*S.Pi in theta2 and S.Zero in theta1)): 47 new_theta_interval = Union(new_theta_interval, 48 FiniteSet(0)) 49 return ComplexRegion(new_r_interval*new_theta_interval, 50 polar=True) 51 52 53 if other.is_subset(S.Reals): 54 new_interval = [] 55 x = symbols("x", cls=Dummy, real=True) 56 57 # self in rectangular form 58 if not self.polar: 59 for element in self.psets: 60 if S.Zero in element.args[1]: 61 new_interval.append(element.args[0]) 62 new_interval = Union(*new_interval) 63 return Intersection(new_interval, other) 64 65 # self in polar form 66 elif self.polar: 67 for element in self.psets: 68 if S.Zero in element.args[1]: 69 new_interval.append(element.args[0]) 70 if S.Pi in element.args[1]: 71 new_interval.append(ImageSet(Lambda(x, -x), element.args[0])) 72 if S.Zero in element.args[0]: 73 new_interval.append(FiniteSet(0)) 74 new_interval = Union(*new_interval) 75 return Intersection(new_interval, other) 76 77@dispatch(Integers, Reals) # type: ignore # noqa:F811 78def intersection_sets(a, b): # noqa:F811 79 return a 80 81@dispatch(Range, Interval) # type: ignore # noqa:F811 82def intersection_sets(a, b): # noqa:F811 83 from sympy.functions.elementary.integers import floor, ceiling 84 if not all(i.is_number for i in b.args[:2]): 85 return 86 87 # In case of null Range, return an EmptySet. 88 if a.size == 0: 89 return S.EmptySet 90 91 # trim down to self's size, and represent 92 # as a Range with step 1. 93 start = ceiling(max(b.inf, a.inf)) 94 if start not in b: 95 start += 1 96 end = floor(min(b.sup, a.sup)) 97 if end not in b: 98 end -= 1 99 return intersection_sets(a, Range(start, end + 1)) 100 101@dispatch(Range, Naturals) # type: ignore # noqa:F811 102def intersection_sets(a, b): # noqa:F811 103 return intersection_sets(a, Interval(b.inf, S.Infinity)) 104 105@dispatch(Range, Range) # type: ignore # noqa:F811 106def intersection_sets(a, b): # noqa:F811 107 from sympy.solvers.diophantine.diophantine import diop_linear 108 from sympy.core.numbers import ilcm 109 from sympy import sign 110 111 # non-overlap quick exits 112 if not b: 113 return S.EmptySet 114 if not a: 115 return S.EmptySet 116 if b.sup < a.inf: 117 return S.EmptySet 118 if b.inf > a.sup: 119 return S.EmptySet 120 121 # work with finite end at the start 122 r1 = a 123 if r1.start.is_infinite: 124 r1 = r1.reversed 125 r2 = b 126 if r2.start.is_infinite: 127 r2 = r2.reversed 128 129 # If both ends are infinite then it means that one Range is just the set 130 # of all integers (the step must be 1). 131 if r1.start.is_infinite: 132 return b 133 if r2.start.is_infinite: 134 return a 135 136 # this equation represents the values of the Range; 137 # it's a linear equation 138 eq = lambda r, i: r.start + i*r.step 139 140 # we want to know when the two equations might 141 # have integer solutions so we use the diophantine 142 # solver 143 va, vb = diop_linear(eq(r1, Dummy('a')) - eq(r2, Dummy('b'))) 144 145 # check for no solution 146 no_solution = va is None and vb is None 147 if no_solution: 148 return S.EmptySet 149 150 # there is a solution 151 # ------------------- 152 153 # find the coincident point, c 154 a0 = va.as_coeff_Add()[0] 155 c = eq(r1, a0) 156 157 # find the first point, if possible, in each range 158 # since c may not be that point 159 def _first_finite_point(r1, c): 160 if c == r1.start: 161 return c 162 # st is the signed step we need to take to 163 # get from c to r1.start 164 st = sign(r1.start - c)*step 165 # use Range to calculate the first point: 166 # we want to get as close as possible to 167 # r1.start; the Range will not be null since 168 # it will at least contain c 169 s1 = Range(c, r1.start + st, st)[-1] 170 if s1 == r1.start: 171 pass 172 else: 173 # if we didn't hit r1.start then, if the 174 # sign of st didn't match the sign of r1.step 175 # we are off by one and s1 is not in r1 176 if sign(r1.step) != sign(st): 177 s1 -= st 178 if s1 not in r1: 179 return 180 return s1 181 182 # calculate the step size of the new Range 183 step = abs(ilcm(r1.step, r2.step)) 184 s1 = _first_finite_point(r1, c) 185 if s1 is None: 186 return S.EmptySet 187 s2 = _first_finite_point(r2, c) 188 if s2 is None: 189 return S.EmptySet 190 191 # replace the corresponding start or stop in 192 # the original Ranges with these points; the 193 # result must have at least one point since 194 # we know that s1 and s2 are in the Ranges 195 def _updated_range(r, first): 196 st = sign(r.step)*step 197 if r.start.is_finite: 198 rv = Range(first, r.stop, st) 199 else: 200 rv = Range(r.start, first + st, st) 201 return rv 202 r1 = _updated_range(a, s1) 203 r2 = _updated_range(b, s2) 204 205 # work with them both in the increasing direction 206 if sign(r1.step) < 0: 207 r1 = r1.reversed 208 if sign(r2.step) < 0: 209 r2 = r2.reversed 210 211 # return clipped Range with positive step; it 212 # can't be empty at this point 213 start = max(r1.start, r2.start) 214 stop = min(r1.stop, r2.stop) 215 return Range(start, stop, step) 216 217 218@dispatch(Range, Integers) # type: ignore # noqa:F811 219def intersection_sets(a, b): # noqa:F811 220 return a 221 222 223@dispatch(ImageSet, Set) # type: ignore # noqa:F811 224def intersection_sets(self, other): # noqa:F811 225 from sympy.solvers.diophantine import diophantine 226 227 # Only handle the straight-forward univariate case 228 if (len(self.lamda.variables) > 1 229 or self.lamda.signature != self.lamda.variables): 230 return None 231 base_set = self.base_sets[0] 232 233 # Intersection between ImageSets with Integers as base set 234 # For {f(n) : n in Integers} & {g(m) : m in Integers} we solve the 235 # diophantine equations f(n)=g(m). 236 # If the solutions for n are {h(t) : t in Integers} then we return 237 # {f(h(t)) : t in integers}. 238 # If the solutions for n are {n_1, n_2, ..., n_k} then we return 239 # {f(n_i) : 1 <= i <= k}. 240 if base_set is S.Integers: 241 gm = None 242 if isinstance(other, ImageSet) and other.base_sets == (S.Integers,): 243 gm = other.lamda.expr 244 var = other.lamda.variables[0] 245 # Symbol of second ImageSet lambda must be distinct from first 246 m = Dummy('m') 247 gm = gm.subs(var, m) 248 elif other is S.Integers: 249 m = gm = Dummy('m') 250 if gm is not None: 251 fn = self.lamda.expr 252 n = self.lamda.variables[0] 253 try: 254 solns = list(diophantine(fn - gm, syms=(n, m), permute=True)) 255 except (TypeError, NotImplementedError): 256 # TypeError if equation not polynomial with rational coeff. 257 # NotImplementedError if correct format but no solver. 258 return 259 # 3 cases are possible for solns: 260 # - empty set, 261 # - one or more parametric (infinite) solutions, 262 # - a finite number of (non-parametric) solution couples. 263 # Among those, there is one type of solution set that is 264 # not helpful here: multiple parametric solutions. 265 if len(solns) == 0: 266 return EmptySet 267 elif any(not isinstance(s, int) and s.free_symbols 268 for tupl in solns for s in tupl): 269 if len(solns) == 1: 270 soln, solm = solns[0] 271 (t,) = soln.free_symbols 272 expr = fn.subs(n, soln.subs(t, n)).expand() 273 return imageset(Lambda(n, expr), S.Integers) 274 else: 275 return 276 else: 277 return FiniteSet(*(fn.subs(n, s[0]) for s in solns)) 278 279 if other == S.Reals: 280 from sympy.core.function import expand_complex 281 from sympy.solvers.solvers import denoms, solve_linear 282 from sympy.core.relational import Eq 283 284 def _solution_union(exprs, sym): 285 # return a union of linear solutions to i in expr; 286 # if i cannot be solved, use a ConditionSet for solution 287 sols = [] 288 for i in exprs: 289 x, xis = solve_linear(i, 0, [sym]) 290 if x == sym: 291 sols.append(FiniteSet(xis)) 292 else: 293 sols.append(ConditionSet(sym, Eq(i, 0))) 294 return Union(*sols) 295 296 f = self.lamda.expr 297 n = self.lamda.variables[0] 298 299 n_ = Dummy(n.name, real=True) 300 f_ = f.subs(n, n_) 301 302 re, im = f_.as_real_imag() 303 im = expand_complex(im) 304 305 re = re.subs(n_, n) 306 im = im.subs(n_, n) 307 ifree = im.free_symbols 308 lam = Lambda(n, re) 309 if im.is_zero: 310 # allow re-evaluation 311 # of self in this case to make 312 # the result canonical 313 pass 314 elif im.is_zero is False: 315 return S.EmptySet 316 elif ifree != {n}: 317 return None 318 else: 319 # univarite imaginary part in same variable; 320 # use numer instead of as_numer_denom to keep 321 # this as fast as possible while still handling 322 # simple cases 323 base_set &= _solution_union( 324 Mul.make_args(numer(im)), n) 325 # exclude values that make denominators 0 326 base_set -= _solution_union(denoms(f), n) 327 return imageset(lam, base_set) 328 329 elif isinstance(other, Interval): 330 from sympy.solvers.solveset import (invert_real, invert_complex, 331 solveset) 332 333 f = self.lamda.expr 334 n = self.lamda.variables[0] 335 new_inf, new_sup = None, None 336 new_lopen, new_ropen = other.left_open, other.right_open 337 338 if f.is_real: 339 inverter = invert_real 340 else: 341 inverter = invert_complex 342 343 g1, h1 = inverter(f, other.inf, n) 344 g2, h2 = inverter(f, other.sup, n) 345 346 if all(isinstance(i, FiniteSet) for i in (h1, h2)): 347 if g1 == n: 348 if len(h1) == 1: 349 new_inf = h1.args[0] 350 if g2 == n: 351 if len(h2) == 1: 352 new_sup = h2.args[0] 353 # TODO: Design a technique to handle multiple-inverse 354 # functions 355 356 # Any of the new boundary values cannot be determined 357 if any(i is None for i in (new_sup, new_inf)): 358 return 359 360 361 range_set = S.EmptySet 362 363 if all(i.is_real for i in (new_sup, new_inf)): 364 # this assumes continuity of underlying function 365 # however fixes the case when it is decreasing 366 if new_inf > new_sup: 367 new_inf, new_sup = new_sup, new_inf 368 new_interval = Interval(new_inf, new_sup, new_lopen, new_ropen) 369 range_set = base_set.intersect(new_interval) 370 else: 371 if other.is_subset(S.Reals): 372 solutions = solveset(f, n, S.Reals) 373 if not isinstance(range_set, (ImageSet, ConditionSet)): 374 range_set = solutions.intersect(other) 375 else: 376 return 377 378 if range_set is S.EmptySet: 379 return S.EmptySet 380 elif isinstance(range_set, Range) and range_set.size is not S.Infinity: 381 range_set = FiniteSet(*list(range_set)) 382 383 if range_set is not None: 384 return imageset(Lambda(n, f), range_set) 385 return 386 else: 387 return 388 389 390@dispatch(ProductSet, ProductSet) # type: ignore # noqa:F811 391def intersection_sets(a, b): # noqa:F811 392 if len(b.args) != len(a.args): 393 return S.EmptySet 394 return ProductSet(*(i.intersect(j) for i, j in zip(a.sets, b.sets))) 395 396 397@dispatch(Interval, Interval) # type: ignore # noqa:F811 398def intersection_sets(a, b): # noqa:F811 399 # handle (-oo, oo) 400 infty = S.NegativeInfinity, S.Infinity 401 if a == Interval(*infty): 402 l, r = a.left, a.right 403 if l.is_real or l in infty or r.is_real or r in infty: 404 return b 405 406 # We can't intersect [0,3] with [x,6] -- we don't know if x>0 or x<0 407 if not a._is_comparable(b): 408 return None 409 410 empty = False 411 412 if a.start <= b.end and b.start <= a.end: 413 # Get topology right. 414 if a.start < b.start: 415 start = b.start 416 left_open = b.left_open 417 elif a.start > b.start: 418 start = a.start 419 left_open = a.left_open 420 else: 421 start = a.start 422 left_open = a.left_open or b.left_open 423 424 if a.end < b.end: 425 end = a.end 426 right_open = a.right_open 427 elif a.end > b.end: 428 end = b.end 429 right_open = b.right_open 430 else: 431 end = a.end 432 right_open = a.right_open or b.right_open 433 434 if end - start == 0 and (left_open or right_open): 435 empty = True 436 else: 437 empty = True 438 439 if empty: 440 return S.EmptySet 441 442 return Interval(start, end, left_open, right_open) 443 444@dispatch(type(EmptySet), Set) # type: ignore # noqa:F811 445def intersection_sets(a, b): # noqa:F811 446 return S.EmptySet 447 448@dispatch(UniversalSet, Set) # type: ignore # noqa:F811 449def intersection_sets(a, b): # noqa:F811 450 return b 451 452@dispatch(FiniteSet, FiniteSet) # type: ignore # noqa:F811 453def intersection_sets(a, b): # noqa:F811 454 return FiniteSet(*(a._elements & b._elements)) 455 456@dispatch(FiniteSet, Set) # type: ignore # noqa:F811 457def intersection_sets(a, b): # noqa:F811 458 try: 459 return FiniteSet(*[el for el in a if el in b]) 460 except TypeError: 461 return None # could not evaluate `el in b` due to symbolic ranges. 462 463@dispatch(Set, Set) # type: ignore # noqa:F811 464def intersection_sets(a, b): # noqa:F811 465 return None 466 467@dispatch(Integers, Rationals) # type: ignore # noqa:F811 468def intersection_sets(a, b): # noqa:F811 469 return a 470 471@dispatch(Naturals, Rationals) # type: ignore # noqa:F811 472def intersection_sets(a, b): # noqa:F811 473 return a 474 475@dispatch(Rationals, Reals) # type: ignore # noqa:F811 476def intersection_sets(a, b): # noqa:F811 477 return a 478 479def _intlike_interval(a, b): 480 try: 481 from sympy.functions.elementary.integers import floor, ceiling 482 if b._inf is S.NegativeInfinity and b._sup is S.Infinity: 483 return a 484 s = Range(max(a.inf, ceiling(b.left)), floor(b.right) + 1) 485 return intersection_sets(s, b) # take out endpoints if open interval 486 except ValueError: 487 return None 488 489@dispatch(Integers, Interval) # type: ignore # noqa:F811 490def intersection_sets(a, b): # noqa:F811 491 return _intlike_interval(a, b) 492 493@dispatch(Naturals, Interval) # type: ignore # noqa:F811 494def intersection_sets(a, b): # noqa:F811 495 return _intlike_interval(a, b) 496