1"""
2This file has the main algorithm for Slice.as_subindex(Slice)
3
4Since Integer can use the same algorithm via Slice(i, i+1), and IntegerArray
5needs to do this but in a way that only uses array friendly operations, we
6need to have this factored out into a separately callable function.
7
8TODO: we could remove the dependency on SymPy if we wanted to, by implementing
9the special cases for ilcm(a, b) and the Chinese Remainder Theorem for 2
10equations. It wouldn't be too bad (it just requires the extended gcd
11algorithm), but depending on SymPy also isn't a big deal for the time being.
12
13"""
14from numpy import broadcast_arrays, amin, amax, where
15
16def _crt(m1, m2, v1, v2):
17    """
18    Chinese Remainder Theorem
19
20    Returns x such that x = v1 (mod m1) and x = v2 (mod m2), or None if no
21    such solution exists.
22
23    """
24    # Avoid calling sympy_crt in the cases where the inputs would be arrays.
25    if m1 == 1:
26        return v2 % m2
27    if m2 == 1:
28        return v1 % m1
29
30    # Only import SymPy when necessary
31    from sympy.ntheory.modular import crt as sympy_crt
32
33    res = sympy_crt([m1, m2], [v1, v2])
34    if res is None:
35        return res
36    # Make sure the result isn't a gmpy2.mpz
37    return int(res[0])
38
39def _ilcm(a, b):
40    # Avoid calling sympy_ilcm in the cases where the inputs would be arrays.
41    if a == 1:
42        return b
43    if b == 1:
44        return a
45
46    # Only import SymPy when necessary
47    from sympy import ilcm as sympy_ilcm
48
49    return sympy_ilcm(a, b)
50
51def ceiling(a, b):
52    """
53    Returns ceil(a/b)
54    """
55    return -(-a//b)
56
57def _max(a, b):
58    if isinstance(a, int) and isinstance(b, int):
59        return max(a, b)
60    return amax(broadcast_arrays(a, b), axis=0)
61
62def _min(a, b):
63    if isinstance(a, int) and isinstance(b, int):
64        return min(a, b)
65    return amin(broadcast_arrays(a, b), axis=0)
66
67def _smallest(x, a, m):
68    """
69    Gives the smallest integer >= x that equals a (mod m)
70
71    Assumes x >= 0, m >= 1, and 0 <= a < m.
72    """
73    n = ceiling(x - a, m)
74    return a + n*m
75
76def subindex_slice(s_start, s_stop, s_step, i_start, i_stop, i_step):
77    """
78    Computes s.as_subindex(i) for slices s and i in a way that is (mostly)
79    compatible with NumPy arrays.
80
81    Returns (start, stop, step).
82
83    """
84    # Chinese Remainder Theorem. We are looking for a solution to
85    #
86    # x = s.start (mod s.step)
87    # x = index.start (mod index.step)
88    #
89    # If crt() returns None, then there are no solutions (the slices do
90    # not overlap).
91    common = _crt(s_step, i_step, s_start, i_start)
92
93    if common is None:
94        return (0, 0, 1)
95    lcm = _ilcm(s_step, i_step)
96    start = _max(s_start, i_start)
97
98    # Get the smallest lcm multiple of common that is >= start
99    start = _smallest(start, common, lcm)
100    # Finally, we need to shift start so that it is relative to index
101    start = (start - i_start)//i_step
102
103    stop = ceiling((_min(s_stop, i_stop) - i_start), i_step)
104    stop = where(stop < 0, 0, stop)
105
106    step = lcm//i_step # = s_step//igcd(s_step, i_step)
107
108    return (start, stop, step)
109