1"""Tests for generic code printing."""
2
3import pytest
4
5from diofant import Dummy, Idx, IndexedBase, Matrix, MatrixSymbol, symbols
6from diofant.printing.codeprinter import Assignment, CodePrinter
7
8
9__all__ = ()
10
11
12def setup_test_printer(**kwargs):
13    p = CodePrinter(settings=kwargs)
14    p._not_supported = set()
15    p._number_symbols = set()
16    return p
17
18
19def test_print_Dummy():
20    d = Dummy('d')
21    p = setup_test_printer()
22    assert p._print_Dummy(d) == f'd_{d.dummy_index:d}'
23
24
25def test_Assignment():
26    x, y = symbols('x, y')
27    A = MatrixSymbol('A', 3, 1)
28    mat = Matrix([1, 2, 3])
29    B = IndexedBase('B')
30    n = symbols('n', integer=True)
31    i = Idx('i', n)
32    # Here we just do things to show they don't error
33    Assignment(x, y)
34    Assignment(x, 0)
35    Assignment(A, mat)
36    Assignment(A[1, 0], 0)
37    Assignment(A[1, 0], x)
38    Assignment(B[i], x)
39    Assignment(B[i], 0)
40    # Here we test things to show that they error
41    # Matrix to scalar
42    pytest.raises(ValueError, lambda: Assignment(B[i], A))
43    pytest.raises(ValueError, lambda: Assignment(B[i], mat))
44    pytest.raises(ValueError, lambda: Assignment(x, mat))
45    pytest.raises(ValueError, lambda: Assignment(x, A))
46    pytest.raises(ValueError, lambda: Assignment(A[1, 0], mat))
47    # Scalar to matrix
48    pytest.raises(ValueError, lambda: Assignment(A, x))
49    pytest.raises(ValueError, lambda: Assignment(A, 0))
50    # Non-atomic lhs
51    pytest.raises(TypeError, lambda: Assignment(mat, A))
52    pytest.raises(TypeError, lambda: Assignment(0, x))
53    pytest.raises(TypeError, lambda: Assignment(x*x, 1))
54    pytest.raises(TypeError, lambda: Assignment(A + A, mat))
55    pytest.raises(TypeError, lambda: Assignment(B, 0))
56
57
58def test_print_Symbol():
59    x, y = symbols('x, if')
60
61    p = setup_test_printer()
62    assert p._print(x) == 'x'
63    assert p._print(y) == 'if'
64
65    p.reserved_words.update(['if'])
66    assert p._print(y) == 'if_'
67
68    p = setup_test_printer(error_on_reserved=True)
69    p.reserved_words.update(['if'])
70    with pytest.raises(ValueError):
71        p._print(y)
72
73    p = setup_test_printer(reserved_word_suffix='_He_Man')
74    p.reserved_words.update(['if'])
75    assert p._print(y) == 'if_He_Man'
76