1"""Special sets."""
2
3from ..core import Basic, Expr, Integer, Lambda, Rational, S, oo
4from ..core.compatibility import as_int
5from ..core.singleton import Singleton
6from ..core.sympify import converter, sympify
7from ..logic import false, true
8from ..utilities.iterables import cantor_product
9from .sets import EmptySet, FiniteSet, Intersection, Interval, Set
10
11
12class Naturals(Set, metaclass=Singleton):
13    """The set of natural numbers.
14
15    Represents the natural numbers (or counting numbers) which are all
16    positive integers starting from 1. This set is also available as
17    the Singleton, S.Naturals.
18
19    Examples
20    ========
21
22    >>> 5 in S.Naturals
23    True
24    >>> iterable = iter(S.Naturals)
25    >>> next(iterable)
26    1
27    >>> next(iterable)
28    2
29    >>> next(iterable)
30    3
31    >>> S.Naturals.intersection(Interval(0, 10))
32    Range(1, 11, 1)
33
34    See Also
35    ========
36
37    Naturals0 : non-negative integers
38    Integers : also includes negative integers
39
40    """
41
42    is_iterable = True
43    inf = Integer(1)
44    sup = oo
45
46    def _intersection(self, other):
47        if other.is_Interval:
48            return Intersection(
49                S.Integers, other, Interval(self.inf, oo, False, True))
50
51    def _contains(self, other):
52        if not isinstance(other, Expr):
53            return false
54        elif other.is_positive and other.is_integer:
55            return true
56        elif other.is_integer is False or other.is_positive is False:
57            return false
58
59    def __iter__(self):
60        i = self.inf
61        while True:
62            yield i
63            i = i + 1
64
65    @property
66    def boundary(self):
67        return self
68
69
70class Naturals0(Naturals):
71    """The set of natural numbers, starting from 0.
72
73    Represents the whole numbers which are all the non-negative
74    integers, inclusive of zero.
75
76    See Also
77    ========
78
79    Naturals : positive integers
80    Integers : also includes the negative integers
81
82    """
83
84    inf = Integer(0)
85
86    def _contains(self, other):
87        if not isinstance(other, Expr):
88            return false
89        elif other.is_integer and other.is_nonnegative:
90            return true
91        elif other.is_integer is False or other.is_nonnegative is False:
92            return false
93
94
95class Integers(Set, metaclass=Singleton):
96    """The set of all integers.
97
98    Represents all integers: positive, negative and zero. This set
99    is also available as the Singleton, S.Integers.
100
101    Examples
102    ========
103
104    >>> 5 in S.Naturals
105    True
106    >>> iterable = iter(S.Integers)
107    >>> next(iterable)
108    0
109    >>> next(iterable)
110    1
111    >>> next(iterable)
112    -1
113    >>> next(iterable)
114    2
115
116    >>> S.Integers.intersection(Interval(-4, 4))
117    Range(-4, 5, 1)
118
119    See Also
120    ========
121
122    Naturals0 : non-negative integers
123    Integers : positive and negative integers and zero
124
125    """
126
127    is_iterable = True
128
129    def _intersection(self, other):
130        from ..functions import ceiling, floor
131        if other is Interval(-oo, oo, True, True) or other is S.Reals:
132            return self
133        elif other.is_Interval:
134            s = Range(ceiling(other.left), floor(other.right) + 1)
135            return s.intersection(other)  # take out endpoints if open interval
136
137    def _contains(self, other):
138        if not isinstance(other, Expr):
139            return false
140        elif other.is_integer:
141            return true
142        elif other.is_integer is False:
143            return false
144
145    def __iter__(self):
146        yield Integer(0)
147        i = Integer(1)
148        while True:
149            yield i
150            yield -i
151            i = i + 1
152
153    @property
154    def inf(self):
155        return -oo
156
157    @property
158    def sup(self):
159        return oo
160
161    @property
162    def boundary(self):
163        return self
164
165    def _eval_imageset(self, f):
166        from ..core import Wild
167        expr = f.expr
168        if len(f.variables) > 1:
169            return
170        n = f.variables[0]
171
172        a = Wild('a')
173        b = Wild('b')
174
175        match = expr.match(a*n + b)
176        if match[a].is_negative:
177            expr = -expr
178
179        match = expr.match(a*n + b)
180        if match[a] == 1 and match[b].is_integer:
181            expr = expr - match[b]
182
183        return ImageSet(Lambda(n, expr), S.Integers)
184
185
186class Rationals(Set, metaclass=Singleton):
187    """The set of all rationals."""
188
189    def _contains(self, other):
190        if other.is_rational:
191            return true
192        elif other.is_rational is False:
193            return false
194
195    @property
196    def inf(self):
197        return -oo
198
199    @property
200    def sup(self):
201        return oo
202
203    @property
204    def boundary(self):
205        return self
206
207    def __iter__(self):
208        seen = []
209        pairs = cantor_product(S.Integers, S.Naturals)
210        while True:
211            n, d = next(pairs)
212            r = Rational(n, d)
213            if r not in seen:
214                seen.append(r)
215                yield r
216
217
218class Reals(Interval, metaclass=Singleton):
219    """The set of all reals."""
220
221    def __new__(cls):
222        return Interval.__new__(cls, -oo, oo, True, True)
223
224    def __eq__(self, other):
225        return other == Interval(-oo, oo, True, True)
226
227    def __hash__(self):
228        return hash(Interval(-oo, oo, True, True))
229
230
231class ExtendedReals(Interval, metaclass=Singleton):
232    """The set of all extended reals."""
233
234    def __new__(cls):
235        return Interval.__new__(cls, -oo, oo)
236
237    def __eq__(self, other):
238        return other == Interval(-oo, oo)
239
240    def __hash__(self):
241        return hash(Interval(-oo, oo))
242
243
244class ImageSet(Set):
245    """Image of a set under a mathematical function.
246
247    Examples
248    ========
249
250    >>> squares = ImageSet(Lambda(x, x**2), S.Naturals)
251    >>> 4 in squares
252    True
253    >>> 5 in squares
254    False
255
256    >>> FiniteSet(0, 1, 2, 3, 4, 5, 6, 7, 9, 10).intersection(squares)
257    {1, 4, 9}
258
259    >>> square_iterable = iter(squares)
260    >>> for i in range(4):
261    ...     next(square_iterable)
262    1
263    4
264    9
265    16
266
267    If you want to get value for `x` = 2, 1/2 etc. (Please check whether the
268    `x` value is in `base_set` or not before passing it as args)
269
270    >>> squares.lamda(2)
271    4
272    >>> squares.lamda(Rational(1, 2))
273    1/4
274
275    """
276
277    def __new__(cls, lamda, base_set):
278        return Basic.__new__(cls, lamda, base_set)
279
280    lamda = property(lambda self: self.args[0])
281    base_set = property(lambda self: self.args[1])
282
283    def __iter__(self):
284        already_seen = set()
285        for i in self.base_set:
286            val = self.lamda(i)
287            if val in already_seen:
288                continue
289            else:
290                already_seen.add(val)
291                yield val
292
293    def _contains(self, other):
294        from ..solvers import solve
295
296        L = self.lamda
297        if len(self.lamda.variables) > 1:
298            return  # pragma: no cover
299
300        solns = solve(L.expr - other, L.variables[0])
301
302        for soln in solns:
303            if soln[L.variables[0]] in self.base_set:
304                return true
305        return false
306
307    @property
308    def is_iterable(self):
309        return self.base_set.is_iterable
310
311    def _intersection(self, other):
312        from ..core import Dummy, expand_complex
313        from ..solvers.diophantine import diophantine
314        from .sets import imageset
315        if self.base_set is S.Integers:
316            if isinstance(other, ImageSet) and other.base_set is S.Integers:
317                f, g = self.lamda.expr, other.lamda.expr
318                n, m = self.lamda.variables[0], other.lamda.variables[0]
319
320                # Diophantine sorts the solutions according to the alphabetic
321                # order of the variable names, since the result should not depend
322                # on the variable name, they are replaced by the dummy variables
323                # below
324                a, b = Dummy('a'), Dummy('b')
325                f, g = f.subs({n: a}), g.subs({m: b})
326                solns_set = diophantine(f - g)
327                if solns_set == set():
328                    return EmptySet()
329                solns = list(diophantine(f - g))
330                if len(solns) == 1:
331                    t = list(solns[0][0].free_symbols)[0]
332                else:
333                    return  # pragma: no cover
334
335                # since 'a' < 'b'
336                return imageset(Lambda(t, f.subs({a: solns[0][0]})), S.Integers)
337
338        if other == S.Reals:
339            if len(self.lamda.variables) > 1 or self.base_set is not S.Integers:
340                return  # pragma: no cover
341
342            f = self.lamda.expr
343            n = self.lamda.variables[0]
344
345            n_ = Dummy(n.name, integer=True)
346            f_ = f.subs({n: n_})
347
348            re, im = map(expand_complex, f_.as_real_imag())
349
350            sols = list(diophantine(im, n_))
351            if all(s[0].has(n_) is False for s in sols):
352                s = FiniteSet(*[s[0] for s in sols])
353            elif len(sols) == 1 and sols[0][0].has(n_):
354                s = imageset(Lambda(n_, sols[0][0]), S.Integers)
355            else:
356                return  # pragma: no cover
357
358            return imageset(Lambda(n_, re), self.base_set.intersection(s))
359
360
361class Range(Set):
362    """Represents a range of integers.
363
364    Examples
365    ========
366
367    >>> list(Range(5))
368    [0, 1, 2, 3, 4]
369    >>> list(Range(10, 15))
370    [10, 11, 12, 13, 14]
371    >>> list(Range(10, 20, 2))
372    [10, 12, 14, 16, 18]
373    >>> list(Range(20, 10, -2))
374    [12, 14, 16, 18, 20]
375
376    """
377
378    is_iterable = True
379
380    def __new__(cls, *args):
381        from ..functions import ceiling
382        if len(args) == 1 and isinstance(args[0], range):
383            args = args[0].start, args[0].stop, args[0].step
384
385        # expand range
386        slc = slice(*args)
387        start, stop, step = slc.start or 0, slc.stop, slc.step or 1
388        try:
389            start, stop, step = [w if w in [-oo, oo] else Integer(as_int(w))
390                                 for w in (start, stop, step)]
391        except ValueError:
392            raise ValueError('Inputs to Range must be Integer Valued\n' +
393                             'Use ImageSets of Ranges for other cases')
394
395        if not step.is_finite:
396            raise ValueError('Infinite step is not allowed')
397        elif start == stop:
398            return S.EmptySet
399
400        n = ceiling((stop - start)/step)
401        if n <= 0:
402            return S.EmptySet
403
404        # normalize args: regardless of how they are entered they will show
405        # canonically as Range(inf, sup, step) with step > 0
406        if n.is_finite:
407            start, stop = sorted((start, start + (n - 1)*step))
408        else:
409            start, stop = sorted((start, stop - step))
410
411        step = abs(step)
412
413        return Basic.__new__(cls, start, stop + step, step)
414
415    start = property(lambda self: self.args[0])
416    stop = property(lambda self: self.args[1])
417    step = property(lambda self: self.args[2])
418
419    def _intersection(self, other):
420        from ..functions import Max, Min, ceiling, floor
421        if other.is_Interval:
422            osup = other.sup
423            oinf = other.inf
424            # if other is [0, 10) we can only go up to 9
425            if osup.is_integer and other.right_open:
426                osup -= 1
427            if oinf.is_integer and other.left_open:
428                oinf += 1
429
430            # Take the most restrictive of the bounds set by the two sets
431            # round inwards
432            inf = ceiling(Max(self.inf, oinf))
433            sup = floor(Min(self.sup, osup))
434            # if we are off the sequence, get back on
435            if inf.is_finite and self.inf.is_finite:
436                off = (inf - self.inf) % self.step
437            else:
438                off = Integer(0)
439            if off:
440                inf += self.step - off
441
442            return Range(inf, sup + 1, self.step)
443
444        if other == S.Naturals:
445            return self._intersection(Interval(1, oo, False, True))
446
447        if other == S.Integers:
448            return self
449
450    def _contains(self, other):
451        if (((self.start - other)/self.step).is_integer or
452                ((self.stop - other)/self.step).is_integer):
453            return sympify(other >= self.inf and other <= self.sup, strict=True)
454        else:
455            return false
456
457    def __iter__(self):
458        if self.start == -oo:
459            i = self.stop - self.step
460            step = -self.step
461        else:
462            i = self.start
463            step = self.step
464
465        while(i < self.stop and i >= self.start):
466            yield i
467            i += step
468
469    def __len__(self):
470        return int((self.stop - self.start)//self.step)
471
472    def __bool__(self):
473        return True
474
475    def _ith_element(self, i):
476        return self.start + i*self.step
477
478    @property
479    def _last_element(self):
480        if self.stop is oo:
481            return oo
482        elif self.start == -oo:
483            return self.stop - self.step
484        else:
485            return self._ith_element(len(self) - 1)
486
487    @property
488    def inf(self):
489        return self.start
490
491    @property
492    def sup(self):
493        return self.stop - self.step
494
495    @property
496    def boundary(self):
497        return self
498
499
500converter[range] = Range
501