1"""
2Tests for `attr._funcs`.
3"""
4
5from __future__ import absolute_import, division, print_function
6
7from collections import OrderedDict
8
9import pytest
10
11from hypothesis import assume, given
12from hypothesis import strategies as st
13
14import attr
15
16from attr import asdict, assoc, astuple, evolve, fields, has
17from attr._compat import TYPE, Mapping, Sequence, ordered_dict
18from attr.exceptions import AttrsAttributeNotFoundError
19from attr.validators import instance_of
20
21from .strategies import nested_classes, simple_classes
22
23
24MAPPING_TYPES = (dict, OrderedDict)
25SEQUENCE_TYPES = (list, tuple)
26
27
28class TestAsDict(object):
29    """
30    Tests for `asdict`.
31    """
32
33    @given(st.sampled_from(MAPPING_TYPES))
34    def test_shallow(self, C, dict_factory):
35        """
36        Shallow asdict returns correct dict.
37        """
38        assert {"x": 1, "y": 2} == asdict(
39            C(x=1, y=2), False, dict_factory=dict_factory
40        )
41
42    @given(st.sampled_from(MAPPING_TYPES))
43    def test_recurse(self, C, dict_class):
44        """
45        Deep asdict returns correct dict.
46        """
47        assert {"x": {"x": 1, "y": 2}, "y": {"x": 3, "y": 4}} == asdict(
48            C(C(1, 2), C(3, 4)), dict_factory=dict_class
49        )
50
51    def test_nested_lists(self, C):
52        """
53        Test unstructuring deeply nested lists.
54        """
55        inner = C(1, 2)
56        outer = C([[inner]], None)
57
58        assert {"x": [[{"x": 1, "y": 2}]], "y": None} == asdict(outer)
59
60    def test_nested_dicts(self, C):
61        """
62        Test unstructuring deeply nested dictionaries.
63        """
64        inner = C(1, 2)
65        outer = C({1: {2: inner}}, None)
66
67        assert {"x": {1: {2: {"x": 1, "y": 2}}}, "y": None} == asdict(outer)
68
69    @given(nested_classes, st.sampled_from(MAPPING_TYPES))
70    def test_recurse_property(self, cls, dict_class):
71        """
72        Property tests for recursive asdict.
73        """
74        obj = cls()
75        obj_dict = asdict(obj, dict_factory=dict_class)
76
77        def assert_proper_dict_class(obj, obj_dict):
78            assert isinstance(obj_dict, dict_class)
79
80            for field in fields(obj.__class__):
81                field_val = getattr(obj, field.name)
82                if has(field_val.__class__):
83                    # This field holds a class, recurse the assertions.
84                    assert_proper_dict_class(field_val, obj_dict[field.name])
85                elif isinstance(field_val, Sequence):
86                    dict_val = obj_dict[field.name]
87                    for item, item_dict in zip(field_val, dict_val):
88                        if has(item.__class__):
89                            assert_proper_dict_class(item, item_dict)
90                elif isinstance(field_val, Mapping):
91                    # This field holds a dictionary.
92                    assert isinstance(obj_dict[field.name], dict_class)
93
94                    for key, val in field_val.items():
95                        if has(val.__class__):
96                            assert_proper_dict_class(
97                                val, obj_dict[field.name][key]
98                            )
99
100        assert_proper_dict_class(obj, obj_dict)
101
102    @given(st.sampled_from(MAPPING_TYPES))
103    def test_filter(self, C, dict_factory):
104        """
105        Attributes that are supposed to be skipped are skipped.
106        """
107        assert {"x": {"x": 1}} == asdict(
108            C(C(1, 2), C(3, 4)),
109            filter=lambda a, v: a.name != "y",
110            dict_factory=dict_factory,
111        )
112
113    @given(container=st.sampled_from(SEQUENCE_TYPES))
114    def test_lists_tuples(self, container, C):
115        """
116        If recurse is True, also recurse into lists.
117        """
118        assert {
119            "x": 1,
120            "y": [{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"],
121        } == asdict(C(1, container([C(2, 3), C(4, 5), "a"])))
122
123    @given(container=st.sampled_from(SEQUENCE_TYPES))
124    def test_lists_tuples_retain_type(self, container, C):
125        """
126        If recurse and retain_collection_types are True, also recurse
127        into lists and do not convert them into list.
128        """
129        assert {
130            "x": 1,
131            "y": container([{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"]),
132        } == asdict(
133            C(1, container([C(2, 3), C(4, 5), "a"])),
134            retain_collection_types=True,
135        )
136
137    @given(st.sampled_from(MAPPING_TYPES))
138    def test_dicts(self, C, dict_factory):
139        """
140        If recurse is True, also recurse into dicts.
141        """
142        res = asdict(C(1, {"a": C(4, 5)}), dict_factory=dict_factory)
143        assert {"x": 1, "y": {"a": {"x": 4, "y": 5}}} == res
144        assert isinstance(res, dict_factory)
145
146    @given(simple_classes(private_attrs=False), st.sampled_from(MAPPING_TYPES))
147    def test_roundtrip(self, cls, dict_class):
148        """
149        Test dumping to dicts and back for Hypothesis-generated classes.
150
151        Private attributes don't round-trip (the attribute name is different
152        than the initializer argument).
153        """
154        instance = cls()
155        dict_instance = asdict(instance, dict_factory=dict_class)
156
157        assert isinstance(dict_instance, dict_class)
158
159        roundtrip_instance = cls(**dict_instance)
160
161        assert instance == roundtrip_instance
162
163    @given(simple_classes())
164    def test_asdict_preserve_order(self, cls):
165        """
166        Field order should be preserved when dumping to an ordered_dict.
167        """
168        instance = cls()
169        dict_instance = asdict(instance, dict_factory=ordered_dict)
170
171        assert [a.name for a in fields(cls)] == list(dict_instance.keys())
172
173
174class TestAsTuple(object):
175    """
176    Tests for `astuple`.
177    """
178
179    @given(st.sampled_from(SEQUENCE_TYPES))
180    def test_shallow(self, C, tuple_factory):
181        """
182        Shallow astuple returns correct dict.
183        """
184        assert tuple_factory([1, 2]) == astuple(
185            C(x=1, y=2), False, tuple_factory=tuple_factory
186        )
187
188    @given(st.sampled_from(SEQUENCE_TYPES))
189    def test_recurse(self, C, tuple_factory):
190        """
191        Deep astuple returns correct tuple.
192        """
193        assert tuple_factory(
194            [tuple_factory([1, 2]), tuple_factory([3, 4])]
195        ) == astuple(C(C(1, 2), C(3, 4)), tuple_factory=tuple_factory)
196
197    @given(nested_classes, st.sampled_from(SEQUENCE_TYPES))
198    def test_recurse_property(self, cls, tuple_class):
199        """
200        Property tests for recursive astuple.
201        """
202        obj = cls()
203        obj_tuple = astuple(obj, tuple_factory=tuple_class)
204
205        def assert_proper_tuple_class(obj, obj_tuple):
206            assert isinstance(obj_tuple, tuple_class)
207            for index, field in enumerate(fields(obj.__class__)):
208                field_val = getattr(obj, field.name)
209                if has(field_val.__class__):
210                    # This field holds a class, recurse the assertions.
211                    assert_proper_tuple_class(field_val, obj_tuple[index])
212
213        assert_proper_tuple_class(obj, obj_tuple)
214
215    @given(nested_classes, st.sampled_from(SEQUENCE_TYPES))
216    def test_recurse_retain(self, cls, tuple_class):
217        """
218        Property tests for asserting collection types are retained.
219        """
220        obj = cls()
221        obj_tuple = astuple(
222            obj, tuple_factory=tuple_class, retain_collection_types=True
223        )
224
225        def assert_proper_col_class(obj, obj_tuple):
226            # Iterate over all attributes, and if they are lists or mappings
227            # in the original, assert they are the same class in the dumped.
228            for index, field in enumerate(fields(obj.__class__)):
229                field_val = getattr(obj, field.name)
230                if has(field_val.__class__):
231                    # This field holds a class, recurse the assertions.
232                    assert_proper_col_class(field_val, obj_tuple[index])
233                elif isinstance(field_val, (list, tuple)):
234                    # This field holds a sequence of something.
235                    expected_type = type(obj_tuple[index])
236                    assert type(field_val) is expected_type  # noqa: E721
237                    for obj_e, obj_tuple_e in zip(field_val, obj_tuple[index]):
238                        if has(obj_e.__class__):
239                            assert_proper_col_class(obj_e, obj_tuple_e)
240                elif isinstance(field_val, dict):
241                    orig = field_val
242                    tupled = obj_tuple[index]
243                    assert type(orig) is type(tupled)  # noqa: E721
244                    for obj_e, obj_tuple_e in zip(
245                        orig.items(), tupled.items()
246                    ):
247                        if has(obj_e[0].__class__):  # Dict key
248                            assert_proper_col_class(obj_e[0], obj_tuple_e[0])
249                        if has(obj_e[1].__class__):  # Dict value
250                            assert_proper_col_class(obj_e[1], obj_tuple_e[1])
251
252        assert_proper_col_class(obj, obj_tuple)
253
254    @given(st.sampled_from(SEQUENCE_TYPES))
255    def test_filter(self, C, tuple_factory):
256        """
257        Attributes that are supposed to be skipped are skipped.
258        """
259        assert tuple_factory([tuple_factory([1])]) == astuple(
260            C(C(1, 2), C(3, 4)),
261            filter=lambda a, v: a.name != "y",
262            tuple_factory=tuple_factory,
263        )
264
265    @given(container=st.sampled_from(SEQUENCE_TYPES))
266    def test_lists_tuples(self, container, C):
267        """
268        If recurse is True, also recurse into lists.
269        """
270        assert (1, [(2, 3), (4, 5), "a"]) == astuple(
271            C(1, container([C(2, 3), C(4, 5), "a"]))
272        )
273
274    @given(st.sampled_from(SEQUENCE_TYPES))
275    def test_dicts(self, C, tuple_factory):
276        """
277        If recurse is True, also recurse into dicts.
278        """
279        res = astuple(C(1, {"a": C(4, 5)}), tuple_factory=tuple_factory)
280        assert tuple_factory([1, {"a": tuple_factory([4, 5])}]) == res
281        assert isinstance(res, tuple_factory)
282
283    @given(container=st.sampled_from(SEQUENCE_TYPES))
284    def test_lists_tuples_retain_type(self, container, C):
285        """
286        If recurse and retain_collection_types are True, also recurse
287        into lists and do not convert them into list.
288        """
289        assert (1, container([(2, 3), (4, 5), "a"])) == astuple(
290            C(1, container([C(2, 3), C(4, 5), "a"])),
291            retain_collection_types=True,
292        )
293
294    @given(container=st.sampled_from(MAPPING_TYPES))
295    def test_dicts_retain_type(self, container, C):
296        """
297        If recurse and retain_collection_types are True, also recurse
298        into lists and do not convert them into list.
299        """
300        assert (1, container({"a": (4, 5)})) == astuple(
301            C(1, container({"a": C(4, 5)})), retain_collection_types=True
302        )
303
304    @given(simple_classes(), st.sampled_from(SEQUENCE_TYPES))
305    def test_roundtrip(self, cls, tuple_class):
306        """
307        Test dumping to tuple and back for Hypothesis-generated classes.
308        """
309        instance = cls()
310        tuple_instance = astuple(instance, tuple_factory=tuple_class)
311
312        assert isinstance(tuple_instance, tuple_class)
313
314        roundtrip_instance = cls(*tuple_instance)
315
316        assert instance == roundtrip_instance
317
318
319class TestHas(object):
320    """
321    Tests for `has`.
322    """
323
324    def test_positive(self, C):
325        """
326        Returns `True` on decorated classes.
327        """
328        assert has(C)
329
330    def test_positive_empty(self):
331        """
332        Returns `True` on decorated classes even if there are no attributes.
333        """
334
335        @attr.s
336        class D(object):
337            pass
338
339        assert has(D)
340
341    def test_negative(self):
342        """
343        Returns `False` on non-decorated classes.
344        """
345        assert not has(object)
346
347
348class TestAssoc(object):
349    """
350    Tests for `assoc`.
351    """
352
353    @given(slots=st.booleans(), frozen=st.booleans())
354    def test_empty(self, slots, frozen):
355        """
356        Empty classes without changes get copied.
357        """
358
359        @attr.s(slots=slots, frozen=frozen)
360        class C(object):
361            pass
362
363        i1 = C()
364        with pytest.deprecated_call():
365            i2 = assoc(i1)
366
367        assert i1 is not i2
368        assert i1 == i2
369
370    @given(simple_classes())
371    def test_no_changes(self, C):
372        """
373        No changes means a verbatim copy.
374        """
375        i1 = C()
376        with pytest.deprecated_call():
377            i2 = assoc(i1)
378
379        assert i1 is not i2
380        assert i1 == i2
381
382    @given(simple_classes(), st.data())
383    def test_change(self, C, data):
384        """
385        Changes work.
386        """
387        # Take the first attribute, and change it.
388        assume(fields(C))  # Skip classes with no attributes.
389        field_names = [a.name for a in fields(C)]
390        original = C()
391        chosen_names = data.draw(st.sets(st.sampled_from(field_names)))
392        change_dict = {name: data.draw(st.integers()) for name in chosen_names}
393
394        with pytest.deprecated_call():
395            changed = assoc(original, **change_dict)
396
397        for k, v in change_dict.items():
398            assert getattr(changed, k) == v
399
400    @given(simple_classes())
401    def test_unknown(self, C):
402        """
403        Wanting to change an unknown attribute raises an
404        AttrsAttributeNotFoundError.
405        """
406        # No generated class will have a four letter attribute.
407        with pytest.raises(
408            AttrsAttributeNotFoundError
409        ) as e, pytest.deprecated_call():
410            assoc(C(), aaaa=2)
411
412        assert (
413            "aaaa is not an attrs attribute on {cls!r}.".format(cls=C),
414        ) == e.value.args
415
416    def test_frozen(self):
417        """
418        Works on frozen classes.
419        """
420
421        @attr.s(frozen=True)
422        class C(object):
423            x = attr.ib()
424            y = attr.ib()
425
426        with pytest.deprecated_call():
427            assert C(3, 2) == assoc(C(1, 2), x=3)
428
429    def test_warning(self):
430        """
431        DeprecationWarning points to the correct file.
432        """
433
434        @attr.s
435        class C(object):
436            x = attr.ib()
437
438        with pytest.warns(DeprecationWarning) as wi:
439            assert C(2) == assoc(C(1), x=2)
440
441        assert __file__ == wi.list[0].filename
442
443
444class TestEvolve(object):
445    """
446    Tests for `evolve`.
447    """
448
449    @given(slots=st.booleans(), frozen=st.booleans())
450    def test_empty(self, slots, frozen):
451        """
452        Empty classes without changes get copied.
453        """
454
455        @attr.s(slots=slots, frozen=frozen)
456        class C(object):
457            pass
458
459        i1 = C()
460        i2 = evolve(i1)
461
462        assert i1 is not i2
463        assert i1 == i2
464
465    @given(simple_classes())
466    def test_no_changes(self, C):
467        """
468        No changes means a verbatim copy.
469        """
470        i1 = C()
471        i2 = evolve(i1)
472
473        assert i1 is not i2
474        assert i1 == i2
475
476    @given(simple_classes(), st.data())
477    def test_change(self, C, data):
478        """
479        Changes work.
480        """
481        # Take the first attribute, and change it.
482        assume(fields(C))  # Skip classes with no attributes.
483        field_names = [a.name for a in fields(C)]
484        original = C()
485        chosen_names = data.draw(st.sets(st.sampled_from(field_names)))
486        # We pay special attention to private attributes, they should behave
487        # like in `__init__`.
488        change_dict = {
489            name.replace("_", ""): data.draw(st.integers())
490            for name in chosen_names
491        }
492        changed = evolve(original, **change_dict)
493        for name in chosen_names:
494            assert getattr(changed, name) == change_dict[name.replace("_", "")]
495
496    @given(simple_classes())
497    def test_unknown(self, C):
498        """
499        Wanting to change an unknown attribute raises an
500        AttrsAttributeNotFoundError.
501        """
502        # No generated class will have a four letter attribute.
503        with pytest.raises(TypeError) as e:
504            evolve(C(), aaaa=2)
505        expected = "__init__() got an unexpected keyword argument 'aaaa'"
506        assert (expected,) == e.value.args
507
508    def test_validator_failure(self):
509        """
510        TypeError isn't swallowed when validation fails within evolve.
511        """
512
513        @attr.s
514        class C(object):
515            a = attr.ib(validator=instance_of(int))
516
517        with pytest.raises(TypeError) as e:
518            evolve(C(a=1), a="some string")
519        m = e.value.args[0]
520
521        assert m.startswith("'a' must be <{type} 'int'>".format(type=TYPE))
522
523    def test_private(self):
524        """
525        evolve() acts as `__init__` with regards to private attributes.
526        """
527
528        @attr.s
529        class C(object):
530            _a = attr.ib()
531
532        assert evolve(C(1), a=2)._a == 2
533
534        with pytest.raises(TypeError):
535            evolve(C(1), _a=2)
536
537        with pytest.raises(TypeError):
538            evolve(C(1), a=3, _a=2)
539
540    def test_non_init_attrs(self):
541        """
542        evolve() handles `init=False` attributes.
543        """
544
545        @attr.s
546        class C(object):
547            a = attr.ib()
548            b = attr.ib(init=False, default=0)
549
550        assert evolve(C(1), a=2).a == 2
551