1from itertools import chain 2from typing import Any, Collection, Dict, List, Optional, Tuple, Union, cast 3 4from ...error import GraphQLError 5from ...language import ( 6 ArgumentNode, 7 FieldNode, 8 FragmentDefinitionNode, 9 FragmentSpreadNode, 10 InlineFragmentNode, 11 SelectionSetNode, 12 ValueNode, 13 print_ast, 14) 15from ...type import ( 16 GraphQLCompositeType, 17 GraphQLField, 18 GraphQLList, 19 GraphQLNamedType, 20 GraphQLNonNull, 21 GraphQLOutputType, 22 get_named_type, 23 is_interface_type, 24 is_leaf_type, 25 is_list_type, 26 is_non_null_type, 27 is_object_type, 28) 29from ...utilities import type_from_ast 30from . import ValidationContext, ValidationRule 31 32MYPY = False 33 34__all__ = ["OverlappingFieldsCanBeMergedRule"] 35 36 37def reason_message(reason: "ConflictReasonMessage") -> str: 38 if isinstance(reason, list): 39 return " and ".join( 40 f"subfields '{response_name}' conflict" 41 f" because {reason_message(sub_reason)}" 42 for response_name, sub_reason in reason 43 ) 44 return reason 45 46 47class OverlappingFieldsCanBeMergedRule(ValidationRule): 48 """Overlapping fields can be merged 49 50 A selection set is only valid if all fields (including spreading any fragments) 51 either correspond to distinct response names or can be merged without ambiguity. 52 """ 53 54 def __init__(self, context: ValidationContext): 55 super().__init__(context) 56 # A memoization for when two fragments are compared "between" each other for 57 # conflicts. Two fragments may be compared many times, so memoizing this can 58 # dramatically improve the performance of this validator. 59 self.compared_fragment_pairs = PairSet() 60 61 # A cache for the "field map" and list of fragment names found in any given 62 # selection set. Selection sets may be asked for this information multiple 63 # times, so this improves the performance of this validator. 64 self.cached_fields_and_fragment_names: Dict = {} 65 66 def enter_selection_set(self, selection_set: SelectionSetNode, *_args: Any) -> None: 67 conflicts = find_conflicts_within_selection_set( 68 self.context, 69 self.cached_fields_and_fragment_names, 70 self.compared_fragment_pairs, 71 self.context.get_parent_type(), 72 selection_set, 73 ) 74 for (reason_name, reason), fields1, fields2 in conflicts: 75 reason_msg = reason_message(reason) 76 self.report_error( 77 GraphQLError( 78 f"Fields '{reason_name}' conflict because {reason_msg}." 79 " Use different aliases on the fields to fetch both" 80 " if this was intentional.", 81 fields1 + fields2, 82 ) 83 ) 84 85 86Conflict = Tuple["ConflictReason", List[FieldNode], List[FieldNode]] 87# Field name and reason. 88ConflictReason = Tuple[str, "ConflictReasonMessage"] 89# Reason is a string, or a nested list of conflicts. 90if MYPY: # recursive types not fully supported yet (/python/mypy/issues/731) 91 ConflictReasonMessage = Union[str, List] 92else: 93 ConflictReasonMessage = Union[str, List[ConflictReason]] 94# Tuple defining a field node in a context. 95NodeAndDef = Tuple[GraphQLCompositeType, FieldNode, Optional[GraphQLField]] 96# Dictionary of lists of those. 97NodeAndDefCollection = Dict[str, List[NodeAndDef]] 98 99 100# Algorithm: 101# 102# Conflicts occur when two fields exist in a query which will produce the same 103# response name, but represent differing values, thus creating a conflict. 104# The algorithm below finds all conflicts via making a series of comparisons 105# between fields. In order to compare as few fields as possible, this makes 106# a series of comparisons "within" sets of fields and "between" sets of fields. 107# 108# Given any selection set, a collection produces both a set of fields by 109# also including all inline fragments, as well as a list of fragments 110# referenced by fragment spreads. 111# 112# A) Each selection set represented in the document first compares "within" its 113# collected set of fields, finding any conflicts between every pair of 114# overlapping fields. 115# Note: This is the#only time* that a the fields "within" a set are compared 116# to each other. After this only fields "between" sets are compared. 117# 118# B) Also, if any fragment is referenced in a selection set, then a 119# comparison is made "between" the original set of fields and the 120# referenced fragment. 121# 122# C) Also, if multiple fragments are referenced, then comparisons 123# are made "between" each referenced fragment. 124# 125# D) When comparing "between" a set of fields and a referenced fragment, first 126# a comparison is made between each field in the original set of fields and 127# each field in the the referenced set of fields. 128# 129# E) Also, if any fragment is referenced in the referenced selection set, 130# then a comparison is made "between" the original set of fields and the 131# referenced fragment (recursively referring to step D). 132# 133# F) When comparing "between" two fragments, first a comparison is made between 134# each field in the first referenced set of fields and each field in the the 135# second referenced set of fields. 136# 137# G) Also, any fragments referenced by the first must be compared to the 138# second, and any fragments referenced by the second must be compared to the 139# first (recursively referring to step F). 140# 141# H) When comparing two fields, if both have selection sets, then a comparison 142# is made "between" both selection sets, first comparing the set of fields in 143# the first selection set with the set of fields in the second. 144# 145# I) Also, if any fragment is referenced in either selection set, then a 146# comparison is made "between" the other set of fields and the 147# referenced fragment. 148# 149# J) Also, if two fragments are referenced in both selection sets, then a 150# comparison is made "between" the two fragments. 151 152 153def find_conflicts_within_selection_set( 154 context: ValidationContext, 155 cached_fields_and_fragment_names: Dict, 156 compared_fragment_pairs: "PairSet", 157 parent_type: Optional[GraphQLNamedType], 158 selection_set: SelectionSetNode, 159) -> List[Conflict]: 160 """Find conflicts within selection set. 161 162 Find all conflicts found "within" a selection set, including those found via 163 spreading in fragments. 164 165 Called when visiting each SelectionSet in the GraphQL Document. 166 """ 167 conflicts: List[Conflict] = [] 168 169 field_map, fragment_names = get_fields_and_fragment_names( 170 context, cached_fields_and_fragment_names, parent_type, selection_set 171 ) 172 173 # (A) Find all conflicts "within" the fields of this selection set. 174 # Note: this is the *only place* `collect_conflicts_within` is called. 175 collect_conflicts_within( 176 context, 177 conflicts, 178 cached_fields_and_fragment_names, 179 compared_fragment_pairs, 180 field_map, 181 ) 182 183 if fragment_names: 184 # (B) Then collect conflicts between these fields and those represented by each 185 # spread fragment name found. 186 for i, fragment_name in enumerate(fragment_names): 187 collect_conflicts_between_fields_and_fragment( 188 context, 189 conflicts, 190 cached_fields_and_fragment_names, 191 compared_fragment_pairs, 192 False, 193 field_map, 194 fragment_name, 195 ) 196 # (C) Then compare this fragment with all other fragments found in this 197 # selection set to collect conflicts within fragments spread together. 198 # This compares each item in the list of fragment names to every other 199 # item in that same list (except for itself). 200 for other_fragment_name in fragment_names[i + 1 :]: 201 collect_conflicts_between_fragments( 202 context, 203 conflicts, 204 cached_fields_and_fragment_names, 205 compared_fragment_pairs, 206 False, 207 fragment_name, 208 other_fragment_name, 209 ) 210 211 return conflicts 212 213 214def collect_conflicts_between_fields_and_fragment( 215 context: ValidationContext, 216 conflicts: List[Conflict], 217 cached_fields_and_fragment_names: Dict, 218 compared_fragment_pairs: "PairSet", 219 are_mutually_exclusive: bool, 220 field_map: NodeAndDefCollection, 221 fragment_name: str, 222) -> None: 223 """Collect conflicts between fields and fragment. 224 225 Collect all conflicts found between a set of fields and a fragment reference 226 including via spreading in any nested fragments. 227 """ 228 fragment = context.get_fragment(fragment_name) 229 if not fragment: 230 return None 231 232 field_map2, fragment_names2 = get_referenced_fields_and_fragment_names( 233 context, cached_fields_and_fragment_names, fragment 234 ) 235 236 # Do not compare a fragment's fieldMap to itself. 237 if field_map is field_map2: 238 return 239 240 # (D) First collect any conflicts between the provided collection of fields and the 241 # collection of fields represented by the given fragment. 242 collect_conflicts_between( 243 context, 244 conflicts, 245 cached_fields_and_fragment_names, 246 compared_fragment_pairs, 247 are_mutually_exclusive, 248 field_map, 249 field_map2, 250 ) 251 252 # (E) Then collect any conflicts between the provided collection of fields and any 253 # fragment names found in the given fragment. 254 for fragment_name2 in fragment_names2: 255 collect_conflicts_between_fields_and_fragment( 256 context, 257 conflicts, 258 cached_fields_and_fragment_names, 259 compared_fragment_pairs, 260 are_mutually_exclusive, 261 field_map, 262 fragment_name2, 263 ) 264 265 266def collect_conflicts_between_fragments( 267 context: ValidationContext, 268 conflicts: List[Conflict], 269 cached_fields_and_fragment_names: Dict, 270 compared_fragment_pairs: "PairSet", 271 are_mutually_exclusive: bool, 272 fragment_name1: str, 273 fragment_name2: str, 274) -> None: 275 """Collect conflicts between fragments. 276 277 Collect all conflicts found between two fragments, including via spreading in any 278 nested fragments. 279 """ 280 # No need to compare a fragment to itself. 281 if fragment_name1 == fragment_name2: 282 return 283 284 # Memoize so two fragments are not compared for conflicts more than once. 285 if compared_fragment_pairs.has( 286 fragment_name1, fragment_name2, are_mutually_exclusive 287 ): 288 return 289 compared_fragment_pairs.add(fragment_name1, fragment_name2, are_mutually_exclusive) 290 291 fragment1 = context.get_fragment(fragment_name1) 292 fragment2 = context.get_fragment(fragment_name2) 293 if not fragment1 or not fragment2: 294 return None 295 296 field_map1, fragment_names1 = get_referenced_fields_and_fragment_names( 297 context, cached_fields_and_fragment_names, fragment1 298 ) 299 300 field_map2, fragment_names2 = get_referenced_fields_and_fragment_names( 301 context, cached_fields_and_fragment_names, fragment2 302 ) 303 304 # (F) First, collect all conflicts between these two collections of fields 305 # (not including any nested fragments) 306 collect_conflicts_between( 307 context, 308 conflicts, 309 cached_fields_and_fragment_names, 310 compared_fragment_pairs, 311 are_mutually_exclusive, 312 field_map1, 313 field_map2, 314 ) 315 316 # (G) Then collect conflicts between the first fragment and any nested fragments 317 # spread in the second fragment. 318 for nested_fragment_name2 in fragment_names2: 319 collect_conflicts_between_fragments( 320 context, 321 conflicts, 322 cached_fields_and_fragment_names, 323 compared_fragment_pairs, 324 are_mutually_exclusive, 325 fragment_name1, 326 nested_fragment_name2, 327 ) 328 329 # (G) Then collect conflicts between the second fragment and any nested fragments 330 # spread in the first fragment. 331 for nested_fragment_name1 in fragment_names1: 332 collect_conflicts_between_fragments( 333 context, 334 conflicts, 335 cached_fields_and_fragment_names, 336 compared_fragment_pairs, 337 are_mutually_exclusive, 338 nested_fragment_name1, 339 fragment_name2, 340 ) 341 342 343def find_conflicts_between_sub_selection_sets( 344 context: ValidationContext, 345 cached_fields_and_fragment_names: Dict, 346 compared_fragment_pairs: "PairSet", 347 are_mutually_exclusive: bool, 348 parent_type1: Optional[GraphQLNamedType], 349 selection_set1: SelectionSetNode, 350 parent_type2: Optional[GraphQLNamedType], 351 selection_set2: SelectionSetNode, 352) -> List[Conflict]: 353 """Find conflicts between sub selection sets. 354 355 Find all conflicts found between two selection sets, including those found via 356 spreading in fragments. Called when determining if conflicts exist between the 357 sub-fields of two overlapping fields. 358 """ 359 conflicts: List[Conflict] = [] 360 361 field_map1, fragment_names1 = get_fields_and_fragment_names( 362 context, cached_fields_and_fragment_names, parent_type1, selection_set1 363 ) 364 field_map2, fragment_names2 = get_fields_and_fragment_names( 365 context, cached_fields_and_fragment_names, parent_type2, selection_set2 366 ) 367 368 # (H) First, collect all conflicts between these two collections of field. 369 collect_conflicts_between( 370 context, 371 conflicts, 372 cached_fields_and_fragment_names, 373 compared_fragment_pairs, 374 are_mutually_exclusive, 375 field_map1, 376 field_map2, 377 ) 378 379 # (I) Then collect conflicts between the first collection of fields and those 380 # referenced by each fragment name associated with the second. 381 if fragment_names2: 382 for fragment_name2 in fragment_names2: 383 collect_conflicts_between_fields_and_fragment( 384 context, 385 conflicts, 386 cached_fields_and_fragment_names, 387 compared_fragment_pairs, 388 are_mutually_exclusive, 389 field_map1, 390 fragment_name2, 391 ) 392 393 # (I) Then collect conflicts between the second collection of fields and those 394 # referenced by each fragment name associated with the first. 395 if fragment_names1: 396 for fragment_name1 in fragment_names1: 397 collect_conflicts_between_fields_and_fragment( 398 context, 399 conflicts, 400 cached_fields_and_fragment_names, 401 compared_fragment_pairs, 402 are_mutually_exclusive, 403 field_map2, 404 fragment_name1, 405 ) 406 407 # (J) Also collect conflicts between any fragment names by the first and fragment 408 # names by the second. This compares each item in the first set of names to each 409 # item in the second set of names. 410 for fragment_name1 in fragment_names1: 411 for fragment_name2 in fragment_names2: 412 collect_conflicts_between_fragments( 413 context, 414 conflicts, 415 cached_fields_and_fragment_names, 416 compared_fragment_pairs, 417 are_mutually_exclusive, 418 fragment_name1, 419 fragment_name2, 420 ) 421 422 return conflicts 423 424 425def collect_conflicts_within( 426 context: ValidationContext, 427 conflicts: List[Conflict], 428 cached_fields_and_fragment_names: Dict, 429 compared_fragment_pairs: "PairSet", 430 field_map: NodeAndDefCollection, 431) -> None: 432 """Collect all Conflicts "within" one collection of fields.""" 433 # A field map is a keyed collection, where each key represents a response name and 434 # the value at that key is a list of all fields which provide that response name. 435 # For every response name, if there are multiple fields, they must be compared to 436 # find a potential conflict. 437 for response_name, fields in field_map.items(): 438 # This compares every field in the list to every other field in this list 439 # (except to itself). If the list only has one item, nothing needs to be 440 # compared. 441 if len(fields) > 1: 442 for i, field in enumerate(fields): 443 for other_field in fields[i + 1 :]: 444 conflict = find_conflict( 445 context, 446 cached_fields_and_fragment_names, 447 compared_fragment_pairs, 448 # within one collection is never mutually exclusive 449 False, 450 response_name, 451 field, 452 other_field, 453 ) 454 if conflict: 455 conflicts.append(conflict) 456 457 458def collect_conflicts_between( 459 context: ValidationContext, 460 conflicts: List[Conflict], 461 cached_fields_and_fragment_names: Dict, 462 compared_fragment_pairs: "PairSet", 463 parent_fields_are_mutually_exclusive: bool, 464 field_map1: NodeAndDefCollection, 465 field_map2: NodeAndDefCollection, 466) -> None: 467 """Collect all Conflicts between two collections of fields. 468 469 This is similar to, but different from the :func:`~.collect_conflicts_within` 470 function above. This check assumes that :func:`~.collect_conflicts_within` has 471 already been called on each provided collection of fields. This is true because 472 this validator traverses each individual selection set. 473 """ 474 # A field map is a keyed collection, where each key represents a response name and 475 # the value at that key is a list of all fields which provide that response name. 476 # For any response name which appears in both provided field maps, each field from 477 # the first field map must be compared to every field in the second field map to 478 # find potential conflicts. 479 for response_name, fields1 in field_map1.items(): 480 fields2 = field_map2.get(response_name) 481 if fields2: 482 for field1 in fields1: 483 for field2 in fields2: 484 conflict = find_conflict( 485 context, 486 cached_fields_and_fragment_names, 487 compared_fragment_pairs, 488 parent_fields_are_mutually_exclusive, 489 response_name, 490 field1, 491 field2, 492 ) 493 if conflict: 494 conflicts.append(conflict) 495 496 497def find_conflict( 498 context: ValidationContext, 499 cached_fields_and_fragment_names: Dict, 500 compared_fragment_pairs: "PairSet", 501 parent_fields_are_mutually_exclusive: bool, 502 response_name: str, 503 field1: NodeAndDef, 504 field2: NodeAndDef, 505) -> Optional[Conflict]: 506 """Find conflict. 507 508 Determines if there is a conflict between two particular fields, including comparing 509 their sub-fields. 510 """ 511 parent_type1, node1, def1 = field1 512 parent_type2, node2, def2 = field2 513 514 # If it is known that two fields could not possibly apply at the same time, due to 515 # the parent types, then it is safe to permit them to diverge in aliased field or 516 # arguments used as they will not present any ambiguity by differing. It is known 517 # that two parent types could never overlap if they are different Object types. 518 # Interface or Union types might overlap - if not in the current state of the 519 # schema, then perhaps in some future version, thus may not safely diverge. 520 are_mutually_exclusive = parent_fields_are_mutually_exclusive or ( 521 parent_type1 != parent_type2 522 and is_object_type(parent_type1) 523 and is_object_type(parent_type2) 524 ) 525 526 # The return type for each field. 527 type1 = cast(Optional[GraphQLOutputType], def1 and def1.type) 528 type2 = cast(Optional[GraphQLOutputType], def2 and def2.type) 529 530 if not are_mutually_exclusive: 531 # Two aliases must refer to the same field. 532 name1 = node1.name.value 533 name2 = node2.name.value 534 if name1 != name2: 535 return ( 536 (response_name, f"'{name1}' and '{name2}' are different fields"), 537 [node1], 538 [node2], 539 ) 540 541 # Two field calls must have the same arguments. 542 if not same_arguments(node1.arguments or [], node2.arguments or []): 543 return (response_name, "they have differing arguments"), [node1], [node2] 544 545 if type1 and type2 and do_types_conflict(type1, type2): 546 return ( 547 (response_name, f"they return conflicting types '{type1}' and '{type2}'"), 548 [node1], 549 [node2], 550 ) 551 552 # Collect and compare sub-fields. Use the same "visited fragment names" list for 553 # both collections so fields in a fragment reference are never compared to 554 # themselves. 555 selection_set1 = node1.selection_set 556 selection_set2 = node2.selection_set 557 if selection_set1 and selection_set2: 558 conflicts = find_conflicts_between_sub_selection_sets( 559 context, 560 cached_fields_and_fragment_names, 561 compared_fragment_pairs, 562 are_mutually_exclusive, 563 get_named_type(type1), 564 selection_set1, 565 get_named_type(type2), 566 selection_set2, 567 ) 568 return subfield_conflicts(conflicts, response_name, node1, node2) 569 570 return None # no conflict 571 572 573def same_arguments( 574 arguments1: Collection[ArgumentNode], arguments2: Collection[ArgumentNode] 575) -> bool: 576 if len(arguments1) != len(arguments2): 577 return False 578 for argument1 in arguments1: 579 for argument2 in arguments2: 580 if argument2.name.value == argument1.name.value: 581 if not same_value(argument1.value, argument2.value): 582 return False 583 break 584 else: 585 return False 586 return True 587 588 589def same_value(value1: ValueNode, value2: ValueNode) -> bool: 590 return print_ast(value1) == print_ast(value2) 591 592 593def do_types_conflict(type1: GraphQLOutputType, type2: GraphQLOutputType) -> bool: 594 """Check whether two types conflict 595 596 Two types conflict if both types could not apply to a value simultaneously. 597 Composite types are ignored as their individual field types will be compared later 598 recursively. However List and Non-Null types must match. 599 """ 600 if is_list_type(type1): 601 return ( 602 do_types_conflict( 603 cast(GraphQLList, type1).of_type, cast(GraphQLList, type2).of_type 604 ) 605 if is_list_type(type2) 606 else True 607 ) 608 if is_list_type(type2): 609 return True 610 if is_non_null_type(type1): 611 return ( 612 do_types_conflict( 613 cast(GraphQLNonNull, type1).of_type, cast(GraphQLNonNull, type2).of_type 614 ) 615 if is_non_null_type(type2) 616 else True 617 ) 618 if is_non_null_type(type2): 619 return True 620 if is_leaf_type(type1) or is_leaf_type(type2): 621 return type1 is not type2 622 return False 623 624 625def get_fields_and_fragment_names( 626 context: ValidationContext, 627 cached_fields_and_fragment_names: Dict, 628 parent_type: Optional[GraphQLNamedType], 629 selection_set: SelectionSetNode, 630) -> Tuple[NodeAndDefCollection, List[str]]: 631 """Get fields and referenced fragment names 632 633 Given a selection set, return the collection of fields (a mapping of response name 634 to field nodes and definitions) as well as a list of fragment names referenced via 635 fragment spreads. 636 """ 637 cached = cached_fields_and_fragment_names.get(selection_set) 638 if not cached: 639 node_and_defs: NodeAndDefCollection = {} 640 fragment_names: Dict[str, bool] = {} 641 collect_fields_and_fragment_names( 642 context, parent_type, selection_set, node_and_defs, fragment_names 643 ) 644 cached = (node_and_defs, list(fragment_names)) 645 cached_fields_and_fragment_names[selection_set] = cached 646 return cached 647 648 649def get_referenced_fields_and_fragment_names( 650 context: ValidationContext, 651 cached_fields_and_fragment_names: Dict, 652 fragment: FragmentDefinitionNode, 653) -> Tuple[NodeAndDefCollection, List[str]]: 654 """Get referenced fields and nested fragment names 655 656 Given a reference to a fragment, return the represented collection of fields as well 657 as a list of nested fragment names referenced via fragment spreads. 658 """ 659 # Short-circuit building a type from the node if possible. 660 cached = cached_fields_and_fragment_names.get(fragment.selection_set) 661 if cached: 662 return cached 663 664 fragment_type = type_from_ast(context.schema, fragment.type_condition) 665 return get_fields_and_fragment_names( 666 context, cached_fields_and_fragment_names, fragment_type, fragment.selection_set 667 ) 668 669 670def collect_fields_and_fragment_names( 671 context: ValidationContext, 672 parent_type: Optional[GraphQLNamedType], 673 selection_set: SelectionSetNode, 674 node_and_defs: NodeAndDefCollection, 675 fragment_names: Dict[str, bool], 676) -> None: 677 for selection in selection_set.selections: 678 if isinstance(selection, FieldNode): 679 field_name = selection.name.value 680 field_def = ( 681 parent_type.fields.get(field_name) # type: ignore 682 if is_object_type(parent_type) or is_interface_type(parent_type) 683 else None 684 ) 685 response_name = selection.alias.value if selection.alias else field_name 686 if not node_and_defs.get(response_name): 687 node_and_defs[response_name] = [] 688 node_and_defs[response_name].append( 689 cast(NodeAndDef, (parent_type, selection, field_def)) 690 ) 691 elif isinstance(selection, FragmentSpreadNode): 692 fragment_names[selection.name.value] = True 693 elif isinstance(selection, InlineFragmentNode): # pragma: no cover else 694 type_condition = selection.type_condition 695 inline_fragment_type = ( 696 type_from_ast(context.schema, type_condition) 697 if type_condition 698 else parent_type 699 ) 700 collect_fields_and_fragment_names( 701 context, 702 inline_fragment_type, 703 selection.selection_set, 704 node_and_defs, 705 fragment_names, 706 ) 707 708 709def subfield_conflicts( 710 conflicts: List[Conflict], response_name: str, node1: FieldNode, node2: FieldNode 711) -> Optional[Conflict]: 712 """Check whether there are conflicts between sub-fields. 713 714 Given a series of Conflicts which occurred between two sub-fields, generate a single 715 Conflict. 716 """ 717 if conflicts: 718 return ( 719 (response_name, [conflict[0] for conflict in conflicts]), 720 list(chain([node1], *[conflict[1] for conflict in conflicts])), 721 list(chain([node2], *[conflict[2] for conflict in conflicts])), 722 ) 723 return None # no conflict 724 725 726class PairSet: 727 """Pair set 728 729 A way to keep track of pairs of things when the ordering of the pair does not 730 matter. We do this by maintaining a sort of double adjacency sets. 731 """ 732 733 __slots__ = ("_data",) 734 735 def __init__(self) -> None: 736 self._data: Dict[str, Dict[str, bool]] = {} 737 738 def has(self, a: str, b: str, are_mutually_exclusive: bool) -> bool: 739 first = self._data.get(a) 740 result = first and first.get(b) 741 if result is None: 742 return False 743 # `are_mutually_exclusive` being False is a superset of being True, hence if we 744 # want to know if this PairSet "has" these two with no exclusivity, we have to 745 # ensure it was added as such. 746 if not are_mutually_exclusive: 747 return not result 748 return True 749 750 def add(self, a: str, b: str, are_mutually_exclusive: bool) -> "PairSet": 751 self._pair_set_add(a, b, are_mutually_exclusive) 752 self._pair_set_add(b, a, are_mutually_exclusive) 753 return self 754 755 def _pair_set_add(self, a: str, b: str, are_mutually_exclusive: bool) -> None: 756 a_map = self._data.get(a) 757 if not a_map: 758 self._data[a] = a_map = {} 759 a_map[b] = are_mutually_exclusive 760