1"""
2Tests for `attr._funcs`.
3"""
4
5from __future__ import absolute_import, division, print_function
6
7from collections import Mapping, OrderedDict, Sequence
8
9import pytest
10
11from hypothesis import HealthCheck, assume, given, settings
12from hypothesis import strategies as st
13
14import attr
15
16from attr import asdict, assoc, astuple, evolve, fields, has
17from attr._compat import TYPE
18from attr.exceptions import AttrsAttributeNotFoundError
19from attr.validators import instance_of
20
21from .strategies import nested_classes, simple_classes
22
23
24MAPPING_TYPES = (dict, OrderedDict)
25SEQUENCE_TYPES = (list, tuple)
26
27
28class TestAsDict(object):
29    """
30    Tests for `asdict`.
31    """
32    @given(st.sampled_from(MAPPING_TYPES))
33    def test_shallow(self, C, dict_factory):
34        """
35        Shallow asdict returns correct dict.
36        """
37        assert {
38            "x": 1,
39            "y": 2,
40        } == asdict(C(x=1, y=2), False, dict_factory=dict_factory)
41
42    @given(st.sampled_from(MAPPING_TYPES))
43    def test_recurse(self, C, dict_class):
44        """
45        Deep asdict returns correct dict.
46        """
47        assert {
48            "x": {"x": 1, "y": 2},
49            "y": {"x": 3, "y": 4},
50        } == asdict(C(
51            C(1, 2),
52            C(3, 4),
53        ), dict_factory=dict_class)
54
55    @given(nested_classes, st.sampled_from(MAPPING_TYPES))
56    @settings(suppress_health_check=[HealthCheck.too_slow])
57    def test_recurse_property(self, cls, dict_class):
58        """
59        Property tests for recursive asdict.
60        """
61        obj = cls()
62        obj_dict = asdict(obj, dict_factory=dict_class)
63
64        def assert_proper_dict_class(obj, obj_dict):
65            assert isinstance(obj_dict, dict_class)
66
67            for field in fields(obj.__class__):
68                field_val = getattr(obj, field.name)
69                if has(field_val.__class__):
70                    # This field holds a class, recurse the assertions.
71                    assert_proper_dict_class(field_val, obj_dict[field.name])
72                elif isinstance(field_val, Sequence):
73                    dict_val = obj_dict[field.name]
74                    for item, item_dict in zip(field_val, dict_val):
75                        if has(item.__class__):
76                            assert_proper_dict_class(item, item_dict)
77                elif isinstance(field_val, Mapping):
78                    # This field holds a dictionary.
79                    assert isinstance(obj_dict[field.name], dict_class)
80
81                    for key, val in field_val.items():
82                        if has(val.__class__):
83                            assert_proper_dict_class(val,
84                                                     obj_dict[field.name][key])
85
86        assert_proper_dict_class(obj, obj_dict)
87
88    @given(st.sampled_from(MAPPING_TYPES))
89    def test_filter(self, C, dict_factory):
90        """
91        Attributes that are supposed to be skipped are skipped.
92        """
93        assert {
94            "x": {"x": 1},
95        } == asdict(C(
96            C(1, 2),
97            C(3, 4),
98        ), filter=lambda a, v: a.name != "y", dict_factory=dict_factory)
99
100    @given(container=st.sampled_from(SEQUENCE_TYPES))
101    def test_lists_tuples(self, container, C):
102        """
103        If recurse is True, also recurse into lists.
104        """
105        assert {
106            "x": 1,
107            "y": [{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"],
108        } == asdict(C(1, container([C(2, 3), C(4, 5), "a"])))
109
110    @given(container=st.sampled_from(SEQUENCE_TYPES))
111    def test_lists_tuples_retain_type(self, container, C):
112        """
113        If recurse and retain_collection_types are True, also recurse
114        into lists and do not convert them into list.
115        """
116        assert {
117            "x": 1,
118            "y": container([{"x": 2, "y": 3}, {"x": 4, "y": 5}, "a"]),
119        } == asdict(C(1, container([C(2, 3), C(4, 5), "a"])),
120                    retain_collection_types=True)
121
122    @given(st.sampled_from(MAPPING_TYPES))
123    def test_dicts(self, C, dict_factory):
124        """
125        If recurse is True, also recurse into dicts.
126        """
127        res = asdict(C(1, {"a": C(4, 5)}), dict_factory=dict_factory)
128        assert {
129            "x": 1,
130            "y": {"a": {"x": 4, "y": 5}},
131        } == res
132        assert isinstance(res, dict_factory)
133
134    @given(simple_classes(private_attrs=False), st.sampled_from(MAPPING_TYPES))
135    def test_roundtrip(self, cls, dict_class):
136        """
137        Test dumping to dicts and back for Hypothesis-generated classes.
138
139        Private attributes don't round-trip (the attribute name is different
140        than the initializer argument).
141        """
142        instance = cls()
143        dict_instance = asdict(instance, dict_factory=dict_class)
144
145        assert isinstance(dict_instance, dict_class)
146
147        roundtrip_instance = cls(**dict_instance)
148
149        assert instance == roundtrip_instance
150
151    @given(simple_classes())
152    def test_asdict_preserve_order(self, cls):
153        """
154        Field order should be preserved when dumping to OrderedDicts.
155        """
156        instance = cls()
157        dict_instance = asdict(instance, dict_factory=OrderedDict)
158
159        assert [a.name for a in fields(cls)] == list(dict_instance.keys())
160
161
162class TestAsTuple(object):
163    """
164    Tests for `astuple`.
165    """
166    @given(st.sampled_from(SEQUENCE_TYPES))
167    def test_shallow(self, C, tuple_factory):
168        """
169        Shallow astuple returns correct dict.
170        """
171        assert (tuple_factory([1, 2]) ==
172                astuple(C(x=1, y=2), False, tuple_factory=tuple_factory))
173
174    @given(st.sampled_from(SEQUENCE_TYPES))
175    def test_recurse(self, C, tuple_factory):
176        """
177        Deep astuple returns correct tuple.
178        """
179        assert (tuple_factory([tuple_factory([1, 2]),
180                              tuple_factory([3, 4])])
181                == astuple(C(
182                            C(1, 2),
183                            C(3, 4),
184                            ),
185                           tuple_factory=tuple_factory))
186
187    @given(nested_classes, st.sampled_from(SEQUENCE_TYPES))
188    @settings(suppress_health_check=[HealthCheck.too_slow])
189    def test_recurse_property(self, cls, tuple_class):
190        """
191        Property tests for recursive astuple.
192        """
193        obj = cls()
194        obj_tuple = astuple(obj, tuple_factory=tuple_class)
195
196        def assert_proper_tuple_class(obj, obj_tuple):
197            assert isinstance(obj_tuple, tuple_class)
198            for index, field in enumerate(fields(obj.__class__)):
199                field_val = getattr(obj, field.name)
200                if has(field_val.__class__):
201                    # This field holds a class, recurse the assertions.
202                    assert_proper_tuple_class(field_val, obj_tuple[index])
203
204        assert_proper_tuple_class(obj, obj_tuple)
205
206    @given(nested_classes, st.sampled_from(SEQUENCE_TYPES))
207    @settings(suppress_health_check=[HealthCheck.too_slow])
208    def test_recurse_retain(self, cls, tuple_class):
209        """
210        Property tests for asserting collection types are retained.
211        """
212        obj = cls()
213        obj_tuple = astuple(obj, tuple_factory=tuple_class,
214                            retain_collection_types=True)
215
216        def assert_proper_col_class(obj, obj_tuple):
217            # Iterate over all attributes, and if they are lists or mappings
218            # in the original, assert they are the same class in the dumped.
219            for index, field in enumerate(fields(obj.__class__)):
220                field_val = getattr(obj, field.name)
221                if has(field_val.__class__):
222                    # This field holds a class, recurse the assertions.
223                    assert_proper_col_class(field_val, obj_tuple[index])
224                elif isinstance(field_val, (list, tuple)):
225                    # This field holds a sequence of something.
226                    expected_type = type(obj_tuple[index])
227                    assert type(field_val) is expected_type  # noqa: E721
228                    for obj_e, obj_tuple_e in zip(field_val, obj_tuple[index]):
229                        if has(obj_e.__class__):
230                            assert_proper_col_class(obj_e, obj_tuple_e)
231                elif isinstance(field_val, dict):
232                    orig = field_val
233                    tupled = obj_tuple[index]
234                    assert type(orig) is type(tupled)  # noqa: E721
235                    for obj_e, obj_tuple_e in zip(orig.items(),
236                                                  tupled.items()):
237                        if has(obj_e[0].__class__):  # Dict key
238                            assert_proper_col_class(obj_e[0], obj_tuple_e[0])
239                        if has(obj_e[1].__class__):  # Dict value
240                            assert_proper_col_class(obj_e[1], obj_tuple_e[1])
241
242        assert_proper_col_class(obj, obj_tuple)
243
244    @given(st.sampled_from(SEQUENCE_TYPES))
245    def test_filter(self, C, tuple_factory):
246        """
247        Attributes that are supposed to be skipped are skipped.
248        """
249        assert tuple_factory([tuple_factory([1, ]), ]) == astuple(C(
250            C(1, 2),
251            C(3, 4),
252        ), filter=lambda a, v: a.name != "y", tuple_factory=tuple_factory)
253
254    @given(container=st.sampled_from(SEQUENCE_TYPES))
255    def test_lists_tuples(self, container, C):
256        """
257        If recurse is True, also recurse into lists.
258        """
259        assert ((1, [(2, 3), (4, 5), "a"])
260                == astuple(C(1, container([C(2, 3), C(4, 5), "a"])))
261                )
262
263    @given(st.sampled_from(SEQUENCE_TYPES))
264    def test_dicts(self, C, tuple_factory):
265        """
266        If recurse is True, also recurse into dicts.
267        """
268        res = astuple(C(1, {"a": C(4, 5)}), tuple_factory=tuple_factory)
269        assert tuple_factory([1, {"a": tuple_factory([4, 5])}]) == res
270        assert isinstance(res, tuple_factory)
271
272    @given(container=st.sampled_from(SEQUENCE_TYPES))
273    def test_lists_tuples_retain_type(self, container, C):
274        """
275        If recurse and retain_collection_types are True, also recurse
276        into lists and do not convert them into list.
277        """
278        assert (
279            (1, container([(2, 3), (4, 5), "a"]))
280            == astuple(C(1, container([C(2, 3), C(4, 5), "a"])),
281                       retain_collection_types=True))
282
283    @given(container=st.sampled_from(MAPPING_TYPES))
284    def test_dicts_retain_type(self, container, C):
285        """
286        If recurse and retain_collection_types are True, also recurse
287        into lists and do not convert them into list.
288        """
289        assert (
290            (1, container({"a": (4, 5)}))
291            == astuple(C(1, container({"a": C(4, 5)})),
292                       retain_collection_types=True))
293
294    @given(simple_classes(), st.sampled_from(SEQUENCE_TYPES))
295    def test_roundtrip(self, cls, tuple_class):
296        """
297        Test dumping to tuple and back for Hypothesis-generated classes.
298        """
299        instance = cls()
300        tuple_instance = astuple(instance, tuple_factory=tuple_class)
301
302        assert isinstance(tuple_instance, tuple_class)
303
304        roundtrip_instance = cls(*tuple_instance)
305
306        assert instance == roundtrip_instance
307
308
309class TestHas(object):
310    """
311    Tests for `has`.
312    """
313    def test_positive(self, C):
314        """
315        Returns `True` on decorated classes.
316        """
317        assert has(C)
318
319    def test_positive_empty(self):
320        """
321        Returns `True` on decorated classes even if there are no attributes.
322        """
323        @attr.s
324        class D(object):
325            pass
326
327        assert has(D)
328
329    def test_negative(self):
330        """
331        Returns `False` on non-decorated classes.
332        """
333        assert not has(object)
334
335
336class TestAssoc(object):
337    """
338    Tests for `assoc`.
339    """
340    @given(slots=st.booleans(), frozen=st.booleans())
341    def test_empty(self, slots, frozen):
342        """
343        Empty classes without changes get copied.
344        """
345        @attr.s(slots=slots, frozen=frozen)
346        class C(object):
347            pass
348
349        i1 = C()
350        with pytest.deprecated_call():
351            i2 = assoc(i1)
352
353        assert i1 is not i2
354        assert i1 == i2
355
356    @given(simple_classes())
357    def test_no_changes(self, C):
358        """
359        No changes means a verbatim copy.
360        """
361        i1 = C()
362        with pytest.deprecated_call():
363            i2 = assoc(i1)
364
365        assert i1 is not i2
366        assert i1 == i2
367
368    @given(simple_classes(), st.data())
369    def test_change(self, C, data):
370        """
371        Changes work.
372        """
373        # Take the first attribute, and change it.
374        assume(fields(C))  # Skip classes with no attributes.
375        field_names = [a.name for a in fields(C)]
376        original = C()
377        chosen_names = data.draw(st.sets(st.sampled_from(field_names)))
378        change_dict = {name: data.draw(st.integers())
379                       for name in chosen_names}
380
381        with pytest.deprecated_call():
382            changed = assoc(original, **change_dict)
383
384        for k, v in change_dict.items():
385            assert getattr(changed, k) == v
386
387    @given(simple_classes())
388    def test_unknown(self, C):
389        """
390        Wanting to change an unknown attribute raises an
391        AttrsAttributeNotFoundError.
392        """
393        # No generated class will have a four letter attribute.
394        with pytest.raises(AttrsAttributeNotFoundError) as e, \
395                pytest.deprecated_call():
396            assoc(C(), aaaa=2)
397
398        assert (
399            "aaaa is not an attrs attribute on {cls!r}.".format(cls=C),
400        ) == e.value.args
401
402    def test_frozen(self):
403        """
404        Works on frozen classes.
405        """
406        @attr.s(frozen=True)
407        class C(object):
408            x = attr.ib()
409            y = attr.ib()
410
411        with pytest.deprecated_call():
412            assert C(3, 2) == assoc(C(1, 2), x=3)
413
414    def test_warning(self):
415        """
416        DeprecationWarning points to the correct file.
417        """
418        @attr.s
419        class C(object):
420            x = attr.ib()
421
422        with pytest.warns(DeprecationWarning) as wi:
423            assert C(2) == assoc(C(1), x=2)
424
425        assert __file__ == wi.list[0].filename
426
427
428class TestEvolve(object):
429    """
430    Tests for `evolve`.
431    """
432    @given(slots=st.booleans(), frozen=st.booleans())
433    def test_empty(self, slots, frozen):
434        """
435        Empty classes without changes get copied.
436        """
437        @attr.s(slots=slots, frozen=frozen)
438        class C(object):
439            pass
440
441        i1 = C()
442        i2 = evolve(i1)
443
444        assert i1 is not i2
445        assert i1 == i2
446
447    @given(simple_classes())
448    def test_no_changes(self, C):
449        """
450        No changes means a verbatim copy.
451        """
452        i1 = C()
453        i2 = evolve(i1)
454
455        assert i1 is not i2
456        assert i1 == i2
457
458    @given(simple_classes(), st.data())
459    def test_change(self, C, data):
460        """
461        Changes work.
462        """
463        # Take the first attribute, and change it.
464        assume(fields(C))  # Skip classes with no attributes.
465        field_names = [a.name for a in fields(C)]
466        original = C()
467        chosen_names = data.draw(st.sets(st.sampled_from(field_names)))
468        # We pay special attention to private attributes, they should behave
469        # like in `__init__`.
470        change_dict = {name.replace('_', ''): data.draw(st.integers())
471                       for name in chosen_names}
472        changed = evolve(original, **change_dict)
473        for name in chosen_names:
474            assert getattr(changed, name) == change_dict[name.replace('_', '')]
475
476    @given(simple_classes())
477    def test_unknown(self, C):
478        """
479        Wanting to change an unknown attribute raises an
480        AttrsAttributeNotFoundError.
481        """
482        # No generated class will have a four letter attribute.
483        with pytest.raises(TypeError) as e:
484            evolve(C(), aaaa=2)
485        expected = "__init__() got an unexpected keyword argument 'aaaa'"
486        assert (expected,) == e.value.args
487
488    def test_validator_failure(self):
489        """
490        TypeError isn't swallowed when validation fails within evolve.
491        """
492        @attr.s
493        class C(object):
494            a = attr.ib(validator=instance_of(int))
495
496        with pytest.raises(TypeError) as e:
497            evolve(C(a=1), a="some string")
498        m = e.value.args[0]
499
500        assert m.startswith("'a' must be <{type} 'int'>".format(type=TYPE))
501
502    def test_private(self):
503        """
504        evolve() acts as `__init__` with regards to private attributes.
505        """
506        @attr.s
507        class C(object):
508            _a = attr.ib()
509
510        assert evolve(C(1), a=2)._a == 2
511
512        with pytest.raises(TypeError):
513            evolve(C(1), _a=2)
514
515        with pytest.raises(TypeError):
516            evolve(C(1), a=3, _a=2)
517
518    def test_non_init_attrs(self):
519        """
520        evolve() handles `init=False` attributes.
521        """
522        @attr.s
523        class C(object):
524            a = attr.ib()
525            b = attr.ib(init=False, default=0)
526
527        assert evolve(C(1), a=2).a == 2
528