1"""Test state.py."""
2
3from pytype import compare
4from pytype import state
5from pytype.typegraph import cfg
6
7import unittest
8
9
10def source_summary(binding, **varnames):
11  """A simple deterministic listing of source variables."""
12  clauses = []
13  name_map = {b.variable: name for name, b in varnames.items()}
14  for origin in binding.origins:
15    for sources in origin.source_sets:
16      bindings = ["%s=%s" % (name_map[b.variable], b.data) for b in sources]
17      clauses.append(" ".join(sorted(bindings)))
18  return " | ".join(sorted(clauses))
19
20
21class FakeValue:
22
23  def __init__(self, name, true_compat, false_compat):
24    self._name = name
25    self.compatible = {
26        True: true_compat,
27        False: false_compat}
28
29  def __str__(self):
30    return self._name
31
32
33ONLY_TRUE = FakeValue("T", True, False)
34ONLY_FALSE = FakeValue("F", False, True)
35AMBIGUOUS = FakeValue("?", True, True)
36
37
38def fake_compatible_with(value, logical_value):
39  return value.compatible[logical_value]
40
41
42class ConditionTestBase(unittest.TestCase):
43
44  def setUp(self):
45    super().setUp()
46    self._program = cfg.Program()
47    self._node = self._program.NewCFGNode("test")
48    self._old_compatible_with = compare.compatible_with
49    compare.compatible_with = fake_compatible_with
50
51  def tearDown(self):
52    super().tearDown()
53    compare.compatible_with = self._old_compatible_with
54
55  def new_binding(self, value=AMBIGUOUS):
56    var = self._program.NewVariable()
57    return var.AddBinding(value)
58
59  def check_binding(self, expected, binding, **varnames):
60    self.assertEqual(len(binding.origins), 1)
61    self.assertEqual(self._node, binding.origins[0].where)
62    self.assertEqual(expected, source_summary(binding, **varnames))
63
64
65class ConditionTest(ConditionTestBase):
66
67  def test_no_parent(self):
68    x = self.new_binding()
69    y = self.new_binding()
70    z = self.new_binding()
71    c = state.Condition(self._node, [[x, y], [z]])
72    self.check_binding("x=? y=? | z=?", c.binding, x=x, y=y, z=z)
73
74  def test_parent_combination(self):
75    p = self.new_binding()
76    x = self.new_binding()
77    y = self.new_binding()
78    z = self.new_binding()
79    c = state.Condition(self._node, [[x, y], [z]])
80    self.check_binding("x=? y=? | z=?", c.binding,
81                       p=p, x=x, y=y, z=z)
82
83
84class SplitConditionTest(ConditionTestBase):
85
86  def test(self):
87    # Test that we split both sides and that everything gets passed through
88    # correctly.  Don't worry about special cases within _restrict_condition
89    # since those are tested separately.
90    self.new_binding()
91    var = self._program.NewVariable()
92    var.AddBinding(ONLY_TRUE)
93    var.AddBinding(ONLY_FALSE)
94    var.AddBinding(AMBIGUOUS)
95    true_cond, false_cond = state.split_conditions(self._node, var)
96    self.check_binding("v=? | v=T", true_cond.binding,
97                       v=var.bindings[0])
98    self.check_binding("v=? | v=F",
99                       false_cond.binding,
100                       v=var.bindings[0])
101
102
103class RestrictConditionTest(ConditionTestBase):
104
105  def setUp(self):
106    super().setUp()
107    p = self.new_binding()
108    self._parent = state.Condition(self._node, [[p]])
109
110  def test_no_bindings(self):
111    c = state._restrict_condition(self._node, [], False)
112    self.assertIs(state.UNSATISFIABLE, c)
113    c = state._restrict_condition(self._node, [], True)
114    self.assertIs(state.UNSATISFIABLE, c)
115
116  def test_none_restricted(self):
117    x = self.new_binding()
118    y = self.new_binding()
119    state._restrict_condition(self._node, [x, y], False)
120    state._restrict_condition(self._node, [x, y], True)
121
122  def test_all_restricted(self):
123    x = self.new_binding(ONLY_FALSE)
124    y = self.new_binding(ONLY_FALSE)
125    c = state._restrict_condition(self._node, [x, y], True)
126    self.assertIs(state.UNSATISFIABLE, c)
127
128  def test_some_restricted_no_parent(self):
129    x = self.new_binding()  # Can be true or false.
130    y = self.new_binding(ONLY_FALSE)
131    z = self.new_binding()  # Can be true or false.
132    c = state._restrict_condition(self._node, [x, y, z], True)
133    self.check_binding("x=? | z=?", c.binding, x=x, y=y, z=z)
134
135  def test_some_restricted_with_parent(self):
136    x = self.new_binding()  # Can be true or false.
137    y = self.new_binding(ONLY_FALSE)
138    z = self.new_binding()  # Can be true or false.
139    c = state._restrict_condition(self._node, [x, y, z], True)
140    self.check_binding("x=? | z=?", c.binding,
141                       x=x, y=y, z=z)
142
143  def test_restricted_to_dnf(self):
144    # DNF for a | (b & c)
145    a = self.new_binding()
146    b = self.new_binding()
147    c = self.new_binding()
148    dnf = [[a],
149           [b, c]]
150    x = self.new_binding()  # Compatible with everything
151    y = self.new_binding(FakeValue("DNF", dnf, False))  # Reduce to dnf
152    cond = state._restrict_condition(self._node, [x, y], True)
153    self.check_binding("a=? | b=? c=? | x=?", cond.binding,
154                       a=a, b=b, c=c, x=x, y=y)
155
156
157if __name__ == "__main__":
158  unittest.main()
159