1from sympy import TableForm, S
2from sympy.printing.latex import latex
3from sympy.abc import x
4from sympy.functions.elementary.miscellaneous import sqrt
5from sympy.functions.elementary.trigonometric import sin
6from sympy.testing.pytest import raises
7
8from textwrap import dedent
9
10
11def test_TableForm():
12    s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]],
13        headings="automatic"))
14    assert s == (
15        '  | 1 2\n'
16        '-------\n'
17        '1 | a b\n'
18        '2 | c d\n'
19        '3 | e  '
20    )
21    s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]],
22        headings="automatic", wipe_zeros=False))
23    assert s == dedent('''\
24          | 1 2
25        -------
26        1 | a b
27        2 | c d
28        3 | e 0''')
29    s = str(TableForm([[x**2, "b"], ["c", x**2], ["e", "f"]],
30            headings=("automatic", None)))
31    assert s == (
32        '1 | x**2 b   \n'
33        '2 | c    x**2\n'
34        '3 | e    f   '
35    )
36    s = str(TableForm([["a", "b"], ["c", "d"], ["e", "f"]],
37            headings=(None, "automatic")))
38    assert s == dedent('''\
39        1 2
40        ---
41        a b
42        c d
43        e f''')
44    s = str(TableForm([[5, 7], [4, 2], [10, 3]],
45            headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]]))
46    assert s == (
47        '        | y1 y2\n'
48        '---------------\n'
49        'Group A | 5  7 \n'
50        'Group B | 4  2 \n'
51        'Group C | 10 3 '
52    )
53    raises(
54        ValueError,
55        lambda:
56        TableForm(
57            [[5, 7], [4, 2], [10, 3]],
58            headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]],
59            alignments="middle")
60    )
61    s = str(TableForm([[5, 7], [4, 2], [10, 3]],
62            headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]],
63            alignments="right"))
64    assert s == dedent('''\
65                | y1 y2
66        ---------------
67        Group A |  5  7
68        Group B |  4  2
69        Group C | 10  3''')
70
71    # other alignment permutations
72    d = [[1, 100], [100, 1]]
73    s = TableForm(d, headings=(('xxx', 'x'), None), alignments='l')
74    assert str(s) == (
75        'xxx | 1   100\n'
76        '  x | 100 1  '
77    )
78    s = TableForm(d, headings=(('xxx', 'x'), None), alignments='lr')
79    assert str(s) == dedent('''\
80    xxx | 1   100
81      x | 100   1''')
82    s = TableForm(d, headings=(('xxx', 'x'), None), alignments='clr')
83    assert str(s) == dedent('''\
84    xxx | 1   100
85     x  | 100   1''')
86
87    s = TableForm(d, headings=(('xxx', 'x'), None))
88    assert str(s) == (
89        'xxx | 1   100\n'
90        '  x | 100 1  '
91    )
92
93    raises(ValueError, lambda: TableForm(d, alignments='clr'))
94
95    #pad
96    s = str(TableForm([[None, "-", 2], [1]], pad='?'))
97    assert s == dedent('''\
98        ? - 2
99        1 ? ?''')
100
101
102def test_TableForm_latex():
103    s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
104            wipe_zeros=True, headings=("automatic", "automatic")))
105    assert s == (
106        '\\begin{tabular}{r l l}\n'
107        ' & 1 & 2 \\\\\n'
108        '\\hline\n'
109        '1 &   & $x^{3}$ \\\\\n'
110        '2 & $c$ & $\\frac{1}{4}$ \\\\\n'
111        '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
112        '\\end{tabular}'
113    )
114    s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
115            wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'))
116    assert s == (
117        '\\begin{tabular}{r l l}\n'
118        ' & 1 & 2 \\\\\n'
119        '\\hline\n'
120        '1 &   & $x^{3}$ \\\\\n'
121        '2 & $c$ & $\\frac{1}{4}$ \\\\\n'
122        '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
123        '\\end{tabular}'
124    )
125    s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
126            wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'*3))
127    assert s == (
128        '\\begin{tabular}{l l l}\n'
129        ' & 1 & 2 \\\\\n'
130        '\\hline\n'
131        '1 &   & $x^{3}$ \\\\\n'
132        '2 & $c$ & $\\frac{1}{4}$ \\\\\n'
133        '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
134        '\\end{tabular}'
135    )
136    s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
137            headings=("automatic", "automatic")))
138    assert s == (
139        '\\begin{tabular}{r l l}\n'
140        ' & 1 & 2 \\\\\n'
141        '\\hline\n'
142        '1 & $a$ & $x^{3}$ \\\\\n'
143        '2 & $c$ & $\\frac{1}{4}$ \\\\\n'
144        '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
145        '\\end{tabular}'
146    )
147    s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
148            formats=['(%s)', None], headings=("automatic", "automatic")))
149    assert s == (
150        '\\begin{tabular}{r l l}\n'
151        ' & 1 & 2 \\\\\n'
152        '\\hline\n'
153        '1 & (a) & $x^{3}$ \\\\\n'
154        '2 & (c) & $\\frac{1}{4}$ \\\\\n'
155        '3 & (sqrt(x)) & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
156        '\\end{tabular}'
157    )
158
159    def neg_in_paren(x, i, j):
160        if i % 2:
161            return ('(%s)' if x < 0 else '%s') % x
162        else:
163            pass  # use default print
164    s = latex(TableForm([[-1, 2], [-3, 4]],
165            formats=[neg_in_paren]*2, headings=("automatic", "automatic")))
166    assert s == (
167        '\\begin{tabular}{r l l}\n'
168        ' & 1 & 2 \\\\\n'
169        '\\hline\n'
170        '1 & -1 & 2 \\\\\n'
171        '2 & (-3) & 4 \\\\\n'
172        '\\end{tabular}'
173    )
174    s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]]))
175    assert s == (
176        '\\begin{tabular}{l l}\n'
177        '$a$ & $x^{3}$ \\\\\n'
178        '$c$ & $\\frac{1}{4}$ \\\\\n'
179        '$\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
180        '\\end{tabular}'
181    )
182