1from sympy import Set, symbols
2from sympy.core import Basic, Expr
3from sympy.multipledispatch import dispatch
4from sympy.sets import Interval
5
6_x, _y = symbols("x y")
7
8
9@dispatch(Basic, Basic)  # type: ignore # noqa:F811
10def _set_mul(x, y): # noqa:F811
11    return None
12
13@dispatch(Set, Set)  # type: ignore # noqa:F811
14def _set_mul(x, y): # noqa:F811
15    return None
16
17@dispatch(Expr, Expr)  # type: ignore # noqa:F811
18def _set_mul(x, y): # noqa:F811
19    return x*y
20
21@dispatch(Interval, Interval)  # type: ignore # noqa:F811
22def _set_mul(x, y): # noqa:F811
23    """
24    Multiplications in interval arithmetic
25    https://en.wikipedia.org/wiki/Interval_arithmetic
26    """
27    # TODO: some intervals containing 0 and oo will fail as 0*oo returns nan.
28    comvals = (
29        (x.start * y.start, bool(x.left_open or y.left_open)),
30        (x.start * y.end, bool(x.left_open or y.right_open)),
31        (x.end * y.start, bool(x.right_open or y.left_open)),
32        (x.end * y.end, bool(x.right_open or y.right_open)),
33    )
34    # TODO: handle symbolic intervals
35    minval, minopen = min(comvals)
36    maxval, maxopen = max(comvals)
37    return Interval(
38        minval,
39        maxval,
40        minopen,
41        maxopen
42    )
43
44@dispatch(Basic, Basic)  # type: ignore # noqa:F811
45def _set_div(x, y): # noqa:F811
46    return None
47
48@dispatch(Expr, Expr)  # type: ignore # noqa:F811
49def _set_div(x, y): # noqa:F811
50    return x/y
51
52@dispatch(Set, Set)  # type: ignore # noqa:F811 # noqa:F811
53def _set_div(x, y): # noqa:F811
54    return None
55
56@dispatch(Interval, Interval)  # type: ignore # noqa:F811
57def _set_div(x, y): # noqa:F811
58    """
59    Divisions in interval arithmetic
60    https://en.wikipedia.org/wiki/Interval_arithmetic
61    """
62    from sympy.sets.setexpr import set_mul
63    from sympy import oo
64    if (y.start*y.end).is_negative:
65        return Interval(-oo, oo)
66    if y.start == 0:
67        s2 = oo
68    else:
69        s2 = 1/y.start
70    if y.end == 0:
71        s1 = -oo
72    else:
73        s1 = 1/y.end
74    return set_mul(x, Interval(s1, s2, y.right_open, y.left_open))
75