1# Deliberately use "from dataclasses import *". Every name in __all__ 2# is tested, so they all must be present. This is a way to catch 3# missing ones. 4 5from dataclasses import * 6 7import abc 8import pickle 9import inspect 10import builtins 11import types 12import unittest 13from unittest.mock import Mock 14from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol 15from typing import get_type_hints 16from collections import deque, OrderedDict, namedtuple 17from functools import total_ordering 18 19import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation. 20import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation. 21 22# Just any custom exception we can catch. 23class CustomError(Exception): pass 24 25class TestCase(unittest.TestCase): 26 def test_no_fields(self): 27 @dataclass 28 class C: 29 pass 30 31 o = C() 32 self.assertEqual(len(fields(C)), 0) 33 34 def test_no_fields_but_member_variable(self): 35 @dataclass 36 class C: 37 i = 0 38 39 o = C() 40 self.assertEqual(len(fields(C)), 0) 41 42 def test_one_field_no_default(self): 43 @dataclass 44 class C: 45 x: int 46 47 o = C(42) 48 self.assertEqual(o.x, 42) 49 50 def test_field_default_default_factory_error(self): 51 msg = "cannot specify both default and default_factory" 52 with self.assertRaisesRegex(ValueError, msg): 53 @dataclass 54 class C: 55 x: int = field(default=1, default_factory=int) 56 57 def test_field_repr(self): 58 int_field = field(default=1, init=True, repr=False) 59 int_field.name = "id" 60 repr_output = repr(int_field) 61 expected_output = "Field(name='id',type=None," \ 62 f"default=1,default_factory={MISSING!r}," \ 63 "init=True,repr=False,hash=None," \ 64 "compare=True,metadata=mappingproxy({})," \ 65 f"kw_only={MISSING!r}," \ 66 "_field_type=None)" 67 68 self.assertEqual(repr_output, expected_output) 69 70 def test_named_init_params(self): 71 @dataclass 72 class C: 73 x: int 74 75 o = C(x=32) 76 self.assertEqual(o.x, 32) 77 78 def test_two_fields_one_default(self): 79 @dataclass 80 class C: 81 x: int 82 y: int = 0 83 84 o = C(3) 85 self.assertEqual((o.x, o.y), (3, 0)) 86 87 # Non-defaults following defaults. 88 with self.assertRaisesRegex(TypeError, 89 "non-default argument 'y' follows " 90 "default argument"): 91 @dataclass 92 class C: 93 x: int = 0 94 y: int 95 96 # A derived class adds a non-default field after a default one. 97 with self.assertRaisesRegex(TypeError, 98 "non-default argument 'y' follows " 99 "default argument"): 100 @dataclass 101 class B: 102 x: int = 0 103 104 @dataclass 105 class C(B): 106 y: int 107 108 # Override a base class field and add a default to 109 # a field which didn't use to have a default. 110 with self.assertRaisesRegex(TypeError, 111 "non-default argument 'y' follows " 112 "default argument"): 113 @dataclass 114 class B: 115 x: int 116 y: int 117 118 @dataclass 119 class C(B): 120 x: int = 0 121 122 def test_overwrite_hash(self): 123 # Test that declaring this class isn't an error. It should 124 # use the user-provided __hash__. 125 @dataclass(frozen=True) 126 class C: 127 x: int 128 def __hash__(self): 129 return 301 130 self.assertEqual(hash(C(100)), 301) 131 132 # Test that declaring this class isn't an error. It should 133 # use the generated __hash__. 134 @dataclass(frozen=True) 135 class C: 136 x: int 137 def __eq__(self, other): 138 return False 139 self.assertEqual(hash(C(100)), hash((100,))) 140 141 # But this one should generate an exception, because with 142 # unsafe_hash=True, it's an error to have a __hash__ defined. 143 with self.assertRaisesRegex(TypeError, 144 'Cannot overwrite attribute __hash__'): 145 @dataclass(unsafe_hash=True) 146 class C: 147 def __hash__(self): 148 pass 149 150 # Creating this class should not generate an exception, 151 # because even though __hash__ exists before @dataclass is 152 # called, (due to __eq__ being defined), since it's None 153 # that's okay. 154 @dataclass(unsafe_hash=True) 155 class C: 156 x: int 157 def __eq__(self): 158 pass 159 # The generated hash function works as we'd expect. 160 self.assertEqual(hash(C(10)), hash((10,))) 161 162 # Creating this class should generate an exception, because 163 # __hash__ exists and is not None, which it would be if it 164 # had been auto-generated due to __eq__ being defined. 165 with self.assertRaisesRegex(TypeError, 166 'Cannot overwrite attribute __hash__'): 167 @dataclass(unsafe_hash=True) 168 class C: 169 x: int 170 def __eq__(self): 171 pass 172 def __hash__(self): 173 pass 174 175 def test_overwrite_fields_in_derived_class(self): 176 # Note that x from C1 replaces x in Base, but the order remains 177 # the same as defined in Base. 178 @dataclass 179 class Base: 180 x: Any = 15.0 181 y: int = 0 182 183 @dataclass 184 class C1(Base): 185 z: int = 10 186 x: int = 15 187 188 o = Base() 189 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)') 190 191 o = C1() 192 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)') 193 194 o = C1(x=5) 195 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)') 196 197 def test_field_named_self(self): 198 @dataclass 199 class C: 200 self: str 201 c=C('foo') 202 self.assertEqual(c.self, 'foo') 203 204 # Make sure the first parameter is not named 'self'. 205 sig = inspect.signature(C.__init__) 206 first = next(iter(sig.parameters)) 207 self.assertNotEqual('self', first) 208 209 # But we do use 'self' if no field named self. 210 @dataclass 211 class C: 212 selfx: str 213 214 # Make sure the first parameter is named 'self'. 215 sig = inspect.signature(C.__init__) 216 first = next(iter(sig.parameters)) 217 self.assertEqual('self', first) 218 219 def test_field_named_object(self): 220 @dataclass 221 class C: 222 object: str 223 c = C('foo') 224 self.assertEqual(c.object, 'foo') 225 226 def test_field_named_object_frozen(self): 227 @dataclass(frozen=True) 228 class C: 229 object: str 230 c = C('foo') 231 self.assertEqual(c.object, 'foo') 232 233 def test_field_named_like_builtin(self): 234 # Attribute names can shadow built-in names 235 # since code generation is used. 236 # Ensure that this is not happening. 237 exclusions = {'None', 'True', 'False'} 238 builtins_names = sorted( 239 b for b in builtins.__dict__.keys() 240 if not b.startswith('__') and b not in exclusions 241 ) 242 attributes = [(name, str) for name in builtins_names] 243 C = make_dataclass('C', attributes) 244 245 c = C(*[name for name in builtins_names]) 246 247 for name in builtins_names: 248 self.assertEqual(getattr(c, name), name) 249 250 def test_field_named_like_builtin_frozen(self): 251 # Attribute names can shadow built-in names 252 # since code generation is used. 253 # Ensure that this is not happening 254 # for frozen data classes. 255 exclusions = {'None', 'True', 'False'} 256 builtins_names = sorted( 257 b for b in builtins.__dict__.keys() 258 if not b.startswith('__') and b not in exclusions 259 ) 260 attributes = [(name, str) for name in builtins_names] 261 C = make_dataclass('C', attributes, frozen=True) 262 263 c = C(*[name for name in builtins_names]) 264 265 for name in builtins_names: 266 self.assertEqual(getattr(c, name), name) 267 268 def test_0_field_compare(self): 269 # Ensure that order=False is the default. 270 @dataclass 271 class C0: 272 pass 273 274 @dataclass(order=False) 275 class C1: 276 pass 277 278 for cls in [C0, C1]: 279 with self.subTest(cls=cls): 280 self.assertEqual(cls(), cls()) 281 for idx, fn in enumerate([lambda a, b: a < b, 282 lambda a, b: a <= b, 283 lambda a, b: a > b, 284 lambda a, b: a >= b]): 285 with self.subTest(idx=idx): 286 with self.assertRaisesRegex(TypeError, 287 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 288 fn(cls(), cls()) 289 290 @dataclass(order=True) 291 class C: 292 pass 293 self.assertLessEqual(C(), C()) 294 self.assertGreaterEqual(C(), C()) 295 296 def test_1_field_compare(self): 297 # Ensure that order=False is the default. 298 @dataclass 299 class C0: 300 x: int 301 302 @dataclass(order=False) 303 class C1: 304 x: int 305 306 for cls in [C0, C1]: 307 with self.subTest(cls=cls): 308 self.assertEqual(cls(1), cls(1)) 309 self.assertNotEqual(cls(0), cls(1)) 310 for idx, fn in enumerate([lambda a, b: a < b, 311 lambda a, b: a <= b, 312 lambda a, b: a > b, 313 lambda a, b: a >= b]): 314 with self.subTest(idx=idx): 315 with self.assertRaisesRegex(TypeError, 316 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 317 fn(cls(0), cls(0)) 318 319 @dataclass(order=True) 320 class C: 321 x: int 322 self.assertLess(C(0), C(1)) 323 self.assertLessEqual(C(0), C(1)) 324 self.assertLessEqual(C(1), C(1)) 325 self.assertGreater(C(1), C(0)) 326 self.assertGreaterEqual(C(1), C(0)) 327 self.assertGreaterEqual(C(1), C(1)) 328 329 def test_simple_compare(self): 330 # Ensure that order=False is the default. 331 @dataclass 332 class C0: 333 x: int 334 y: int 335 336 @dataclass(order=False) 337 class C1: 338 x: int 339 y: int 340 341 for cls in [C0, C1]: 342 with self.subTest(cls=cls): 343 self.assertEqual(cls(0, 0), cls(0, 0)) 344 self.assertEqual(cls(1, 2), cls(1, 2)) 345 self.assertNotEqual(cls(1, 0), cls(0, 0)) 346 self.assertNotEqual(cls(1, 0), cls(1, 1)) 347 for idx, fn in enumerate([lambda a, b: a < b, 348 lambda a, b: a <= b, 349 lambda a, b: a > b, 350 lambda a, b: a >= b]): 351 with self.subTest(idx=idx): 352 with self.assertRaisesRegex(TypeError, 353 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 354 fn(cls(0, 0), cls(0, 0)) 355 356 @dataclass(order=True) 357 class C: 358 x: int 359 y: int 360 361 for idx, fn in enumerate([lambda a, b: a == b, 362 lambda a, b: a <= b, 363 lambda a, b: a >= b]): 364 with self.subTest(idx=idx): 365 self.assertTrue(fn(C(0, 0), C(0, 0))) 366 367 for idx, fn in enumerate([lambda a, b: a < b, 368 lambda a, b: a <= b, 369 lambda a, b: a != b]): 370 with self.subTest(idx=idx): 371 self.assertTrue(fn(C(0, 0), C(0, 1))) 372 self.assertTrue(fn(C(0, 1), C(1, 0))) 373 self.assertTrue(fn(C(1, 0), C(1, 1))) 374 375 for idx, fn in enumerate([lambda a, b: a > b, 376 lambda a, b: a >= b, 377 lambda a, b: a != b]): 378 with self.subTest(idx=idx): 379 self.assertTrue(fn(C(0, 1), C(0, 0))) 380 self.assertTrue(fn(C(1, 0), C(0, 1))) 381 self.assertTrue(fn(C(1, 1), C(1, 0))) 382 383 def test_compare_subclasses(self): 384 # Comparisons fail for subclasses, even if no fields 385 # are added. 386 @dataclass 387 class B: 388 i: int 389 390 @dataclass 391 class C(B): 392 pass 393 394 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False), 395 (lambda a, b: a != b, True)]): 396 with self.subTest(idx=idx): 397 self.assertEqual(fn(B(0), C(0)), expected) 398 399 for idx, fn in enumerate([lambda a, b: a < b, 400 lambda a, b: a <= b, 401 lambda a, b: a > b, 402 lambda a, b: a >= b]): 403 with self.subTest(idx=idx): 404 with self.assertRaisesRegex(TypeError, 405 "not supported between instances of 'B' and 'C'"): 406 fn(B(0), C(0)) 407 408 def test_eq_order(self): 409 # Test combining eq and order. 410 for (eq, order, result ) in [ 411 (False, False, 'neither'), 412 (False, True, 'exception'), 413 (True, False, 'eq_only'), 414 (True, True, 'both'), 415 ]: 416 with self.subTest(eq=eq, order=order): 417 if result == 'exception': 418 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'): 419 @dataclass(eq=eq, order=order) 420 class C: 421 pass 422 else: 423 @dataclass(eq=eq, order=order) 424 class C: 425 pass 426 427 if result == 'neither': 428 self.assertNotIn('__eq__', C.__dict__) 429 self.assertNotIn('__lt__', C.__dict__) 430 self.assertNotIn('__le__', C.__dict__) 431 self.assertNotIn('__gt__', C.__dict__) 432 self.assertNotIn('__ge__', C.__dict__) 433 elif result == 'both': 434 self.assertIn('__eq__', C.__dict__) 435 self.assertIn('__lt__', C.__dict__) 436 self.assertIn('__le__', C.__dict__) 437 self.assertIn('__gt__', C.__dict__) 438 self.assertIn('__ge__', C.__dict__) 439 elif result == 'eq_only': 440 self.assertIn('__eq__', C.__dict__) 441 self.assertNotIn('__lt__', C.__dict__) 442 self.assertNotIn('__le__', C.__dict__) 443 self.assertNotIn('__gt__', C.__dict__) 444 self.assertNotIn('__ge__', C.__dict__) 445 else: 446 assert False, f'unknown result {result!r}' 447 448 def test_field_no_default(self): 449 @dataclass 450 class C: 451 x: int = field() 452 453 self.assertEqual(C(5).x, 5) 454 455 with self.assertRaisesRegex(TypeError, 456 r"__init__\(\) missing 1 required " 457 "positional argument: 'x'"): 458 C() 459 460 def test_field_default(self): 461 default = object() 462 @dataclass 463 class C: 464 x: object = field(default=default) 465 466 self.assertIs(C.x, default) 467 c = C(10) 468 self.assertEqual(c.x, 10) 469 470 # If we delete the instance attribute, we should then see the 471 # class attribute. 472 del c.x 473 self.assertIs(c.x, default) 474 475 self.assertIs(C().x, default) 476 477 def test_not_in_repr(self): 478 @dataclass 479 class C: 480 x: int = field(repr=False) 481 with self.assertRaises(TypeError): 482 C() 483 c = C(10) 484 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()') 485 486 @dataclass 487 class C: 488 x: int = field(repr=False) 489 y: int 490 c = C(10, 20) 491 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)') 492 493 def test_not_in_compare(self): 494 @dataclass 495 class C: 496 x: int = 0 497 y: int = field(compare=False, default=4) 498 499 self.assertEqual(C(), C(0, 20)) 500 self.assertEqual(C(1, 10), C(1, 20)) 501 self.assertNotEqual(C(3), C(4, 10)) 502 self.assertNotEqual(C(3, 10), C(4, 10)) 503 504 def test_hash_field_rules(self): 505 # Test all 6 cases of: 506 # hash=True/False/None 507 # compare=True/False 508 for (hash_, compare, result ) in [ 509 (True, False, 'field' ), 510 (True, True, 'field' ), 511 (False, False, 'absent'), 512 (False, True, 'absent'), 513 (None, False, 'absent'), 514 (None, True, 'field' ), 515 ]: 516 with self.subTest(hash=hash_, compare=compare): 517 @dataclass(unsafe_hash=True) 518 class C: 519 x: int = field(compare=compare, hash=hash_, default=5) 520 521 if result == 'field': 522 # __hash__ contains the field. 523 self.assertEqual(hash(C(5)), hash((5,))) 524 elif result == 'absent': 525 # The field is not present in the hash. 526 self.assertEqual(hash(C(5)), hash(())) 527 else: 528 assert False, f'unknown result {result!r}' 529 530 def test_init_false_no_default(self): 531 # If init=False and no default value, then the field won't be 532 # present in the instance. 533 @dataclass 534 class C: 535 x: int = field(init=False) 536 537 self.assertNotIn('x', C().__dict__) 538 539 @dataclass 540 class C: 541 x: int 542 y: int = 0 543 z: int = field(init=False) 544 t: int = 10 545 546 self.assertNotIn('z', C(0).__dict__) 547 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0}) 548 549 def test_class_marker(self): 550 @dataclass 551 class C: 552 x: int 553 y: str = field(init=False, default=None) 554 z: str = field(repr=False) 555 556 the_fields = fields(C) 557 # the_fields is a tuple of 3 items, each value 558 # is in __annotations__. 559 self.assertIsInstance(the_fields, tuple) 560 for f in the_fields: 561 self.assertIs(type(f), Field) 562 self.assertIn(f.name, C.__annotations__) 563 564 self.assertEqual(len(the_fields), 3) 565 566 self.assertEqual(the_fields[0].name, 'x') 567 self.assertEqual(the_fields[0].type, int) 568 self.assertFalse(hasattr(C, 'x')) 569 self.assertTrue (the_fields[0].init) 570 self.assertTrue (the_fields[0].repr) 571 self.assertEqual(the_fields[1].name, 'y') 572 self.assertEqual(the_fields[1].type, str) 573 self.assertIsNone(getattr(C, 'y')) 574 self.assertFalse(the_fields[1].init) 575 self.assertTrue (the_fields[1].repr) 576 self.assertEqual(the_fields[2].name, 'z') 577 self.assertEqual(the_fields[2].type, str) 578 self.assertFalse(hasattr(C, 'z')) 579 self.assertTrue (the_fields[2].init) 580 self.assertFalse(the_fields[2].repr) 581 582 def test_field_order(self): 583 @dataclass 584 class B: 585 a: str = 'B:a' 586 b: str = 'B:b' 587 c: str = 'B:c' 588 589 @dataclass 590 class C(B): 591 b: str = 'C:b' 592 593 self.assertEqual([(f.name, f.default) for f in fields(C)], 594 [('a', 'B:a'), 595 ('b', 'C:b'), 596 ('c', 'B:c')]) 597 598 @dataclass 599 class D(B): 600 c: str = 'D:c' 601 602 self.assertEqual([(f.name, f.default) for f in fields(D)], 603 [('a', 'B:a'), 604 ('b', 'B:b'), 605 ('c', 'D:c')]) 606 607 @dataclass 608 class E(D): 609 a: str = 'E:a' 610 d: str = 'E:d' 611 612 self.assertEqual([(f.name, f.default) for f in fields(E)], 613 [('a', 'E:a'), 614 ('b', 'B:b'), 615 ('c', 'D:c'), 616 ('d', 'E:d')]) 617 618 def test_class_attrs(self): 619 # We only have a class attribute if a default value is 620 # specified, either directly or via a field with a default. 621 default = object() 622 @dataclass 623 class C: 624 x: int 625 y: int = field(repr=False) 626 z: object = default 627 t: int = field(default=100) 628 629 self.assertFalse(hasattr(C, 'x')) 630 self.assertFalse(hasattr(C, 'y')) 631 self.assertIs (C.z, default) 632 self.assertEqual(C.t, 100) 633 634 def test_disallowed_mutable_defaults(self): 635 # For the known types, don't allow mutable default values. 636 for typ, empty, non_empty in [(list, [], [1]), 637 (dict, {}, {0:1}), 638 (set, set(), set([1])), 639 ]: 640 with self.subTest(typ=typ): 641 # Can't use a zero-length value. 642 with self.assertRaisesRegex(ValueError, 643 f'mutable default {typ} for field ' 644 'x is not allowed'): 645 @dataclass 646 class Point: 647 x: typ = empty 648 649 650 # Nor a non-zero-length value 651 with self.assertRaisesRegex(ValueError, 652 f'mutable default {typ} for field ' 653 'y is not allowed'): 654 @dataclass 655 class Point: 656 y: typ = non_empty 657 658 # Check subtypes also fail. 659 class Subclass(typ): pass 660 661 with self.assertRaisesRegex(ValueError, 662 f"mutable default .*Subclass'>" 663 ' for field z is not allowed' 664 ): 665 @dataclass 666 class Point: 667 z: typ = Subclass() 668 669 # Because this is a ClassVar, it can be mutable. 670 @dataclass 671 class C: 672 z: ClassVar[typ] = typ() 673 674 # Because this is a ClassVar, it can be mutable. 675 @dataclass 676 class C: 677 x: ClassVar[typ] = Subclass() 678 679 def test_deliberately_mutable_defaults(self): 680 # If a mutable default isn't in the known list of 681 # (list, dict, set), then it's okay. 682 class Mutable: 683 def __init__(self): 684 self.l = [] 685 686 @dataclass 687 class C: 688 x: Mutable 689 690 # These 2 instances will share this value of x. 691 lst = Mutable() 692 o1 = C(lst) 693 o2 = C(lst) 694 self.assertEqual(o1, o2) 695 o1.x.l.extend([1, 2]) 696 self.assertEqual(o1, o2) 697 self.assertEqual(o1.x.l, [1, 2]) 698 self.assertIs(o1.x, o2.x) 699 700 def test_no_options(self): 701 # Call with dataclass(). 702 @dataclass() 703 class C: 704 x: int 705 706 self.assertEqual(C(42).x, 42) 707 708 def test_not_tuple(self): 709 # Make sure we can't be compared to a tuple. 710 @dataclass 711 class Point: 712 x: int 713 y: int 714 self.assertNotEqual(Point(1, 2), (1, 2)) 715 716 # And that we can't compare to another unrelated dataclass. 717 @dataclass 718 class C: 719 x: int 720 y: int 721 self.assertNotEqual(Point(1, 3), C(1, 3)) 722 723 def test_not_other_dataclass(self): 724 # Test that some of the problems with namedtuple don't happen 725 # here. 726 @dataclass 727 class Point3D: 728 x: int 729 y: int 730 z: int 731 732 @dataclass 733 class Date: 734 year: int 735 month: int 736 day: int 737 738 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3)) 739 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3)) 740 741 # Make sure we can't unpack. 742 with self.assertRaisesRegex(TypeError, 'unpack'): 743 x, y, z = Point3D(4, 5, 6) 744 745 # Make sure another class with the same field names isn't 746 # equal. 747 @dataclass 748 class Point3Dv1: 749 x: int = 0 750 y: int = 0 751 z: int = 0 752 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1()) 753 754 def test_function_annotations(self): 755 # Some dummy class and instance to use as a default. 756 class F: 757 pass 758 f = F() 759 760 def validate_class(cls): 761 # First, check __annotations__, even though they're not 762 # function annotations. 763 self.assertEqual(cls.__annotations__['i'], int) 764 self.assertEqual(cls.__annotations__['j'], str) 765 self.assertEqual(cls.__annotations__['k'], F) 766 self.assertEqual(cls.__annotations__['l'], float) 767 self.assertEqual(cls.__annotations__['z'], complex) 768 769 # Verify __init__. 770 771 signature = inspect.signature(cls.__init__) 772 # Check the return type, should be None. 773 self.assertIs(signature.return_annotation, None) 774 775 # Check each parameter. 776 params = iter(signature.parameters.values()) 777 param = next(params) 778 # This is testing an internal name, and probably shouldn't be tested. 779 self.assertEqual(param.name, 'self') 780 param = next(params) 781 self.assertEqual(param.name, 'i') 782 self.assertIs (param.annotation, int) 783 self.assertEqual(param.default, inspect.Parameter.empty) 784 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 785 param = next(params) 786 self.assertEqual(param.name, 'j') 787 self.assertIs (param.annotation, str) 788 self.assertEqual(param.default, inspect.Parameter.empty) 789 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 790 param = next(params) 791 self.assertEqual(param.name, 'k') 792 self.assertIs (param.annotation, F) 793 # Don't test for the default, since it's set to MISSING. 794 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 795 param = next(params) 796 self.assertEqual(param.name, 'l') 797 self.assertIs (param.annotation, float) 798 # Don't test for the default, since it's set to MISSING. 799 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 800 self.assertRaises(StopIteration, next, params) 801 802 803 @dataclass 804 class C: 805 i: int 806 j: str 807 k: F = f 808 l: float=field(default=None) 809 z: complex=field(default=3+4j, init=False) 810 811 validate_class(C) 812 813 # Now repeat with __hash__. 814 @dataclass(frozen=True, unsafe_hash=True) 815 class C: 816 i: int 817 j: str 818 k: F = f 819 l: float=field(default=None) 820 z: complex=field(default=3+4j, init=False) 821 822 validate_class(C) 823 824 def test_missing_default(self): 825 # Test that MISSING works the same as a default not being 826 # specified. 827 @dataclass 828 class C: 829 x: int=field(default=MISSING) 830 with self.assertRaisesRegex(TypeError, 831 r'__init__\(\) missing 1 required ' 832 'positional argument'): 833 C() 834 self.assertNotIn('x', C.__dict__) 835 836 @dataclass 837 class D: 838 x: int 839 with self.assertRaisesRegex(TypeError, 840 r'__init__\(\) missing 1 required ' 841 'positional argument'): 842 D() 843 self.assertNotIn('x', D.__dict__) 844 845 def test_missing_default_factory(self): 846 # Test that MISSING works the same as a default factory not 847 # being specified (which is really the same as a default not 848 # being specified, too). 849 @dataclass 850 class C: 851 x: int=field(default_factory=MISSING) 852 with self.assertRaisesRegex(TypeError, 853 r'__init__\(\) missing 1 required ' 854 'positional argument'): 855 C() 856 self.assertNotIn('x', C.__dict__) 857 858 @dataclass 859 class D: 860 x: int=field(default=MISSING, default_factory=MISSING) 861 with self.assertRaisesRegex(TypeError, 862 r'__init__\(\) missing 1 required ' 863 'positional argument'): 864 D() 865 self.assertNotIn('x', D.__dict__) 866 867 def test_missing_repr(self): 868 self.assertIn('MISSING_TYPE object', repr(MISSING)) 869 870 def test_dont_include_other_annotations(self): 871 @dataclass 872 class C: 873 i: int 874 def foo(self) -> int: 875 return 4 876 @property 877 def bar(self) -> int: 878 return 5 879 self.assertEqual(list(C.__annotations__), ['i']) 880 self.assertEqual(C(10).foo(), 4) 881 self.assertEqual(C(10).bar, 5) 882 self.assertEqual(C(10).i, 10) 883 884 def test_post_init(self): 885 # Just make sure it gets called 886 @dataclass 887 class C: 888 def __post_init__(self): 889 raise CustomError() 890 with self.assertRaises(CustomError): 891 C() 892 893 @dataclass 894 class C: 895 i: int = 10 896 def __post_init__(self): 897 if self.i == 10: 898 raise CustomError() 899 with self.assertRaises(CustomError): 900 C() 901 # post-init gets called, but doesn't raise. This is just 902 # checking that self is used correctly. 903 C(5) 904 905 # If there's not an __init__, then post-init won't get called. 906 @dataclass(init=False) 907 class C: 908 def __post_init__(self): 909 raise CustomError() 910 # Creating the class won't raise 911 C() 912 913 @dataclass 914 class C: 915 x: int = 0 916 def __post_init__(self): 917 self.x *= 2 918 self.assertEqual(C().x, 0) 919 self.assertEqual(C(2).x, 4) 920 921 # Make sure that if we're frozen, post-init can't set 922 # attributes. 923 @dataclass(frozen=True) 924 class C: 925 x: int = 0 926 def __post_init__(self): 927 self.x *= 2 928 with self.assertRaises(FrozenInstanceError): 929 C() 930 931 def test_post_init_super(self): 932 # Make sure super() post-init isn't called by default. 933 class B: 934 def __post_init__(self): 935 raise CustomError() 936 937 @dataclass 938 class C(B): 939 def __post_init__(self): 940 self.x = 5 941 942 self.assertEqual(C().x, 5) 943 944 # Now call super(), and it will raise. 945 @dataclass 946 class C(B): 947 def __post_init__(self): 948 super().__post_init__() 949 950 with self.assertRaises(CustomError): 951 C() 952 953 # Make sure post-init is called, even if not defined in our 954 # class. 955 @dataclass 956 class C(B): 957 pass 958 959 with self.assertRaises(CustomError): 960 C() 961 962 def test_post_init_staticmethod(self): 963 flag = False 964 @dataclass 965 class C: 966 x: int 967 y: int 968 @staticmethod 969 def __post_init__(): 970 nonlocal flag 971 flag = True 972 973 self.assertFalse(flag) 974 c = C(3, 4) 975 self.assertEqual((c.x, c.y), (3, 4)) 976 self.assertTrue(flag) 977 978 def test_post_init_classmethod(self): 979 @dataclass 980 class C: 981 flag = False 982 x: int 983 y: int 984 @classmethod 985 def __post_init__(cls): 986 cls.flag = True 987 988 self.assertFalse(C.flag) 989 c = C(3, 4) 990 self.assertEqual((c.x, c.y), (3, 4)) 991 self.assertTrue(C.flag) 992 993 def test_class_var(self): 994 # Make sure ClassVars are ignored in __init__, __repr__, etc. 995 @dataclass 996 class C: 997 x: int 998 y: int = 10 999 z: ClassVar[int] = 1000 1000 w: ClassVar[int] = 2000 1001 t: ClassVar[int] = 3000 1002 s: ClassVar = 4000 1003 1004 c = C(5) 1005 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)') 1006 self.assertEqual(len(fields(C)), 2) # We have 2 fields. 1007 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars. 1008 self.assertEqual(c.z, 1000) 1009 self.assertEqual(c.w, 2000) 1010 self.assertEqual(c.t, 3000) 1011 self.assertEqual(c.s, 4000) 1012 C.z += 1 1013 self.assertEqual(c.z, 1001) 1014 c = C(20) 1015 self.assertEqual((c.x, c.y), (20, 10)) 1016 self.assertEqual(c.z, 1001) 1017 self.assertEqual(c.w, 2000) 1018 self.assertEqual(c.t, 3000) 1019 self.assertEqual(c.s, 4000) 1020 1021 def test_class_var_no_default(self): 1022 # If a ClassVar has no default value, it should not be set on the class. 1023 @dataclass 1024 class C: 1025 x: ClassVar[int] 1026 1027 self.assertNotIn('x', C.__dict__) 1028 1029 def test_class_var_default_factory(self): 1030 # It makes no sense for a ClassVar to have a default factory. When 1031 # would it be called? Call it yourself, since it's class-wide. 1032 with self.assertRaisesRegex(TypeError, 1033 'cannot have a default factory'): 1034 @dataclass 1035 class C: 1036 x: ClassVar[int] = field(default_factory=int) 1037 1038 self.assertNotIn('x', C.__dict__) 1039 1040 def test_class_var_with_default(self): 1041 # If a ClassVar has a default value, it should be set on the class. 1042 @dataclass 1043 class C: 1044 x: ClassVar[int] = 10 1045 self.assertEqual(C.x, 10) 1046 1047 @dataclass 1048 class C: 1049 x: ClassVar[int] = field(default=10) 1050 self.assertEqual(C.x, 10) 1051 1052 def test_class_var_frozen(self): 1053 # Make sure ClassVars work even if we're frozen. 1054 @dataclass(frozen=True) 1055 class C: 1056 x: int 1057 y: int = 10 1058 z: ClassVar[int] = 1000 1059 w: ClassVar[int] = 2000 1060 t: ClassVar[int] = 3000 1061 1062 c = C(5) 1063 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)') 1064 self.assertEqual(len(fields(C)), 2) # We have 2 fields 1065 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars 1066 self.assertEqual(c.z, 1000) 1067 self.assertEqual(c.w, 2000) 1068 self.assertEqual(c.t, 3000) 1069 # We can still modify the ClassVar, it's only instances that are 1070 # frozen. 1071 C.z += 1 1072 self.assertEqual(c.z, 1001) 1073 c = C(20) 1074 self.assertEqual((c.x, c.y), (20, 10)) 1075 self.assertEqual(c.z, 1001) 1076 self.assertEqual(c.w, 2000) 1077 self.assertEqual(c.t, 3000) 1078 1079 def test_init_var_no_default(self): 1080 # If an InitVar has no default value, it should not be set on the class. 1081 @dataclass 1082 class C: 1083 x: InitVar[int] 1084 1085 self.assertNotIn('x', C.__dict__) 1086 1087 def test_init_var_default_factory(self): 1088 # It makes no sense for an InitVar to have a default factory. When 1089 # would it be called? Call it yourself, since it's class-wide. 1090 with self.assertRaisesRegex(TypeError, 1091 'cannot have a default factory'): 1092 @dataclass 1093 class C: 1094 x: InitVar[int] = field(default_factory=int) 1095 1096 self.assertNotIn('x', C.__dict__) 1097 1098 def test_init_var_with_default(self): 1099 # If an InitVar has a default value, it should be set on the class. 1100 @dataclass 1101 class C: 1102 x: InitVar[int] = 10 1103 self.assertEqual(C.x, 10) 1104 1105 @dataclass 1106 class C: 1107 x: InitVar[int] = field(default=10) 1108 self.assertEqual(C.x, 10) 1109 1110 def test_init_var(self): 1111 @dataclass 1112 class C: 1113 x: int = None 1114 init_param: InitVar[int] = None 1115 1116 def __post_init__(self, init_param): 1117 if self.x is None: 1118 self.x = init_param*2 1119 1120 c = C(init_param=10) 1121 self.assertEqual(c.x, 20) 1122 1123 def test_init_var_preserve_type(self): 1124 self.assertEqual(InitVar[int].type, int) 1125 1126 # Make sure the repr is correct. 1127 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]') 1128 self.assertEqual(repr(InitVar[List[int]]), 1129 'dataclasses.InitVar[typing.List[int]]') 1130 self.assertEqual(repr(InitVar[list[int]]), 1131 'dataclasses.InitVar[list[int]]') 1132 self.assertEqual(repr(InitVar[int|str]), 1133 'dataclasses.InitVar[int | str]') 1134 1135 def test_init_var_inheritance(self): 1136 # Note that this deliberately tests that a dataclass need not 1137 # have a __post_init__ function if it has an InitVar field. 1138 # It could just be used in a derived class, as shown here. 1139 @dataclass 1140 class Base: 1141 x: int 1142 init_base: InitVar[int] 1143 1144 # We can instantiate by passing the InitVar, even though 1145 # it's not used. 1146 b = Base(0, 10) 1147 self.assertEqual(vars(b), {'x': 0}) 1148 1149 @dataclass 1150 class C(Base): 1151 y: int 1152 init_derived: InitVar[int] 1153 1154 def __post_init__(self, init_base, init_derived): 1155 self.x = self.x + init_base 1156 self.y = self.y + init_derived 1157 1158 c = C(10, 11, 50, 51) 1159 self.assertEqual(vars(c), {'x': 21, 'y': 101}) 1160 1161 def test_default_factory(self): 1162 # Test a factory that returns a new list. 1163 @dataclass 1164 class C: 1165 x: int 1166 y: list = field(default_factory=list) 1167 1168 c0 = C(3) 1169 c1 = C(3) 1170 self.assertEqual(c0.x, 3) 1171 self.assertEqual(c0.y, []) 1172 self.assertEqual(c0, c1) 1173 self.assertIsNot(c0.y, c1.y) 1174 self.assertEqual(astuple(C(5, [1])), (5, [1])) 1175 1176 # Test a factory that returns a shared list. 1177 l = [] 1178 @dataclass 1179 class C: 1180 x: int 1181 y: list = field(default_factory=lambda: l) 1182 1183 c0 = C(3) 1184 c1 = C(3) 1185 self.assertEqual(c0.x, 3) 1186 self.assertEqual(c0.y, []) 1187 self.assertEqual(c0, c1) 1188 self.assertIs(c0.y, c1.y) 1189 self.assertEqual(astuple(C(5, [1])), (5, [1])) 1190 1191 # Test various other field flags. 1192 # repr 1193 @dataclass 1194 class C: 1195 x: list = field(default_factory=list, repr=False) 1196 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()') 1197 self.assertEqual(C().x, []) 1198 1199 # hash 1200 @dataclass(unsafe_hash=True) 1201 class C: 1202 x: list = field(default_factory=list, hash=False) 1203 self.assertEqual(astuple(C()), ([],)) 1204 self.assertEqual(hash(C()), hash(())) 1205 1206 # init (see also test_default_factory_with_no_init) 1207 @dataclass 1208 class C: 1209 x: list = field(default_factory=list, init=False) 1210 self.assertEqual(astuple(C()), ([],)) 1211 1212 # compare 1213 @dataclass 1214 class C: 1215 x: list = field(default_factory=list, compare=False) 1216 self.assertEqual(C(), C([1])) 1217 1218 def test_default_factory_with_no_init(self): 1219 # We need a factory with a side effect. 1220 factory = Mock() 1221 1222 @dataclass 1223 class C: 1224 x: list = field(default_factory=factory, init=False) 1225 1226 # Make sure the default factory is called for each new instance. 1227 C().x 1228 self.assertEqual(factory.call_count, 1) 1229 C().x 1230 self.assertEqual(factory.call_count, 2) 1231 1232 def test_default_factory_not_called_if_value_given(self): 1233 # We need a factory that we can test if it's been called. 1234 factory = Mock() 1235 1236 @dataclass 1237 class C: 1238 x: int = field(default_factory=factory) 1239 1240 # Make sure that if a field has a default factory function, 1241 # it's not called if a value is specified. 1242 C().x 1243 self.assertEqual(factory.call_count, 1) 1244 self.assertEqual(C(10).x, 10) 1245 self.assertEqual(factory.call_count, 1) 1246 C().x 1247 self.assertEqual(factory.call_count, 2) 1248 1249 def test_default_factory_derived(self): 1250 # See bpo-32896. 1251 @dataclass 1252 class Foo: 1253 x: dict = field(default_factory=dict) 1254 1255 @dataclass 1256 class Bar(Foo): 1257 y: int = 1 1258 1259 self.assertEqual(Foo().x, {}) 1260 self.assertEqual(Bar().x, {}) 1261 self.assertEqual(Bar().y, 1) 1262 1263 @dataclass 1264 class Baz(Foo): 1265 pass 1266 self.assertEqual(Baz().x, {}) 1267 1268 def test_intermediate_non_dataclass(self): 1269 # Test that an intermediate class that defines 1270 # annotations does not define fields. 1271 1272 @dataclass 1273 class A: 1274 x: int 1275 1276 class B(A): 1277 y: int 1278 1279 @dataclass 1280 class C(B): 1281 z: int 1282 1283 c = C(1, 3) 1284 self.assertEqual((c.x, c.z), (1, 3)) 1285 1286 # .y was not initialized. 1287 with self.assertRaisesRegex(AttributeError, 1288 'object has no attribute'): 1289 c.y 1290 1291 # And if we again derive a non-dataclass, no fields are added. 1292 class D(C): 1293 t: int 1294 d = D(4, 5) 1295 self.assertEqual((d.x, d.z), (4, 5)) 1296 1297 def test_classvar_default_factory(self): 1298 # It's an error for a ClassVar to have a factory function. 1299 with self.assertRaisesRegex(TypeError, 1300 'cannot have a default factory'): 1301 @dataclass 1302 class C: 1303 x: ClassVar[int] = field(default_factory=int) 1304 1305 def test_is_dataclass(self): 1306 class NotDataClass: 1307 pass 1308 1309 self.assertFalse(is_dataclass(0)) 1310 self.assertFalse(is_dataclass(int)) 1311 self.assertFalse(is_dataclass(NotDataClass)) 1312 self.assertFalse(is_dataclass(NotDataClass())) 1313 1314 @dataclass 1315 class C: 1316 x: int 1317 1318 @dataclass 1319 class D: 1320 d: C 1321 e: int 1322 1323 c = C(10) 1324 d = D(c, 4) 1325 1326 self.assertTrue(is_dataclass(C)) 1327 self.assertTrue(is_dataclass(c)) 1328 self.assertFalse(is_dataclass(c.x)) 1329 self.assertTrue(is_dataclass(d.d)) 1330 self.assertFalse(is_dataclass(d.e)) 1331 1332 def test_is_dataclass_when_getattr_always_returns(self): 1333 # See bpo-37868. 1334 class A: 1335 def __getattr__(self, key): 1336 return 0 1337 self.assertFalse(is_dataclass(A)) 1338 a = A() 1339 1340 # Also test for an instance attribute. 1341 class B: 1342 pass 1343 b = B() 1344 b.__dataclass_fields__ = [] 1345 1346 for obj in a, b: 1347 with self.subTest(obj=obj): 1348 self.assertFalse(is_dataclass(obj)) 1349 1350 # Indirect tests for _is_dataclass_instance(). 1351 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1352 asdict(obj) 1353 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1354 astuple(obj) 1355 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1356 replace(obj, x=0) 1357 1358 def test_is_dataclass_genericalias(self): 1359 @dataclass 1360 class A(types.GenericAlias): 1361 origin: type 1362 args: type 1363 self.assertTrue(is_dataclass(A)) 1364 a = A(list, int) 1365 self.assertTrue(is_dataclass(type(a))) 1366 self.assertTrue(is_dataclass(a)) 1367 1368 1369 def test_helper_fields_with_class_instance(self): 1370 # Check that we can call fields() on either a class or instance, 1371 # and get back the same thing. 1372 @dataclass 1373 class C: 1374 x: int 1375 y: float 1376 1377 self.assertEqual(fields(C), fields(C(0, 0.0))) 1378 1379 def test_helper_fields_exception(self): 1380 # Check that TypeError is raised if not passed a dataclass or 1381 # instance. 1382 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1383 fields(0) 1384 1385 class C: pass 1386 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1387 fields(C) 1388 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1389 fields(C()) 1390 1391 def test_helper_asdict(self): 1392 # Basic tests for asdict(), it should return a new dictionary. 1393 @dataclass 1394 class C: 1395 x: int 1396 y: int 1397 c = C(1, 2) 1398 1399 self.assertEqual(asdict(c), {'x': 1, 'y': 2}) 1400 self.assertEqual(asdict(c), asdict(c)) 1401 self.assertIsNot(asdict(c), asdict(c)) 1402 c.x = 42 1403 self.assertEqual(asdict(c), {'x': 42, 'y': 2}) 1404 self.assertIs(type(asdict(c)), dict) 1405 1406 def test_helper_asdict_raises_on_classes(self): 1407 # asdict() should raise on a class object. 1408 @dataclass 1409 class C: 1410 x: int 1411 y: int 1412 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1413 asdict(C) 1414 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1415 asdict(int) 1416 1417 def test_helper_asdict_copy_values(self): 1418 @dataclass 1419 class C: 1420 x: int 1421 y: List[int] = field(default_factory=list) 1422 initial = [] 1423 c = C(1, initial) 1424 d = asdict(c) 1425 self.assertEqual(d['y'], initial) 1426 self.assertIsNot(d['y'], initial) 1427 c = C(1) 1428 d = asdict(c) 1429 d['y'].append(1) 1430 self.assertEqual(c.y, []) 1431 1432 def test_helper_asdict_nested(self): 1433 @dataclass 1434 class UserId: 1435 token: int 1436 group: int 1437 @dataclass 1438 class User: 1439 name: str 1440 id: UserId 1441 u = User('Joe', UserId(123, 1)) 1442 d = asdict(u) 1443 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}}) 1444 self.assertIsNot(asdict(u), asdict(u)) 1445 u.id.group = 2 1446 self.assertEqual(asdict(u), {'name': 'Joe', 1447 'id': {'token': 123, 'group': 2}}) 1448 1449 def test_helper_asdict_builtin_containers(self): 1450 @dataclass 1451 class User: 1452 name: str 1453 id: int 1454 @dataclass 1455 class GroupList: 1456 id: int 1457 users: List[User] 1458 @dataclass 1459 class GroupTuple: 1460 id: int 1461 users: Tuple[User, ...] 1462 @dataclass 1463 class GroupDict: 1464 id: int 1465 users: Dict[str, User] 1466 a = User('Alice', 1) 1467 b = User('Bob', 2) 1468 gl = GroupList(0, [a, b]) 1469 gt = GroupTuple(0, (a, b)) 1470 gd = GroupDict(0, {'first': a, 'second': b}) 1471 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1}, 1472 {'name': 'Bob', 'id': 2}]}) 1473 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1}, 1474 {'name': 'Bob', 'id': 2})}) 1475 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1}, 1476 'second': {'name': 'Bob', 'id': 2}}}) 1477 1478 def test_helper_asdict_builtin_object_containers(self): 1479 @dataclass 1480 class Child: 1481 d: object 1482 1483 @dataclass 1484 class Parent: 1485 child: Child 1486 1487 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}}) 1488 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}}) 1489 1490 def test_helper_asdict_factory(self): 1491 @dataclass 1492 class C: 1493 x: int 1494 y: int 1495 c = C(1, 2) 1496 d = asdict(c, dict_factory=OrderedDict) 1497 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)])) 1498 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict)) 1499 c.x = 42 1500 d = asdict(c, dict_factory=OrderedDict) 1501 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)])) 1502 self.assertIs(type(d), OrderedDict) 1503 1504 def test_helper_asdict_namedtuple(self): 1505 T = namedtuple('T', 'a b c') 1506 @dataclass 1507 class C: 1508 x: str 1509 y: T 1510 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 1511 1512 d = asdict(c) 1513 self.assertEqual(d, {'x': 'outer', 1514 'y': T(1, 1515 {'x': 'inner', 1516 'y': T(11, 12, 13)}, 1517 2), 1518 } 1519 ) 1520 1521 # Now with a dict_factory. OrderedDict is convenient, but 1522 # since it compares to dicts, we also need to have separate 1523 # assertIs tests. 1524 d = asdict(c, dict_factory=OrderedDict) 1525 self.assertEqual(d, {'x': 'outer', 1526 'y': T(1, 1527 {'x': 'inner', 1528 'y': T(11, 12, 13)}, 1529 2), 1530 } 1531 ) 1532 1533 # Make sure that the returned dicts are actually OrderedDicts. 1534 self.assertIs(type(d), OrderedDict) 1535 self.assertIs(type(d['y'][1]), OrderedDict) 1536 1537 def test_helper_asdict_namedtuple_key(self): 1538 # Ensure that a field that contains a dict which has a 1539 # namedtuple as a key works with asdict(). 1540 1541 @dataclass 1542 class C: 1543 f: dict 1544 T = namedtuple('T', 'a') 1545 1546 c = C({T('an a'): 0}) 1547 1548 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}}) 1549 1550 def test_helper_asdict_namedtuple_derived(self): 1551 class T(namedtuple('Tbase', 'a')): 1552 def my_a(self): 1553 return self.a 1554 1555 @dataclass 1556 class C: 1557 f: T 1558 1559 t = T(6) 1560 c = C(t) 1561 1562 d = asdict(c) 1563 self.assertEqual(d, {'f': T(a=6)}) 1564 # Make sure that t has been copied, not used directly. 1565 self.assertIsNot(d['f'], t) 1566 self.assertEqual(d['f'].my_a(), 6) 1567 1568 def test_helper_astuple(self): 1569 # Basic tests for astuple(), it should return a new tuple. 1570 @dataclass 1571 class C: 1572 x: int 1573 y: int = 0 1574 c = C(1) 1575 1576 self.assertEqual(astuple(c), (1, 0)) 1577 self.assertEqual(astuple(c), astuple(c)) 1578 self.assertIsNot(astuple(c), astuple(c)) 1579 c.y = 42 1580 self.assertEqual(astuple(c), (1, 42)) 1581 self.assertIs(type(astuple(c)), tuple) 1582 1583 def test_helper_astuple_raises_on_classes(self): 1584 # astuple() should raise on a class object. 1585 @dataclass 1586 class C: 1587 x: int 1588 y: int 1589 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1590 astuple(C) 1591 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1592 astuple(int) 1593 1594 def test_helper_astuple_copy_values(self): 1595 @dataclass 1596 class C: 1597 x: int 1598 y: List[int] = field(default_factory=list) 1599 initial = [] 1600 c = C(1, initial) 1601 t = astuple(c) 1602 self.assertEqual(t[1], initial) 1603 self.assertIsNot(t[1], initial) 1604 c = C(1) 1605 t = astuple(c) 1606 t[1].append(1) 1607 self.assertEqual(c.y, []) 1608 1609 def test_helper_astuple_nested(self): 1610 @dataclass 1611 class UserId: 1612 token: int 1613 group: int 1614 @dataclass 1615 class User: 1616 name: str 1617 id: UserId 1618 u = User('Joe', UserId(123, 1)) 1619 t = astuple(u) 1620 self.assertEqual(t, ('Joe', (123, 1))) 1621 self.assertIsNot(astuple(u), astuple(u)) 1622 u.id.group = 2 1623 self.assertEqual(astuple(u), ('Joe', (123, 2))) 1624 1625 def test_helper_astuple_builtin_containers(self): 1626 @dataclass 1627 class User: 1628 name: str 1629 id: int 1630 @dataclass 1631 class GroupList: 1632 id: int 1633 users: List[User] 1634 @dataclass 1635 class GroupTuple: 1636 id: int 1637 users: Tuple[User, ...] 1638 @dataclass 1639 class GroupDict: 1640 id: int 1641 users: Dict[str, User] 1642 a = User('Alice', 1) 1643 b = User('Bob', 2) 1644 gl = GroupList(0, [a, b]) 1645 gt = GroupTuple(0, (a, b)) 1646 gd = GroupDict(0, {'first': a, 'second': b}) 1647 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)])) 1648 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2)))) 1649 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)})) 1650 1651 def test_helper_astuple_builtin_object_containers(self): 1652 @dataclass 1653 class Child: 1654 d: object 1655 1656 @dataclass 1657 class Parent: 1658 child: Child 1659 1660 self.assertEqual(astuple(Parent(Child([1]))), (([1],),)) 1661 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),)) 1662 1663 def test_helper_astuple_factory(self): 1664 @dataclass 1665 class C: 1666 x: int 1667 y: int 1668 NT = namedtuple('NT', 'x y') 1669 def nt(lst): 1670 return NT(*lst) 1671 c = C(1, 2) 1672 t = astuple(c, tuple_factory=nt) 1673 self.assertEqual(t, NT(1, 2)) 1674 self.assertIsNot(t, astuple(c, tuple_factory=nt)) 1675 c.x = 42 1676 t = astuple(c, tuple_factory=nt) 1677 self.assertEqual(t, NT(42, 2)) 1678 self.assertIs(type(t), NT) 1679 1680 def test_helper_astuple_namedtuple(self): 1681 T = namedtuple('T', 'a b c') 1682 @dataclass 1683 class C: 1684 x: str 1685 y: T 1686 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 1687 1688 t = astuple(c) 1689 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2))) 1690 1691 # Now, using a tuple_factory. list is convenient here. 1692 t = astuple(c, tuple_factory=list) 1693 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) 1694 1695 def test_dynamic_class_creation(self): 1696 cls_dict = {'__annotations__': {'x': int, 'y': int}, 1697 } 1698 1699 # Create the class. 1700 cls = type('C', (), cls_dict) 1701 1702 # Make it a dataclass. 1703 cls1 = dataclass(cls) 1704 1705 self.assertEqual(cls1, cls) 1706 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2}) 1707 1708 def test_dynamic_class_creation_using_field(self): 1709 cls_dict = {'__annotations__': {'x': int, 'y': int}, 1710 'y': field(default=5), 1711 } 1712 1713 # Create the class. 1714 cls = type('C', (), cls_dict) 1715 1716 # Make it a dataclass. 1717 cls1 = dataclass(cls) 1718 1719 self.assertEqual(cls1, cls) 1720 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5}) 1721 1722 def test_init_in_order(self): 1723 @dataclass 1724 class C: 1725 a: int 1726 b: int = field() 1727 c: list = field(default_factory=list, init=False) 1728 d: list = field(default_factory=list) 1729 e: int = field(default=4, init=False) 1730 f: int = 4 1731 1732 calls = [] 1733 def setattr(self, name, value): 1734 calls.append((name, value)) 1735 1736 C.__setattr__ = setattr 1737 c = C(0, 1) 1738 self.assertEqual(('a', 0), calls[0]) 1739 self.assertEqual(('b', 1), calls[1]) 1740 self.assertEqual(('c', []), calls[2]) 1741 self.assertEqual(('d', []), calls[3]) 1742 self.assertNotIn(('e', 4), calls) 1743 self.assertEqual(('f', 4), calls[4]) 1744 1745 def test_items_in_dicts(self): 1746 @dataclass 1747 class C: 1748 a: int 1749 b: list = field(default_factory=list, init=False) 1750 c: list = field(default_factory=list) 1751 d: int = field(default=4, init=False) 1752 e: int = 0 1753 1754 c = C(0) 1755 # Class dict 1756 self.assertNotIn('a', C.__dict__) 1757 self.assertNotIn('b', C.__dict__) 1758 self.assertNotIn('c', C.__dict__) 1759 self.assertIn('d', C.__dict__) 1760 self.assertEqual(C.d, 4) 1761 self.assertIn('e', C.__dict__) 1762 self.assertEqual(C.e, 0) 1763 # Instance dict 1764 self.assertIn('a', c.__dict__) 1765 self.assertEqual(c.a, 0) 1766 self.assertIn('b', c.__dict__) 1767 self.assertEqual(c.b, []) 1768 self.assertIn('c', c.__dict__) 1769 self.assertEqual(c.c, []) 1770 self.assertNotIn('d', c.__dict__) 1771 self.assertIn('e', c.__dict__) 1772 self.assertEqual(c.e, 0) 1773 1774 def test_alternate_classmethod_constructor(self): 1775 # Since __post_init__ can't take params, use a classmethod 1776 # alternate constructor. This is mostly an example to show 1777 # how to use this technique. 1778 @dataclass 1779 class C: 1780 x: int 1781 @classmethod 1782 def from_file(cls, filename): 1783 # In a real example, create a new instance 1784 # and populate 'x' from contents of a file. 1785 value_in_file = 20 1786 return cls(value_in_file) 1787 1788 self.assertEqual(C.from_file('filename').x, 20) 1789 1790 def test_field_metadata_default(self): 1791 # Make sure the default metadata is read-only and of 1792 # zero length. 1793 @dataclass 1794 class C: 1795 i: int 1796 1797 self.assertFalse(fields(C)[0].metadata) 1798 self.assertEqual(len(fields(C)[0].metadata), 0) 1799 with self.assertRaisesRegex(TypeError, 1800 'does not support item assignment'): 1801 fields(C)[0].metadata['test'] = 3 1802 1803 def test_field_metadata_mapping(self): 1804 # Make sure only a mapping can be passed as metadata 1805 # zero length. 1806 with self.assertRaises(TypeError): 1807 @dataclass 1808 class C: 1809 i: int = field(metadata=0) 1810 1811 # Make sure an empty dict works. 1812 d = {} 1813 @dataclass 1814 class C: 1815 i: int = field(metadata=d) 1816 self.assertFalse(fields(C)[0].metadata) 1817 self.assertEqual(len(fields(C)[0].metadata), 0) 1818 # Update should work (see bpo-35960). 1819 d['foo'] = 1 1820 self.assertEqual(len(fields(C)[0].metadata), 1) 1821 self.assertEqual(fields(C)[0].metadata['foo'], 1) 1822 with self.assertRaisesRegex(TypeError, 1823 'does not support item assignment'): 1824 fields(C)[0].metadata['test'] = 3 1825 1826 # Make sure a non-empty dict works. 1827 d = {'test': 10, 'bar': '42', 3: 'three'} 1828 @dataclass 1829 class C: 1830 i: int = field(metadata=d) 1831 self.assertEqual(len(fields(C)[0].metadata), 3) 1832 self.assertEqual(fields(C)[0].metadata['test'], 10) 1833 self.assertEqual(fields(C)[0].metadata['bar'], '42') 1834 self.assertEqual(fields(C)[0].metadata[3], 'three') 1835 # Update should work. 1836 d['foo'] = 1 1837 self.assertEqual(len(fields(C)[0].metadata), 4) 1838 self.assertEqual(fields(C)[0].metadata['foo'], 1) 1839 with self.assertRaises(KeyError): 1840 # Non-existent key. 1841 fields(C)[0].metadata['baz'] 1842 with self.assertRaisesRegex(TypeError, 1843 'does not support item assignment'): 1844 fields(C)[0].metadata['test'] = 3 1845 1846 def test_field_metadata_custom_mapping(self): 1847 # Try a custom mapping. 1848 class SimpleNameSpace: 1849 def __init__(self, **kw): 1850 self.__dict__.update(kw) 1851 1852 def __getitem__(self, item): 1853 if item == 'xyzzy': 1854 return 'plugh' 1855 return getattr(self, item) 1856 1857 def __len__(self): 1858 return self.__dict__.__len__() 1859 1860 @dataclass 1861 class C: 1862 i: int = field(metadata=SimpleNameSpace(a=10)) 1863 1864 self.assertEqual(len(fields(C)[0].metadata), 1) 1865 self.assertEqual(fields(C)[0].metadata['a'], 10) 1866 with self.assertRaises(AttributeError): 1867 fields(C)[0].metadata['b'] 1868 # Make sure we're still talking to our custom mapping. 1869 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh') 1870 1871 def test_generic_dataclasses(self): 1872 T = TypeVar('T') 1873 1874 @dataclass 1875 class LabeledBox(Generic[T]): 1876 content: T 1877 label: str = '<unknown>' 1878 1879 box = LabeledBox(42) 1880 self.assertEqual(box.content, 42) 1881 self.assertEqual(box.label, '<unknown>') 1882 1883 # Subscripting the resulting class should work, etc. 1884 Alias = List[LabeledBox[int]] 1885 1886 def test_generic_extending(self): 1887 S = TypeVar('S') 1888 T = TypeVar('T') 1889 1890 @dataclass 1891 class Base(Generic[T, S]): 1892 x: T 1893 y: S 1894 1895 @dataclass 1896 class DataDerived(Base[int, T]): 1897 new_field: str 1898 Alias = DataDerived[str] 1899 c = Alias(0, 'test1', 'test2') 1900 self.assertEqual(astuple(c), (0, 'test1', 'test2')) 1901 1902 class NonDataDerived(Base[int, T]): 1903 def new_method(self): 1904 return self.y 1905 Alias = NonDataDerived[float] 1906 c = Alias(10, 1.0) 1907 self.assertEqual(c.new_method(), 1.0) 1908 1909 def test_generic_dynamic(self): 1910 T = TypeVar('T') 1911 1912 @dataclass 1913 class Parent(Generic[T]): 1914 x: T 1915 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)], 1916 bases=(Parent[int], Generic[T]), namespace={'other': 42}) 1917 self.assertIs(Child[int](1, 2).z, None) 1918 self.assertEqual(Child[int](1, 2, 3).z, 3) 1919 self.assertEqual(Child[int](1, 2, 3).other, 42) 1920 # Check that type aliases work correctly. 1921 Alias = Child[T] 1922 self.assertEqual(Alias[int](1, 2).x, 1) 1923 # Check MRO resolution. 1924 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object)) 1925 1926 def test_dataclasses_pickleable(self): 1927 global P, Q, R 1928 @dataclass 1929 class P: 1930 x: int 1931 y: int = 0 1932 @dataclass 1933 class Q: 1934 x: int 1935 y: int = field(default=0, init=False) 1936 @dataclass 1937 class R: 1938 x: int 1939 y: List[int] = field(default_factory=list) 1940 q = Q(1) 1941 q.y = 2 1942 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])] 1943 for sample in samples: 1944 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1945 with self.subTest(sample=sample, proto=proto): 1946 new_sample = pickle.loads(pickle.dumps(sample, proto)) 1947 self.assertEqual(sample.x, new_sample.x) 1948 self.assertEqual(sample.y, new_sample.y) 1949 self.assertIsNot(sample, new_sample) 1950 new_sample.x = 42 1951 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto)) 1952 self.assertEqual(new_sample.x, another_new_sample.x) 1953 self.assertEqual(sample.y, another_new_sample.y) 1954 1955 def test_dataclasses_qualnames(self): 1956 @dataclass(order=True, unsafe_hash=True, frozen=True) 1957 class A: 1958 x: int 1959 y: int 1960 1961 self.assertEqual(A.__init__.__name__, "__init__") 1962 for function in ( 1963 '__eq__', 1964 '__lt__', 1965 '__le__', 1966 '__gt__', 1967 '__ge__', 1968 '__hash__', 1969 '__init__', 1970 '__repr__', 1971 '__setattr__', 1972 '__delattr__', 1973 ): 1974 self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}") 1975 1976 with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"): 1977 A() 1978 1979 1980class TestFieldNoAnnotation(unittest.TestCase): 1981 def test_field_without_annotation(self): 1982 with self.assertRaisesRegex(TypeError, 1983 "'f' is a field but has no type annotation"): 1984 @dataclass 1985 class C: 1986 f = field() 1987 1988 def test_field_without_annotation_but_annotation_in_base(self): 1989 @dataclass 1990 class B: 1991 f: int 1992 1993 with self.assertRaisesRegex(TypeError, 1994 "'f' is a field but has no type annotation"): 1995 # This is still an error: make sure we don't pick up the 1996 # type annotation in the base class. 1997 @dataclass 1998 class C(B): 1999 f = field() 2000 2001 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self): 2002 # Same test, but with the base class not a dataclass. 2003 class B: 2004 f: int 2005 2006 with self.assertRaisesRegex(TypeError, 2007 "'f' is a field but has no type annotation"): 2008 # This is still an error: make sure we don't pick up the 2009 # type annotation in the base class. 2010 @dataclass 2011 class C(B): 2012 f = field() 2013 2014 2015class TestDocString(unittest.TestCase): 2016 def assertDocStrEqual(self, a, b): 2017 # Because 3.6 and 3.7 differ in how inspect.signature work 2018 # (see bpo #32108), for the time being just compare them with 2019 # whitespace stripped. 2020 self.assertEqual(a.replace(' ', ''), b.replace(' ', '')) 2021 2022 def test_existing_docstring_not_overridden(self): 2023 @dataclass 2024 class C: 2025 """Lorem ipsum""" 2026 x: int 2027 2028 self.assertEqual(C.__doc__, "Lorem ipsum") 2029 2030 def test_docstring_no_fields(self): 2031 @dataclass 2032 class C: 2033 pass 2034 2035 self.assertDocStrEqual(C.__doc__, "C()") 2036 2037 def test_docstring_one_field(self): 2038 @dataclass 2039 class C: 2040 x: int 2041 2042 self.assertDocStrEqual(C.__doc__, "C(x:int)") 2043 2044 def test_docstring_two_fields(self): 2045 @dataclass 2046 class C: 2047 x: int 2048 y: int 2049 2050 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)") 2051 2052 def test_docstring_three_fields(self): 2053 @dataclass 2054 class C: 2055 x: int 2056 y: int 2057 z: str 2058 2059 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)") 2060 2061 def test_docstring_one_field_with_default(self): 2062 @dataclass 2063 class C: 2064 x: int = 3 2065 2066 self.assertDocStrEqual(C.__doc__, "C(x:int=3)") 2067 2068 def test_docstring_one_field_with_default_none(self): 2069 @dataclass 2070 class C: 2071 x: Union[int, type(None)] = None 2072 2073 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)") 2074 2075 def test_docstring_list_field(self): 2076 @dataclass 2077 class C: 2078 x: List[int] 2079 2080 self.assertDocStrEqual(C.__doc__, "C(x:List[int])") 2081 2082 def test_docstring_list_field_with_default_factory(self): 2083 @dataclass 2084 class C: 2085 x: List[int] = field(default_factory=list) 2086 2087 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)") 2088 2089 def test_docstring_deque_field(self): 2090 @dataclass 2091 class C: 2092 x: deque 2093 2094 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)") 2095 2096 def test_docstring_deque_field_with_default_factory(self): 2097 @dataclass 2098 class C: 2099 x: deque = field(default_factory=deque) 2100 2101 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)") 2102 2103 2104class TestInit(unittest.TestCase): 2105 def test_base_has_init(self): 2106 class B: 2107 def __init__(self): 2108 self.z = 100 2109 pass 2110 2111 # Make sure that declaring this class doesn't raise an error. 2112 # The issue is that we can't override __init__ in our class, 2113 # but it should be okay to add __init__ to us if our base has 2114 # an __init__. 2115 @dataclass 2116 class C(B): 2117 x: int = 0 2118 c = C(10) 2119 self.assertEqual(c.x, 10) 2120 self.assertNotIn('z', vars(c)) 2121 2122 # Make sure that if we don't add an init, the base __init__ 2123 # gets called. 2124 @dataclass(init=False) 2125 class C(B): 2126 x: int = 10 2127 c = C() 2128 self.assertEqual(c.x, 10) 2129 self.assertEqual(c.z, 100) 2130 2131 def test_no_init(self): 2132 dataclass(init=False) 2133 class C: 2134 i: int = 0 2135 self.assertEqual(C().i, 0) 2136 2137 dataclass(init=False) 2138 class C: 2139 i: int = 2 2140 def __init__(self): 2141 self.i = 3 2142 self.assertEqual(C().i, 3) 2143 2144 def test_overwriting_init(self): 2145 # If the class has __init__, use it no matter the value of 2146 # init=. 2147 2148 @dataclass 2149 class C: 2150 x: int 2151 def __init__(self, x): 2152 self.x = 2 * x 2153 self.assertEqual(C(3).x, 6) 2154 2155 @dataclass(init=True) 2156 class C: 2157 x: int 2158 def __init__(self, x): 2159 self.x = 2 * x 2160 self.assertEqual(C(4).x, 8) 2161 2162 @dataclass(init=False) 2163 class C: 2164 x: int 2165 def __init__(self, x): 2166 self.x = 2 * x 2167 self.assertEqual(C(5).x, 10) 2168 2169 def test_inherit_from_protocol(self): 2170 # Dataclasses inheriting from protocol should preserve their own `__init__`. 2171 # See bpo-45081. 2172 2173 class P(Protocol): 2174 a: int 2175 2176 @dataclass 2177 class C(P): 2178 a: int 2179 2180 self.assertEqual(C(5).a, 5) 2181 2182 @dataclass 2183 class D(P): 2184 def __init__(self, a): 2185 self.a = a * 2 2186 2187 self.assertEqual(D(5).a, 10) 2188 2189 2190class TestRepr(unittest.TestCase): 2191 def test_repr(self): 2192 @dataclass 2193 class B: 2194 x: int 2195 2196 @dataclass 2197 class C(B): 2198 y: int = 10 2199 2200 o = C(4) 2201 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)') 2202 2203 @dataclass 2204 class D(C): 2205 x: int = 20 2206 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)') 2207 2208 @dataclass 2209 class C: 2210 @dataclass 2211 class D: 2212 i: int 2213 @dataclass 2214 class E: 2215 pass 2216 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)') 2217 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()') 2218 2219 def test_no_repr(self): 2220 # Test a class with no __repr__ and repr=False. 2221 @dataclass(repr=False) 2222 class C: 2223 x: int 2224 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at', 2225 repr(C(3))) 2226 2227 # Test a class with a __repr__ and repr=False. 2228 @dataclass(repr=False) 2229 class C: 2230 x: int 2231 def __repr__(self): 2232 return 'C-class' 2233 self.assertEqual(repr(C(3)), 'C-class') 2234 2235 def test_overwriting_repr(self): 2236 # If the class has __repr__, use it no matter the value of 2237 # repr=. 2238 2239 @dataclass 2240 class C: 2241 x: int 2242 def __repr__(self): 2243 return 'x' 2244 self.assertEqual(repr(C(0)), 'x') 2245 2246 @dataclass(repr=True) 2247 class C: 2248 x: int 2249 def __repr__(self): 2250 return 'x' 2251 self.assertEqual(repr(C(0)), 'x') 2252 2253 @dataclass(repr=False) 2254 class C: 2255 x: int 2256 def __repr__(self): 2257 return 'x' 2258 self.assertEqual(repr(C(0)), 'x') 2259 2260 2261class TestEq(unittest.TestCase): 2262 def test_no_eq(self): 2263 # Test a class with no __eq__ and eq=False. 2264 @dataclass(eq=False) 2265 class C: 2266 x: int 2267 self.assertNotEqual(C(0), C(0)) 2268 c = C(3) 2269 self.assertEqual(c, c) 2270 2271 # Test a class with an __eq__ and eq=False. 2272 @dataclass(eq=False) 2273 class C: 2274 x: int 2275 def __eq__(self, other): 2276 return other == 10 2277 self.assertEqual(C(3), 10) 2278 2279 def test_overwriting_eq(self): 2280 # If the class has __eq__, use it no matter the value of 2281 # eq=. 2282 2283 @dataclass 2284 class C: 2285 x: int 2286 def __eq__(self, other): 2287 return other == 3 2288 self.assertEqual(C(1), 3) 2289 self.assertNotEqual(C(1), 1) 2290 2291 @dataclass(eq=True) 2292 class C: 2293 x: int 2294 def __eq__(self, other): 2295 return other == 4 2296 self.assertEqual(C(1), 4) 2297 self.assertNotEqual(C(1), 1) 2298 2299 @dataclass(eq=False) 2300 class C: 2301 x: int 2302 def __eq__(self, other): 2303 return other == 5 2304 self.assertEqual(C(1), 5) 2305 self.assertNotEqual(C(1), 1) 2306 2307 2308class TestOrdering(unittest.TestCase): 2309 def test_functools_total_ordering(self): 2310 # Test that functools.total_ordering works with this class. 2311 @total_ordering 2312 @dataclass 2313 class C: 2314 x: int 2315 def __lt__(self, other): 2316 # Perform the test "backward", just to make 2317 # sure this is being called. 2318 return self.x >= other 2319 2320 self.assertLess(C(0), -1) 2321 self.assertLessEqual(C(0), -1) 2322 self.assertGreater(C(0), 1) 2323 self.assertGreaterEqual(C(0), 1) 2324 2325 def test_no_order(self): 2326 # Test that no ordering functions are added by default. 2327 @dataclass(order=False) 2328 class C: 2329 x: int 2330 # Make sure no order methods are added. 2331 self.assertNotIn('__le__', C.__dict__) 2332 self.assertNotIn('__lt__', C.__dict__) 2333 self.assertNotIn('__ge__', C.__dict__) 2334 self.assertNotIn('__gt__', C.__dict__) 2335 2336 # Test that __lt__ is still called 2337 @dataclass(order=False) 2338 class C: 2339 x: int 2340 def __lt__(self, other): 2341 return False 2342 # Make sure other methods aren't added. 2343 self.assertNotIn('__le__', C.__dict__) 2344 self.assertNotIn('__ge__', C.__dict__) 2345 self.assertNotIn('__gt__', C.__dict__) 2346 2347 def test_overwriting_order(self): 2348 with self.assertRaisesRegex(TypeError, 2349 'Cannot overwrite attribute __lt__' 2350 '.*using functools.total_ordering'): 2351 @dataclass(order=True) 2352 class C: 2353 x: int 2354 def __lt__(self): 2355 pass 2356 2357 with self.assertRaisesRegex(TypeError, 2358 'Cannot overwrite attribute __le__' 2359 '.*using functools.total_ordering'): 2360 @dataclass(order=True) 2361 class C: 2362 x: int 2363 def __le__(self): 2364 pass 2365 2366 with self.assertRaisesRegex(TypeError, 2367 'Cannot overwrite attribute __gt__' 2368 '.*using functools.total_ordering'): 2369 @dataclass(order=True) 2370 class C: 2371 x: int 2372 def __gt__(self): 2373 pass 2374 2375 with self.assertRaisesRegex(TypeError, 2376 'Cannot overwrite attribute __ge__' 2377 '.*using functools.total_ordering'): 2378 @dataclass(order=True) 2379 class C: 2380 x: int 2381 def __ge__(self): 2382 pass 2383 2384class TestHash(unittest.TestCase): 2385 def test_unsafe_hash(self): 2386 @dataclass(unsafe_hash=True) 2387 class C: 2388 x: int 2389 y: str 2390 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) 2391 2392 def test_hash_rules(self): 2393 def non_bool(value): 2394 # Map to something else that's True, but not a bool. 2395 if value is None: 2396 return None 2397 if value: 2398 return (3,) 2399 return 0 2400 2401 def test(case, unsafe_hash, eq, frozen, with_hash, result): 2402 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq, 2403 frozen=frozen): 2404 if result != 'exception': 2405 if with_hash: 2406 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2407 class C: 2408 def __hash__(self): 2409 return 0 2410 else: 2411 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2412 class C: 2413 pass 2414 2415 # See if the result matches what's expected. 2416 if result == 'fn': 2417 # __hash__ contains the function we generated. 2418 self.assertIn('__hash__', C.__dict__) 2419 self.assertIsNotNone(C.__dict__['__hash__']) 2420 2421 elif result == '': 2422 # __hash__ is not present in our class. 2423 if not with_hash: 2424 self.assertNotIn('__hash__', C.__dict__) 2425 2426 elif result == 'none': 2427 # __hash__ is set to None. 2428 self.assertIn('__hash__', C.__dict__) 2429 self.assertIsNone(C.__dict__['__hash__']) 2430 2431 elif result == 'exception': 2432 # Creating the class should cause an exception. 2433 # This only happens with with_hash==True. 2434 assert(with_hash) 2435 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'): 2436 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2437 class C: 2438 def __hash__(self): 2439 return 0 2440 2441 else: 2442 assert False, f'unknown result {result!r}' 2443 2444 # There are 8 cases of: 2445 # unsafe_hash=True/False 2446 # eq=True/False 2447 # frozen=True/False 2448 # And for each of these, a different result if 2449 # __hash__ is defined or not. 2450 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([ 2451 (False, False, False, '', ''), 2452 (False, False, True, '', ''), 2453 (False, True, False, 'none', ''), 2454 (False, True, True, 'fn', ''), 2455 (True, False, False, 'fn', 'exception'), 2456 (True, False, True, 'fn', 'exception'), 2457 (True, True, False, 'fn', 'exception'), 2458 (True, True, True, 'fn', 'exception'), 2459 ], 1): 2460 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash) 2461 test(case, unsafe_hash, eq, frozen, True, res_defined_hash) 2462 2463 # Test non-bool truth values, too. This is just to 2464 # make sure the data-driven table in the decorator 2465 # handles non-bool values. 2466 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash) 2467 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash) 2468 2469 2470 def test_eq_only(self): 2471 # If a class defines __eq__, __hash__ is automatically added 2472 # and set to None. This is normal Python behavior, not 2473 # related to dataclasses. Make sure we don't interfere with 2474 # that (see bpo=32546). 2475 2476 @dataclass 2477 class C: 2478 i: int 2479 def __eq__(self, other): 2480 return self.i == other.i 2481 self.assertEqual(C(1), C(1)) 2482 self.assertNotEqual(C(1), C(4)) 2483 2484 # And make sure things work in this case if we specify 2485 # unsafe_hash=True. 2486 @dataclass(unsafe_hash=True) 2487 class C: 2488 i: int 2489 def __eq__(self, other): 2490 return self.i == other.i 2491 self.assertEqual(C(1), C(1.0)) 2492 self.assertEqual(hash(C(1)), hash(C(1.0))) 2493 2494 # And check that the classes __eq__ is being used, despite 2495 # specifying eq=True. 2496 @dataclass(unsafe_hash=True, eq=True) 2497 class C: 2498 i: int 2499 def __eq__(self, other): 2500 return self.i == 3 and self.i == other.i 2501 self.assertEqual(C(3), C(3)) 2502 self.assertNotEqual(C(1), C(1)) 2503 self.assertEqual(hash(C(1)), hash(C(1.0))) 2504 2505 def test_0_field_hash(self): 2506 @dataclass(frozen=True) 2507 class C: 2508 pass 2509 self.assertEqual(hash(C()), hash(())) 2510 2511 @dataclass(unsafe_hash=True) 2512 class C: 2513 pass 2514 self.assertEqual(hash(C()), hash(())) 2515 2516 def test_1_field_hash(self): 2517 @dataclass(frozen=True) 2518 class C: 2519 x: int 2520 self.assertEqual(hash(C(4)), hash((4,))) 2521 self.assertEqual(hash(C(42)), hash((42,))) 2522 2523 @dataclass(unsafe_hash=True) 2524 class C: 2525 x: int 2526 self.assertEqual(hash(C(4)), hash((4,))) 2527 self.assertEqual(hash(C(42)), hash((42,))) 2528 2529 def test_hash_no_args(self): 2530 # Test dataclasses with no hash= argument. This exists to 2531 # make sure that if the @dataclass parameter name is changed 2532 # or the non-default hashing behavior changes, the default 2533 # hashability keeps working the same way. 2534 2535 class Base: 2536 def __hash__(self): 2537 return 301 2538 2539 # If frozen or eq is None, then use the default value (do not 2540 # specify any value in the decorator). 2541 for frozen, eq, base, expected in [ 2542 (None, None, object, 'unhashable'), 2543 (None, None, Base, 'unhashable'), 2544 (None, False, object, 'object'), 2545 (None, False, Base, 'base'), 2546 (None, True, object, 'unhashable'), 2547 (None, True, Base, 'unhashable'), 2548 (False, None, object, 'unhashable'), 2549 (False, None, Base, 'unhashable'), 2550 (False, False, object, 'object'), 2551 (False, False, Base, 'base'), 2552 (False, True, object, 'unhashable'), 2553 (False, True, Base, 'unhashable'), 2554 (True, None, object, 'tuple'), 2555 (True, None, Base, 'tuple'), 2556 (True, False, object, 'object'), 2557 (True, False, Base, 'base'), 2558 (True, True, object, 'tuple'), 2559 (True, True, Base, 'tuple'), 2560 ]: 2561 2562 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected): 2563 # First, create the class. 2564 if frozen is None and eq is None: 2565 @dataclass 2566 class C(base): 2567 i: int 2568 elif frozen is None: 2569 @dataclass(eq=eq) 2570 class C(base): 2571 i: int 2572 elif eq is None: 2573 @dataclass(frozen=frozen) 2574 class C(base): 2575 i: int 2576 else: 2577 @dataclass(frozen=frozen, eq=eq) 2578 class C(base): 2579 i: int 2580 2581 # Now, make sure it hashes as expected. 2582 if expected == 'unhashable': 2583 c = C(10) 2584 with self.assertRaisesRegex(TypeError, 'unhashable type'): 2585 hash(c) 2586 2587 elif expected == 'base': 2588 self.assertEqual(hash(C(10)), 301) 2589 2590 elif expected == 'object': 2591 # I'm not sure what test to use here. object's 2592 # hash isn't based on id(), so calling hash() 2593 # won't tell us much. So, just check the 2594 # function used is object's. 2595 self.assertIs(C.__hash__, object.__hash__) 2596 2597 elif expected == 'tuple': 2598 self.assertEqual(hash(C(42)), hash((42,))) 2599 2600 else: 2601 assert False, f'unknown value for expected={expected!r}' 2602 2603 2604class TestFrozen(unittest.TestCase): 2605 def test_frozen(self): 2606 @dataclass(frozen=True) 2607 class C: 2608 i: int 2609 2610 c = C(10) 2611 self.assertEqual(c.i, 10) 2612 with self.assertRaises(FrozenInstanceError): 2613 c.i = 5 2614 self.assertEqual(c.i, 10) 2615 2616 def test_inherit(self): 2617 @dataclass(frozen=True) 2618 class C: 2619 i: int 2620 2621 @dataclass(frozen=True) 2622 class D(C): 2623 j: int 2624 2625 d = D(0, 10) 2626 with self.assertRaises(FrozenInstanceError): 2627 d.i = 5 2628 with self.assertRaises(FrozenInstanceError): 2629 d.j = 6 2630 self.assertEqual(d.i, 0) 2631 self.assertEqual(d.j, 10) 2632 2633 def test_inherit_nonfrozen_from_empty_frozen(self): 2634 @dataclass(frozen=True) 2635 class C: 2636 pass 2637 2638 with self.assertRaisesRegex(TypeError, 2639 'cannot inherit non-frozen dataclass from a frozen one'): 2640 @dataclass 2641 class D(C): 2642 j: int 2643 2644 def test_inherit_nonfrozen_from_empty(self): 2645 @dataclass 2646 class C: 2647 pass 2648 2649 @dataclass 2650 class D(C): 2651 j: int 2652 2653 d = D(3) 2654 self.assertEqual(d.j, 3) 2655 self.assertIsInstance(d, C) 2656 2657 # Test both ways: with an intermediate normal (non-dataclass) 2658 # class and without an intermediate class. 2659 def test_inherit_nonfrozen_from_frozen(self): 2660 for intermediate_class in [True, False]: 2661 with self.subTest(intermediate_class=intermediate_class): 2662 @dataclass(frozen=True) 2663 class C: 2664 i: int 2665 2666 if intermediate_class: 2667 class I(C): pass 2668 else: 2669 I = C 2670 2671 with self.assertRaisesRegex(TypeError, 2672 'cannot inherit non-frozen dataclass from a frozen one'): 2673 @dataclass 2674 class D(I): 2675 pass 2676 2677 def test_inherit_frozen_from_nonfrozen(self): 2678 for intermediate_class in [True, False]: 2679 with self.subTest(intermediate_class=intermediate_class): 2680 @dataclass 2681 class C: 2682 i: int 2683 2684 if intermediate_class: 2685 class I(C): pass 2686 else: 2687 I = C 2688 2689 with self.assertRaisesRegex(TypeError, 2690 'cannot inherit frozen dataclass from a non-frozen one'): 2691 @dataclass(frozen=True) 2692 class D(I): 2693 pass 2694 2695 def test_inherit_from_normal_class(self): 2696 for intermediate_class in [True, False]: 2697 with self.subTest(intermediate_class=intermediate_class): 2698 class C: 2699 pass 2700 2701 if intermediate_class: 2702 class I(C): pass 2703 else: 2704 I = C 2705 2706 @dataclass(frozen=True) 2707 class D(I): 2708 i: int 2709 2710 d = D(10) 2711 with self.assertRaises(FrozenInstanceError): 2712 d.i = 5 2713 2714 def test_non_frozen_normal_derived(self): 2715 # See bpo-32953. 2716 2717 @dataclass(frozen=True) 2718 class D: 2719 x: int 2720 y: int = 10 2721 2722 class S(D): 2723 pass 2724 2725 s = S(3) 2726 self.assertEqual(s.x, 3) 2727 self.assertEqual(s.y, 10) 2728 s.cached = True 2729 2730 # But can't change the frozen attributes. 2731 with self.assertRaises(FrozenInstanceError): 2732 s.x = 5 2733 with self.assertRaises(FrozenInstanceError): 2734 s.y = 5 2735 self.assertEqual(s.x, 3) 2736 self.assertEqual(s.y, 10) 2737 self.assertEqual(s.cached, True) 2738 2739 def test_overwriting_frozen(self): 2740 # frozen uses __setattr__ and __delattr__. 2741 with self.assertRaisesRegex(TypeError, 2742 'Cannot overwrite attribute __setattr__'): 2743 @dataclass(frozen=True) 2744 class C: 2745 x: int 2746 def __setattr__(self): 2747 pass 2748 2749 with self.assertRaisesRegex(TypeError, 2750 'Cannot overwrite attribute __delattr__'): 2751 @dataclass(frozen=True) 2752 class C: 2753 x: int 2754 def __delattr__(self): 2755 pass 2756 2757 @dataclass(frozen=False) 2758 class C: 2759 x: int 2760 def __setattr__(self, name, value): 2761 self.__dict__['x'] = value * 2 2762 self.assertEqual(C(10).x, 20) 2763 2764 def test_frozen_hash(self): 2765 @dataclass(frozen=True) 2766 class C: 2767 x: Any 2768 2769 # If x is immutable, we can compute the hash. No exception is 2770 # raised. 2771 hash(C(3)) 2772 2773 # If x is mutable, computing the hash is an error. 2774 with self.assertRaisesRegex(TypeError, 'unhashable type'): 2775 hash(C({})) 2776 2777 2778class TestSlots(unittest.TestCase): 2779 def test_simple(self): 2780 @dataclass 2781 class C: 2782 __slots__ = ('x',) 2783 x: Any 2784 2785 # There was a bug where a variable in a slot was assumed to 2786 # also have a default value (of type 2787 # types.MemberDescriptorType). 2788 with self.assertRaisesRegex(TypeError, 2789 r"__init__\(\) missing 1 required positional argument: 'x'"): 2790 C() 2791 2792 # We can create an instance, and assign to x. 2793 c = C(10) 2794 self.assertEqual(c.x, 10) 2795 c.x = 5 2796 self.assertEqual(c.x, 5) 2797 2798 # We can't assign to anything else. 2799 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"): 2800 c.y = 5 2801 2802 def test_derived_added_field(self): 2803 # See bpo-33100. 2804 @dataclass 2805 class Base: 2806 __slots__ = ('x',) 2807 x: Any 2808 2809 @dataclass 2810 class Derived(Base): 2811 x: int 2812 y: int 2813 2814 d = Derived(1, 2) 2815 self.assertEqual((d.x, d.y), (1, 2)) 2816 2817 # We can add a new field to the derived instance. 2818 d.z = 10 2819 2820 def test_generated_slots(self): 2821 @dataclass(slots=True) 2822 class C: 2823 x: int 2824 y: int 2825 2826 c = C(1, 2) 2827 self.assertEqual((c.x, c.y), (1, 2)) 2828 2829 c.x = 3 2830 c.y = 4 2831 self.assertEqual((c.x, c.y), (3, 4)) 2832 2833 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"): 2834 c.z = 5 2835 2836 def test_add_slots_when_slots_exists(self): 2837 with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'): 2838 @dataclass(slots=True) 2839 class C: 2840 __slots__ = ('x',) 2841 x: int 2842 2843 def test_generated_slots_value(self): 2844 @dataclass(slots=True) 2845 class Base: 2846 x: int 2847 2848 self.assertEqual(Base.__slots__, ('x',)) 2849 2850 @dataclass(slots=True) 2851 class Delivered(Base): 2852 y: int 2853 2854 self.assertEqual(Delivered.__slots__, ('x', 'y')) 2855 2856 @dataclass 2857 class AnotherDelivered(Base): 2858 z: int 2859 2860 self.assertTrue('__slots__' not in AnotherDelivered.__dict__) 2861 2862 def test_returns_new_class(self): 2863 class A: 2864 x: int 2865 2866 B = dataclass(A, slots=True) 2867 self.assertIsNot(A, B) 2868 2869 self.assertFalse(hasattr(A, "__slots__")) 2870 self.assertTrue(hasattr(B, "__slots__")) 2871 2872 # Can't be local to test_frozen_pickle. 2873 @dataclass(frozen=True, slots=True) 2874 class FrozenSlotsClass: 2875 foo: str 2876 bar: int 2877 2878 @dataclass(frozen=True) 2879 class FrozenWithoutSlotsClass: 2880 foo: str 2881 bar: int 2882 2883 def test_frozen_pickle(self): 2884 # bpo-43999 2885 2886 self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar")) 2887 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 2888 with self.subTest(proto=proto): 2889 obj = self.FrozenSlotsClass("a", 1) 2890 p = pickle.loads(pickle.dumps(obj, protocol=proto)) 2891 self.assertIsNot(obj, p) 2892 self.assertEqual(obj, p) 2893 2894 obj = self.FrozenWithoutSlotsClass("a", 1) 2895 p = pickle.loads(pickle.dumps(obj, protocol=proto)) 2896 self.assertIsNot(obj, p) 2897 self.assertEqual(obj, p) 2898 2899 def test_slots_with_default_no_init(self): 2900 # Originally reported in bpo-44649. 2901 @dataclass(slots=True) 2902 class A: 2903 a: str 2904 b: str = field(default='b', init=False) 2905 2906 obj = A("a") 2907 self.assertEqual(obj.a, 'a') 2908 self.assertEqual(obj.b, 'b') 2909 2910 def test_slots_with_default_factory_no_init(self): 2911 # Originally reported in bpo-44649. 2912 @dataclass(slots=True) 2913 class A: 2914 a: str 2915 b: str = field(default_factory=lambda:'b', init=False) 2916 2917 obj = A("a") 2918 self.assertEqual(obj.a, 'a') 2919 self.assertEqual(obj.b, 'b') 2920 2921class TestDescriptors(unittest.TestCase): 2922 def test_set_name(self): 2923 # See bpo-33141. 2924 2925 # Create a descriptor. 2926 class D: 2927 def __set_name__(self, owner, name): 2928 self.name = name + 'x' 2929 def __get__(self, instance, owner): 2930 if instance is not None: 2931 return 1 2932 return self 2933 2934 # This is the case of just normal descriptor behavior, no 2935 # dataclass code is involved in initializing the descriptor. 2936 @dataclass 2937 class C: 2938 c: int=D() 2939 self.assertEqual(C.c.name, 'cx') 2940 2941 # Now test with a default value and init=False, which is the 2942 # only time this is really meaningful. If not using 2943 # init=False, then the descriptor will be overwritten, anyway. 2944 @dataclass 2945 class C: 2946 c: int=field(default=D(), init=False) 2947 self.assertEqual(C.c.name, 'cx') 2948 self.assertEqual(C().c, 1) 2949 2950 def test_non_descriptor(self): 2951 # PEP 487 says __set_name__ should work on non-descriptors. 2952 # Create a descriptor. 2953 2954 class D: 2955 def __set_name__(self, owner, name): 2956 self.name = name + 'x' 2957 2958 @dataclass 2959 class C: 2960 c: int=field(default=D(), init=False) 2961 self.assertEqual(C.c.name, 'cx') 2962 2963 def test_lookup_on_instance(self): 2964 # See bpo-33175. 2965 class D: 2966 pass 2967 2968 d = D() 2969 # Create an attribute on the instance, not type. 2970 d.__set_name__ = Mock() 2971 2972 # Make sure d.__set_name__ is not called. 2973 @dataclass 2974 class C: 2975 i: int=field(default=d, init=False) 2976 2977 self.assertEqual(d.__set_name__.call_count, 0) 2978 2979 def test_lookup_on_class(self): 2980 # See bpo-33175. 2981 class D: 2982 pass 2983 D.__set_name__ = Mock() 2984 2985 # Make sure D.__set_name__ is called. 2986 @dataclass 2987 class C: 2988 i: int=field(default=D(), init=False) 2989 2990 self.assertEqual(D.__set_name__.call_count, 1) 2991 2992 2993class TestStringAnnotations(unittest.TestCase): 2994 def test_classvar(self): 2995 # Some expressions recognized as ClassVar really aren't. But 2996 # if you're using string annotations, it's not an exact 2997 # science. 2998 # These tests assume that both "import typing" and "from 2999 # typing import *" have been run in this file. 3000 for typestr in ('ClassVar[int]', 3001 'ClassVar [int]', 3002 ' ClassVar [int]', 3003 'ClassVar', 3004 ' ClassVar ', 3005 'typing.ClassVar[int]', 3006 'typing.ClassVar[str]', 3007 ' typing.ClassVar[str]', 3008 'typing .ClassVar[str]', 3009 'typing. ClassVar[str]', 3010 'typing.ClassVar [str]', 3011 'typing.ClassVar [ str]', 3012 3013 # Not syntactically valid, but these will 3014 # be treated as ClassVars. 3015 'typing.ClassVar.[int]', 3016 'typing.ClassVar+', 3017 ): 3018 with self.subTest(typestr=typestr): 3019 @dataclass 3020 class C: 3021 x: typestr 3022 3023 # x is a ClassVar, so C() takes no args. 3024 C() 3025 3026 # And it won't appear in the class's dict because it doesn't 3027 # have a default. 3028 self.assertNotIn('x', C.__dict__) 3029 3030 def test_isnt_classvar(self): 3031 for typestr in ('CV', 3032 't.ClassVar', 3033 't.ClassVar[int]', 3034 'typing..ClassVar[int]', 3035 'Classvar', 3036 'Classvar[int]', 3037 'typing.ClassVarx[int]', 3038 'typong.ClassVar[int]', 3039 'dataclasses.ClassVar[int]', 3040 'typingxClassVar[str]', 3041 ): 3042 with self.subTest(typestr=typestr): 3043 @dataclass 3044 class C: 3045 x: typestr 3046 3047 # x is not a ClassVar, so C() takes one arg. 3048 self.assertEqual(C(10).x, 10) 3049 3050 def test_initvar(self): 3051 # These tests assume that both "import dataclasses" and "from 3052 # dataclasses import *" have been run in this file. 3053 for typestr in ('InitVar[int]', 3054 'InitVar [int]' 3055 ' InitVar [int]', 3056 'InitVar', 3057 ' InitVar ', 3058 'dataclasses.InitVar[int]', 3059 'dataclasses.InitVar[str]', 3060 ' dataclasses.InitVar[str]', 3061 'dataclasses .InitVar[str]', 3062 'dataclasses. InitVar[str]', 3063 'dataclasses.InitVar [str]', 3064 'dataclasses.InitVar [ str]', 3065 3066 # Not syntactically valid, but these will 3067 # be treated as InitVars. 3068 'dataclasses.InitVar.[int]', 3069 'dataclasses.InitVar+', 3070 ): 3071 with self.subTest(typestr=typestr): 3072 @dataclass 3073 class C: 3074 x: typestr 3075 3076 # x is an InitVar, so doesn't create a member. 3077 with self.assertRaisesRegex(AttributeError, 3078 "object has no attribute 'x'"): 3079 C(1).x 3080 3081 def test_isnt_initvar(self): 3082 for typestr in ('IV', 3083 'dc.InitVar', 3084 'xdataclasses.xInitVar', 3085 'typing.xInitVar[int]', 3086 ): 3087 with self.subTest(typestr=typestr): 3088 @dataclass 3089 class C: 3090 x: typestr 3091 3092 # x is not an InitVar, so there will be a member x. 3093 self.assertEqual(C(10).x, 10) 3094 3095 def test_classvar_module_level_import(self): 3096 from test import dataclass_module_1 3097 from test import dataclass_module_1_str 3098 from test import dataclass_module_2 3099 from test import dataclass_module_2_str 3100 3101 for m in (dataclass_module_1, dataclass_module_1_str, 3102 dataclass_module_2, dataclass_module_2_str, 3103 ): 3104 with self.subTest(m=m): 3105 # There's a difference in how the ClassVars are 3106 # interpreted when using string annotations or 3107 # not. See the imported modules for details. 3108 if m.USING_STRINGS: 3109 c = m.CV(10) 3110 else: 3111 c = m.CV() 3112 self.assertEqual(c.cv0, 20) 3113 3114 3115 # There's a difference in how the InitVars are 3116 # interpreted when using string annotations or 3117 # not. See the imported modules for details. 3118 c = m.IV(0, 1, 2, 3, 4) 3119 3120 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'): 3121 with self.subTest(field_name=field_name): 3122 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"): 3123 # Since field_name is an InitVar, it's 3124 # not an instance field. 3125 getattr(c, field_name) 3126 3127 if m.USING_STRINGS: 3128 # iv4 is interpreted as a normal field. 3129 self.assertIn('not_iv4', c.__dict__) 3130 self.assertEqual(c.not_iv4, 4) 3131 else: 3132 # iv4 is interpreted as an InitVar, so it 3133 # won't exist on the instance. 3134 self.assertNotIn('not_iv4', c.__dict__) 3135 3136 def test_text_annotations(self): 3137 from test import dataclass_textanno 3138 3139 self.assertEqual( 3140 get_type_hints(dataclass_textanno.Bar), 3141 {'foo': dataclass_textanno.Foo}) 3142 self.assertEqual( 3143 get_type_hints(dataclass_textanno.Bar.__init__), 3144 {'foo': dataclass_textanno.Foo, 3145 'return': type(None)}) 3146 3147 3148class TestMakeDataclass(unittest.TestCase): 3149 def test_simple(self): 3150 C = make_dataclass('C', 3151 [('x', int), 3152 ('y', int, field(default=5))], 3153 namespace={'add_one': lambda self: self.x + 1}) 3154 c = C(10) 3155 self.assertEqual((c.x, c.y), (10, 5)) 3156 self.assertEqual(c.add_one(), 11) 3157 3158 3159 def test_no_mutate_namespace(self): 3160 # Make sure a provided namespace isn't mutated. 3161 ns = {} 3162 C = make_dataclass('C', 3163 [('x', int), 3164 ('y', int, field(default=5))], 3165 namespace=ns) 3166 self.assertEqual(ns, {}) 3167 3168 def test_base(self): 3169 class Base1: 3170 pass 3171 class Base2: 3172 pass 3173 C = make_dataclass('C', 3174 [('x', int)], 3175 bases=(Base1, Base2)) 3176 c = C(2) 3177 self.assertIsInstance(c, C) 3178 self.assertIsInstance(c, Base1) 3179 self.assertIsInstance(c, Base2) 3180 3181 def test_base_dataclass(self): 3182 @dataclass 3183 class Base1: 3184 x: int 3185 class Base2: 3186 pass 3187 C = make_dataclass('C', 3188 [('y', int)], 3189 bases=(Base1, Base2)) 3190 with self.assertRaisesRegex(TypeError, 'required positional'): 3191 c = C(2) 3192 c = C(1, 2) 3193 self.assertIsInstance(c, C) 3194 self.assertIsInstance(c, Base1) 3195 self.assertIsInstance(c, Base2) 3196 3197 self.assertEqual((c.x, c.y), (1, 2)) 3198 3199 def test_init_var(self): 3200 def post_init(self, y): 3201 self.x *= y 3202 3203 C = make_dataclass('C', 3204 [('x', int), 3205 ('y', InitVar[int]), 3206 ], 3207 namespace={'__post_init__': post_init}, 3208 ) 3209 c = C(2, 3) 3210 self.assertEqual(vars(c), {'x': 6}) 3211 self.assertEqual(len(fields(c)), 1) 3212 3213 def test_class_var(self): 3214 C = make_dataclass('C', 3215 [('x', int), 3216 ('y', ClassVar[int], 10), 3217 ('z', ClassVar[int], field(default=20)), 3218 ]) 3219 c = C(1) 3220 self.assertEqual(vars(c), {'x': 1}) 3221 self.assertEqual(len(fields(c)), 1) 3222 self.assertEqual(C.y, 10) 3223 self.assertEqual(C.z, 20) 3224 3225 def test_other_params(self): 3226 C = make_dataclass('C', 3227 [('x', int), 3228 ('y', ClassVar[int], 10), 3229 ('z', ClassVar[int], field(default=20)), 3230 ], 3231 init=False) 3232 # Make sure we have a repr, but no init. 3233 self.assertNotIn('__init__', vars(C)) 3234 self.assertIn('__repr__', vars(C)) 3235 3236 # Make sure random other params don't work. 3237 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): 3238 C = make_dataclass('C', 3239 [], 3240 xxinit=False) 3241 3242 def test_no_types(self): 3243 C = make_dataclass('Point', ['x', 'y', 'z']) 3244 c = C(1, 2, 3) 3245 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 3246 self.assertEqual(C.__annotations__, {'x': 'typing.Any', 3247 'y': 'typing.Any', 3248 'z': 'typing.Any'}) 3249 3250 C = make_dataclass('Point', ['x', ('y', int), 'z']) 3251 c = C(1, 2, 3) 3252 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 3253 self.assertEqual(C.__annotations__, {'x': 'typing.Any', 3254 'y': int, 3255 'z': 'typing.Any'}) 3256 3257 def test_invalid_type_specification(self): 3258 for bad_field in [(), 3259 (1, 2, 3, 4), 3260 ]: 3261 with self.subTest(bad_field=bad_field): 3262 with self.assertRaisesRegex(TypeError, r'Invalid field: '): 3263 make_dataclass('C', ['a', bad_field]) 3264 3265 # And test for things with no len(). 3266 for bad_field in [float, 3267 lambda x:x, 3268 ]: 3269 with self.subTest(bad_field=bad_field): 3270 with self.assertRaisesRegex(TypeError, r'has no len\(\)'): 3271 make_dataclass('C', ['a', bad_field]) 3272 3273 def test_duplicate_field_names(self): 3274 for field in ['a', 'ab']: 3275 with self.subTest(field=field): 3276 with self.assertRaisesRegex(TypeError, 'Field name duplicated'): 3277 make_dataclass('C', [field, 'a', field]) 3278 3279 def test_keyword_field_names(self): 3280 for field in ['for', 'async', 'await', 'as']: 3281 with self.subTest(field=field): 3282 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3283 make_dataclass('C', ['a', field]) 3284 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3285 make_dataclass('C', [field]) 3286 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3287 make_dataclass('C', [field, 'a']) 3288 3289 def test_non_identifier_field_names(self): 3290 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']: 3291 with self.subTest(field=field): 3292 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 3293 make_dataclass('C', ['a', field]) 3294 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 3295 make_dataclass('C', [field]) 3296 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 3297 make_dataclass('C', [field, 'a']) 3298 3299 def test_underscore_field_names(self): 3300 # Unlike namedtuple, it's okay if dataclass field names have 3301 # an underscore. 3302 make_dataclass('C', ['_', '_a', 'a_a', 'a_']) 3303 3304 def test_funny_class_names_names(self): 3305 # No reason to prevent weird class names, since 3306 # types.new_class allows them. 3307 for classname in ['()', 'x,y', '*', '2@3', '']: 3308 with self.subTest(classname=classname): 3309 C = make_dataclass(classname, ['a', 'b']) 3310 self.assertEqual(C.__name__, classname) 3311 3312class TestReplace(unittest.TestCase): 3313 def test(self): 3314 @dataclass(frozen=True) 3315 class C: 3316 x: int 3317 y: int 3318 3319 c = C(1, 2) 3320 c1 = replace(c, x=3) 3321 self.assertEqual(c1.x, 3) 3322 self.assertEqual(c1.y, 2) 3323 3324 def test_frozen(self): 3325 @dataclass(frozen=True) 3326 class C: 3327 x: int 3328 y: int 3329 z: int = field(init=False, default=10) 3330 t: int = field(init=False, default=100) 3331 3332 c = C(1, 2) 3333 c1 = replace(c, x=3) 3334 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100)) 3335 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100)) 3336 3337 3338 with self.assertRaisesRegex(ValueError, 'init=False'): 3339 replace(c, x=3, z=20, t=50) 3340 with self.assertRaisesRegex(ValueError, 'init=False'): 3341 replace(c, z=20) 3342 replace(c, x=3, z=20, t=50) 3343 3344 # Make sure the result is still frozen. 3345 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"): 3346 c1.x = 3 3347 3348 # Make sure we can't replace an attribute that doesn't exist, 3349 # if we're also replacing one that does exist. Test this 3350 # here, because setting attributes on frozen instances is 3351 # handled slightly differently from non-frozen ones. 3352 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 3353 "keyword argument 'a'"): 3354 c1 = replace(c, x=20, a=5) 3355 3356 def test_invalid_field_name(self): 3357 @dataclass(frozen=True) 3358 class C: 3359 x: int 3360 y: int 3361 3362 c = C(1, 2) 3363 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 3364 "keyword argument 'z'"): 3365 c1 = replace(c, z=3) 3366 3367 def test_invalid_object(self): 3368 @dataclass(frozen=True) 3369 class C: 3370 x: int 3371 y: int 3372 3373 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 3374 replace(C, x=3) 3375 3376 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 3377 replace(0, x=3) 3378 3379 def test_no_init(self): 3380 @dataclass 3381 class C: 3382 x: int 3383 y: int = field(init=False, default=10) 3384 3385 c = C(1) 3386 c.y = 20 3387 3388 # Make sure y gets the default value. 3389 c1 = replace(c, x=5) 3390 self.assertEqual((c1.x, c1.y), (5, 10)) 3391 3392 # Trying to replace y is an error. 3393 with self.assertRaisesRegex(ValueError, 'init=False'): 3394 replace(c, x=2, y=30) 3395 3396 with self.assertRaisesRegex(ValueError, 'init=False'): 3397 replace(c, y=30) 3398 3399 def test_classvar(self): 3400 @dataclass 3401 class C: 3402 x: int 3403 y: ClassVar[int] = 1000 3404 3405 c = C(1) 3406 d = C(2) 3407 3408 self.assertIs(c.y, d.y) 3409 self.assertEqual(c.y, 1000) 3410 3411 # Trying to replace y is an error: can't replace ClassVars. 3412 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an " 3413 "unexpected keyword argument 'y'"): 3414 replace(c, y=30) 3415 3416 replace(c, x=5) 3417 3418 def test_initvar_is_specified(self): 3419 @dataclass 3420 class C: 3421 x: int 3422 y: InitVar[int] 3423 3424 def __post_init__(self, y): 3425 self.x *= y 3426 3427 c = C(1, 10) 3428 self.assertEqual(c.x, 10) 3429 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be " 3430 "specified with replace()"): 3431 replace(c, x=3) 3432 c = replace(c, x=3, y=5) 3433 self.assertEqual(c.x, 15) 3434 3435 def test_initvar_with_default_value(self): 3436 @dataclass 3437 class C: 3438 x: int 3439 y: InitVar[int] = None 3440 z: InitVar[int] = 42 3441 3442 def __post_init__(self, y, z): 3443 if y is not None: 3444 self.x += y 3445 if z is not None: 3446 self.x += z 3447 3448 c = C(x=1, y=10, z=1) 3449 self.assertEqual(replace(c), C(x=12)) 3450 self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42)) 3451 self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1)) 3452 3453 def test_recursive_repr(self): 3454 @dataclass 3455 class C: 3456 f: "C" 3457 3458 c = C(None) 3459 c.f = c 3460 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)") 3461 3462 def test_recursive_repr_two_attrs(self): 3463 @dataclass 3464 class C: 3465 f: "C" 3466 g: "C" 3467 3468 c = C(None, None) 3469 c.f = c 3470 c.g = c 3471 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs" 3472 ".<locals>.C(f=..., g=...)") 3473 3474 def test_recursive_repr_indirection(self): 3475 @dataclass 3476 class C: 3477 f: "D" 3478 3479 @dataclass 3480 class D: 3481 f: "C" 3482 3483 c = C(None) 3484 d = D(None) 3485 c.f = d 3486 d.f = c 3487 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection" 3488 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection" 3489 ".<locals>.D(f=...))") 3490 3491 def test_recursive_repr_indirection_two(self): 3492 @dataclass 3493 class C: 3494 f: "D" 3495 3496 @dataclass 3497 class D: 3498 f: "E" 3499 3500 @dataclass 3501 class E: 3502 f: "C" 3503 3504 c = C(None) 3505 d = D(None) 3506 e = E(None) 3507 c.f = d 3508 d.f = e 3509 e.f = c 3510 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two" 3511 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two" 3512 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two" 3513 ".<locals>.E(f=...)))") 3514 3515 def test_recursive_repr_misc_attrs(self): 3516 @dataclass 3517 class C: 3518 f: "C" 3519 g: int 3520 3521 c = C(None, 1) 3522 c.f = c 3523 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs" 3524 ".<locals>.C(f=..., g=1)") 3525 3526 ## def test_initvar(self): 3527 ## @dataclass 3528 ## class C: 3529 ## x: int 3530 ## y: InitVar[int] 3531 3532 ## c = C(1, 10) 3533 ## d = C(2, 20) 3534 3535 ## # In our case, replacing an InitVar is a no-op 3536 ## self.assertEqual(c, replace(c, y=5)) 3537 3538 ## replace(c, x=5) 3539 3540class TestAbstract(unittest.TestCase): 3541 def test_abc_implementation(self): 3542 class Ordered(abc.ABC): 3543 @abc.abstractmethod 3544 def __lt__(self, other): 3545 pass 3546 3547 @abc.abstractmethod 3548 def __le__(self, other): 3549 pass 3550 3551 @dataclass(order=True) 3552 class Date(Ordered): 3553 year: int 3554 month: 'Month' 3555 day: 'int' 3556 3557 self.assertFalse(inspect.isabstract(Date)) 3558 self.assertGreater(Date(2020,12,25), Date(2020,8,31)) 3559 3560 def test_maintain_abc(self): 3561 class A(abc.ABC): 3562 @abc.abstractmethod 3563 def foo(self): 3564 pass 3565 3566 @dataclass 3567 class Date(A): 3568 year: int 3569 month: 'Month' 3570 day: 'int' 3571 3572 self.assertTrue(inspect.isabstract(Date)) 3573 msg = 'class Date with abstract method foo' 3574 self.assertRaisesRegex(TypeError, msg, Date) 3575 3576 3577class TestMatchArgs(unittest.TestCase): 3578 def test_match_args(self): 3579 @dataclass 3580 class C: 3581 a: int 3582 self.assertEqual(C(42).__match_args__, ('a',)) 3583 3584 def test_explicit_match_args(self): 3585 ma = () 3586 @dataclass 3587 class C: 3588 a: int 3589 __match_args__ = ma 3590 self.assertIs(C(42).__match_args__, ma) 3591 3592 def test_bpo_43764(self): 3593 @dataclass(repr=False, eq=False, init=False) 3594 class X: 3595 a: int 3596 b: int 3597 c: int 3598 self.assertEqual(X.__match_args__, ("a", "b", "c")) 3599 3600 def test_match_args_argument(self): 3601 @dataclass(match_args=False) 3602 class X: 3603 a: int 3604 self.assertNotIn('__match_args__', X.__dict__) 3605 3606 @dataclass(match_args=False) 3607 class Y: 3608 a: int 3609 __match_args__ = ('b',) 3610 self.assertEqual(Y.__match_args__, ('b',)) 3611 3612 @dataclass(match_args=False) 3613 class Z(Y): 3614 z: int 3615 self.assertEqual(Z.__match_args__, ('b',)) 3616 3617 # Ensure parent dataclass __match_args__ is seen, if child class 3618 # specifies match_args=False. 3619 @dataclass 3620 class A: 3621 a: int 3622 z: int 3623 @dataclass(match_args=False) 3624 class B(A): 3625 b: int 3626 self.assertEqual(B.__match_args__, ('a', 'z')) 3627 3628 def test_make_dataclasses(self): 3629 C = make_dataclass('C', [('x', int), ('y', int)]) 3630 self.assertEqual(C.__match_args__, ('x', 'y')) 3631 3632 C = make_dataclass('C', [('x', int), ('y', int)], match_args=True) 3633 self.assertEqual(C.__match_args__, ('x', 'y')) 3634 3635 C = make_dataclass('C', [('x', int), ('y', int)], match_args=False) 3636 self.assertNotIn('__match__args__', C.__dict__) 3637 3638 C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)}) 3639 self.assertEqual(C.__match_args__, ('z',)) 3640 3641 3642class TestKeywordArgs(unittest.TestCase): 3643 def test_no_classvar_kwarg(self): 3644 msg = 'field a is a ClassVar but specifies kw_only' 3645 with self.assertRaisesRegex(TypeError, msg): 3646 @dataclass 3647 class A: 3648 a: ClassVar[int] = field(kw_only=True) 3649 3650 with self.assertRaisesRegex(TypeError, msg): 3651 @dataclass 3652 class A: 3653 a: ClassVar[int] = field(kw_only=False) 3654 3655 with self.assertRaisesRegex(TypeError, msg): 3656 @dataclass(kw_only=True) 3657 class A: 3658 a: ClassVar[int] = field(kw_only=False) 3659 3660 def test_field_marked_as_kwonly(self): 3661 ####################### 3662 # Using dataclass(kw_only=True) 3663 @dataclass(kw_only=True) 3664 class A: 3665 a: int 3666 self.assertTrue(fields(A)[0].kw_only) 3667 3668 @dataclass(kw_only=True) 3669 class A: 3670 a: int = field(kw_only=True) 3671 self.assertTrue(fields(A)[0].kw_only) 3672 3673 @dataclass(kw_only=True) 3674 class A: 3675 a: int = field(kw_only=False) 3676 self.assertFalse(fields(A)[0].kw_only) 3677 3678 ####################### 3679 # Using dataclass(kw_only=False) 3680 @dataclass(kw_only=False) 3681 class A: 3682 a: int 3683 self.assertFalse(fields(A)[0].kw_only) 3684 3685 @dataclass(kw_only=False) 3686 class A: 3687 a: int = field(kw_only=True) 3688 self.assertTrue(fields(A)[0].kw_only) 3689 3690 @dataclass(kw_only=False) 3691 class A: 3692 a: int = field(kw_only=False) 3693 self.assertFalse(fields(A)[0].kw_only) 3694 3695 ####################### 3696 # Not specifying dataclass(kw_only) 3697 @dataclass 3698 class A: 3699 a: int 3700 self.assertFalse(fields(A)[0].kw_only) 3701 3702 @dataclass 3703 class A: 3704 a: int = field(kw_only=True) 3705 self.assertTrue(fields(A)[0].kw_only) 3706 3707 @dataclass 3708 class A: 3709 a: int = field(kw_only=False) 3710 self.assertFalse(fields(A)[0].kw_only) 3711 3712 def test_match_args(self): 3713 # kw fields don't show up in __match_args__. 3714 @dataclass(kw_only=True) 3715 class C: 3716 a: int 3717 self.assertEqual(C(a=42).__match_args__, ()) 3718 3719 @dataclass 3720 class C: 3721 a: int 3722 b: int = field(kw_only=True) 3723 self.assertEqual(C(42, b=10).__match_args__, ('a',)) 3724 3725 def test_KW_ONLY(self): 3726 @dataclass 3727 class A: 3728 a: int 3729 _: KW_ONLY 3730 b: int 3731 c: int 3732 A(3, c=5, b=4) 3733 msg = "takes 2 positional arguments but 4 were given" 3734 with self.assertRaisesRegex(TypeError, msg): 3735 A(3, 4, 5) 3736 3737 3738 @dataclass(kw_only=True) 3739 class B: 3740 a: int 3741 _: KW_ONLY 3742 b: int 3743 c: int 3744 B(a=3, b=4, c=5) 3745 msg = "takes 1 positional argument but 4 were given" 3746 with self.assertRaisesRegex(TypeError, msg): 3747 B(3, 4, 5) 3748 3749 # Explicitly make a field that follows KW_ONLY be non-keyword-only. 3750 @dataclass 3751 class C: 3752 a: int 3753 _: KW_ONLY 3754 b: int 3755 c: int = field(kw_only=False) 3756 c = C(1, 2, b=3) 3757 self.assertEqual(c.a, 1) 3758 self.assertEqual(c.b, 3) 3759 self.assertEqual(c.c, 2) 3760 c = C(1, b=3, c=2) 3761 self.assertEqual(c.a, 1) 3762 self.assertEqual(c.b, 3) 3763 self.assertEqual(c.c, 2) 3764 c = C(1, b=3, c=2) 3765 self.assertEqual(c.a, 1) 3766 self.assertEqual(c.b, 3) 3767 self.assertEqual(c.c, 2) 3768 c = C(c=2, b=3, a=1) 3769 self.assertEqual(c.a, 1) 3770 self.assertEqual(c.b, 3) 3771 self.assertEqual(c.c, 2) 3772 3773 def test_KW_ONLY_as_string(self): 3774 @dataclass 3775 class A: 3776 a: int 3777 _: 'dataclasses.KW_ONLY' 3778 b: int 3779 c: int 3780 A(3, c=5, b=4) 3781 msg = "takes 2 positional arguments but 4 were given" 3782 with self.assertRaisesRegex(TypeError, msg): 3783 A(3, 4, 5) 3784 3785 def test_KW_ONLY_twice(self): 3786 msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified" 3787 3788 with self.assertRaisesRegex(TypeError, msg): 3789 @dataclass 3790 class A: 3791 a: int 3792 X: KW_ONLY 3793 Y: KW_ONLY 3794 b: int 3795 c: int 3796 3797 with self.assertRaisesRegex(TypeError, msg): 3798 @dataclass 3799 class A: 3800 a: int 3801 X: KW_ONLY 3802 b: int 3803 Y: KW_ONLY 3804 c: int 3805 3806 with self.assertRaisesRegex(TypeError, msg): 3807 @dataclass 3808 class A: 3809 a: int 3810 X: KW_ONLY 3811 b: int 3812 c: int 3813 Y: KW_ONLY 3814 3815 # But this usage is okay, since it's not using KW_ONLY. 3816 @dataclass 3817 class A: 3818 a: int 3819 _: KW_ONLY 3820 b: int 3821 c: int = field(kw_only=True) 3822 3823 # And if inheriting, it's okay. 3824 @dataclass 3825 class A: 3826 a: int 3827 _: KW_ONLY 3828 b: int 3829 c: int 3830 @dataclass 3831 class B(A): 3832 _: KW_ONLY 3833 d: int 3834 3835 # Make sure the error is raised in a derived class. 3836 with self.assertRaisesRegex(TypeError, msg): 3837 @dataclass 3838 class A: 3839 a: int 3840 _: KW_ONLY 3841 b: int 3842 c: int 3843 @dataclass 3844 class B(A): 3845 X: KW_ONLY 3846 d: int 3847 Y: KW_ONLY 3848 3849 3850 def test_post_init(self): 3851 @dataclass 3852 class A: 3853 a: int 3854 _: KW_ONLY 3855 b: InitVar[int] 3856 c: int 3857 d: InitVar[int] 3858 def __post_init__(self, b, d): 3859 raise CustomError(f'{b=} {d=}') 3860 with self.assertRaisesRegex(CustomError, 'b=3 d=4'): 3861 A(1, c=2, b=3, d=4) 3862 3863 @dataclass 3864 class B: 3865 a: int 3866 _: KW_ONLY 3867 b: InitVar[int] 3868 c: int 3869 d: InitVar[int] 3870 def __post_init__(self, b, d): 3871 self.a = b 3872 self.c = d 3873 b = B(1, c=2, b=3, d=4) 3874 self.assertEqual(asdict(b), {'a': 3, 'c': 4}) 3875 3876 def test_defaults(self): 3877 # For kwargs, make sure we can have defaults after non-defaults. 3878 @dataclass 3879 class A: 3880 a: int = 0 3881 _: KW_ONLY 3882 b: int 3883 c: int = 1 3884 d: int 3885 3886 a = A(d=4, b=3) 3887 self.assertEqual(a.a, 0) 3888 self.assertEqual(a.b, 3) 3889 self.assertEqual(a.c, 1) 3890 self.assertEqual(a.d, 4) 3891 3892 # Make sure we still check for non-kwarg non-defaults not following 3893 # defaults. 3894 err_regex = "non-default argument 'z' follows default argument" 3895 with self.assertRaisesRegex(TypeError, err_regex): 3896 @dataclass 3897 class A: 3898 a: int = 0 3899 z: int 3900 _: KW_ONLY 3901 b: int 3902 c: int = 1 3903 d: int 3904 3905 def test_make_dataclass(self): 3906 A = make_dataclass("A", ['a'], kw_only=True) 3907 self.assertTrue(fields(A)[0].kw_only) 3908 3909 B = make_dataclass("B", 3910 ['a', ('b', int, field(kw_only=False))], 3911 kw_only=True) 3912 self.assertTrue(fields(B)[0].kw_only) 3913 self.assertFalse(fields(B)[1].kw_only) 3914 3915 3916if __name__ == '__main__': 3917 unittest.main() 3918