1# Deliberately use "from dataclasses import *".  Every name in __all__
2# is tested, so they all must be present.  This is a way to catch
3# missing ones.
4
5from dataclasses import *
6
7import abc
8import pickle
9import inspect
10import builtins
11import types
12import unittest
13from unittest.mock import Mock
14from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol
15from typing import get_type_hints
16from collections import deque, OrderedDict, namedtuple
17from functools import total_ordering
18
19import typing       # Needed for the string "typing.ClassVar[int]" to work as an annotation.
20import dataclasses  # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
21
22# Just any custom exception we can catch.
23class CustomError(Exception): pass
24
25class TestCase(unittest.TestCase):
26    def test_no_fields(self):
27        @dataclass
28        class C:
29            pass
30
31        o = C()
32        self.assertEqual(len(fields(C)), 0)
33
34    def test_no_fields_but_member_variable(self):
35        @dataclass
36        class C:
37            i = 0
38
39        o = C()
40        self.assertEqual(len(fields(C)), 0)
41
42    def test_one_field_no_default(self):
43        @dataclass
44        class C:
45            x: int
46
47        o = C(42)
48        self.assertEqual(o.x, 42)
49
50    def test_field_default_default_factory_error(self):
51        msg = "cannot specify both default and default_factory"
52        with self.assertRaisesRegex(ValueError, msg):
53            @dataclass
54            class C:
55                x: int = field(default=1, default_factory=int)
56
57    def test_field_repr(self):
58        int_field = field(default=1, init=True, repr=False)
59        int_field.name = "id"
60        repr_output = repr(int_field)
61        expected_output = "Field(name='id',type=None," \
62                           f"default=1,default_factory={MISSING!r}," \
63                           "init=True,repr=False,hash=None," \
64                           "compare=True,metadata=mappingproxy({})," \
65                           f"kw_only={MISSING!r}," \
66                           "_field_type=None)"
67
68        self.assertEqual(repr_output, expected_output)
69
70    def test_named_init_params(self):
71        @dataclass
72        class C:
73            x: int
74
75        o = C(x=32)
76        self.assertEqual(o.x, 32)
77
78    def test_two_fields_one_default(self):
79        @dataclass
80        class C:
81            x: int
82            y: int = 0
83
84        o = C(3)
85        self.assertEqual((o.x, o.y), (3, 0))
86
87        # Non-defaults following defaults.
88        with self.assertRaisesRegex(TypeError,
89                                    "non-default argument 'y' follows "
90                                    "default argument"):
91            @dataclass
92            class C:
93                x: int = 0
94                y: int
95
96        # A derived class adds a non-default field after a default one.
97        with self.assertRaisesRegex(TypeError,
98                                    "non-default argument 'y' follows "
99                                    "default argument"):
100            @dataclass
101            class B:
102                x: int = 0
103
104            @dataclass
105            class C(B):
106                y: int
107
108        # Override a base class field and add a default to
109        #  a field which didn't use to have a default.
110        with self.assertRaisesRegex(TypeError,
111                                    "non-default argument 'y' follows "
112                                    "default argument"):
113            @dataclass
114            class B:
115                x: int
116                y: int
117
118            @dataclass
119            class C(B):
120                x: int = 0
121
122    def test_overwrite_hash(self):
123        # Test that declaring this class isn't an error.  It should
124        #  use the user-provided __hash__.
125        @dataclass(frozen=True)
126        class C:
127            x: int
128            def __hash__(self):
129                return 301
130        self.assertEqual(hash(C(100)), 301)
131
132        # Test that declaring this class isn't an error.  It should
133        #  use the generated __hash__.
134        @dataclass(frozen=True)
135        class C:
136            x: int
137            def __eq__(self, other):
138                return False
139        self.assertEqual(hash(C(100)), hash((100,)))
140
141        # But this one should generate an exception, because with
142        #  unsafe_hash=True, it's an error to have a __hash__ defined.
143        with self.assertRaisesRegex(TypeError,
144                                    'Cannot overwrite attribute __hash__'):
145            @dataclass(unsafe_hash=True)
146            class C:
147                def __hash__(self):
148                    pass
149
150        # Creating this class should not generate an exception,
151        #  because even though __hash__ exists before @dataclass is
152        #  called, (due to __eq__ being defined), since it's None
153        #  that's okay.
154        @dataclass(unsafe_hash=True)
155        class C:
156            x: int
157            def __eq__(self):
158                pass
159        # The generated hash function works as we'd expect.
160        self.assertEqual(hash(C(10)), hash((10,)))
161
162        # Creating this class should generate an exception, because
163        #  __hash__ exists and is not None, which it would be if it
164        #  had been auto-generated due to __eq__ being defined.
165        with self.assertRaisesRegex(TypeError,
166                                    'Cannot overwrite attribute __hash__'):
167            @dataclass(unsafe_hash=True)
168            class C:
169                x: int
170                def __eq__(self):
171                    pass
172                def __hash__(self):
173                    pass
174
175    def test_overwrite_fields_in_derived_class(self):
176        # Note that x from C1 replaces x in Base, but the order remains
177        #  the same as defined in Base.
178        @dataclass
179        class Base:
180            x: Any = 15.0
181            y: int = 0
182
183        @dataclass
184        class C1(Base):
185            z: int = 10
186            x: int = 15
187
188        o = Base()
189        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
190
191        o = C1()
192        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
193
194        o = C1(x=5)
195        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
196
197    def test_field_named_self(self):
198        @dataclass
199        class C:
200            self: str
201        c=C('foo')
202        self.assertEqual(c.self, 'foo')
203
204        # Make sure the first parameter is not named 'self'.
205        sig = inspect.signature(C.__init__)
206        first = next(iter(sig.parameters))
207        self.assertNotEqual('self', first)
208
209        # But we do use 'self' if no field named self.
210        @dataclass
211        class C:
212            selfx: str
213
214        # Make sure the first parameter is named 'self'.
215        sig = inspect.signature(C.__init__)
216        first = next(iter(sig.parameters))
217        self.assertEqual('self', first)
218
219    def test_field_named_object(self):
220        @dataclass
221        class C:
222            object: str
223        c = C('foo')
224        self.assertEqual(c.object, 'foo')
225
226    def test_field_named_object_frozen(self):
227        @dataclass(frozen=True)
228        class C:
229            object: str
230        c = C('foo')
231        self.assertEqual(c.object, 'foo')
232
233    def test_field_named_like_builtin(self):
234        # Attribute names can shadow built-in names
235        # since code generation is used.
236        # Ensure that this is not happening.
237        exclusions = {'None', 'True', 'False'}
238        builtins_names = sorted(
239            b for b in builtins.__dict__.keys()
240            if not b.startswith('__') and b not in exclusions
241        )
242        attributes = [(name, str) for name in builtins_names]
243        C = make_dataclass('C', attributes)
244
245        c = C(*[name for name in builtins_names])
246
247        for name in builtins_names:
248            self.assertEqual(getattr(c, name), name)
249
250    def test_field_named_like_builtin_frozen(self):
251        # Attribute names can shadow built-in names
252        # since code generation is used.
253        # Ensure that this is not happening
254        # for frozen data classes.
255        exclusions = {'None', 'True', 'False'}
256        builtins_names = sorted(
257            b for b in builtins.__dict__.keys()
258            if not b.startswith('__') and b not in exclusions
259        )
260        attributes = [(name, str) for name in builtins_names]
261        C = make_dataclass('C', attributes, frozen=True)
262
263        c = C(*[name for name in builtins_names])
264
265        for name in builtins_names:
266            self.assertEqual(getattr(c, name), name)
267
268    def test_0_field_compare(self):
269        # Ensure that order=False is the default.
270        @dataclass
271        class C0:
272            pass
273
274        @dataclass(order=False)
275        class C1:
276            pass
277
278        for cls in [C0, C1]:
279            with self.subTest(cls=cls):
280                self.assertEqual(cls(), cls())
281                for idx, fn in enumerate([lambda a, b: a < b,
282                                          lambda a, b: a <= b,
283                                          lambda a, b: a > b,
284                                          lambda a, b: a >= b]):
285                    with self.subTest(idx=idx):
286                        with self.assertRaisesRegex(TypeError,
287                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
288                            fn(cls(), cls())
289
290        @dataclass(order=True)
291        class C:
292            pass
293        self.assertLessEqual(C(), C())
294        self.assertGreaterEqual(C(), C())
295
296    def test_1_field_compare(self):
297        # Ensure that order=False is the default.
298        @dataclass
299        class C0:
300            x: int
301
302        @dataclass(order=False)
303        class C1:
304            x: int
305
306        for cls in [C0, C1]:
307            with self.subTest(cls=cls):
308                self.assertEqual(cls(1), cls(1))
309                self.assertNotEqual(cls(0), cls(1))
310                for idx, fn in enumerate([lambda a, b: a < b,
311                                          lambda a, b: a <= b,
312                                          lambda a, b: a > b,
313                                          lambda a, b: a >= b]):
314                    with self.subTest(idx=idx):
315                        with self.assertRaisesRegex(TypeError,
316                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
317                            fn(cls(0), cls(0))
318
319        @dataclass(order=True)
320        class C:
321            x: int
322        self.assertLess(C(0), C(1))
323        self.assertLessEqual(C(0), C(1))
324        self.assertLessEqual(C(1), C(1))
325        self.assertGreater(C(1), C(0))
326        self.assertGreaterEqual(C(1), C(0))
327        self.assertGreaterEqual(C(1), C(1))
328
329    def test_simple_compare(self):
330        # Ensure that order=False is the default.
331        @dataclass
332        class C0:
333            x: int
334            y: int
335
336        @dataclass(order=False)
337        class C1:
338            x: int
339            y: int
340
341        for cls in [C0, C1]:
342            with self.subTest(cls=cls):
343                self.assertEqual(cls(0, 0), cls(0, 0))
344                self.assertEqual(cls(1, 2), cls(1, 2))
345                self.assertNotEqual(cls(1, 0), cls(0, 0))
346                self.assertNotEqual(cls(1, 0), cls(1, 1))
347                for idx, fn in enumerate([lambda a, b: a < b,
348                                          lambda a, b: a <= b,
349                                          lambda a, b: a > b,
350                                          lambda a, b: a >= b]):
351                    with self.subTest(idx=idx):
352                        with self.assertRaisesRegex(TypeError,
353                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
354                            fn(cls(0, 0), cls(0, 0))
355
356        @dataclass(order=True)
357        class C:
358            x: int
359            y: int
360
361        for idx, fn in enumerate([lambda a, b: a == b,
362                                  lambda a, b: a <= b,
363                                  lambda a, b: a >= b]):
364            with self.subTest(idx=idx):
365                self.assertTrue(fn(C(0, 0), C(0, 0)))
366
367        for idx, fn in enumerate([lambda a, b: a < b,
368                                  lambda a, b: a <= b,
369                                  lambda a, b: a != b]):
370            with self.subTest(idx=idx):
371                self.assertTrue(fn(C(0, 0), C(0, 1)))
372                self.assertTrue(fn(C(0, 1), C(1, 0)))
373                self.assertTrue(fn(C(1, 0), C(1, 1)))
374
375        for idx, fn in enumerate([lambda a, b: a > b,
376                                  lambda a, b: a >= b,
377                                  lambda a, b: a != b]):
378            with self.subTest(idx=idx):
379                self.assertTrue(fn(C(0, 1), C(0, 0)))
380                self.assertTrue(fn(C(1, 0), C(0, 1)))
381                self.assertTrue(fn(C(1, 1), C(1, 0)))
382
383    def test_compare_subclasses(self):
384        # Comparisons fail for subclasses, even if no fields
385        #  are added.
386        @dataclass
387        class B:
388            i: int
389
390        @dataclass
391        class C(B):
392            pass
393
394        for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
395                                              (lambda a, b: a != b, True)]):
396            with self.subTest(idx=idx):
397                self.assertEqual(fn(B(0), C(0)), expected)
398
399        for idx, fn in enumerate([lambda a, b: a < b,
400                                  lambda a, b: a <= b,
401                                  lambda a, b: a > b,
402                                  lambda a, b: a >= b]):
403            with self.subTest(idx=idx):
404                with self.assertRaisesRegex(TypeError,
405                                            "not supported between instances of 'B' and 'C'"):
406                    fn(B(0), C(0))
407
408    def test_eq_order(self):
409        # Test combining eq and order.
410        for (eq,    order, result   ) in [
411            (False, False, 'neither'),
412            (False, True,  'exception'),
413            (True,  False, 'eq_only'),
414            (True,  True,  'both'),
415        ]:
416            with self.subTest(eq=eq, order=order):
417                if result == 'exception':
418                    with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
419                        @dataclass(eq=eq, order=order)
420                        class C:
421                            pass
422                else:
423                    @dataclass(eq=eq, order=order)
424                    class C:
425                        pass
426
427                    if result == 'neither':
428                        self.assertNotIn('__eq__', C.__dict__)
429                        self.assertNotIn('__lt__', C.__dict__)
430                        self.assertNotIn('__le__', C.__dict__)
431                        self.assertNotIn('__gt__', C.__dict__)
432                        self.assertNotIn('__ge__', C.__dict__)
433                    elif result == 'both':
434                        self.assertIn('__eq__', C.__dict__)
435                        self.assertIn('__lt__', C.__dict__)
436                        self.assertIn('__le__', C.__dict__)
437                        self.assertIn('__gt__', C.__dict__)
438                        self.assertIn('__ge__', C.__dict__)
439                    elif result == 'eq_only':
440                        self.assertIn('__eq__', C.__dict__)
441                        self.assertNotIn('__lt__', C.__dict__)
442                        self.assertNotIn('__le__', C.__dict__)
443                        self.assertNotIn('__gt__', C.__dict__)
444                        self.assertNotIn('__ge__', C.__dict__)
445                    else:
446                        assert False, f'unknown result {result!r}'
447
448    def test_field_no_default(self):
449        @dataclass
450        class C:
451            x: int = field()
452
453        self.assertEqual(C(5).x, 5)
454
455        with self.assertRaisesRegex(TypeError,
456                                    r"__init__\(\) missing 1 required "
457                                    "positional argument: 'x'"):
458            C()
459
460    def test_field_default(self):
461        default = object()
462        @dataclass
463        class C:
464            x: object = field(default=default)
465
466        self.assertIs(C.x, default)
467        c = C(10)
468        self.assertEqual(c.x, 10)
469
470        # If we delete the instance attribute, we should then see the
471        #  class attribute.
472        del c.x
473        self.assertIs(c.x, default)
474
475        self.assertIs(C().x, default)
476
477    def test_not_in_repr(self):
478        @dataclass
479        class C:
480            x: int = field(repr=False)
481        with self.assertRaises(TypeError):
482            C()
483        c = C(10)
484        self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
485
486        @dataclass
487        class C:
488            x: int = field(repr=False)
489            y: int
490        c = C(10, 20)
491        self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
492
493    def test_not_in_compare(self):
494        @dataclass
495        class C:
496            x: int = 0
497            y: int = field(compare=False, default=4)
498
499        self.assertEqual(C(), C(0, 20))
500        self.assertEqual(C(1, 10), C(1, 20))
501        self.assertNotEqual(C(3), C(4, 10))
502        self.assertNotEqual(C(3, 10), C(4, 10))
503
504    def test_hash_field_rules(self):
505        # Test all 6 cases of:
506        #  hash=True/False/None
507        #  compare=True/False
508        for (hash_,    compare, result  ) in [
509            (True,     False,   'field' ),
510            (True,     True,    'field' ),
511            (False,    False,   'absent'),
512            (False,    True,    'absent'),
513            (None,     False,   'absent'),
514            (None,     True,    'field' ),
515            ]:
516            with self.subTest(hash=hash_, compare=compare):
517                @dataclass(unsafe_hash=True)
518                class C:
519                    x: int = field(compare=compare, hash=hash_, default=5)
520
521                if result == 'field':
522                    # __hash__ contains the field.
523                    self.assertEqual(hash(C(5)), hash((5,)))
524                elif result == 'absent':
525                    # The field is not present in the hash.
526                    self.assertEqual(hash(C(5)), hash(()))
527                else:
528                    assert False, f'unknown result {result!r}'
529
530    def test_init_false_no_default(self):
531        # If init=False and no default value, then the field won't be
532        #  present in the instance.
533        @dataclass
534        class C:
535            x: int = field(init=False)
536
537        self.assertNotIn('x', C().__dict__)
538
539        @dataclass
540        class C:
541            x: int
542            y: int = 0
543            z: int = field(init=False)
544            t: int = 10
545
546        self.assertNotIn('z', C(0).__dict__)
547        self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
548
549    def test_class_marker(self):
550        @dataclass
551        class C:
552            x: int
553            y: str = field(init=False, default=None)
554            z: str = field(repr=False)
555
556        the_fields = fields(C)
557        # the_fields is a tuple of 3 items, each value
558        #  is in __annotations__.
559        self.assertIsInstance(the_fields, tuple)
560        for f in the_fields:
561            self.assertIs(type(f), Field)
562            self.assertIn(f.name, C.__annotations__)
563
564        self.assertEqual(len(the_fields), 3)
565
566        self.assertEqual(the_fields[0].name, 'x')
567        self.assertEqual(the_fields[0].type, int)
568        self.assertFalse(hasattr(C, 'x'))
569        self.assertTrue (the_fields[0].init)
570        self.assertTrue (the_fields[0].repr)
571        self.assertEqual(the_fields[1].name, 'y')
572        self.assertEqual(the_fields[1].type, str)
573        self.assertIsNone(getattr(C, 'y'))
574        self.assertFalse(the_fields[1].init)
575        self.assertTrue (the_fields[1].repr)
576        self.assertEqual(the_fields[2].name, 'z')
577        self.assertEqual(the_fields[2].type, str)
578        self.assertFalse(hasattr(C, 'z'))
579        self.assertTrue (the_fields[2].init)
580        self.assertFalse(the_fields[2].repr)
581
582    def test_field_order(self):
583        @dataclass
584        class B:
585            a: str = 'B:a'
586            b: str = 'B:b'
587            c: str = 'B:c'
588
589        @dataclass
590        class C(B):
591            b: str = 'C:b'
592
593        self.assertEqual([(f.name, f.default) for f in fields(C)],
594                         [('a', 'B:a'),
595                          ('b', 'C:b'),
596                          ('c', 'B:c')])
597
598        @dataclass
599        class D(B):
600            c: str = 'D:c'
601
602        self.assertEqual([(f.name, f.default) for f in fields(D)],
603                         [('a', 'B:a'),
604                          ('b', 'B:b'),
605                          ('c', 'D:c')])
606
607        @dataclass
608        class E(D):
609            a: str = 'E:a'
610            d: str = 'E:d'
611
612        self.assertEqual([(f.name, f.default) for f in fields(E)],
613                         [('a', 'E:a'),
614                          ('b', 'B:b'),
615                          ('c', 'D:c'),
616                          ('d', 'E:d')])
617
618    def test_class_attrs(self):
619        # We only have a class attribute if a default value is
620        #  specified, either directly or via a field with a default.
621        default = object()
622        @dataclass
623        class C:
624            x: int
625            y: int = field(repr=False)
626            z: object = default
627            t: int = field(default=100)
628
629        self.assertFalse(hasattr(C, 'x'))
630        self.assertFalse(hasattr(C, 'y'))
631        self.assertIs   (C.z, default)
632        self.assertEqual(C.t, 100)
633
634    def test_disallowed_mutable_defaults(self):
635        # For the known types, don't allow mutable default values.
636        for typ, empty, non_empty in [(list, [], [1]),
637                                      (dict, {}, {0:1}),
638                                      (set, set(), set([1])),
639                                      ]:
640            with self.subTest(typ=typ):
641                # Can't use a zero-length value.
642                with self.assertRaisesRegex(ValueError,
643                                            f'mutable default {typ} for field '
644                                            'x is not allowed'):
645                    @dataclass
646                    class Point:
647                        x: typ = empty
648
649
650                # Nor a non-zero-length value
651                with self.assertRaisesRegex(ValueError,
652                                            f'mutable default {typ} for field '
653                                            'y is not allowed'):
654                    @dataclass
655                    class Point:
656                        y: typ = non_empty
657
658                # Check subtypes also fail.
659                class Subclass(typ): pass
660
661                with self.assertRaisesRegex(ValueError,
662                                            f"mutable default .*Subclass'>"
663                                            ' for field z is not allowed'
664                                            ):
665                    @dataclass
666                    class Point:
667                        z: typ = Subclass()
668
669                # Because this is a ClassVar, it can be mutable.
670                @dataclass
671                class C:
672                    z: ClassVar[typ] = typ()
673
674                # Because this is a ClassVar, it can be mutable.
675                @dataclass
676                class C:
677                    x: ClassVar[typ] = Subclass()
678
679    def test_deliberately_mutable_defaults(self):
680        # If a mutable default isn't in the known list of
681        #  (list, dict, set), then it's okay.
682        class Mutable:
683            def __init__(self):
684                self.l = []
685
686        @dataclass
687        class C:
688            x: Mutable
689
690        # These 2 instances will share this value of x.
691        lst = Mutable()
692        o1 = C(lst)
693        o2 = C(lst)
694        self.assertEqual(o1, o2)
695        o1.x.l.extend([1, 2])
696        self.assertEqual(o1, o2)
697        self.assertEqual(o1.x.l, [1, 2])
698        self.assertIs(o1.x, o2.x)
699
700    def test_no_options(self):
701        # Call with dataclass().
702        @dataclass()
703        class C:
704            x: int
705
706        self.assertEqual(C(42).x, 42)
707
708    def test_not_tuple(self):
709        # Make sure we can't be compared to a tuple.
710        @dataclass
711        class Point:
712            x: int
713            y: int
714        self.assertNotEqual(Point(1, 2), (1, 2))
715
716        # And that we can't compare to another unrelated dataclass.
717        @dataclass
718        class C:
719            x: int
720            y: int
721        self.assertNotEqual(Point(1, 3), C(1, 3))
722
723    def test_not_other_dataclass(self):
724        # Test that some of the problems with namedtuple don't happen
725        #  here.
726        @dataclass
727        class Point3D:
728            x: int
729            y: int
730            z: int
731
732        @dataclass
733        class Date:
734            year: int
735            month: int
736            day: int
737
738        self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
739        self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
740
741        # Make sure we can't unpack.
742        with self.assertRaisesRegex(TypeError, 'unpack'):
743            x, y, z = Point3D(4, 5, 6)
744
745        # Make sure another class with the same field names isn't
746        #  equal.
747        @dataclass
748        class Point3Dv1:
749            x: int = 0
750            y: int = 0
751            z: int = 0
752        self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
753
754    def test_function_annotations(self):
755        # Some dummy class and instance to use as a default.
756        class F:
757            pass
758        f = F()
759
760        def validate_class(cls):
761            # First, check __annotations__, even though they're not
762            #  function annotations.
763            self.assertEqual(cls.__annotations__['i'], int)
764            self.assertEqual(cls.__annotations__['j'], str)
765            self.assertEqual(cls.__annotations__['k'], F)
766            self.assertEqual(cls.__annotations__['l'], float)
767            self.assertEqual(cls.__annotations__['z'], complex)
768
769            # Verify __init__.
770
771            signature = inspect.signature(cls.__init__)
772            # Check the return type, should be None.
773            self.assertIs(signature.return_annotation, None)
774
775            # Check each parameter.
776            params = iter(signature.parameters.values())
777            param = next(params)
778            # This is testing an internal name, and probably shouldn't be tested.
779            self.assertEqual(param.name, 'self')
780            param = next(params)
781            self.assertEqual(param.name, 'i')
782            self.assertIs   (param.annotation, int)
783            self.assertEqual(param.default, inspect.Parameter.empty)
784            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
785            param = next(params)
786            self.assertEqual(param.name, 'j')
787            self.assertIs   (param.annotation, str)
788            self.assertEqual(param.default, inspect.Parameter.empty)
789            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
790            param = next(params)
791            self.assertEqual(param.name, 'k')
792            self.assertIs   (param.annotation, F)
793            # Don't test for the default, since it's set to MISSING.
794            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
795            param = next(params)
796            self.assertEqual(param.name, 'l')
797            self.assertIs   (param.annotation, float)
798            # Don't test for the default, since it's set to MISSING.
799            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
800            self.assertRaises(StopIteration, next, params)
801
802
803        @dataclass
804        class C:
805            i: int
806            j: str
807            k: F = f
808            l: float=field(default=None)
809            z: complex=field(default=3+4j, init=False)
810
811        validate_class(C)
812
813        # Now repeat with __hash__.
814        @dataclass(frozen=True, unsafe_hash=True)
815        class C:
816            i: int
817            j: str
818            k: F = f
819            l: float=field(default=None)
820            z: complex=field(default=3+4j, init=False)
821
822        validate_class(C)
823
824    def test_missing_default(self):
825        # Test that MISSING works the same as a default not being
826        #  specified.
827        @dataclass
828        class C:
829            x: int=field(default=MISSING)
830        with self.assertRaisesRegex(TypeError,
831                                    r'__init__\(\) missing 1 required '
832                                    'positional argument'):
833            C()
834        self.assertNotIn('x', C.__dict__)
835
836        @dataclass
837        class D:
838            x: int
839        with self.assertRaisesRegex(TypeError,
840                                    r'__init__\(\) missing 1 required '
841                                    'positional argument'):
842            D()
843        self.assertNotIn('x', D.__dict__)
844
845    def test_missing_default_factory(self):
846        # Test that MISSING works the same as a default factory not
847        #  being specified (which is really the same as a default not
848        #  being specified, too).
849        @dataclass
850        class C:
851            x: int=field(default_factory=MISSING)
852        with self.assertRaisesRegex(TypeError,
853                                    r'__init__\(\) missing 1 required '
854                                    'positional argument'):
855            C()
856        self.assertNotIn('x', C.__dict__)
857
858        @dataclass
859        class D:
860            x: int=field(default=MISSING, default_factory=MISSING)
861        with self.assertRaisesRegex(TypeError,
862                                    r'__init__\(\) missing 1 required '
863                                    'positional argument'):
864            D()
865        self.assertNotIn('x', D.__dict__)
866
867    def test_missing_repr(self):
868        self.assertIn('MISSING_TYPE object', repr(MISSING))
869
870    def test_dont_include_other_annotations(self):
871        @dataclass
872        class C:
873            i: int
874            def foo(self) -> int:
875                return 4
876            @property
877            def bar(self) -> int:
878                return 5
879        self.assertEqual(list(C.__annotations__), ['i'])
880        self.assertEqual(C(10).foo(), 4)
881        self.assertEqual(C(10).bar, 5)
882        self.assertEqual(C(10).i, 10)
883
884    def test_post_init(self):
885        # Just make sure it gets called
886        @dataclass
887        class C:
888            def __post_init__(self):
889                raise CustomError()
890        with self.assertRaises(CustomError):
891            C()
892
893        @dataclass
894        class C:
895            i: int = 10
896            def __post_init__(self):
897                if self.i == 10:
898                    raise CustomError()
899        with self.assertRaises(CustomError):
900            C()
901        # post-init gets called, but doesn't raise. This is just
902        #  checking that self is used correctly.
903        C(5)
904
905        # If there's not an __init__, then post-init won't get called.
906        @dataclass(init=False)
907        class C:
908            def __post_init__(self):
909                raise CustomError()
910        # Creating the class won't raise
911        C()
912
913        @dataclass
914        class C:
915            x: int = 0
916            def __post_init__(self):
917                self.x *= 2
918        self.assertEqual(C().x, 0)
919        self.assertEqual(C(2).x, 4)
920
921        # Make sure that if we're frozen, post-init can't set
922        #  attributes.
923        @dataclass(frozen=True)
924        class C:
925            x: int = 0
926            def __post_init__(self):
927                self.x *= 2
928        with self.assertRaises(FrozenInstanceError):
929            C()
930
931    def test_post_init_super(self):
932        # Make sure super() post-init isn't called by default.
933        class B:
934            def __post_init__(self):
935                raise CustomError()
936
937        @dataclass
938        class C(B):
939            def __post_init__(self):
940                self.x = 5
941
942        self.assertEqual(C().x, 5)
943
944        # Now call super(), and it will raise.
945        @dataclass
946        class C(B):
947            def __post_init__(self):
948                super().__post_init__()
949
950        with self.assertRaises(CustomError):
951            C()
952
953        # Make sure post-init is called, even if not defined in our
954        #  class.
955        @dataclass
956        class C(B):
957            pass
958
959        with self.assertRaises(CustomError):
960            C()
961
962    def test_post_init_staticmethod(self):
963        flag = False
964        @dataclass
965        class C:
966            x: int
967            y: int
968            @staticmethod
969            def __post_init__():
970                nonlocal flag
971                flag = True
972
973        self.assertFalse(flag)
974        c = C(3, 4)
975        self.assertEqual((c.x, c.y), (3, 4))
976        self.assertTrue(flag)
977
978    def test_post_init_classmethod(self):
979        @dataclass
980        class C:
981            flag = False
982            x: int
983            y: int
984            @classmethod
985            def __post_init__(cls):
986                cls.flag = True
987
988        self.assertFalse(C.flag)
989        c = C(3, 4)
990        self.assertEqual((c.x, c.y), (3, 4))
991        self.assertTrue(C.flag)
992
993    def test_class_var(self):
994        # Make sure ClassVars are ignored in __init__, __repr__, etc.
995        @dataclass
996        class C:
997            x: int
998            y: int = 10
999            z: ClassVar[int] = 1000
1000            w: ClassVar[int] = 2000
1001            t: ClassVar[int] = 3000
1002            s: ClassVar      = 4000
1003
1004        c = C(5)
1005        self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
1006        self.assertEqual(len(fields(C)), 2)                 # We have 2 fields.
1007        self.assertEqual(len(C.__annotations__), 6)         # And 4 ClassVars.
1008        self.assertEqual(c.z, 1000)
1009        self.assertEqual(c.w, 2000)
1010        self.assertEqual(c.t, 3000)
1011        self.assertEqual(c.s, 4000)
1012        C.z += 1
1013        self.assertEqual(c.z, 1001)
1014        c = C(20)
1015        self.assertEqual((c.x, c.y), (20, 10))
1016        self.assertEqual(c.z, 1001)
1017        self.assertEqual(c.w, 2000)
1018        self.assertEqual(c.t, 3000)
1019        self.assertEqual(c.s, 4000)
1020
1021    def test_class_var_no_default(self):
1022        # If a ClassVar has no default value, it should not be set on the class.
1023        @dataclass
1024        class C:
1025            x: ClassVar[int]
1026
1027        self.assertNotIn('x', C.__dict__)
1028
1029    def test_class_var_default_factory(self):
1030        # It makes no sense for a ClassVar to have a default factory. When
1031        #  would it be called? Call it yourself, since it's class-wide.
1032        with self.assertRaisesRegex(TypeError,
1033                                    'cannot have a default factory'):
1034            @dataclass
1035            class C:
1036                x: ClassVar[int] = field(default_factory=int)
1037
1038            self.assertNotIn('x', C.__dict__)
1039
1040    def test_class_var_with_default(self):
1041        # If a ClassVar has a default value, it should be set on the class.
1042        @dataclass
1043        class C:
1044            x: ClassVar[int] = 10
1045        self.assertEqual(C.x, 10)
1046
1047        @dataclass
1048        class C:
1049            x: ClassVar[int] = field(default=10)
1050        self.assertEqual(C.x, 10)
1051
1052    def test_class_var_frozen(self):
1053        # Make sure ClassVars work even if we're frozen.
1054        @dataclass(frozen=True)
1055        class C:
1056            x: int
1057            y: int = 10
1058            z: ClassVar[int] = 1000
1059            w: ClassVar[int] = 2000
1060            t: ClassVar[int] = 3000
1061
1062        c = C(5)
1063        self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1064        self.assertEqual(len(fields(C)), 2)                 # We have 2 fields
1065        self.assertEqual(len(C.__annotations__), 5)         # And 3 ClassVars
1066        self.assertEqual(c.z, 1000)
1067        self.assertEqual(c.w, 2000)
1068        self.assertEqual(c.t, 3000)
1069        # We can still modify the ClassVar, it's only instances that are
1070        #  frozen.
1071        C.z += 1
1072        self.assertEqual(c.z, 1001)
1073        c = C(20)
1074        self.assertEqual((c.x, c.y), (20, 10))
1075        self.assertEqual(c.z, 1001)
1076        self.assertEqual(c.w, 2000)
1077        self.assertEqual(c.t, 3000)
1078
1079    def test_init_var_no_default(self):
1080        # If an InitVar has no default value, it should not be set on the class.
1081        @dataclass
1082        class C:
1083            x: InitVar[int]
1084
1085        self.assertNotIn('x', C.__dict__)
1086
1087    def test_init_var_default_factory(self):
1088        # It makes no sense for an InitVar to have a default factory. When
1089        #  would it be called? Call it yourself, since it's class-wide.
1090        with self.assertRaisesRegex(TypeError,
1091                                    'cannot have a default factory'):
1092            @dataclass
1093            class C:
1094                x: InitVar[int] = field(default_factory=int)
1095
1096            self.assertNotIn('x', C.__dict__)
1097
1098    def test_init_var_with_default(self):
1099        # If an InitVar has a default value, it should be set on the class.
1100        @dataclass
1101        class C:
1102            x: InitVar[int] = 10
1103        self.assertEqual(C.x, 10)
1104
1105        @dataclass
1106        class C:
1107            x: InitVar[int] = field(default=10)
1108        self.assertEqual(C.x, 10)
1109
1110    def test_init_var(self):
1111        @dataclass
1112        class C:
1113            x: int = None
1114            init_param: InitVar[int] = None
1115
1116            def __post_init__(self, init_param):
1117                if self.x is None:
1118                    self.x = init_param*2
1119
1120        c = C(init_param=10)
1121        self.assertEqual(c.x, 20)
1122
1123    def test_init_var_preserve_type(self):
1124        self.assertEqual(InitVar[int].type, int)
1125
1126        # Make sure the repr is correct.
1127        self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
1128        self.assertEqual(repr(InitVar[List[int]]),
1129                         'dataclasses.InitVar[typing.List[int]]')
1130        self.assertEqual(repr(InitVar[list[int]]),
1131                         'dataclasses.InitVar[list[int]]')
1132        self.assertEqual(repr(InitVar[int|str]),
1133                         'dataclasses.InitVar[int | str]')
1134
1135    def test_init_var_inheritance(self):
1136        # Note that this deliberately tests that a dataclass need not
1137        #  have a __post_init__ function if it has an InitVar field.
1138        #  It could just be used in a derived class, as shown here.
1139        @dataclass
1140        class Base:
1141            x: int
1142            init_base: InitVar[int]
1143
1144        # We can instantiate by passing the InitVar, even though
1145        #  it's not used.
1146        b = Base(0, 10)
1147        self.assertEqual(vars(b), {'x': 0})
1148
1149        @dataclass
1150        class C(Base):
1151            y: int
1152            init_derived: InitVar[int]
1153
1154            def __post_init__(self, init_base, init_derived):
1155                self.x = self.x + init_base
1156                self.y = self.y + init_derived
1157
1158        c = C(10, 11, 50, 51)
1159        self.assertEqual(vars(c), {'x': 21, 'y': 101})
1160
1161    def test_default_factory(self):
1162        # Test a factory that returns a new list.
1163        @dataclass
1164        class C:
1165            x: int
1166            y: list = field(default_factory=list)
1167
1168        c0 = C(3)
1169        c1 = C(3)
1170        self.assertEqual(c0.x, 3)
1171        self.assertEqual(c0.y, [])
1172        self.assertEqual(c0, c1)
1173        self.assertIsNot(c0.y, c1.y)
1174        self.assertEqual(astuple(C(5, [1])), (5, [1]))
1175
1176        # Test a factory that returns a shared list.
1177        l = []
1178        @dataclass
1179        class C:
1180            x: int
1181            y: list = field(default_factory=lambda: l)
1182
1183        c0 = C(3)
1184        c1 = C(3)
1185        self.assertEqual(c0.x, 3)
1186        self.assertEqual(c0.y, [])
1187        self.assertEqual(c0, c1)
1188        self.assertIs(c0.y, c1.y)
1189        self.assertEqual(astuple(C(5, [1])), (5, [1]))
1190
1191        # Test various other field flags.
1192        # repr
1193        @dataclass
1194        class C:
1195            x: list = field(default_factory=list, repr=False)
1196        self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1197        self.assertEqual(C().x, [])
1198
1199        # hash
1200        @dataclass(unsafe_hash=True)
1201        class C:
1202            x: list = field(default_factory=list, hash=False)
1203        self.assertEqual(astuple(C()), ([],))
1204        self.assertEqual(hash(C()), hash(()))
1205
1206        # init (see also test_default_factory_with_no_init)
1207        @dataclass
1208        class C:
1209            x: list = field(default_factory=list, init=False)
1210        self.assertEqual(astuple(C()), ([],))
1211
1212        # compare
1213        @dataclass
1214        class C:
1215            x: list = field(default_factory=list, compare=False)
1216        self.assertEqual(C(), C([1]))
1217
1218    def test_default_factory_with_no_init(self):
1219        # We need a factory with a side effect.
1220        factory = Mock()
1221
1222        @dataclass
1223        class C:
1224            x: list = field(default_factory=factory, init=False)
1225
1226        # Make sure the default factory is called for each new instance.
1227        C().x
1228        self.assertEqual(factory.call_count, 1)
1229        C().x
1230        self.assertEqual(factory.call_count, 2)
1231
1232    def test_default_factory_not_called_if_value_given(self):
1233        # We need a factory that we can test if it's been called.
1234        factory = Mock()
1235
1236        @dataclass
1237        class C:
1238            x: int = field(default_factory=factory)
1239
1240        # Make sure that if a field has a default factory function,
1241        #  it's not called if a value is specified.
1242        C().x
1243        self.assertEqual(factory.call_count, 1)
1244        self.assertEqual(C(10).x, 10)
1245        self.assertEqual(factory.call_count, 1)
1246        C().x
1247        self.assertEqual(factory.call_count, 2)
1248
1249    def test_default_factory_derived(self):
1250        # See bpo-32896.
1251        @dataclass
1252        class Foo:
1253            x: dict = field(default_factory=dict)
1254
1255        @dataclass
1256        class Bar(Foo):
1257            y: int = 1
1258
1259        self.assertEqual(Foo().x, {})
1260        self.assertEqual(Bar().x, {})
1261        self.assertEqual(Bar().y, 1)
1262
1263        @dataclass
1264        class Baz(Foo):
1265            pass
1266        self.assertEqual(Baz().x, {})
1267
1268    def test_intermediate_non_dataclass(self):
1269        # Test that an intermediate class that defines
1270        #  annotations does not define fields.
1271
1272        @dataclass
1273        class A:
1274            x: int
1275
1276        class B(A):
1277            y: int
1278
1279        @dataclass
1280        class C(B):
1281            z: int
1282
1283        c = C(1, 3)
1284        self.assertEqual((c.x, c.z), (1, 3))
1285
1286        # .y was not initialized.
1287        with self.assertRaisesRegex(AttributeError,
1288                                    'object has no attribute'):
1289            c.y
1290
1291        # And if we again derive a non-dataclass, no fields are added.
1292        class D(C):
1293            t: int
1294        d = D(4, 5)
1295        self.assertEqual((d.x, d.z), (4, 5))
1296
1297    def test_classvar_default_factory(self):
1298        # It's an error for a ClassVar to have a factory function.
1299        with self.assertRaisesRegex(TypeError,
1300                                    'cannot have a default factory'):
1301            @dataclass
1302            class C:
1303                x: ClassVar[int] = field(default_factory=int)
1304
1305    def test_is_dataclass(self):
1306        class NotDataClass:
1307            pass
1308
1309        self.assertFalse(is_dataclass(0))
1310        self.assertFalse(is_dataclass(int))
1311        self.assertFalse(is_dataclass(NotDataClass))
1312        self.assertFalse(is_dataclass(NotDataClass()))
1313
1314        @dataclass
1315        class C:
1316            x: int
1317
1318        @dataclass
1319        class D:
1320            d: C
1321            e: int
1322
1323        c = C(10)
1324        d = D(c, 4)
1325
1326        self.assertTrue(is_dataclass(C))
1327        self.assertTrue(is_dataclass(c))
1328        self.assertFalse(is_dataclass(c.x))
1329        self.assertTrue(is_dataclass(d.d))
1330        self.assertFalse(is_dataclass(d.e))
1331
1332    def test_is_dataclass_when_getattr_always_returns(self):
1333        # See bpo-37868.
1334        class A:
1335            def __getattr__(self, key):
1336                return 0
1337        self.assertFalse(is_dataclass(A))
1338        a = A()
1339
1340        # Also test for an instance attribute.
1341        class B:
1342            pass
1343        b = B()
1344        b.__dataclass_fields__ = []
1345
1346        for obj in a, b:
1347            with self.subTest(obj=obj):
1348                self.assertFalse(is_dataclass(obj))
1349
1350                # Indirect tests for _is_dataclass_instance().
1351                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1352                    asdict(obj)
1353                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1354                    astuple(obj)
1355                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1356                    replace(obj, x=0)
1357
1358    def test_is_dataclass_genericalias(self):
1359        @dataclass
1360        class A(types.GenericAlias):
1361            origin: type
1362            args: type
1363        self.assertTrue(is_dataclass(A))
1364        a = A(list, int)
1365        self.assertTrue(is_dataclass(type(a)))
1366        self.assertTrue(is_dataclass(a))
1367
1368
1369    def test_helper_fields_with_class_instance(self):
1370        # Check that we can call fields() on either a class or instance,
1371        #  and get back the same thing.
1372        @dataclass
1373        class C:
1374            x: int
1375            y: float
1376
1377        self.assertEqual(fields(C), fields(C(0, 0.0)))
1378
1379    def test_helper_fields_exception(self):
1380        # Check that TypeError is raised if not passed a dataclass or
1381        #  instance.
1382        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1383            fields(0)
1384
1385        class C: pass
1386        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1387            fields(C)
1388        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1389            fields(C())
1390
1391    def test_helper_asdict(self):
1392        # Basic tests for asdict(), it should return a new dictionary.
1393        @dataclass
1394        class C:
1395            x: int
1396            y: int
1397        c = C(1, 2)
1398
1399        self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1400        self.assertEqual(asdict(c), asdict(c))
1401        self.assertIsNot(asdict(c), asdict(c))
1402        c.x = 42
1403        self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1404        self.assertIs(type(asdict(c)), dict)
1405
1406    def test_helper_asdict_raises_on_classes(self):
1407        # asdict() should raise on a class object.
1408        @dataclass
1409        class C:
1410            x: int
1411            y: int
1412        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1413            asdict(C)
1414        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1415            asdict(int)
1416
1417    def test_helper_asdict_copy_values(self):
1418        @dataclass
1419        class C:
1420            x: int
1421            y: List[int] = field(default_factory=list)
1422        initial = []
1423        c = C(1, initial)
1424        d = asdict(c)
1425        self.assertEqual(d['y'], initial)
1426        self.assertIsNot(d['y'], initial)
1427        c = C(1)
1428        d = asdict(c)
1429        d['y'].append(1)
1430        self.assertEqual(c.y, [])
1431
1432    def test_helper_asdict_nested(self):
1433        @dataclass
1434        class UserId:
1435            token: int
1436            group: int
1437        @dataclass
1438        class User:
1439            name: str
1440            id: UserId
1441        u = User('Joe', UserId(123, 1))
1442        d = asdict(u)
1443        self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1444        self.assertIsNot(asdict(u), asdict(u))
1445        u.id.group = 2
1446        self.assertEqual(asdict(u), {'name': 'Joe',
1447                                     'id': {'token': 123, 'group': 2}})
1448
1449    def test_helper_asdict_builtin_containers(self):
1450        @dataclass
1451        class User:
1452            name: str
1453            id: int
1454        @dataclass
1455        class GroupList:
1456            id: int
1457            users: List[User]
1458        @dataclass
1459        class GroupTuple:
1460            id: int
1461            users: Tuple[User, ...]
1462        @dataclass
1463        class GroupDict:
1464            id: int
1465            users: Dict[str, User]
1466        a = User('Alice', 1)
1467        b = User('Bob', 2)
1468        gl = GroupList(0, [a, b])
1469        gt = GroupTuple(0, (a, b))
1470        gd = GroupDict(0, {'first': a, 'second': b})
1471        self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1472                                                         {'name': 'Bob', 'id': 2}]})
1473        self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1474                                                         {'name': 'Bob', 'id': 2})})
1475        self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1476                                                         'second': {'name': 'Bob', 'id': 2}}})
1477
1478    def test_helper_asdict_builtin_object_containers(self):
1479        @dataclass
1480        class Child:
1481            d: object
1482
1483        @dataclass
1484        class Parent:
1485            child: Child
1486
1487        self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1488        self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1489
1490    def test_helper_asdict_factory(self):
1491        @dataclass
1492        class C:
1493            x: int
1494            y: int
1495        c = C(1, 2)
1496        d = asdict(c, dict_factory=OrderedDict)
1497        self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1498        self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1499        c.x = 42
1500        d = asdict(c, dict_factory=OrderedDict)
1501        self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1502        self.assertIs(type(d), OrderedDict)
1503
1504    def test_helper_asdict_namedtuple(self):
1505        T = namedtuple('T', 'a b c')
1506        @dataclass
1507        class C:
1508            x: str
1509            y: T
1510        c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1511
1512        d = asdict(c)
1513        self.assertEqual(d, {'x': 'outer',
1514                             'y': T(1,
1515                                    {'x': 'inner',
1516                                     'y': T(11, 12, 13)},
1517                                    2),
1518                             }
1519                         )
1520
1521        # Now with a dict_factory.  OrderedDict is convenient, but
1522        # since it compares to dicts, we also need to have separate
1523        # assertIs tests.
1524        d = asdict(c, dict_factory=OrderedDict)
1525        self.assertEqual(d, {'x': 'outer',
1526                             'y': T(1,
1527                                    {'x': 'inner',
1528                                     'y': T(11, 12, 13)},
1529                                    2),
1530                             }
1531                         )
1532
1533        # Make sure that the returned dicts are actually OrderedDicts.
1534        self.assertIs(type(d), OrderedDict)
1535        self.assertIs(type(d['y'][1]), OrderedDict)
1536
1537    def test_helper_asdict_namedtuple_key(self):
1538        # Ensure that a field that contains a dict which has a
1539        # namedtuple as a key works with asdict().
1540
1541        @dataclass
1542        class C:
1543            f: dict
1544        T = namedtuple('T', 'a')
1545
1546        c = C({T('an a'): 0})
1547
1548        self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1549
1550    def test_helper_asdict_namedtuple_derived(self):
1551        class T(namedtuple('Tbase', 'a')):
1552            def my_a(self):
1553                return self.a
1554
1555        @dataclass
1556        class C:
1557            f: T
1558
1559        t = T(6)
1560        c = C(t)
1561
1562        d = asdict(c)
1563        self.assertEqual(d, {'f': T(a=6)})
1564        # Make sure that t has been copied, not used directly.
1565        self.assertIsNot(d['f'], t)
1566        self.assertEqual(d['f'].my_a(), 6)
1567
1568    def test_helper_astuple(self):
1569        # Basic tests for astuple(), it should return a new tuple.
1570        @dataclass
1571        class C:
1572            x: int
1573            y: int = 0
1574        c = C(1)
1575
1576        self.assertEqual(astuple(c), (1, 0))
1577        self.assertEqual(astuple(c), astuple(c))
1578        self.assertIsNot(astuple(c), astuple(c))
1579        c.y = 42
1580        self.assertEqual(astuple(c), (1, 42))
1581        self.assertIs(type(astuple(c)), tuple)
1582
1583    def test_helper_astuple_raises_on_classes(self):
1584        # astuple() should raise on a class object.
1585        @dataclass
1586        class C:
1587            x: int
1588            y: int
1589        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1590            astuple(C)
1591        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1592            astuple(int)
1593
1594    def test_helper_astuple_copy_values(self):
1595        @dataclass
1596        class C:
1597            x: int
1598            y: List[int] = field(default_factory=list)
1599        initial = []
1600        c = C(1, initial)
1601        t = astuple(c)
1602        self.assertEqual(t[1], initial)
1603        self.assertIsNot(t[1], initial)
1604        c = C(1)
1605        t = astuple(c)
1606        t[1].append(1)
1607        self.assertEqual(c.y, [])
1608
1609    def test_helper_astuple_nested(self):
1610        @dataclass
1611        class UserId:
1612            token: int
1613            group: int
1614        @dataclass
1615        class User:
1616            name: str
1617            id: UserId
1618        u = User('Joe', UserId(123, 1))
1619        t = astuple(u)
1620        self.assertEqual(t, ('Joe', (123, 1)))
1621        self.assertIsNot(astuple(u), astuple(u))
1622        u.id.group = 2
1623        self.assertEqual(astuple(u), ('Joe', (123, 2)))
1624
1625    def test_helper_astuple_builtin_containers(self):
1626        @dataclass
1627        class User:
1628            name: str
1629            id: int
1630        @dataclass
1631        class GroupList:
1632            id: int
1633            users: List[User]
1634        @dataclass
1635        class GroupTuple:
1636            id: int
1637            users: Tuple[User, ...]
1638        @dataclass
1639        class GroupDict:
1640            id: int
1641            users: Dict[str, User]
1642        a = User('Alice', 1)
1643        b = User('Bob', 2)
1644        gl = GroupList(0, [a, b])
1645        gt = GroupTuple(0, (a, b))
1646        gd = GroupDict(0, {'first': a, 'second': b})
1647        self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1648        self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1649        self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1650
1651    def test_helper_astuple_builtin_object_containers(self):
1652        @dataclass
1653        class Child:
1654            d: object
1655
1656        @dataclass
1657        class Parent:
1658            child: Child
1659
1660        self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1661        self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1662
1663    def test_helper_astuple_factory(self):
1664        @dataclass
1665        class C:
1666            x: int
1667            y: int
1668        NT = namedtuple('NT', 'x y')
1669        def nt(lst):
1670            return NT(*lst)
1671        c = C(1, 2)
1672        t = astuple(c, tuple_factory=nt)
1673        self.assertEqual(t, NT(1, 2))
1674        self.assertIsNot(t, astuple(c, tuple_factory=nt))
1675        c.x = 42
1676        t = astuple(c, tuple_factory=nt)
1677        self.assertEqual(t, NT(42, 2))
1678        self.assertIs(type(t), NT)
1679
1680    def test_helper_astuple_namedtuple(self):
1681        T = namedtuple('T', 'a b c')
1682        @dataclass
1683        class C:
1684            x: str
1685            y: T
1686        c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1687
1688        t = astuple(c)
1689        self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1690
1691        # Now, using a tuple_factory.  list is convenient here.
1692        t = astuple(c, tuple_factory=list)
1693        self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1694
1695    def test_dynamic_class_creation(self):
1696        cls_dict = {'__annotations__': {'x': int, 'y': int},
1697                    }
1698
1699        # Create the class.
1700        cls = type('C', (), cls_dict)
1701
1702        # Make it a dataclass.
1703        cls1 = dataclass(cls)
1704
1705        self.assertEqual(cls1, cls)
1706        self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1707
1708    def test_dynamic_class_creation_using_field(self):
1709        cls_dict = {'__annotations__': {'x': int, 'y': int},
1710                    'y': field(default=5),
1711                    }
1712
1713        # Create the class.
1714        cls = type('C', (), cls_dict)
1715
1716        # Make it a dataclass.
1717        cls1 = dataclass(cls)
1718
1719        self.assertEqual(cls1, cls)
1720        self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1721
1722    def test_init_in_order(self):
1723        @dataclass
1724        class C:
1725            a: int
1726            b: int = field()
1727            c: list = field(default_factory=list, init=False)
1728            d: list = field(default_factory=list)
1729            e: int = field(default=4, init=False)
1730            f: int = 4
1731
1732        calls = []
1733        def setattr(self, name, value):
1734            calls.append((name, value))
1735
1736        C.__setattr__ = setattr
1737        c = C(0, 1)
1738        self.assertEqual(('a', 0), calls[0])
1739        self.assertEqual(('b', 1), calls[1])
1740        self.assertEqual(('c', []), calls[2])
1741        self.assertEqual(('d', []), calls[3])
1742        self.assertNotIn(('e', 4), calls)
1743        self.assertEqual(('f', 4), calls[4])
1744
1745    def test_items_in_dicts(self):
1746        @dataclass
1747        class C:
1748            a: int
1749            b: list = field(default_factory=list, init=False)
1750            c: list = field(default_factory=list)
1751            d: int = field(default=4, init=False)
1752            e: int = 0
1753
1754        c = C(0)
1755        # Class dict
1756        self.assertNotIn('a', C.__dict__)
1757        self.assertNotIn('b', C.__dict__)
1758        self.assertNotIn('c', C.__dict__)
1759        self.assertIn('d', C.__dict__)
1760        self.assertEqual(C.d, 4)
1761        self.assertIn('e', C.__dict__)
1762        self.assertEqual(C.e, 0)
1763        # Instance dict
1764        self.assertIn('a', c.__dict__)
1765        self.assertEqual(c.a, 0)
1766        self.assertIn('b', c.__dict__)
1767        self.assertEqual(c.b, [])
1768        self.assertIn('c', c.__dict__)
1769        self.assertEqual(c.c, [])
1770        self.assertNotIn('d', c.__dict__)
1771        self.assertIn('e', c.__dict__)
1772        self.assertEqual(c.e, 0)
1773
1774    def test_alternate_classmethod_constructor(self):
1775        # Since __post_init__ can't take params, use a classmethod
1776        #  alternate constructor.  This is mostly an example to show
1777        #  how to use this technique.
1778        @dataclass
1779        class C:
1780            x: int
1781            @classmethod
1782            def from_file(cls, filename):
1783                # In a real example, create a new instance
1784                #  and populate 'x' from contents of a file.
1785                value_in_file = 20
1786                return cls(value_in_file)
1787
1788        self.assertEqual(C.from_file('filename').x, 20)
1789
1790    def test_field_metadata_default(self):
1791        # Make sure the default metadata is read-only and of
1792        #  zero length.
1793        @dataclass
1794        class C:
1795            i: int
1796
1797        self.assertFalse(fields(C)[0].metadata)
1798        self.assertEqual(len(fields(C)[0].metadata), 0)
1799        with self.assertRaisesRegex(TypeError,
1800                                    'does not support item assignment'):
1801            fields(C)[0].metadata['test'] = 3
1802
1803    def test_field_metadata_mapping(self):
1804        # Make sure only a mapping can be passed as metadata
1805        #  zero length.
1806        with self.assertRaises(TypeError):
1807            @dataclass
1808            class C:
1809                i: int = field(metadata=0)
1810
1811        # Make sure an empty dict works.
1812        d = {}
1813        @dataclass
1814        class C:
1815            i: int = field(metadata=d)
1816        self.assertFalse(fields(C)[0].metadata)
1817        self.assertEqual(len(fields(C)[0].metadata), 0)
1818        # Update should work (see bpo-35960).
1819        d['foo'] = 1
1820        self.assertEqual(len(fields(C)[0].metadata), 1)
1821        self.assertEqual(fields(C)[0].metadata['foo'], 1)
1822        with self.assertRaisesRegex(TypeError,
1823                                    'does not support item assignment'):
1824            fields(C)[0].metadata['test'] = 3
1825
1826        # Make sure a non-empty dict works.
1827        d = {'test': 10, 'bar': '42', 3: 'three'}
1828        @dataclass
1829        class C:
1830            i: int = field(metadata=d)
1831        self.assertEqual(len(fields(C)[0].metadata), 3)
1832        self.assertEqual(fields(C)[0].metadata['test'], 10)
1833        self.assertEqual(fields(C)[0].metadata['bar'], '42')
1834        self.assertEqual(fields(C)[0].metadata[3], 'three')
1835        # Update should work.
1836        d['foo'] = 1
1837        self.assertEqual(len(fields(C)[0].metadata), 4)
1838        self.assertEqual(fields(C)[0].metadata['foo'], 1)
1839        with self.assertRaises(KeyError):
1840            # Non-existent key.
1841            fields(C)[0].metadata['baz']
1842        with self.assertRaisesRegex(TypeError,
1843                                    'does not support item assignment'):
1844            fields(C)[0].metadata['test'] = 3
1845
1846    def test_field_metadata_custom_mapping(self):
1847        # Try a custom mapping.
1848        class SimpleNameSpace:
1849            def __init__(self, **kw):
1850                self.__dict__.update(kw)
1851
1852            def __getitem__(self, item):
1853                if item == 'xyzzy':
1854                    return 'plugh'
1855                return getattr(self, item)
1856
1857            def __len__(self):
1858                return self.__dict__.__len__()
1859
1860        @dataclass
1861        class C:
1862            i: int = field(metadata=SimpleNameSpace(a=10))
1863
1864        self.assertEqual(len(fields(C)[0].metadata), 1)
1865        self.assertEqual(fields(C)[0].metadata['a'], 10)
1866        with self.assertRaises(AttributeError):
1867            fields(C)[0].metadata['b']
1868        # Make sure we're still talking to our custom mapping.
1869        self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1870
1871    def test_generic_dataclasses(self):
1872        T = TypeVar('T')
1873
1874        @dataclass
1875        class LabeledBox(Generic[T]):
1876            content: T
1877            label: str = '<unknown>'
1878
1879        box = LabeledBox(42)
1880        self.assertEqual(box.content, 42)
1881        self.assertEqual(box.label, '<unknown>')
1882
1883        # Subscripting the resulting class should work, etc.
1884        Alias = List[LabeledBox[int]]
1885
1886    def test_generic_extending(self):
1887        S = TypeVar('S')
1888        T = TypeVar('T')
1889
1890        @dataclass
1891        class Base(Generic[T, S]):
1892            x: T
1893            y: S
1894
1895        @dataclass
1896        class DataDerived(Base[int, T]):
1897            new_field: str
1898        Alias = DataDerived[str]
1899        c = Alias(0, 'test1', 'test2')
1900        self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1901
1902        class NonDataDerived(Base[int, T]):
1903            def new_method(self):
1904                return self.y
1905        Alias = NonDataDerived[float]
1906        c = Alias(10, 1.0)
1907        self.assertEqual(c.new_method(), 1.0)
1908
1909    def test_generic_dynamic(self):
1910        T = TypeVar('T')
1911
1912        @dataclass
1913        class Parent(Generic[T]):
1914            x: T
1915        Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1916                               bases=(Parent[int], Generic[T]), namespace={'other': 42})
1917        self.assertIs(Child[int](1, 2).z, None)
1918        self.assertEqual(Child[int](1, 2, 3).z, 3)
1919        self.assertEqual(Child[int](1, 2, 3).other, 42)
1920        # Check that type aliases work correctly.
1921        Alias = Child[T]
1922        self.assertEqual(Alias[int](1, 2).x, 1)
1923        # Check MRO resolution.
1924        self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1925
1926    def test_dataclasses_pickleable(self):
1927        global P, Q, R
1928        @dataclass
1929        class P:
1930            x: int
1931            y: int = 0
1932        @dataclass
1933        class Q:
1934            x: int
1935            y: int = field(default=0, init=False)
1936        @dataclass
1937        class R:
1938            x: int
1939            y: List[int] = field(default_factory=list)
1940        q = Q(1)
1941        q.y = 2
1942        samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1943        for sample in samples:
1944            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1945                with self.subTest(sample=sample, proto=proto):
1946                    new_sample = pickle.loads(pickle.dumps(sample, proto))
1947                    self.assertEqual(sample.x, new_sample.x)
1948                    self.assertEqual(sample.y, new_sample.y)
1949                    self.assertIsNot(sample, new_sample)
1950                    new_sample.x = 42
1951                    another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1952                    self.assertEqual(new_sample.x, another_new_sample.x)
1953                    self.assertEqual(sample.y, another_new_sample.y)
1954
1955    def test_dataclasses_qualnames(self):
1956        @dataclass(order=True, unsafe_hash=True, frozen=True)
1957        class A:
1958            x: int
1959            y: int
1960
1961        self.assertEqual(A.__init__.__name__, "__init__")
1962        for function in (
1963            '__eq__',
1964            '__lt__',
1965            '__le__',
1966            '__gt__',
1967            '__ge__',
1968            '__hash__',
1969            '__init__',
1970            '__repr__',
1971            '__setattr__',
1972            '__delattr__',
1973        ):
1974            self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}")
1975
1976        with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"):
1977            A()
1978
1979
1980class TestFieldNoAnnotation(unittest.TestCase):
1981    def test_field_without_annotation(self):
1982        with self.assertRaisesRegex(TypeError,
1983                                    "'f' is a field but has no type annotation"):
1984            @dataclass
1985            class C:
1986                f = field()
1987
1988    def test_field_without_annotation_but_annotation_in_base(self):
1989        @dataclass
1990        class B:
1991            f: int
1992
1993        with self.assertRaisesRegex(TypeError,
1994                                    "'f' is a field but has no type annotation"):
1995            # This is still an error: make sure we don't pick up the
1996            #  type annotation in the base class.
1997            @dataclass
1998            class C(B):
1999                f = field()
2000
2001    def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
2002        # Same test, but with the base class not a dataclass.
2003        class B:
2004            f: int
2005
2006        with self.assertRaisesRegex(TypeError,
2007                                    "'f' is a field but has no type annotation"):
2008            # This is still an error: make sure we don't pick up the
2009            #  type annotation in the base class.
2010            @dataclass
2011            class C(B):
2012                f = field()
2013
2014
2015class TestDocString(unittest.TestCase):
2016    def assertDocStrEqual(self, a, b):
2017        # Because 3.6 and 3.7 differ in how inspect.signature work
2018        #  (see bpo #32108), for the time being just compare them with
2019        #  whitespace stripped.
2020        self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
2021
2022    def test_existing_docstring_not_overridden(self):
2023        @dataclass
2024        class C:
2025            """Lorem ipsum"""
2026            x: int
2027
2028        self.assertEqual(C.__doc__, "Lorem ipsum")
2029
2030    def test_docstring_no_fields(self):
2031        @dataclass
2032        class C:
2033            pass
2034
2035        self.assertDocStrEqual(C.__doc__, "C()")
2036
2037    def test_docstring_one_field(self):
2038        @dataclass
2039        class C:
2040            x: int
2041
2042        self.assertDocStrEqual(C.__doc__, "C(x:int)")
2043
2044    def test_docstring_two_fields(self):
2045        @dataclass
2046        class C:
2047            x: int
2048            y: int
2049
2050        self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2051
2052    def test_docstring_three_fields(self):
2053        @dataclass
2054        class C:
2055            x: int
2056            y: int
2057            z: str
2058
2059        self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2060
2061    def test_docstring_one_field_with_default(self):
2062        @dataclass
2063        class C:
2064            x: int = 3
2065
2066        self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2067
2068    def test_docstring_one_field_with_default_none(self):
2069        @dataclass
2070        class C:
2071            x: Union[int, type(None)] = None
2072
2073        self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
2074
2075    def test_docstring_list_field(self):
2076        @dataclass
2077        class C:
2078            x: List[int]
2079
2080        self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2081
2082    def test_docstring_list_field_with_default_factory(self):
2083        @dataclass
2084        class C:
2085            x: List[int] = field(default_factory=list)
2086
2087        self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2088
2089    def test_docstring_deque_field(self):
2090        @dataclass
2091        class C:
2092            x: deque
2093
2094        self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2095
2096    def test_docstring_deque_field_with_default_factory(self):
2097        @dataclass
2098        class C:
2099            x: deque = field(default_factory=deque)
2100
2101        self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2102
2103
2104class TestInit(unittest.TestCase):
2105    def test_base_has_init(self):
2106        class B:
2107            def __init__(self):
2108                self.z = 100
2109                pass
2110
2111        # Make sure that declaring this class doesn't raise an error.
2112        #  The issue is that we can't override __init__ in our class,
2113        #  but it should be okay to add __init__ to us if our base has
2114        #  an __init__.
2115        @dataclass
2116        class C(B):
2117            x: int = 0
2118        c = C(10)
2119        self.assertEqual(c.x, 10)
2120        self.assertNotIn('z', vars(c))
2121
2122        # Make sure that if we don't add an init, the base __init__
2123        #  gets called.
2124        @dataclass(init=False)
2125        class C(B):
2126            x: int = 10
2127        c = C()
2128        self.assertEqual(c.x, 10)
2129        self.assertEqual(c.z, 100)
2130
2131    def test_no_init(self):
2132        dataclass(init=False)
2133        class C:
2134            i: int = 0
2135        self.assertEqual(C().i, 0)
2136
2137        dataclass(init=False)
2138        class C:
2139            i: int = 2
2140            def __init__(self):
2141                self.i = 3
2142        self.assertEqual(C().i, 3)
2143
2144    def test_overwriting_init(self):
2145        # If the class has __init__, use it no matter the value of
2146        #  init=.
2147
2148        @dataclass
2149        class C:
2150            x: int
2151            def __init__(self, x):
2152                self.x = 2 * x
2153        self.assertEqual(C(3).x, 6)
2154
2155        @dataclass(init=True)
2156        class C:
2157            x: int
2158            def __init__(self, x):
2159                self.x = 2 * x
2160        self.assertEqual(C(4).x, 8)
2161
2162        @dataclass(init=False)
2163        class C:
2164            x: int
2165            def __init__(self, x):
2166                self.x = 2 * x
2167        self.assertEqual(C(5).x, 10)
2168
2169    def test_inherit_from_protocol(self):
2170        # Dataclasses inheriting from protocol should preserve their own `__init__`.
2171        # See bpo-45081.
2172
2173        class P(Protocol):
2174            a: int
2175
2176        @dataclass
2177        class C(P):
2178            a: int
2179
2180        self.assertEqual(C(5).a, 5)
2181
2182        @dataclass
2183        class D(P):
2184            def __init__(self, a):
2185                self.a = a * 2
2186
2187        self.assertEqual(D(5).a, 10)
2188
2189
2190class TestRepr(unittest.TestCase):
2191    def test_repr(self):
2192        @dataclass
2193        class B:
2194            x: int
2195
2196        @dataclass
2197        class C(B):
2198            y: int = 10
2199
2200        o = C(4)
2201        self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2202
2203        @dataclass
2204        class D(C):
2205            x: int = 20
2206        self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2207
2208        @dataclass
2209        class C:
2210            @dataclass
2211            class D:
2212                i: int
2213            @dataclass
2214            class E:
2215                pass
2216        self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2217        self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2218
2219    def test_no_repr(self):
2220        # Test a class with no __repr__ and repr=False.
2221        @dataclass(repr=False)
2222        class C:
2223            x: int
2224        self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
2225                      repr(C(3)))
2226
2227        # Test a class with a __repr__ and repr=False.
2228        @dataclass(repr=False)
2229        class C:
2230            x: int
2231            def __repr__(self):
2232                return 'C-class'
2233        self.assertEqual(repr(C(3)), 'C-class')
2234
2235    def test_overwriting_repr(self):
2236        # If the class has __repr__, use it no matter the value of
2237        #  repr=.
2238
2239        @dataclass
2240        class C:
2241            x: int
2242            def __repr__(self):
2243                return 'x'
2244        self.assertEqual(repr(C(0)), 'x')
2245
2246        @dataclass(repr=True)
2247        class C:
2248            x: int
2249            def __repr__(self):
2250                return 'x'
2251        self.assertEqual(repr(C(0)), 'x')
2252
2253        @dataclass(repr=False)
2254        class C:
2255            x: int
2256            def __repr__(self):
2257                return 'x'
2258        self.assertEqual(repr(C(0)), 'x')
2259
2260
2261class TestEq(unittest.TestCase):
2262    def test_no_eq(self):
2263        # Test a class with no __eq__ and eq=False.
2264        @dataclass(eq=False)
2265        class C:
2266            x: int
2267        self.assertNotEqual(C(0), C(0))
2268        c = C(3)
2269        self.assertEqual(c, c)
2270
2271        # Test a class with an __eq__ and eq=False.
2272        @dataclass(eq=False)
2273        class C:
2274            x: int
2275            def __eq__(self, other):
2276                return other == 10
2277        self.assertEqual(C(3), 10)
2278
2279    def test_overwriting_eq(self):
2280        # If the class has __eq__, use it no matter the value of
2281        #  eq=.
2282
2283        @dataclass
2284        class C:
2285            x: int
2286            def __eq__(self, other):
2287                return other == 3
2288        self.assertEqual(C(1), 3)
2289        self.assertNotEqual(C(1), 1)
2290
2291        @dataclass(eq=True)
2292        class C:
2293            x: int
2294            def __eq__(self, other):
2295                return other == 4
2296        self.assertEqual(C(1), 4)
2297        self.assertNotEqual(C(1), 1)
2298
2299        @dataclass(eq=False)
2300        class C:
2301            x: int
2302            def __eq__(self, other):
2303                return other == 5
2304        self.assertEqual(C(1), 5)
2305        self.assertNotEqual(C(1), 1)
2306
2307
2308class TestOrdering(unittest.TestCase):
2309    def test_functools_total_ordering(self):
2310        # Test that functools.total_ordering works with this class.
2311        @total_ordering
2312        @dataclass
2313        class C:
2314            x: int
2315            def __lt__(self, other):
2316                # Perform the test "backward", just to make
2317                #  sure this is being called.
2318                return self.x >= other
2319
2320        self.assertLess(C(0), -1)
2321        self.assertLessEqual(C(0), -1)
2322        self.assertGreater(C(0), 1)
2323        self.assertGreaterEqual(C(0), 1)
2324
2325    def test_no_order(self):
2326        # Test that no ordering functions are added by default.
2327        @dataclass(order=False)
2328        class C:
2329            x: int
2330        # Make sure no order methods are added.
2331        self.assertNotIn('__le__', C.__dict__)
2332        self.assertNotIn('__lt__', C.__dict__)
2333        self.assertNotIn('__ge__', C.__dict__)
2334        self.assertNotIn('__gt__', C.__dict__)
2335
2336        # Test that __lt__ is still called
2337        @dataclass(order=False)
2338        class C:
2339            x: int
2340            def __lt__(self, other):
2341                return False
2342        # Make sure other methods aren't added.
2343        self.assertNotIn('__le__', C.__dict__)
2344        self.assertNotIn('__ge__', C.__dict__)
2345        self.assertNotIn('__gt__', C.__dict__)
2346
2347    def test_overwriting_order(self):
2348        with self.assertRaisesRegex(TypeError,
2349                                    'Cannot overwrite attribute __lt__'
2350                                    '.*using functools.total_ordering'):
2351            @dataclass(order=True)
2352            class C:
2353                x: int
2354                def __lt__(self):
2355                    pass
2356
2357        with self.assertRaisesRegex(TypeError,
2358                                    'Cannot overwrite attribute __le__'
2359                                    '.*using functools.total_ordering'):
2360            @dataclass(order=True)
2361            class C:
2362                x: int
2363                def __le__(self):
2364                    pass
2365
2366        with self.assertRaisesRegex(TypeError,
2367                                    'Cannot overwrite attribute __gt__'
2368                                    '.*using functools.total_ordering'):
2369            @dataclass(order=True)
2370            class C:
2371                x: int
2372                def __gt__(self):
2373                    pass
2374
2375        with self.assertRaisesRegex(TypeError,
2376                                    'Cannot overwrite attribute __ge__'
2377                                    '.*using functools.total_ordering'):
2378            @dataclass(order=True)
2379            class C:
2380                x: int
2381                def __ge__(self):
2382                    pass
2383
2384class TestHash(unittest.TestCase):
2385    def test_unsafe_hash(self):
2386        @dataclass(unsafe_hash=True)
2387        class C:
2388            x: int
2389            y: str
2390        self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2391
2392    def test_hash_rules(self):
2393        def non_bool(value):
2394            # Map to something else that's True, but not a bool.
2395            if value is None:
2396                return None
2397            if value:
2398                return (3,)
2399            return 0
2400
2401        def test(case, unsafe_hash, eq, frozen, with_hash, result):
2402            with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2403                              frozen=frozen):
2404                if result != 'exception':
2405                    if with_hash:
2406                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2407                        class C:
2408                            def __hash__(self):
2409                                return 0
2410                    else:
2411                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2412                        class C:
2413                            pass
2414
2415                # See if the result matches what's expected.
2416                if result == 'fn':
2417                    # __hash__ contains the function we generated.
2418                    self.assertIn('__hash__', C.__dict__)
2419                    self.assertIsNotNone(C.__dict__['__hash__'])
2420
2421                elif result == '':
2422                    # __hash__ is not present in our class.
2423                    if not with_hash:
2424                        self.assertNotIn('__hash__', C.__dict__)
2425
2426                elif result == 'none':
2427                    # __hash__ is set to None.
2428                    self.assertIn('__hash__', C.__dict__)
2429                    self.assertIsNone(C.__dict__['__hash__'])
2430
2431                elif result == 'exception':
2432                    # Creating the class should cause an exception.
2433                    #  This only happens with with_hash==True.
2434                    assert(with_hash)
2435                    with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2436                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2437                        class C:
2438                            def __hash__(self):
2439                                return 0
2440
2441                else:
2442                    assert False, f'unknown result {result!r}'
2443
2444        # There are 8 cases of:
2445        #  unsafe_hash=True/False
2446        #  eq=True/False
2447        #  frozen=True/False
2448        # And for each of these, a different result if
2449        #  __hash__ is defined or not.
2450        for case, (unsafe_hash,  eq,    frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2451                  (False,        False, False,  '',                  ''),
2452                  (False,        False, True,   '',                  ''),
2453                  (False,        True,  False,  'none',              ''),
2454                  (False,        True,  True,   'fn',                ''),
2455                  (True,         False, False,  'fn',                'exception'),
2456                  (True,         False, True,   'fn',                'exception'),
2457                  (True,         True,  False,  'fn',                'exception'),
2458                  (True,         True,  True,   'fn',                'exception'),
2459                  ], 1):
2460            test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2461            test(case, unsafe_hash, eq, frozen, True,  res_defined_hash)
2462
2463            # Test non-bool truth values, too.  This is just to
2464            #  make sure the data-driven table in the decorator
2465            #  handles non-bool values.
2466            test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2467            test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True,  res_defined_hash)
2468
2469
2470    def test_eq_only(self):
2471        # If a class defines __eq__, __hash__ is automatically added
2472        #  and set to None.  This is normal Python behavior, not
2473        #  related to dataclasses.  Make sure we don't interfere with
2474        #  that (see bpo=32546).
2475
2476        @dataclass
2477        class C:
2478            i: int
2479            def __eq__(self, other):
2480                return self.i == other.i
2481        self.assertEqual(C(1), C(1))
2482        self.assertNotEqual(C(1), C(4))
2483
2484        # And make sure things work in this case if we specify
2485        #  unsafe_hash=True.
2486        @dataclass(unsafe_hash=True)
2487        class C:
2488            i: int
2489            def __eq__(self, other):
2490                return self.i == other.i
2491        self.assertEqual(C(1), C(1.0))
2492        self.assertEqual(hash(C(1)), hash(C(1.0)))
2493
2494        # And check that the classes __eq__ is being used, despite
2495        #  specifying eq=True.
2496        @dataclass(unsafe_hash=True, eq=True)
2497        class C:
2498            i: int
2499            def __eq__(self, other):
2500                return self.i == 3 and self.i == other.i
2501        self.assertEqual(C(3), C(3))
2502        self.assertNotEqual(C(1), C(1))
2503        self.assertEqual(hash(C(1)), hash(C(1.0)))
2504
2505    def test_0_field_hash(self):
2506        @dataclass(frozen=True)
2507        class C:
2508            pass
2509        self.assertEqual(hash(C()), hash(()))
2510
2511        @dataclass(unsafe_hash=True)
2512        class C:
2513            pass
2514        self.assertEqual(hash(C()), hash(()))
2515
2516    def test_1_field_hash(self):
2517        @dataclass(frozen=True)
2518        class C:
2519            x: int
2520        self.assertEqual(hash(C(4)), hash((4,)))
2521        self.assertEqual(hash(C(42)), hash((42,)))
2522
2523        @dataclass(unsafe_hash=True)
2524        class C:
2525            x: int
2526        self.assertEqual(hash(C(4)), hash((4,)))
2527        self.assertEqual(hash(C(42)), hash((42,)))
2528
2529    def test_hash_no_args(self):
2530        # Test dataclasses with no hash= argument.  This exists to
2531        #  make sure that if the @dataclass parameter name is changed
2532        #  or the non-default hashing behavior changes, the default
2533        #  hashability keeps working the same way.
2534
2535        class Base:
2536            def __hash__(self):
2537                return 301
2538
2539        # If frozen or eq is None, then use the default value (do not
2540        #  specify any value in the decorator).
2541        for frozen, eq,    base,   expected       in [
2542            (None,  None,  object, 'unhashable'),
2543            (None,  None,  Base,   'unhashable'),
2544            (None,  False, object, 'object'),
2545            (None,  False, Base,   'base'),
2546            (None,  True,  object, 'unhashable'),
2547            (None,  True,  Base,   'unhashable'),
2548            (False, None,  object, 'unhashable'),
2549            (False, None,  Base,   'unhashable'),
2550            (False, False, object, 'object'),
2551            (False, False, Base,   'base'),
2552            (False, True,  object, 'unhashable'),
2553            (False, True,  Base,   'unhashable'),
2554            (True,  None,  object, 'tuple'),
2555            (True,  None,  Base,   'tuple'),
2556            (True,  False, object, 'object'),
2557            (True,  False, Base,   'base'),
2558            (True,  True,  object, 'tuple'),
2559            (True,  True,  Base,   'tuple'),
2560            ]:
2561
2562            with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2563                # First, create the class.
2564                if frozen is None and eq is None:
2565                    @dataclass
2566                    class C(base):
2567                        i: int
2568                elif frozen is None:
2569                    @dataclass(eq=eq)
2570                    class C(base):
2571                        i: int
2572                elif eq is None:
2573                    @dataclass(frozen=frozen)
2574                    class C(base):
2575                        i: int
2576                else:
2577                    @dataclass(frozen=frozen, eq=eq)
2578                    class C(base):
2579                        i: int
2580
2581                # Now, make sure it hashes as expected.
2582                if expected == 'unhashable':
2583                    c = C(10)
2584                    with self.assertRaisesRegex(TypeError, 'unhashable type'):
2585                        hash(c)
2586
2587                elif expected == 'base':
2588                    self.assertEqual(hash(C(10)), 301)
2589
2590                elif expected == 'object':
2591                    # I'm not sure what test to use here.  object's
2592                    #  hash isn't based on id(), so calling hash()
2593                    #  won't tell us much.  So, just check the
2594                    #  function used is object's.
2595                    self.assertIs(C.__hash__, object.__hash__)
2596
2597                elif expected == 'tuple':
2598                    self.assertEqual(hash(C(42)), hash((42,)))
2599
2600                else:
2601                    assert False, f'unknown value for expected={expected!r}'
2602
2603
2604class TestFrozen(unittest.TestCase):
2605    def test_frozen(self):
2606        @dataclass(frozen=True)
2607        class C:
2608            i: int
2609
2610        c = C(10)
2611        self.assertEqual(c.i, 10)
2612        with self.assertRaises(FrozenInstanceError):
2613            c.i = 5
2614        self.assertEqual(c.i, 10)
2615
2616    def test_inherit(self):
2617        @dataclass(frozen=True)
2618        class C:
2619            i: int
2620
2621        @dataclass(frozen=True)
2622        class D(C):
2623            j: int
2624
2625        d = D(0, 10)
2626        with self.assertRaises(FrozenInstanceError):
2627            d.i = 5
2628        with self.assertRaises(FrozenInstanceError):
2629            d.j = 6
2630        self.assertEqual(d.i, 0)
2631        self.assertEqual(d.j, 10)
2632
2633    def test_inherit_nonfrozen_from_empty_frozen(self):
2634        @dataclass(frozen=True)
2635        class C:
2636            pass
2637
2638        with self.assertRaisesRegex(TypeError,
2639                                    'cannot inherit non-frozen dataclass from a frozen one'):
2640            @dataclass
2641            class D(C):
2642                j: int
2643
2644    def test_inherit_nonfrozen_from_empty(self):
2645        @dataclass
2646        class C:
2647            pass
2648
2649        @dataclass
2650        class D(C):
2651            j: int
2652
2653        d = D(3)
2654        self.assertEqual(d.j, 3)
2655        self.assertIsInstance(d, C)
2656
2657    # Test both ways: with an intermediate normal (non-dataclass)
2658    #  class and without an intermediate class.
2659    def test_inherit_nonfrozen_from_frozen(self):
2660        for intermediate_class in [True, False]:
2661            with self.subTest(intermediate_class=intermediate_class):
2662                @dataclass(frozen=True)
2663                class C:
2664                    i: int
2665
2666                if intermediate_class:
2667                    class I(C): pass
2668                else:
2669                    I = C
2670
2671                with self.assertRaisesRegex(TypeError,
2672                                            'cannot inherit non-frozen dataclass from a frozen one'):
2673                    @dataclass
2674                    class D(I):
2675                        pass
2676
2677    def test_inherit_frozen_from_nonfrozen(self):
2678        for intermediate_class in [True, False]:
2679            with self.subTest(intermediate_class=intermediate_class):
2680                @dataclass
2681                class C:
2682                    i: int
2683
2684                if intermediate_class:
2685                    class I(C): pass
2686                else:
2687                    I = C
2688
2689                with self.assertRaisesRegex(TypeError,
2690                                            'cannot inherit frozen dataclass from a non-frozen one'):
2691                    @dataclass(frozen=True)
2692                    class D(I):
2693                        pass
2694
2695    def test_inherit_from_normal_class(self):
2696        for intermediate_class in [True, False]:
2697            with self.subTest(intermediate_class=intermediate_class):
2698                class C:
2699                    pass
2700
2701                if intermediate_class:
2702                    class I(C): pass
2703                else:
2704                    I = C
2705
2706                @dataclass(frozen=True)
2707                class D(I):
2708                    i: int
2709
2710            d = D(10)
2711            with self.assertRaises(FrozenInstanceError):
2712                d.i = 5
2713
2714    def test_non_frozen_normal_derived(self):
2715        # See bpo-32953.
2716
2717        @dataclass(frozen=True)
2718        class D:
2719            x: int
2720            y: int = 10
2721
2722        class S(D):
2723            pass
2724
2725        s = S(3)
2726        self.assertEqual(s.x, 3)
2727        self.assertEqual(s.y, 10)
2728        s.cached = True
2729
2730        # But can't change the frozen attributes.
2731        with self.assertRaises(FrozenInstanceError):
2732            s.x = 5
2733        with self.assertRaises(FrozenInstanceError):
2734            s.y = 5
2735        self.assertEqual(s.x, 3)
2736        self.assertEqual(s.y, 10)
2737        self.assertEqual(s.cached, True)
2738
2739    def test_overwriting_frozen(self):
2740        # frozen uses __setattr__ and __delattr__.
2741        with self.assertRaisesRegex(TypeError,
2742                                    'Cannot overwrite attribute __setattr__'):
2743            @dataclass(frozen=True)
2744            class C:
2745                x: int
2746                def __setattr__(self):
2747                    pass
2748
2749        with self.assertRaisesRegex(TypeError,
2750                                    'Cannot overwrite attribute __delattr__'):
2751            @dataclass(frozen=True)
2752            class C:
2753                x: int
2754                def __delattr__(self):
2755                    pass
2756
2757        @dataclass(frozen=False)
2758        class C:
2759            x: int
2760            def __setattr__(self, name, value):
2761                self.__dict__['x'] = value * 2
2762        self.assertEqual(C(10).x, 20)
2763
2764    def test_frozen_hash(self):
2765        @dataclass(frozen=True)
2766        class C:
2767            x: Any
2768
2769        # If x is immutable, we can compute the hash.  No exception is
2770        # raised.
2771        hash(C(3))
2772
2773        # If x is mutable, computing the hash is an error.
2774        with self.assertRaisesRegex(TypeError, 'unhashable type'):
2775            hash(C({}))
2776
2777
2778class TestSlots(unittest.TestCase):
2779    def test_simple(self):
2780        @dataclass
2781        class C:
2782            __slots__ = ('x',)
2783            x: Any
2784
2785        # There was a bug where a variable in a slot was assumed to
2786        #  also have a default value (of type
2787        #  types.MemberDescriptorType).
2788        with self.assertRaisesRegex(TypeError,
2789                                    r"__init__\(\) missing 1 required positional argument: 'x'"):
2790            C()
2791
2792        # We can create an instance, and assign to x.
2793        c = C(10)
2794        self.assertEqual(c.x, 10)
2795        c.x = 5
2796        self.assertEqual(c.x, 5)
2797
2798        # We can't assign to anything else.
2799        with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2800            c.y = 5
2801
2802    def test_derived_added_field(self):
2803        # See bpo-33100.
2804        @dataclass
2805        class Base:
2806            __slots__ = ('x',)
2807            x: Any
2808
2809        @dataclass
2810        class Derived(Base):
2811            x: int
2812            y: int
2813
2814        d = Derived(1, 2)
2815        self.assertEqual((d.x, d.y), (1, 2))
2816
2817        # We can add a new field to the derived instance.
2818        d.z = 10
2819
2820    def test_generated_slots(self):
2821        @dataclass(slots=True)
2822        class C:
2823            x: int
2824            y: int
2825
2826        c = C(1, 2)
2827        self.assertEqual((c.x, c.y), (1, 2))
2828
2829        c.x = 3
2830        c.y = 4
2831        self.assertEqual((c.x, c.y), (3, 4))
2832
2833        with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"):
2834            c.z = 5
2835
2836    def test_add_slots_when_slots_exists(self):
2837        with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'):
2838            @dataclass(slots=True)
2839            class C:
2840                __slots__ = ('x',)
2841                x: int
2842
2843    def test_generated_slots_value(self):
2844        @dataclass(slots=True)
2845        class Base:
2846            x: int
2847
2848        self.assertEqual(Base.__slots__, ('x',))
2849
2850        @dataclass(slots=True)
2851        class Delivered(Base):
2852            y: int
2853
2854        self.assertEqual(Delivered.__slots__, ('x', 'y'))
2855
2856        @dataclass
2857        class AnotherDelivered(Base):
2858            z: int
2859
2860        self.assertTrue('__slots__' not in AnotherDelivered.__dict__)
2861
2862    def test_returns_new_class(self):
2863        class A:
2864            x: int
2865
2866        B = dataclass(A, slots=True)
2867        self.assertIsNot(A, B)
2868
2869        self.assertFalse(hasattr(A, "__slots__"))
2870        self.assertTrue(hasattr(B, "__slots__"))
2871
2872    # Can't be local to test_frozen_pickle.
2873    @dataclass(frozen=True, slots=True)
2874    class FrozenSlotsClass:
2875        foo: str
2876        bar: int
2877
2878    @dataclass(frozen=True)
2879    class FrozenWithoutSlotsClass:
2880        foo: str
2881        bar: int
2882
2883    def test_frozen_pickle(self):
2884        # bpo-43999
2885
2886        self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar"))
2887        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2888            with self.subTest(proto=proto):
2889                obj = self.FrozenSlotsClass("a", 1)
2890                p = pickle.loads(pickle.dumps(obj, protocol=proto))
2891                self.assertIsNot(obj, p)
2892                self.assertEqual(obj, p)
2893
2894                obj = self.FrozenWithoutSlotsClass("a", 1)
2895                p = pickle.loads(pickle.dumps(obj, protocol=proto))
2896                self.assertIsNot(obj, p)
2897                self.assertEqual(obj, p)
2898
2899    def test_slots_with_default_no_init(self):
2900        # Originally reported in bpo-44649.
2901        @dataclass(slots=True)
2902        class A:
2903            a: str
2904            b: str = field(default='b', init=False)
2905
2906        obj = A("a")
2907        self.assertEqual(obj.a, 'a')
2908        self.assertEqual(obj.b, 'b')
2909
2910    def test_slots_with_default_factory_no_init(self):
2911        # Originally reported in bpo-44649.
2912        @dataclass(slots=True)
2913        class A:
2914            a: str
2915            b: str = field(default_factory=lambda:'b', init=False)
2916
2917        obj = A("a")
2918        self.assertEqual(obj.a, 'a')
2919        self.assertEqual(obj.b, 'b')
2920
2921class TestDescriptors(unittest.TestCase):
2922    def test_set_name(self):
2923        # See bpo-33141.
2924
2925        # Create a descriptor.
2926        class D:
2927            def __set_name__(self, owner, name):
2928                self.name = name + 'x'
2929            def __get__(self, instance, owner):
2930                if instance is not None:
2931                    return 1
2932                return self
2933
2934        # This is the case of just normal descriptor behavior, no
2935        #  dataclass code is involved in initializing the descriptor.
2936        @dataclass
2937        class C:
2938            c: int=D()
2939        self.assertEqual(C.c.name, 'cx')
2940
2941        # Now test with a default value and init=False, which is the
2942        #  only time this is really meaningful.  If not using
2943        #  init=False, then the descriptor will be overwritten, anyway.
2944        @dataclass
2945        class C:
2946            c: int=field(default=D(), init=False)
2947        self.assertEqual(C.c.name, 'cx')
2948        self.assertEqual(C().c, 1)
2949
2950    def test_non_descriptor(self):
2951        # PEP 487 says __set_name__ should work on non-descriptors.
2952        # Create a descriptor.
2953
2954        class D:
2955            def __set_name__(self, owner, name):
2956                self.name = name + 'x'
2957
2958        @dataclass
2959        class C:
2960            c: int=field(default=D(), init=False)
2961        self.assertEqual(C.c.name, 'cx')
2962
2963    def test_lookup_on_instance(self):
2964        # See bpo-33175.
2965        class D:
2966            pass
2967
2968        d = D()
2969        # Create an attribute on the instance, not type.
2970        d.__set_name__ = Mock()
2971
2972        # Make sure d.__set_name__ is not called.
2973        @dataclass
2974        class C:
2975            i: int=field(default=d, init=False)
2976
2977        self.assertEqual(d.__set_name__.call_count, 0)
2978
2979    def test_lookup_on_class(self):
2980        # See bpo-33175.
2981        class D:
2982            pass
2983        D.__set_name__ = Mock()
2984
2985        # Make sure D.__set_name__ is called.
2986        @dataclass
2987        class C:
2988            i: int=field(default=D(), init=False)
2989
2990        self.assertEqual(D.__set_name__.call_count, 1)
2991
2992
2993class TestStringAnnotations(unittest.TestCase):
2994    def test_classvar(self):
2995        # Some expressions recognized as ClassVar really aren't.  But
2996        #  if you're using string annotations, it's not an exact
2997        #  science.
2998        # These tests assume that both "import typing" and "from
2999        # typing import *" have been run in this file.
3000        for typestr in ('ClassVar[int]',
3001                        'ClassVar [int]',
3002                        ' ClassVar [int]',
3003                        'ClassVar',
3004                        ' ClassVar ',
3005                        'typing.ClassVar[int]',
3006                        'typing.ClassVar[str]',
3007                        ' typing.ClassVar[str]',
3008                        'typing .ClassVar[str]',
3009                        'typing. ClassVar[str]',
3010                        'typing.ClassVar [str]',
3011                        'typing.ClassVar [ str]',
3012
3013                        # Not syntactically valid, but these will
3014                        #  be treated as ClassVars.
3015                        'typing.ClassVar.[int]',
3016                        'typing.ClassVar+',
3017                        ):
3018            with self.subTest(typestr=typestr):
3019                @dataclass
3020                class C:
3021                    x: typestr
3022
3023                # x is a ClassVar, so C() takes no args.
3024                C()
3025
3026                # And it won't appear in the class's dict because it doesn't
3027                # have a default.
3028                self.assertNotIn('x', C.__dict__)
3029
3030    def test_isnt_classvar(self):
3031        for typestr in ('CV',
3032                        't.ClassVar',
3033                        't.ClassVar[int]',
3034                        'typing..ClassVar[int]',
3035                        'Classvar',
3036                        'Classvar[int]',
3037                        'typing.ClassVarx[int]',
3038                        'typong.ClassVar[int]',
3039                        'dataclasses.ClassVar[int]',
3040                        'typingxClassVar[str]',
3041                        ):
3042            with self.subTest(typestr=typestr):
3043                @dataclass
3044                class C:
3045                    x: typestr
3046
3047                # x is not a ClassVar, so C() takes one arg.
3048                self.assertEqual(C(10).x, 10)
3049
3050    def test_initvar(self):
3051        # These tests assume that both "import dataclasses" and "from
3052        #  dataclasses import *" have been run in this file.
3053        for typestr in ('InitVar[int]',
3054                        'InitVar [int]'
3055                        ' InitVar [int]',
3056                        'InitVar',
3057                        ' InitVar ',
3058                        'dataclasses.InitVar[int]',
3059                        'dataclasses.InitVar[str]',
3060                        ' dataclasses.InitVar[str]',
3061                        'dataclasses .InitVar[str]',
3062                        'dataclasses. InitVar[str]',
3063                        'dataclasses.InitVar [str]',
3064                        'dataclasses.InitVar [ str]',
3065
3066                        # Not syntactically valid, but these will
3067                        #  be treated as InitVars.
3068                        'dataclasses.InitVar.[int]',
3069                        'dataclasses.InitVar+',
3070                        ):
3071            with self.subTest(typestr=typestr):
3072                @dataclass
3073                class C:
3074                    x: typestr
3075
3076                # x is an InitVar, so doesn't create a member.
3077                with self.assertRaisesRegex(AttributeError,
3078                                            "object has no attribute 'x'"):
3079                    C(1).x
3080
3081    def test_isnt_initvar(self):
3082        for typestr in ('IV',
3083                        'dc.InitVar',
3084                        'xdataclasses.xInitVar',
3085                        'typing.xInitVar[int]',
3086                        ):
3087            with self.subTest(typestr=typestr):
3088                @dataclass
3089                class C:
3090                    x: typestr
3091
3092                # x is not an InitVar, so there will be a member x.
3093                self.assertEqual(C(10).x, 10)
3094
3095    def test_classvar_module_level_import(self):
3096        from test import dataclass_module_1
3097        from test import dataclass_module_1_str
3098        from test import dataclass_module_2
3099        from test import dataclass_module_2_str
3100
3101        for m in (dataclass_module_1, dataclass_module_1_str,
3102                  dataclass_module_2, dataclass_module_2_str,
3103                  ):
3104            with self.subTest(m=m):
3105                # There's a difference in how the ClassVars are
3106                # interpreted when using string annotations or
3107                # not. See the imported modules for details.
3108                if m.USING_STRINGS:
3109                    c = m.CV(10)
3110                else:
3111                    c = m.CV()
3112                self.assertEqual(c.cv0, 20)
3113
3114
3115                # There's a difference in how the InitVars are
3116                # interpreted when using string annotations or
3117                # not. See the imported modules for details.
3118                c = m.IV(0, 1, 2, 3, 4)
3119
3120                for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
3121                    with self.subTest(field_name=field_name):
3122                        with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
3123                            # Since field_name is an InitVar, it's
3124                            # not an instance field.
3125                            getattr(c, field_name)
3126
3127                if m.USING_STRINGS:
3128                    # iv4 is interpreted as a normal field.
3129                    self.assertIn('not_iv4', c.__dict__)
3130                    self.assertEqual(c.not_iv4, 4)
3131                else:
3132                    # iv4 is interpreted as an InitVar, so it
3133                    # won't exist on the instance.
3134                    self.assertNotIn('not_iv4', c.__dict__)
3135
3136    def test_text_annotations(self):
3137        from test import dataclass_textanno
3138
3139        self.assertEqual(
3140            get_type_hints(dataclass_textanno.Bar),
3141            {'foo': dataclass_textanno.Foo})
3142        self.assertEqual(
3143            get_type_hints(dataclass_textanno.Bar.__init__),
3144            {'foo': dataclass_textanno.Foo,
3145             'return': type(None)})
3146
3147
3148class TestMakeDataclass(unittest.TestCase):
3149    def test_simple(self):
3150        C = make_dataclass('C',
3151                           [('x', int),
3152                            ('y', int, field(default=5))],
3153                           namespace={'add_one': lambda self: self.x + 1})
3154        c = C(10)
3155        self.assertEqual((c.x, c.y), (10, 5))
3156        self.assertEqual(c.add_one(), 11)
3157
3158
3159    def test_no_mutate_namespace(self):
3160        # Make sure a provided namespace isn't mutated.
3161        ns = {}
3162        C = make_dataclass('C',
3163                           [('x', int),
3164                            ('y', int, field(default=5))],
3165                           namespace=ns)
3166        self.assertEqual(ns, {})
3167
3168    def test_base(self):
3169        class Base1:
3170            pass
3171        class Base2:
3172            pass
3173        C = make_dataclass('C',
3174                           [('x', int)],
3175                           bases=(Base1, Base2))
3176        c = C(2)
3177        self.assertIsInstance(c, C)
3178        self.assertIsInstance(c, Base1)
3179        self.assertIsInstance(c, Base2)
3180
3181    def test_base_dataclass(self):
3182        @dataclass
3183        class Base1:
3184            x: int
3185        class Base2:
3186            pass
3187        C = make_dataclass('C',
3188                           [('y', int)],
3189                           bases=(Base1, Base2))
3190        with self.assertRaisesRegex(TypeError, 'required positional'):
3191            c = C(2)
3192        c = C(1, 2)
3193        self.assertIsInstance(c, C)
3194        self.assertIsInstance(c, Base1)
3195        self.assertIsInstance(c, Base2)
3196
3197        self.assertEqual((c.x, c.y), (1, 2))
3198
3199    def test_init_var(self):
3200        def post_init(self, y):
3201            self.x *= y
3202
3203        C = make_dataclass('C',
3204                           [('x', int),
3205                            ('y', InitVar[int]),
3206                            ],
3207                           namespace={'__post_init__': post_init},
3208                           )
3209        c = C(2, 3)
3210        self.assertEqual(vars(c), {'x': 6})
3211        self.assertEqual(len(fields(c)), 1)
3212
3213    def test_class_var(self):
3214        C = make_dataclass('C',
3215                           [('x', int),
3216                            ('y', ClassVar[int], 10),
3217                            ('z', ClassVar[int], field(default=20)),
3218                            ])
3219        c = C(1)
3220        self.assertEqual(vars(c), {'x': 1})
3221        self.assertEqual(len(fields(c)), 1)
3222        self.assertEqual(C.y, 10)
3223        self.assertEqual(C.z, 20)
3224
3225    def test_other_params(self):
3226        C = make_dataclass('C',
3227                           [('x', int),
3228                            ('y', ClassVar[int], 10),
3229                            ('z', ClassVar[int], field(default=20)),
3230                            ],
3231                           init=False)
3232        # Make sure we have a repr, but no init.
3233        self.assertNotIn('__init__', vars(C))
3234        self.assertIn('__repr__', vars(C))
3235
3236        # Make sure random other params don't work.
3237        with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3238            C = make_dataclass('C',
3239                               [],
3240                               xxinit=False)
3241
3242    def test_no_types(self):
3243        C = make_dataclass('Point', ['x', 'y', 'z'])
3244        c = C(1, 2, 3)
3245        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3246        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3247                                             'y': 'typing.Any',
3248                                             'z': 'typing.Any'})
3249
3250        C = make_dataclass('Point', ['x', ('y', int), 'z'])
3251        c = C(1, 2, 3)
3252        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3253        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3254                                             'y': int,
3255                                             'z': 'typing.Any'})
3256
3257    def test_invalid_type_specification(self):
3258        for bad_field in [(),
3259                          (1, 2, 3, 4),
3260                          ]:
3261            with self.subTest(bad_field=bad_field):
3262                with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3263                    make_dataclass('C', ['a', bad_field])
3264
3265        # And test for things with no len().
3266        for bad_field in [float,
3267                          lambda x:x,
3268                          ]:
3269            with self.subTest(bad_field=bad_field):
3270                with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3271                    make_dataclass('C', ['a', bad_field])
3272
3273    def test_duplicate_field_names(self):
3274        for field in ['a', 'ab']:
3275            with self.subTest(field=field):
3276                with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3277                    make_dataclass('C', [field, 'a', field])
3278
3279    def test_keyword_field_names(self):
3280        for field in ['for', 'async', 'await', 'as']:
3281            with self.subTest(field=field):
3282                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3283                    make_dataclass('C', ['a', field])
3284                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3285                    make_dataclass('C', [field])
3286                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3287                    make_dataclass('C', [field, 'a'])
3288
3289    def test_non_identifier_field_names(self):
3290        for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3291            with self.subTest(field=field):
3292                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
3293                    make_dataclass('C', ['a', field])
3294                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
3295                    make_dataclass('C', [field])
3296                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
3297                    make_dataclass('C', [field, 'a'])
3298
3299    def test_underscore_field_names(self):
3300        # Unlike namedtuple, it's okay if dataclass field names have
3301        # an underscore.
3302        make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3303
3304    def test_funny_class_names_names(self):
3305        # No reason to prevent weird class names, since
3306        # types.new_class allows them.
3307        for classname in ['()', 'x,y', '*', '2@3', '']:
3308            with self.subTest(classname=classname):
3309                C = make_dataclass(classname, ['a', 'b'])
3310                self.assertEqual(C.__name__, classname)
3311
3312class TestReplace(unittest.TestCase):
3313    def test(self):
3314        @dataclass(frozen=True)
3315        class C:
3316            x: int
3317            y: int
3318
3319        c = C(1, 2)
3320        c1 = replace(c, x=3)
3321        self.assertEqual(c1.x, 3)
3322        self.assertEqual(c1.y, 2)
3323
3324    def test_frozen(self):
3325        @dataclass(frozen=True)
3326        class C:
3327            x: int
3328            y: int
3329            z: int = field(init=False, default=10)
3330            t: int = field(init=False, default=100)
3331
3332        c = C(1, 2)
3333        c1 = replace(c, x=3)
3334        self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3335        self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3336
3337
3338        with self.assertRaisesRegex(ValueError, 'init=False'):
3339            replace(c, x=3, z=20, t=50)
3340        with self.assertRaisesRegex(ValueError, 'init=False'):
3341            replace(c, z=20)
3342            replace(c, x=3, z=20, t=50)
3343
3344        # Make sure the result is still frozen.
3345        with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3346            c1.x = 3
3347
3348        # Make sure we can't replace an attribute that doesn't exist,
3349        #  if we're also replacing one that does exist.  Test this
3350        #  here, because setting attributes on frozen instances is
3351        #  handled slightly differently from non-frozen ones.
3352        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3353                                             "keyword argument 'a'"):
3354            c1 = replace(c, x=20, a=5)
3355
3356    def test_invalid_field_name(self):
3357        @dataclass(frozen=True)
3358        class C:
3359            x: int
3360            y: int
3361
3362        c = C(1, 2)
3363        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3364                                    "keyword argument 'z'"):
3365            c1 = replace(c, z=3)
3366
3367    def test_invalid_object(self):
3368        @dataclass(frozen=True)
3369        class C:
3370            x: int
3371            y: int
3372
3373        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3374            replace(C, x=3)
3375
3376        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3377            replace(0, x=3)
3378
3379    def test_no_init(self):
3380        @dataclass
3381        class C:
3382            x: int
3383            y: int = field(init=False, default=10)
3384
3385        c = C(1)
3386        c.y = 20
3387
3388        # Make sure y gets the default value.
3389        c1 = replace(c, x=5)
3390        self.assertEqual((c1.x, c1.y), (5, 10))
3391
3392        # Trying to replace y is an error.
3393        with self.assertRaisesRegex(ValueError, 'init=False'):
3394            replace(c, x=2, y=30)
3395
3396        with self.assertRaisesRegex(ValueError, 'init=False'):
3397            replace(c, y=30)
3398
3399    def test_classvar(self):
3400        @dataclass
3401        class C:
3402            x: int
3403            y: ClassVar[int] = 1000
3404
3405        c = C(1)
3406        d = C(2)
3407
3408        self.assertIs(c.y, d.y)
3409        self.assertEqual(c.y, 1000)
3410
3411        # Trying to replace y is an error: can't replace ClassVars.
3412        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3413                                    "unexpected keyword argument 'y'"):
3414            replace(c, y=30)
3415
3416        replace(c, x=5)
3417
3418    def test_initvar_is_specified(self):
3419        @dataclass
3420        class C:
3421            x: int
3422            y: InitVar[int]
3423
3424            def __post_init__(self, y):
3425                self.x *= y
3426
3427        c = C(1, 10)
3428        self.assertEqual(c.x, 10)
3429        with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3430                                    "specified with replace()"):
3431            replace(c, x=3)
3432        c = replace(c, x=3, y=5)
3433        self.assertEqual(c.x, 15)
3434
3435    def test_initvar_with_default_value(self):
3436        @dataclass
3437        class C:
3438            x: int
3439            y: InitVar[int] = None
3440            z: InitVar[int] = 42
3441
3442            def __post_init__(self, y, z):
3443                if y is not None:
3444                    self.x += y
3445                if z is not None:
3446                    self.x += z
3447
3448        c = C(x=1, y=10, z=1)
3449        self.assertEqual(replace(c), C(x=12))
3450        self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42))
3451        self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1))
3452
3453    def test_recursive_repr(self):
3454        @dataclass
3455        class C:
3456            f: "C"
3457
3458        c = C(None)
3459        c.f = c
3460        self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3461
3462    def test_recursive_repr_two_attrs(self):
3463        @dataclass
3464        class C:
3465            f: "C"
3466            g: "C"
3467
3468        c = C(None, None)
3469        c.f = c
3470        c.g = c
3471        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3472                                  ".<locals>.C(f=..., g=...)")
3473
3474    def test_recursive_repr_indirection(self):
3475        @dataclass
3476        class C:
3477            f: "D"
3478
3479        @dataclass
3480        class D:
3481            f: "C"
3482
3483        c = C(None)
3484        d = D(None)
3485        c.f = d
3486        d.f = c
3487        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3488                                  ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3489                                  ".<locals>.D(f=...))")
3490
3491    def test_recursive_repr_indirection_two(self):
3492        @dataclass
3493        class C:
3494            f: "D"
3495
3496        @dataclass
3497        class D:
3498            f: "E"
3499
3500        @dataclass
3501        class E:
3502            f: "C"
3503
3504        c = C(None)
3505        d = D(None)
3506        e = E(None)
3507        c.f = d
3508        d.f = e
3509        e.f = c
3510        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3511                                  ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3512                                  ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3513                                  ".<locals>.E(f=...)))")
3514
3515    def test_recursive_repr_misc_attrs(self):
3516        @dataclass
3517        class C:
3518            f: "C"
3519            g: int
3520
3521        c = C(None, 1)
3522        c.f = c
3523        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3524                                  ".<locals>.C(f=..., g=1)")
3525
3526    ## def test_initvar(self):
3527    ##     @dataclass
3528    ##     class C:
3529    ##         x: int
3530    ##         y: InitVar[int]
3531
3532    ##     c = C(1, 10)
3533    ##     d = C(2, 20)
3534
3535    ##     # In our case, replacing an InitVar is a no-op
3536    ##     self.assertEqual(c, replace(c, y=5))
3537
3538    ##     replace(c, x=5)
3539
3540class TestAbstract(unittest.TestCase):
3541    def test_abc_implementation(self):
3542        class Ordered(abc.ABC):
3543            @abc.abstractmethod
3544            def __lt__(self, other):
3545                pass
3546
3547            @abc.abstractmethod
3548            def __le__(self, other):
3549                pass
3550
3551        @dataclass(order=True)
3552        class Date(Ordered):
3553            year: int
3554            month: 'Month'
3555            day: 'int'
3556
3557        self.assertFalse(inspect.isabstract(Date))
3558        self.assertGreater(Date(2020,12,25), Date(2020,8,31))
3559
3560    def test_maintain_abc(self):
3561        class A(abc.ABC):
3562            @abc.abstractmethod
3563            def foo(self):
3564                pass
3565
3566        @dataclass
3567        class Date(A):
3568            year: int
3569            month: 'Month'
3570            day: 'int'
3571
3572        self.assertTrue(inspect.isabstract(Date))
3573        msg = 'class Date with abstract method foo'
3574        self.assertRaisesRegex(TypeError, msg, Date)
3575
3576
3577class TestMatchArgs(unittest.TestCase):
3578    def test_match_args(self):
3579        @dataclass
3580        class C:
3581            a: int
3582        self.assertEqual(C(42).__match_args__, ('a',))
3583
3584    def test_explicit_match_args(self):
3585        ma = ()
3586        @dataclass
3587        class C:
3588            a: int
3589            __match_args__ = ma
3590        self.assertIs(C(42).__match_args__, ma)
3591
3592    def test_bpo_43764(self):
3593        @dataclass(repr=False, eq=False, init=False)
3594        class X:
3595            a: int
3596            b: int
3597            c: int
3598        self.assertEqual(X.__match_args__, ("a", "b", "c"))
3599
3600    def test_match_args_argument(self):
3601        @dataclass(match_args=False)
3602        class X:
3603            a: int
3604        self.assertNotIn('__match_args__', X.__dict__)
3605
3606        @dataclass(match_args=False)
3607        class Y:
3608            a: int
3609            __match_args__ = ('b',)
3610        self.assertEqual(Y.__match_args__, ('b',))
3611
3612        @dataclass(match_args=False)
3613        class Z(Y):
3614            z: int
3615        self.assertEqual(Z.__match_args__, ('b',))
3616
3617        # Ensure parent dataclass __match_args__ is seen, if child class
3618        # specifies match_args=False.
3619        @dataclass
3620        class A:
3621            a: int
3622            z: int
3623        @dataclass(match_args=False)
3624        class B(A):
3625            b: int
3626        self.assertEqual(B.__match_args__, ('a', 'z'))
3627
3628    def test_make_dataclasses(self):
3629        C = make_dataclass('C', [('x', int), ('y', int)])
3630        self.assertEqual(C.__match_args__, ('x', 'y'))
3631
3632        C = make_dataclass('C', [('x', int), ('y', int)], match_args=True)
3633        self.assertEqual(C.__match_args__, ('x', 'y'))
3634
3635        C = make_dataclass('C', [('x', int), ('y', int)], match_args=False)
3636        self.assertNotIn('__match__args__', C.__dict__)
3637
3638        C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)})
3639        self.assertEqual(C.__match_args__, ('z',))
3640
3641
3642class TestKeywordArgs(unittest.TestCase):
3643    def test_no_classvar_kwarg(self):
3644        msg = 'field a is a ClassVar but specifies kw_only'
3645        with self.assertRaisesRegex(TypeError, msg):
3646            @dataclass
3647            class A:
3648                a: ClassVar[int] = field(kw_only=True)
3649
3650        with self.assertRaisesRegex(TypeError, msg):
3651            @dataclass
3652            class A:
3653                a: ClassVar[int] = field(kw_only=False)
3654
3655        with self.assertRaisesRegex(TypeError, msg):
3656            @dataclass(kw_only=True)
3657            class A:
3658                a: ClassVar[int] = field(kw_only=False)
3659
3660    def test_field_marked_as_kwonly(self):
3661        #######################
3662        # Using dataclass(kw_only=True)
3663        @dataclass(kw_only=True)
3664        class A:
3665            a: int
3666        self.assertTrue(fields(A)[0].kw_only)
3667
3668        @dataclass(kw_only=True)
3669        class A:
3670            a: int = field(kw_only=True)
3671        self.assertTrue(fields(A)[0].kw_only)
3672
3673        @dataclass(kw_only=True)
3674        class A:
3675            a: int = field(kw_only=False)
3676        self.assertFalse(fields(A)[0].kw_only)
3677
3678        #######################
3679        # Using dataclass(kw_only=False)
3680        @dataclass(kw_only=False)
3681        class A:
3682            a: int
3683        self.assertFalse(fields(A)[0].kw_only)
3684
3685        @dataclass(kw_only=False)
3686        class A:
3687            a: int = field(kw_only=True)
3688        self.assertTrue(fields(A)[0].kw_only)
3689
3690        @dataclass(kw_only=False)
3691        class A:
3692            a: int = field(kw_only=False)
3693        self.assertFalse(fields(A)[0].kw_only)
3694
3695        #######################
3696        # Not specifying dataclass(kw_only)
3697        @dataclass
3698        class A:
3699            a: int
3700        self.assertFalse(fields(A)[0].kw_only)
3701
3702        @dataclass
3703        class A:
3704            a: int = field(kw_only=True)
3705        self.assertTrue(fields(A)[0].kw_only)
3706
3707        @dataclass
3708        class A:
3709            a: int = field(kw_only=False)
3710        self.assertFalse(fields(A)[0].kw_only)
3711
3712    def test_match_args(self):
3713        # kw fields don't show up in __match_args__.
3714        @dataclass(kw_only=True)
3715        class C:
3716            a: int
3717        self.assertEqual(C(a=42).__match_args__, ())
3718
3719        @dataclass
3720        class C:
3721            a: int
3722            b: int = field(kw_only=True)
3723        self.assertEqual(C(42, b=10).__match_args__, ('a',))
3724
3725    def test_KW_ONLY(self):
3726        @dataclass
3727        class A:
3728            a: int
3729            _: KW_ONLY
3730            b: int
3731            c: int
3732        A(3, c=5, b=4)
3733        msg = "takes 2 positional arguments but 4 were given"
3734        with self.assertRaisesRegex(TypeError, msg):
3735            A(3, 4, 5)
3736
3737
3738        @dataclass(kw_only=True)
3739        class B:
3740            a: int
3741            _: KW_ONLY
3742            b: int
3743            c: int
3744        B(a=3, b=4, c=5)
3745        msg = "takes 1 positional argument but 4 were given"
3746        with self.assertRaisesRegex(TypeError, msg):
3747            B(3, 4, 5)
3748
3749        # Explicitly make a field that follows KW_ONLY be non-keyword-only.
3750        @dataclass
3751        class C:
3752            a: int
3753            _: KW_ONLY
3754            b: int
3755            c: int = field(kw_only=False)
3756        c = C(1, 2, b=3)
3757        self.assertEqual(c.a, 1)
3758        self.assertEqual(c.b, 3)
3759        self.assertEqual(c.c, 2)
3760        c = C(1, b=3, c=2)
3761        self.assertEqual(c.a, 1)
3762        self.assertEqual(c.b, 3)
3763        self.assertEqual(c.c, 2)
3764        c = C(1, b=3, c=2)
3765        self.assertEqual(c.a, 1)
3766        self.assertEqual(c.b, 3)
3767        self.assertEqual(c.c, 2)
3768        c = C(c=2, b=3, a=1)
3769        self.assertEqual(c.a, 1)
3770        self.assertEqual(c.b, 3)
3771        self.assertEqual(c.c, 2)
3772
3773    def test_KW_ONLY_as_string(self):
3774        @dataclass
3775        class A:
3776            a: int
3777            _: 'dataclasses.KW_ONLY'
3778            b: int
3779            c: int
3780        A(3, c=5, b=4)
3781        msg = "takes 2 positional arguments but 4 were given"
3782        with self.assertRaisesRegex(TypeError, msg):
3783            A(3, 4, 5)
3784
3785    def test_KW_ONLY_twice(self):
3786        msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified"
3787
3788        with self.assertRaisesRegex(TypeError, msg):
3789            @dataclass
3790            class A:
3791                a: int
3792                X: KW_ONLY
3793                Y: KW_ONLY
3794                b: int
3795                c: int
3796
3797        with self.assertRaisesRegex(TypeError, msg):
3798            @dataclass
3799            class A:
3800                a: int
3801                X: KW_ONLY
3802                b: int
3803                Y: KW_ONLY
3804                c: int
3805
3806        with self.assertRaisesRegex(TypeError, msg):
3807            @dataclass
3808            class A:
3809                a: int
3810                X: KW_ONLY
3811                b: int
3812                c: int
3813                Y: KW_ONLY
3814
3815        # But this usage is okay, since it's not using KW_ONLY.
3816        @dataclass
3817        class A:
3818            a: int
3819            _: KW_ONLY
3820            b: int
3821            c: int = field(kw_only=True)
3822
3823        # And if inheriting, it's okay.
3824        @dataclass
3825        class A:
3826            a: int
3827            _: KW_ONLY
3828            b: int
3829            c: int
3830        @dataclass
3831        class B(A):
3832            _: KW_ONLY
3833            d: int
3834
3835        # Make sure the error is raised in a derived class.
3836        with self.assertRaisesRegex(TypeError, msg):
3837            @dataclass
3838            class A:
3839                a: int
3840                _: KW_ONLY
3841                b: int
3842                c: int
3843            @dataclass
3844            class B(A):
3845                X: KW_ONLY
3846                d: int
3847                Y: KW_ONLY
3848
3849
3850    def test_post_init(self):
3851        @dataclass
3852        class A:
3853            a: int
3854            _: KW_ONLY
3855            b: InitVar[int]
3856            c: int
3857            d: InitVar[int]
3858            def __post_init__(self, b, d):
3859                raise CustomError(f'{b=} {d=}')
3860        with self.assertRaisesRegex(CustomError, 'b=3 d=4'):
3861            A(1, c=2, b=3, d=4)
3862
3863        @dataclass
3864        class B:
3865            a: int
3866            _: KW_ONLY
3867            b: InitVar[int]
3868            c: int
3869            d: InitVar[int]
3870            def __post_init__(self, b, d):
3871                self.a = b
3872                self.c = d
3873        b = B(1, c=2, b=3, d=4)
3874        self.assertEqual(asdict(b), {'a': 3, 'c': 4})
3875
3876    def test_defaults(self):
3877        # For kwargs, make sure we can have defaults after non-defaults.
3878        @dataclass
3879        class A:
3880            a: int = 0
3881            _: KW_ONLY
3882            b: int
3883            c: int = 1
3884            d: int
3885
3886        a = A(d=4, b=3)
3887        self.assertEqual(a.a, 0)
3888        self.assertEqual(a.b, 3)
3889        self.assertEqual(a.c, 1)
3890        self.assertEqual(a.d, 4)
3891
3892        # Make sure we still check for non-kwarg non-defaults not following
3893        # defaults.
3894        err_regex = "non-default argument 'z' follows default argument"
3895        with self.assertRaisesRegex(TypeError, err_regex):
3896            @dataclass
3897            class A:
3898                a: int = 0
3899                z: int
3900                _: KW_ONLY
3901                b: int
3902                c: int = 1
3903                d: int
3904
3905    def test_make_dataclass(self):
3906        A = make_dataclass("A", ['a'], kw_only=True)
3907        self.assertTrue(fields(A)[0].kw_only)
3908
3909        B = make_dataclass("B",
3910                           ['a', ('b', int, field(kw_only=False))],
3911                           kw_only=True)
3912        self.assertTrue(fields(B)[0].kw_only)
3913        self.assertFalse(fields(B)[1].kw_only)
3914
3915
3916if __name__ == '__main__':
3917    unittest.main()
3918