1from sympy.core import Basic, Dict, sympify
2from sympy.core.compatibility import as_int, default_sort_key
3from sympy.core.sympify import _sympify
4from sympy.functions.combinatorial.numbers import bell
5from sympy.matrices import zeros
6from sympy.sets.sets import FiniteSet, Union
7from sympy.utilities.iterables import flatten, group
8
9from collections import defaultdict
10
11
12class Partition(FiniteSet):
13    """
14    This class represents an abstract partition.
15
16    A partition is a set of disjoint sets whose union equals a given set.
17
18    See Also
19    ========
20
21    sympy.utilities.iterables.partitions,
22    sympy.utilities.iterables.multiset_partitions
23    """
24
25    _rank = None
26    _partition = None
27
28    def __new__(cls, *partition):
29        """
30        Generates a new partition object.
31
32        This method also verifies if the arguments passed are
33        valid and raises a ValueError if they are not.
34
35        Examples
36        ========
37
38        Creating Partition from Python lists:
39
40        >>> from sympy.combinatorics.partitions import Partition
41        >>> a = Partition([1, 2], [3])
42        >>> a
43        Partition({3}, {1, 2})
44        >>> a.partition
45        [[1, 2], [3]]
46        >>> len(a)
47        2
48        >>> a.members
49        (1, 2, 3)
50
51        Creating Partition from Python sets:
52
53        >>> Partition({1, 2, 3}, {4, 5})
54        Partition({4, 5}, {1, 2, 3})
55
56        Creating Partition from SymPy finite sets:
57
58        >>> from sympy.sets.sets import FiniteSet
59        >>> a = FiniteSet(1, 2, 3)
60        >>> b = FiniteSet(4, 5)
61        >>> Partition(a, b)
62        Partition({4, 5}, {1, 2, 3})
63        """
64        args = []
65        dups = False
66        for arg in partition:
67            if isinstance(arg, list):
68                as_set = set(arg)
69                if len(as_set) < len(arg):
70                    dups = True
71                    break  # error below
72                arg = as_set
73            args.append(_sympify(arg))
74
75        if not all(isinstance(part, FiniteSet) for part in args):
76            raise ValueError(
77                "Each argument to Partition should be " \
78                "a list, set, or a FiniteSet")
79
80        # sort so we have a canonical reference for RGS
81        U = Union(*args)
82        if dups or len(U) < sum(len(arg) for arg in args):
83            raise ValueError("Partition contained duplicate elements.")
84
85        obj = FiniteSet.__new__(cls, *args)
86        obj.members = tuple(U)
87        obj.size = len(U)
88        return obj
89
90    def sort_key(self, order=None):
91        """Return a canonical key that can be used for sorting.
92
93        Ordering is based on the size and sorted elements of the partition
94        and ties are broken with the rank.
95
96        Examples
97        ========
98
99        >>> from sympy.utilities.iterables import default_sort_key
100        >>> from sympy.combinatorics.partitions import Partition
101        >>> from sympy.abc import x
102        >>> a = Partition([1, 2])
103        >>> b = Partition([3, 4])
104        >>> c = Partition([1, x])
105        >>> d = Partition(list(range(4)))
106        >>> l = [d, b, a + 1, a, c]
107        >>> l.sort(key=default_sort_key); l
108        [Partition({1, 2}), Partition({1}, {2}), Partition({1, x}), Partition({3, 4}), Partition({0, 1, 2, 3})]
109        """
110        if order is None:
111            members = self.members
112        else:
113            members = tuple(sorted(self.members,
114                             key=lambda w: default_sort_key(w, order)))
115        return tuple(map(default_sort_key, (self.size, members, self.rank)))
116
117    @property
118    def partition(self):
119        """Return partition as a sorted list of lists.
120
121        Examples
122        ========
123
124        >>> from sympy.combinatorics.partitions import Partition
125        >>> Partition([1], [2, 3]).partition
126        [[1], [2, 3]]
127        """
128        if self._partition is None:
129            self._partition = sorted([sorted(p, key=default_sort_key)
130                                      for p in self.args])
131        return self._partition
132
133    def __add__(self, other):
134        """
135        Return permutation whose rank is ``other`` greater than current rank,
136        (mod the maximum rank for the set).
137
138        Examples
139        ========
140
141        >>> from sympy.combinatorics.partitions import Partition
142        >>> a = Partition([1, 2], [3])
143        >>> a.rank
144        1
145        >>> (a + 1).rank
146        2
147        >>> (a + 100).rank
148        1
149        """
150        other = as_int(other)
151        offset = self.rank + other
152        result = RGS_unrank((offset) %
153                            RGS_enum(self.size),
154                            self.size)
155        return Partition.from_rgs(result, self.members)
156
157    def __sub__(self, other):
158        """
159        Return permutation whose rank is ``other`` less than current rank,
160        (mod the maximum rank for the set).
161
162        Examples
163        ========
164
165        >>> from sympy.combinatorics.partitions import Partition
166        >>> a = Partition([1, 2], [3])
167        >>> a.rank
168        1
169        >>> (a - 1).rank
170        0
171        >>> (a - 100).rank
172        1
173        """
174        return self.__add__(-other)
175
176    def __le__(self, other):
177        """
178        Checks if a partition is less than or equal to
179        the other based on rank.
180
181        Examples
182        ========
183
184        >>> from sympy.combinatorics.partitions import Partition
185        >>> a = Partition([1, 2], [3, 4, 5])
186        >>> b = Partition([1], [2, 3], [4], [5])
187        >>> a.rank, b.rank
188        (9, 34)
189        >>> a <= a
190        True
191        >>> a <= b
192        True
193        """
194        return self.sort_key() <= sympify(other).sort_key()
195
196    def __lt__(self, other):
197        """
198        Checks if a partition is less than the other.
199
200        Examples
201        ========
202
203        >>> from sympy.combinatorics.partitions import Partition
204        >>> a = Partition([1, 2], [3, 4, 5])
205        >>> b = Partition([1], [2, 3], [4], [5])
206        >>> a.rank, b.rank
207        (9, 34)
208        >>> a < b
209        True
210        """
211        return self.sort_key() < sympify(other).sort_key()
212
213    @property
214    def rank(self):
215        """
216        Gets the rank of a partition.
217
218        Examples
219        ========
220
221        >>> from sympy.combinatorics.partitions import Partition
222        >>> a = Partition([1, 2], [3], [4, 5])
223        >>> a.rank
224        13
225        """
226        if self._rank is not None:
227            return self._rank
228        self._rank = RGS_rank(self.RGS)
229        return self._rank
230
231    @property
232    def RGS(self):
233        """
234        Returns the "restricted growth string" of the partition.
235
236        Explanation
237        ===========
238
239        The RGS is returned as a list of indices, L, where L[i] indicates
240        the block in which element i appears. For example, in a partition
241        of 3 elements (a, b, c) into 2 blocks ([c], [a, b]) the RGS is
242        [1, 1, 0]: "a" is in block 1, "b" is in block 1 and "c" is in block 0.
243
244        Examples
245        ========
246
247        >>> from sympy.combinatorics.partitions import Partition
248        >>> a = Partition([1, 2], [3], [4, 5])
249        >>> a.members
250        (1, 2, 3, 4, 5)
251        >>> a.RGS
252        (0, 0, 1, 2, 2)
253        >>> a + 1
254        Partition({3}, {4}, {5}, {1, 2})
255        >>> _.RGS
256        (0, 0, 1, 2, 3)
257        """
258        rgs = {}
259        partition = self.partition
260        for i, part in enumerate(partition):
261            for j in part:
262                rgs[j] = i
263        return tuple([rgs[i] for i in sorted(
264            [i for p in partition for i in p], key=default_sort_key)])
265
266    @classmethod
267    def from_rgs(self, rgs, elements):
268        """
269        Creates a set partition from a restricted growth string.
270
271        Explanation
272        ===========
273
274        The indices given in rgs are assumed to be the index
275        of the element as given in elements *as provided* (the
276        elements are not sorted by this routine). Block numbering
277        starts from 0. If any block was not referenced in ``rgs``
278        an error will be raised.
279
280        Examples
281        ========
282
283        >>> from sympy.combinatorics.partitions import Partition
284        >>> Partition.from_rgs([0, 1, 2, 0, 1], list('abcde'))
285        Partition({c}, {a, d}, {b, e})
286        >>> Partition.from_rgs([0, 1, 2, 0, 1], list('cbead'))
287        Partition({e}, {a, c}, {b, d})
288        >>> a = Partition([1, 4], [2], [3, 5])
289        >>> Partition.from_rgs(a.RGS, a.members)
290        Partition({2}, {1, 4}, {3, 5})
291        """
292        if len(rgs) != len(elements):
293            raise ValueError('mismatch in rgs and element lengths')
294        max_elem = max(rgs) + 1
295        partition = [[] for i in range(max_elem)]
296        j = 0
297        for i in rgs:
298            partition[i].append(elements[j])
299            j += 1
300        if not all(p for p in partition):
301            raise ValueError('some blocks of the partition were empty.')
302        return Partition(*partition)
303
304
305class IntegerPartition(Basic):
306    """
307    This class represents an integer partition.
308
309    Explanation
310    ===========
311
312    In number theory and combinatorics, a partition of a positive integer,
313    ``n``, also called an integer partition, is a way of writing ``n`` as a
314    list of positive integers that sum to n. Two partitions that differ only
315    in the order of summands are considered to be the same partition; if order
316    matters then the partitions are referred to as compositions. For example,
317    4 has five partitions: [4], [3, 1], [2, 2], [2, 1, 1], and [1, 1, 1, 1];
318    the compositions [1, 2, 1] and [1, 1, 2] are the same as partition
319    [2, 1, 1].
320
321    See Also
322    ========
323
324    sympy.utilities.iterables.partitions,
325    sympy.utilities.iterables.multiset_partitions
326
327    References
328    ==========
329
330    https://en.wikipedia.org/wiki/Partition_%28number_theory%29
331    """
332
333    _dict = None
334    _keys = None
335
336    def __new__(cls, partition, integer=None):
337        """
338        Generates a new IntegerPartition object from a list or dictionary.
339
340        Explantion
341        ==========
342
343        The partition can be given as a list of positive integers or a
344        dictionary of (integer, multiplicity) items. If the partition is
345        preceded by an integer an error will be raised if the partition
346        does not sum to that given integer.
347
348        Examples
349        ========
350
351        >>> from sympy.combinatorics.partitions import IntegerPartition
352        >>> a = IntegerPartition([5, 4, 3, 1, 1])
353        >>> a
354        IntegerPartition(14, (5, 4, 3, 1, 1))
355        >>> print(a)
356        [5, 4, 3, 1, 1]
357        >>> IntegerPartition({1:3, 2:1})
358        IntegerPartition(5, (2, 1, 1, 1))
359
360        If the value that the partition should sum to is given first, a check
361        will be made to see n error will be raised if there is a discrepancy:
362
363        >>> IntegerPartition(10, [5, 4, 3, 1])
364        Traceback (most recent call last):
365        ...
366        ValueError: The partition is not valid
367
368        """
369        if integer is not None:
370            integer, partition = partition, integer
371        if isinstance(partition, (dict, Dict)):
372            _ = []
373            for k, v in sorted(list(partition.items()), reverse=True):
374                if not v:
375                    continue
376                k, v = as_int(k), as_int(v)
377                _.extend([k]*v)
378            partition = tuple(_)
379        else:
380            partition = tuple(sorted(map(as_int, partition), reverse=True))
381        sum_ok = False
382        if integer is None:
383            integer = sum(partition)
384            sum_ok = True
385        else:
386            integer = as_int(integer)
387
388        if not sum_ok and sum(partition) != integer:
389            raise ValueError("Partition did not add to %s" % integer)
390        if any(i < 1 for i in partition):
391            raise ValueError("All integer summands must be greater than one")
392
393        obj = Basic.__new__(cls, integer, partition)
394        obj.partition = list(partition)
395        obj.integer = integer
396        return obj
397
398    def prev_lex(self):
399        """Return the previous partition of the integer, n, in lexical order,
400        wrapping around to [1, ..., 1] if the partition is [n].
401
402        Examples
403        ========
404
405        >>> from sympy.combinatorics.partitions import IntegerPartition
406        >>> p = IntegerPartition([4])
407        >>> print(p.prev_lex())
408        [3, 1]
409        >>> p.partition > p.prev_lex().partition
410        True
411        """
412        d = defaultdict(int)
413        d.update(self.as_dict())
414        keys = self._keys
415        if keys == [1]:
416            return IntegerPartition({self.integer: 1})
417        if keys[-1] != 1:
418            d[keys[-1]] -= 1
419            if keys[-1] == 2:
420                d[1] = 2
421            else:
422                d[keys[-1] - 1] = d[1] = 1
423        else:
424            d[keys[-2]] -= 1
425            left = d[1] + keys[-2]
426            new = keys[-2]
427            d[1] = 0
428            while left:
429                new -= 1
430                if left - new >= 0:
431                    d[new] += left//new
432                    left -= d[new]*new
433        return IntegerPartition(self.integer, d)
434
435    def next_lex(self):
436        """Return the next partition of the integer, n, in lexical order,
437        wrapping around to [n] if the partition is [1, ..., 1].
438
439        Examples
440        ========
441
442        >>> from sympy.combinatorics.partitions import IntegerPartition
443        >>> p = IntegerPartition([3, 1])
444        >>> print(p.next_lex())
445        [4]
446        >>> p.partition < p.next_lex().partition
447        True
448        """
449        d = defaultdict(int)
450        d.update(self.as_dict())
451        key = self._keys
452        a = key[-1]
453        if a == self.integer:
454            d.clear()
455            d[1] = self.integer
456        elif a == 1:
457            if d[a] > 1:
458                d[a + 1] += 1
459                d[a] -= 2
460            else:
461                b = key[-2]
462                d[b + 1] += 1
463                d[1] = (d[b] - 1)*b
464                d[b] = 0
465        else:
466            if d[a] > 1:
467                if len(key) == 1:
468                    d.clear()
469                    d[a + 1] = 1
470                    d[1] = self.integer - a - 1
471                else:
472                    a1 = a + 1
473                    d[a1] += 1
474                    d[1] = d[a]*a - a1
475                    d[a] = 0
476            else:
477                b = key[-2]
478                b1 = b + 1
479                d[b1] += 1
480                need = d[b]*b + d[a]*a - b1
481                d[a] = d[b] = 0
482                d[1] = need
483        return IntegerPartition(self.integer, d)
484
485    def as_dict(self):
486        """Return the partition as a dictionary whose keys are the
487        partition integers and the values are the multiplicity of that
488        integer.
489
490        Examples
491        ========
492
493        >>> from sympy.combinatorics.partitions import IntegerPartition
494        >>> IntegerPartition([1]*3 + [2] + [3]*4).as_dict()
495        {1: 3, 2: 1, 3: 4}
496        """
497        if self._dict is None:
498            groups = group(self.partition, multiple=False)
499            self._keys = [g[0] for g in groups]
500            self._dict = dict(groups)
501        return self._dict
502
503    @property
504    def conjugate(self):
505        """
506        Computes the conjugate partition of itself.
507
508        Examples
509        ========
510
511        >>> from sympy.combinatorics.partitions import IntegerPartition
512        >>> a = IntegerPartition([6, 3, 3, 2, 1])
513        >>> a.conjugate
514        [5, 4, 3, 1, 1, 1]
515        """
516        j = 1
517        temp_arr = list(self.partition) + [0]
518        k = temp_arr[0]
519        b = [0]*k
520        while k > 0:
521            while k > temp_arr[j]:
522                b[k - 1] = j
523                k -= 1
524            j += 1
525        return b
526
527    def __lt__(self, other):
528        """Return True if self is less than other when the partition
529        is listed from smallest to biggest.
530
531        Examples
532        ========
533
534        >>> from sympy.combinatorics.partitions import IntegerPartition
535        >>> a = IntegerPartition([3, 1])
536        >>> a < a
537        False
538        >>> b = a.next_lex()
539        >>> a < b
540        True
541        >>> a == b
542        False
543        """
544        return list(reversed(self.partition)) < list(reversed(other.partition))
545
546    def __le__(self, other):
547        """Return True if self is less than other when the partition
548        is listed from smallest to biggest.
549
550        Examples
551        ========
552
553        >>> from sympy.combinatorics.partitions import IntegerPartition
554        >>> a = IntegerPartition([4])
555        >>> a <= a
556        True
557        """
558        return list(reversed(self.partition)) <= list(reversed(other.partition))
559
560    def as_ferrers(self, char='#'):
561        """
562        Prints the ferrer diagram of a partition.
563
564        Examples
565        ========
566
567        >>> from sympy.combinatorics.partitions import IntegerPartition
568        >>> print(IntegerPartition([1, 1, 5]).as_ferrers())
569        #####
570        #
571        #
572        """
573        return "\n".join([char*i for i in self.partition])
574
575    def __str__(self):
576        return str(list(self.partition))
577
578
579def random_integer_partition(n, seed=None):
580    """
581    Generates a random integer partition summing to ``n`` as a list
582    of reverse-sorted integers.
583
584    Examples
585    ========
586
587    >>> from sympy.combinatorics.partitions import random_integer_partition
588
589    For the following, a seed is given so a known value can be shown; in
590    practice, the seed would not be given.
591
592    >>> random_integer_partition(100, seed=[1, 1, 12, 1, 2, 1, 85, 1])
593    [85, 12, 2, 1]
594    >>> random_integer_partition(10, seed=[1, 2, 3, 1, 5, 1])
595    [5, 3, 1, 1]
596    >>> random_integer_partition(1)
597    [1]
598    """
599    from sympy.testing.randtest import _randint
600
601    n = as_int(n)
602    if n < 1:
603        raise ValueError('n must be a positive integer')
604
605    randint = _randint(seed)
606
607    partition = []
608    while (n > 0):
609        k = randint(1, n)
610        mult = randint(1, n//k)
611        partition.append((k, mult))
612        n -= k*mult
613    partition.sort(reverse=True)
614    partition = flatten([[k]*m for k, m in partition])
615    return partition
616
617
618def RGS_generalized(m):
619    """
620    Computes the m + 1 generalized unrestricted growth strings
621    and returns them as rows in matrix.
622
623    Examples
624    ========
625
626    >>> from sympy.combinatorics.partitions import RGS_generalized
627    >>> RGS_generalized(6)
628    Matrix([
629    [  1,   1,   1,  1,  1, 1, 1],
630    [  1,   2,   3,  4,  5, 6, 0],
631    [  2,   5,  10, 17, 26, 0, 0],
632    [  5,  15,  37, 77,  0, 0, 0],
633    [ 15,  52, 151,  0,  0, 0, 0],
634    [ 52, 203,   0,  0,  0, 0, 0],
635    [203,   0,   0,  0,  0, 0, 0]])
636    """
637    d = zeros(m + 1)
638    for i in range(0, m + 1):
639        d[0, i] = 1
640
641    for i in range(1, m + 1):
642        for j in range(m):
643            if j <= m - i:
644                d[i, j] = j * d[i - 1, j] + d[i - 1, j + 1]
645            else:
646                d[i, j] = 0
647    return d
648
649
650def RGS_enum(m):
651    """
652    RGS_enum computes the total number of restricted growth strings
653    possible for a superset of size m.
654
655    Examples
656    ========
657
658    >>> from sympy.combinatorics.partitions import RGS_enum
659    >>> from sympy.combinatorics.partitions import Partition
660    >>> RGS_enum(4)
661    15
662    >>> RGS_enum(5)
663    52
664    >>> RGS_enum(6)
665    203
666
667    We can check that the enumeration is correct by actually generating
668    the partitions. Here, the 15 partitions of 4 items are generated:
669
670    >>> a = Partition(list(range(4)))
671    >>> s = set()
672    >>> for i in range(20):
673    ...     s.add(a)
674    ...     a += 1
675    ...
676    >>> assert len(s) == 15
677
678    """
679    if (m < 1):
680        return 0
681    elif (m == 1):
682        return 1
683    else:
684        return bell(m)
685
686
687def RGS_unrank(rank, m):
688    """
689    Gives the unranked restricted growth string for a given
690    superset size.
691
692    Examples
693    ========
694
695    >>> from sympy.combinatorics.partitions import RGS_unrank
696    >>> RGS_unrank(14, 4)
697    [0, 1, 2, 3]
698    >>> RGS_unrank(0, 4)
699    [0, 0, 0, 0]
700    """
701    if m < 1:
702        raise ValueError("The superset size must be >= 1")
703    if rank < 0 or RGS_enum(m) <= rank:
704        raise ValueError("Invalid arguments")
705
706    L = [1] * (m + 1)
707    j = 1
708    D = RGS_generalized(m)
709    for i in range(2, m + 1):
710        v = D[m - i, j]
711        cr = j*v
712        if cr <= rank:
713            L[i] = j + 1
714            rank -= cr
715            j += 1
716        else:
717            L[i] = int(rank / v + 1)
718            rank %= v
719    return [x - 1 for x in L[1:]]
720
721
722def RGS_rank(rgs):
723    """
724    Computes the rank of a restricted growth string.
725
726    Examples
727    ========
728
729    >>> from sympy.combinatorics.partitions import RGS_rank, RGS_unrank
730    >>> RGS_rank([0, 1, 2, 1, 3])
731    42
732    >>> RGS_rank(RGS_unrank(4, 7))
733    4
734    """
735    rgs_size = len(rgs)
736    rank = 0
737    D = RGS_generalized(rgs_size)
738    for i in range(1, rgs_size):
739        n = len(rgs[(i + 1):])
740        m = max(rgs[0:i])
741        rank += D[n, m + 1] * rgs[i]
742    return rank
743