1"""
2Hypothesis-based tests for pvector.
3"""
4
5import gc
6
7from pyrsistent._compat import Iterable
8from functools import wraps
9from pyrsistent import PClass, field
10
11from pytest import fixture
12
13from pyrsistent import pvector, discard
14
15from hypothesis import strategies as st, assume
16from hypothesis.stateful import RuleBasedStateMachine, Bundle, rule
17
18
19class TestObject(object):
20    """
21    An object that might catch reference count errors sometimes.
22    """
23    def __init__(self):
24        self.id = id(self)
25
26    def __repr__(self):
27        return "<%s>" % (self.id,)
28
29    def __del__(self):
30        # If self is a dangling memory reference this check might fail. Or
31        # segfault :)
32        if self.id != id(self):
33            raise RuntimeError()
34
35
36@fixture(scope="module")
37def gc_when_done(request):
38    request.addfinalizer(gc.collect)
39
40
41def test_setup(gc_when_done):
42    """
43    Ensure we GC when tests finish.
44    """
45
46
47# Pairs of a list and corresponding pvector:
48PVectorAndLists = st.lists(st.builds(TestObject)).map(
49    lambda l: (l, pvector(l)))
50
51
52def verify_inputs_unmodified(original):
53    """
54    Decorator that asserts that the wrapped function does not modify its
55    inputs.
56    """
57    def to_tuples(pairs):
58        return [(tuple(l), tuple(pv)) for (l, pv) in pairs]
59
60    @wraps(original)
61    def wrapper(self, **kwargs):
62        inputs = [k for k in kwargs.values() if isinstance(k, Iterable)]
63        tuple_inputs = to_tuples(inputs)
64        try:
65            return original(self, **kwargs)
66        finally:
67            # Ensure inputs were unmodified:
68            assert to_tuples(inputs) == tuple_inputs
69    return wrapper
70
71
72def assert_equal(l, pv):
73    assert l == pv
74    assert len(l) == len(pv)
75    length = len(l)
76    for i in range(length):
77        assert l[i] == pv[i]
78    for i in range(length):
79        for j in range(i, length):
80            assert l[i:j] == pv[i:j]
81    assert l == list(iter(pv))
82
83
84class PVectorBuilder(RuleBasedStateMachine):
85    """
86    Build a list and matching pvector step-by-step.
87
88    In each step in the state machine we do same operation on a list and
89    on a pvector, and then when we're done we compare the two.
90    """
91    sequences = Bundle("sequences")
92
93    @rule(target=sequences, start=PVectorAndLists)
94    def initial_value(self, start):
95        """
96        Some initial values generated by a hypothesis strategy.
97        """
98        return start
99
100    @rule(target=sequences, former=sequences)
101    @verify_inputs_unmodified
102    def append(self, former):
103        """
104        Append an item to the pair of sequences.
105        """
106        l, pv = former
107        obj = TestObject()
108        l2 = l[:]
109        l2.append(obj)
110        return l2, pv.append(obj)
111
112    @rule(target=sequences, start=sequences, end=sequences)
113    @verify_inputs_unmodified
114    def extend(self, start, end):
115        """
116        Extend a pair of sequences with another pair of sequences.
117        """
118        l, pv = start
119        l2, pv2 = end
120        # compare() has O(N**2) behavior, so don't want too-large lists:
121        assume(len(l) + len(l2) < 50)
122        l3 = l[:]
123        l3.extend(l2)
124        return l3, pv.extend(pv2)
125
126    @rule(target=sequences, former=sequences, data=st.data())
127    @verify_inputs_unmodified
128    def remove(self, former, data):
129        """
130        Remove an item from the sequences.
131        """
132        l, pv = former
133        assume(l)
134        l2 = l[:]
135        i = data.draw(st.sampled_from(range(len(l))))
136        del l2[i]
137        return l2, pv.delete(i)
138
139    @rule(target=sequences, former=sequences, data=st.data())
140    @verify_inputs_unmodified
141    def set(self, former, data):
142        """
143        Overwrite an item in the sequence.
144        """
145        l, pv = former
146        assume(l)
147        l2 = l[:]
148        i = data.draw(st.sampled_from(range(len(l))))
149        obj = TestObject()
150        l2[i] = obj
151        return l2, pv.set(i, obj)
152
153    @rule(target=sequences, former=sequences, data=st.data())
154    @verify_inputs_unmodified
155    def transform_set(self, former, data):
156        """
157        Transform the sequence by setting value.
158        """
159        l, pv = former
160        assume(l)
161        l2 = l[:]
162        i = data.draw(st.sampled_from(range(len(l))))
163        obj = TestObject()
164        l2[i] = obj
165        return l2, pv.transform([i], obj)
166
167    @rule(target=sequences, former=sequences, data=st.data())
168    @verify_inputs_unmodified
169    def transform_discard(self, former, data):
170        """
171        Transform the sequence by discarding a value.
172        """
173        l, pv = former
174        assume(l)
175        l2 = l[:]
176        i = data.draw(st.sampled_from(range(len(l))))
177        del l2[i]
178        return l2, pv.transform([i], discard)
179
180    @rule(target=sequences, former=sequences, data=st.data())
181    @verify_inputs_unmodified
182    def subset(self, former, data):
183        """
184        A subset of the previous sequence.
185        """
186        l, pv = former
187        assume(l)
188        i = data.draw(st.sampled_from(range(len(l))))
189        j = data.draw(st.sampled_from(range(len(l))))
190        return l[i:j], pv[i:j]
191
192    @rule(pair=sequences)
193    @verify_inputs_unmodified
194    def compare(self, pair):
195        """
196        The list and pvector must match.
197        """
198        l, pv = pair
199        # compare() has O(N**2) behavior, so don't want too-large lists:
200        assume(len(l) < 50)
201        assert_equal(l, pv)
202
203
204PVectorBuilderTests = PVectorBuilder.TestCase
205
206
207class EvolverItem(PClass):
208    original_list = field()
209    original_pvector = field()
210    current_list = field()
211    current_evolver = field()
212
213
214class PVectorEvolverBuilder(RuleBasedStateMachine):
215    """
216    Build a list and matching pvector evolver step-by-step.
217
218    In each step in the state machine we do same operation on a list and
219    on a pvector evolver, and then when we're done we compare the two.
220    """
221    sequences = Bundle("evolver_sequences")
222
223    @rule(target=sequences, start=PVectorAndLists)
224    def initial_value(self, start):
225        """
226        Some initial values generated by a hypothesis strategy.
227        """
228        l, pv = start
229        return EvolverItem(original_list=l,
230                           original_pvector=pv,
231                           current_list=l[:],
232                           current_evolver=pv.evolver())
233
234    @rule(item=sequences)
235    def append(self, item):
236        """
237        Append an item to the pair of sequences.
238        """
239        obj = TestObject()
240        item.current_list.append(obj)
241        item.current_evolver.append(obj)
242
243    @rule(start=sequences, end=sequences)
244    def extend(self, start, end):
245        """
246        Extend a pair of sequences with another pair of sequences.
247        """
248        # compare() has O(N**2) behavior, so don't want too-large lists:
249        assume(len(start.current_list) + len(end.current_list) < 50)
250        start.current_evolver.extend(end.current_list)
251        start.current_list.extend(end.current_list)
252
253    @rule(item=sequences, data=st.data())
254    def delete(self, item, data):
255        """
256        Remove an item from the sequences.
257        """
258        assume(item.current_list)
259        i = data.draw(st.sampled_from(range(len(item.current_list))))
260        del item.current_list[i]
261        del item.current_evolver[i]
262
263    @rule(item=sequences, data=st.data())
264    def setitem(self, item, data):
265        """
266        Overwrite an item in the sequence using ``__setitem__``.
267        """
268        assume(item.current_list)
269        i = data.draw(st.sampled_from(range(len(item.current_list))))
270        obj = TestObject()
271        item.current_list[i] = obj
272        item.current_evolver[i] = obj
273
274    @rule(item=sequences, data=st.data())
275    def set(self, item, data):
276        """
277        Overwrite an item in the sequence using ``set``.
278        """
279        assume(item.current_list)
280        i = data.draw(st.sampled_from(range(len(item.current_list))))
281        obj = TestObject()
282        item.current_list[i] = obj
283        item.current_evolver.set(i, obj)
284
285    @rule(item=sequences)
286    def compare(self, item):
287        """
288        The list and pvector evolver must match.
289        """
290        item.current_evolver.is_dirty()
291        # compare() has O(N**2) behavior, so don't want too-large lists:
292        assume(len(item.current_list) < 50)
293        # original object unmodified
294        assert item.original_list == item.original_pvector
295        # evolver matches:
296        for i in range(len(item.current_evolver)):
297            assert item.current_list[i] == item.current_evolver[i]
298        # persistent version matches
299        assert_equal(item.current_list, item.current_evolver.persistent())
300        # original object still unmodified
301        assert item.original_list == item.original_pvector
302
303
304PVectorEvolverBuilderTests = PVectorEvolverBuilder.TestCase
305