1from itertools import chain
2from typing import Any, Collection, Dict, List, Optional, Tuple, Union, cast
3
4from ...error import GraphQLError
5from ...language import (
6    ArgumentNode,
7    FieldNode,
8    FragmentDefinitionNode,
9    FragmentSpreadNode,
10    InlineFragmentNode,
11    SelectionSetNode,
12    ValueNode,
13    print_ast,
14)
15from ...type import (
16    GraphQLCompositeType,
17    GraphQLField,
18    GraphQLList,
19    GraphQLNamedType,
20    GraphQLNonNull,
21    GraphQLOutputType,
22    get_named_type,
23    is_interface_type,
24    is_leaf_type,
25    is_list_type,
26    is_non_null_type,
27    is_object_type,
28)
29from ...utilities import type_from_ast
30from . import ValidationContext, ValidationRule
31
32MYPY = False
33
34__all__ = ["OverlappingFieldsCanBeMergedRule"]
35
36
37def reason_message(reason: "ConflictReasonMessage") -> str:
38    if isinstance(reason, list):
39        return " and ".join(
40            f"subfields '{response_name}' conflict"
41            f" because {reason_message(sub_reason)}"
42            for response_name, sub_reason in reason
43        )
44    return reason
45
46
47class OverlappingFieldsCanBeMergedRule(ValidationRule):
48    """Overlapping fields can be merged
49
50    A selection set is only valid if all fields (including spreading any fragments)
51    either correspond to distinct response names or can be merged without ambiguity.
52    """
53
54    def __init__(self, context: ValidationContext):
55        super().__init__(context)
56        # A memoization for when two fragments are compared "between" each other for
57        # conflicts. Two fragments may be compared many times, so memoizing this can
58        # dramatically improve the performance of this validator.
59        self.compared_fragment_pairs = PairSet()
60
61        # A cache for the "field map" and list of fragment names found in any given
62        # selection set. Selection sets may be asked for this information multiple
63        # times, so this improves the performance of this validator.
64        self.cached_fields_and_fragment_names: Dict = {}
65
66    def enter_selection_set(self, selection_set: SelectionSetNode, *_args: Any) -> None:
67        conflicts = find_conflicts_within_selection_set(
68            self.context,
69            self.cached_fields_and_fragment_names,
70            self.compared_fragment_pairs,
71            self.context.get_parent_type(),
72            selection_set,
73        )
74        for (reason_name, reason), fields1, fields2 in conflicts:
75            reason_msg = reason_message(reason)
76            self.report_error(
77                GraphQLError(
78                    f"Fields '{reason_name}' conflict because {reason_msg}."
79                    " Use different aliases on the fields to fetch both"
80                    " if this was intentional.",
81                    fields1 + fields2,
82                )
83            )
84
85
86Conflict = Tuple["ConflictReason", List[FieldNode], List[FieldNode]]
87# Field name and reason.
88ConflictReason = Tuple[str, "ConflictReasonMessage"]
89# Reason is a string, or a nested list of conflicts.
90if MYPY:  # recursive types not fully supported yet (/python/mypy/issues/731)
91    ConflictReasonMessage = Union[str, List]
92else:
93    ConflictReasonMessage = Union[str, List[ConflictReason]]
94# Tuple defining a field node in a context.
95NodeAndDef = Tuple[GraphQLCompositeType, FieldNode, Optional[GraphQLField]]
96# Dictionary of lists of those.
97NodeAndDefCollection = Dict[str, List[NodeAndDef]]
98
99
100# Algorithm:
101#
102# Conflicts occur when two fields exist in a query which will produce the same
103# response name, but represent differing values, thus creating a conflict.
104# The algorithm below finds all conflicts via making a series of comparisons
105# between fields. In order to compare as few fields as possible, this makes
106# a series of comparisons "within" sets of fields and "between" sets of fields.
107#
108# Given any selection set, a collection produces both a set of fields by
109# also including all inline fragments, as well as a list of fragments
110# referenced by fragment spreads.
111#
112# A) Each selection set represented in the document first compares "within" its
113# collected set of fields, finding any conflicts between every pair of
114# overlapping fields.
115# Note: This is the#only time* that a the fields "within" a set are compared
116# to each other. After this only fields "between" sets are compared.
117#
118# B) Also, if any fragment is referenced in a selection set, then a
119# comparison is made "between" the original set of fields and the
120# referenced fragment.
121#
122# C) Also, if multiple fragments are referenced, then comparisons
123# are made "between" each referenced fragment.
124#
125# D) When comparing "between" a set of fields and a referenced fragment, first
126# a comparison is made between each field in the original set of fields and
127# each field in the the referenced set of fields.
128#
129# E) Also, if any fragment is referenced in the referenced selection set,
130# then a comparison is made "between" the original set of fields and the
131# referenced fragment (recursively referring to step D).
132#
133# F) When comparing "between" two fragments, first a comparison is made between
134# each field in the first referenced set of fields and each field in the the
135# second referenced set of fields.
136#
137# G) Also, any fragments referenced by the first must be compared to the
138# second, and any fragments referenced by the second must be compared to the
139# first (recursively referring to step F).
140#
141# H) When comparing two fields, if both have selection sets, then a comparison
142# is made "between" both selection sets, first comparing the set of fields in
143# the first selection set with the set of fields in the second.
144#
145# I) Also, if any fragment is referenced in either selection set, then a
146# comparison is made "between" the other set of fields and the
147# referenced fragment.
148#
149# J) Also, if two fragments are referenced in both selection sets, then a
150# comparison is made "between" the two fragments.
151
152
153def find_conflicts_within_selection_set(
154    context: ValidationContext,
155    cached_fields_and_fragment_names: Dict,
156    compared_fragment_pairs: "PairSet",
157    parent_type: Optional[GraphQLNamedType],
158    selection_set: SelectionSetNode,
159) -> List[Conflict]:
160    """Find conflicts within selection set.
161
162    Find all conflicts found "within" a selection set, including those found via
163    spreading in fragments.
164
165    Called when visiting each SelectionSet in the GraphQL Document.
166    """
167    conflicts: List[Conflict] = []
168
169    field_map, fragment_names = get_fields_and_fragment_names(
170        context, cached_fields_and_fragment_names, parent_type, selection_set
171    )
172
173    # (A) Find all conflicts "within" the fields of this selection set.
174    # Note: this is the *only place* `collect_conflicts_within` is called.
175    collect_conflicts_within(
176        context,
177        conflicts,
178        cached_fields_and_fragment_names,
179        compared_fragment_pairs,
180        field_map,
181    )
182
183    if fragment_names:
184        # (B) Then collect conflicts between these fields and those represented by each
185        # spread fragment name found.
186        for i, fragment_name in enumerate(fragment_names):
187            collect_conflicts_between_fields_and_fragment(
188                context,
189                conflicts,
190                cached_fields_and_fragment_names,
191                compared_fragment_pairs,
192                False,
193                field_map,
194                fragment_name,
195            )
196            # (C) Then compare this fragment with all other fragments found in this
197            # selection set to collect conflicts within fragments spread together.
198            # This compares each item in the list of fragment names to every other
199            # item in that same list (except for itself).
200            for other_fragment_name in fragment_names[i + 1 :]:
201                collect_conflicts_between_fragments(
202                    context,
203                    conflicts,
204                    cached_fields_and_fragment_names,
205                    compared_fragment_pairs,
206                    False,
207                    fragment_name,
208                    other_fragment_name,
209                )
210
211    return conflicts
212
213
214def collect_conflicts_between_fields_and_fragment(
215    context: ValidationContext,
216    conflicts: List[Conflict],
217    cached_fields_and_fragment_names: Dict,
218    compared_fragment_pairs: "PairSet",
219    are_mutually_exclusive: bool,
220    field_map: NodeAndDefCollection,
221    fragment_name: str,
222) -> None:
223    """Collect conflicts between fields and fragment.
224
225    Collect all conflicts found between a set of fields and a fragment reference
226    including via spreading in any nested fragments.
227    """
228    fragment = context.get_fragment(fragment_name)
229    if not fragment:
230        return None
231
232    field_map2, fragment_names2 = get_referenced_fields_and_fragment_names(
233        context, cached_fields_and_fragment_names, fragment
234    )
235
236    # Do not compare a fragment's fieldMap to itself.
237    if field_map is field_map2:
238        return
239
240    # (D) First collect any conflicts between the provided collection of fields and the
241    # collection of fields represented by the given fragment.
242    collect_conflicts_between(
243        context,
244        conflicts,
245        cached_fields_and_fragment_names,
246        compared_fragment_pairs,
247        are_mutually_exclusive,
248        field_map,
249        field_map2,
250    )
251
252    # (E) Then collect any conflicts between the provided collection of fields and any
253    # fragment names found in the given fragment.
254    for fragment_name2 in fragment_names2:
255        collect_conflicts_between_fields_and_fragment(
256            context,
257            conflicts,
258            cached_fields_and_fragment_names,
259            compared_fragment_pairs,
260            are_mutually_exclusive,
261            field_map,
262            fragment_name2,
263        )
264
265
266def collect_conflicts_between_fragments(
267    context: ValidationContext,
268    conflicts: List[Conflict],
269    cached_fields_and_fragment_names: Dict,
270    compared_fragment_pairs: "PairSet",
271    are_mutually_exclusive: bool,
272    fragment_name1: str,
273    fragment_name2: str,
274) -> None:
275    """Collect conflicts between fragments.
276
277    Collect all conflicts found between two fragments, including via spreading in any
278    nested fragments.
279    """
280    # No need to compare a fragment to itself.
281    if fragment_name1 == fragment_name2:
282        return
283
284    # Memoize so two fragments are not compared for conflicts more than once.
285    if compared_fragment_pairs.has(
286        fragment_name1, fragment_name2, are_mutually_exclusive
287    ):
288        return
289    compared_fragment_pairs.add(fragment_name1, fragment_name2, are_mutually_exclusive)
290
291    fragment1 = context.get_fragment(fragment_name1)
292    fragment2 = context.get_fragment(fragment_name2)
293    if not fragment1 or not fragment2:
294        return None
295
296    field_map1, fragment_names1 = get_referenced_fields_and_fragment_names(
297        context, cached_fields_and_fragment_names, fragment1
298    )
299
300    field_map2, fragment_names2 = get_referenced_fields_and_fragment_names(
301        context, cached_fields_and_fragment_names, fragment2
302    )
303
304    # (F) First, collect all conflicts between these two collections of fields
305    # (not including any nested fragments)
306    collect_conflicts_between(
307        context,
308        conflicts,
309        cached_fields_and_fragment_names,
310        compared_fragment_pairs,
311        are_mutually_exclusive,
312        field_map1,
313        field_map2,
314    )
315
316    # (G) Then collect conflicts between the first fragment and any nested fragments
317    # spread in the second fragment.
318    for nested_fragment_name2 in fragment_names2:
319        collect_conflicts_between_fragments(
320            context,
321            conflicts,
322            cached_fields_and_fragment_names,
323            compared_fragment_pairs,
324            are_mutually_exclusive,
325            fragment_name1,
326            nested_fragment_name2,
327        )
328
329    # (G) Then collect conflicts between the second fragment and any nested fragments
330    # spread in the first fragment.
331    for nested_fragment_name1 in fragment_names1:
332        collect_conflicts_between_fragments(
333            context,
334            conflicts,
335            cached_fields_and_fragment_names,
336            compared_fragment_pairs,
337            are_mutually_exclusive,
338            nested_fragment_name1,
339            fragment_name2,
340        )
341
342
343def find_conflicts_between_sub_selection_sets(
344    context: ValidationContext,
345    cached_fields_and_fragment_names: Dict,
346    compared_fragment_pairs: "PairSet",
347    are_mutually_exclusive: bool,
348    parent_type1: Optional[GraphQLNamedType],
349    selection_set1: SelectionSetNode,
350    parent_type2: Optional[GraphQLNamedType],
351    selection_set2: SelectionSetNode,
352) -> List[Conflict]:
353    """Find conflicts between sub selection sets.
354
355    Find all conflicts found between two selection sets, including those found via
356    spreading in fragments. Called when determining if conflicts exist between the
357    sub-fields of two overlapping fields.
358    """
359    conflicts: List[Conflict] = []
360
361    field_map1, fragment_names1 = get_fields_and_fragment_names(
362        context, cached_fields_and_fragment_names, parent_type1, selection_set1
363    )
364    field_map2, fragment_names2 = get_fields_and_fragment_names(
365        context, cached_fields_and_fragment_names, parent_type2, selection_set2
366    )
367
368    # (H) First, collect all conflicts between these two collections of field.
369    collect_conflicts_between(
370        context,
371        conflicts,
372        cached_fields_and_fragment_names,
373        compared_fragment_pairs,
374        are_mutually_exclusive,
375        field_map1,
376        field_map2,
377    )
378
379    # (I) Then collect conflicts between the first collection of fields and those
380    # referenced by each fragment name associated with the second.
381    if fragment_names2:
382        for fragment_name2 in fragment_names2:
383            collect_conflicts_between_fields_and_fragment(
384                context,
385                conflicts,
386                cached_fields_and_fragment_names,
387                compared_fragment_pairs,
388                are_mutually_exclusive,
389                field_map1,
390                fragment_name2,
391            )
392
393    # (I) Then collect conflicts between the second collection of fields and those
394    # referenced by each fragment name associated with the first.
395    if fragment_names1:
396        for fragment_name1 in fragment_names1:
397            collect_conflicts_between_fields_and_fragment(
398                context,
399                conflicts,
400                cached_fields_and_fragment_names,
401                compared_fragment_pairs,
402                are_mutually_exclusive,
403                field_map2,
404                fragment_name1,
405            )
406
407    # (J) Also collect conflicts between any fragment names by the first and fragment
408    # names by the second. This compares each item in the first set of names to each
409    # item in the second set of names.
410    for fragment_name1 in fragment_names1:
411        for fragment_name2 in fragment_names2:
412            collect_conflicts_between_fragments(
413                context,
414                conflicts,
415                cached_fields_and_fragment_names,
416                compared_fragment_pairs,
417                are_mutually_exclusive,
418                fragment_name1,
419                fragment_name2,
420            )
421
422    return conflicts
423
424
425def collect_conflicts_within(
426    context: ValidationContext,
427    conflicts: List[Conflict],
428    cached_fields_and_fragment_names: Dict,
429    compared_fragment_pairs: "PairSet",
430    field_map: NodeAndDefCollection,
431) -> None:
432    """Collect all Conflicts "within" one collection of fields."""
433    # A field map is a keyed collection, where each key represents a response name and
434    # the value at that key is a list of all fields which provide that response name.
435    # For every response name, if there are multiple fields, they must be compared to
436    # find a potential conflict.
437    for response_name, fields in field_map.items():
438        # This compares every field in the list to every other field in this list
439        # (except to itself). If the list only has one item, nothing needs to be
440        # compared.
441        if len(fields) > 1:
442            for i, field in enumerate(fields):
443                for other_field in fields[i + 1 :]:
444                    conflict = find_conflict(
445                        context,
446                        cached_fields_and_fragment_names,
447                        compared_fragment_pairs,
448                        # within one collection is never mutually exclusive
449                        False,
450                        response_name,
451                        field,
452                        other_field,
453                    )
454                    if conflict:
455                        conflicts.append(conflict)
456
457
458def collect_conflicts_between(
459    context: ValidationContext,
460    conflicts: List[Conflict],
461    cached_fields_and_fragment_names: Dict,
462    compared_fragment_pairs: "PairSet",
463    parent_fields_are_mutually_exclusive: bool,
464    field_map1: NodeAndDefCollection,
465    field_map2: NodeAndDefCollection,
466) -> None:
467    """Collect all Conflicts between two collections of fields.
468
469    This is similar to, but different from the :func:`~.collect_conflicts_within`
470    function above. This check assumes that :func:`~.collect_conflicts_within` has
471    already been called on each provided collection of fields. This is true because
472    this validator traverses each individual selection set.
473    """
474    # A field map is a keyed collection, where each key represents a response name and
475    # the value at that key is a list of all fields which provide that response name.
476    # For any response name which appears in both provided field maps, each field from
477    # the first field map must be compared to every field in the second field map to
478    # find potential conflicts.
479    for response_name, fields1 in field_map1.items():
480        fields2 = field_map2.get(response_name)
481        if fields2:
482            for field1 in fields1:
483                for field2 in fields2:
484                    conflict = find_conflict(
485                        context,
486                        cached_fields_and_fragment_names,
487                        compared_fragment_pairs,
488                        parent_fields_are_mutually_exclusive,
489                        response_name,
490                        field1,
491                        field2,
492                    )
493                    if conflict:
494                        conflicts.append(conflict)
495
496
497def find_conflict(
498    context: ValidationContext,
499    cached_fields_and_fragment_names: Dict,
500    compared_fragment_pairs: "PairSet",
501    parent_fields_are_mutually_exclusive: bool,
502    response_name: str,
503    field1: NodeAndDef,
504    field2: NodeAndDef,
505) -> Optional[Conflict]:
506    """Find conflict.
507
508    Determines if there is a conflict between two particular fields, including comparing
509    their sub-fields.
510    """
511    parent_type1, node1, def1 = field1
512    parent_type2, node2, def2 = field2
513
514    # If it is known that two fields could not possibly apply at the same time, due to
515    # the parent types, then it is safe to permit them to diverge in aliased field or
516    # arguments used as they will not present any ambiguity by differing. It is known
517    # that two parent types could never overlap if they are different Object types.
518    # Interface or Union types might overlap - if not in the current state of the
519    # schema, then perhaps in some future version, thus may not safely diverge.
520    are_mutually_exclusive = parent_fields_are_mutually_exclusive or (
521        parent_type1 != parent_type2
522        and is_object_type(parent_type1)
523        and is_object_type(parent_type2)
524    )
525
526    # The return type for each field.
527    type1 = cast(Optional[GraphQLOutputType], def1 and def1.type)
528    type2 = cast(Optional[GraphQLOutputType], def2 and def2.type)
529
530    if not are_mutually_exclusive:
531        # Two aliases must refer to the same field.
532        name1 = node1.name.value
533        name2 = node2.name.value
534        if name1 != name2:
535            return (
536                (response_name, f"'{name1}' and '{name2}' are different fields"),
537                [node1],
538                [node2],
539            )
540
541        # Two field calls must have the same arguments.
542        if not same_arguments(node1.arguments or [], node2.arguments or []):
543            return (response_name, "they have differing arguments"), [node1], [node2]
544
545    if type1 and type2 and do_types_conflict(type1, type2):
546        return (
547            (response_name, f"they return conflicting types '{type1}' and '{type2}'"),
548            [node1],
549            [node2],
550        )
551
552    # Collect and compare sub-fields. Use the same "visited fragment names" list for
553    # both collections so fields in a fragment reference are never compared to
554    # themselves.
555    selection_set1 = node1.selection_set
556    selection_set2 = node2.selection_set
557    if selection_set1 and selection_set2:
558        conflicts = find_conflicts_between_sub_selection_sets(
559            context,
560            cached_fields_and_fragment_names,
561            compared_fragment_pairs,
562            are_mutually_exclusive,
563            get_named_type(type1),
564            selection_set1,
565            get_named_type(type2),
566            selection_set2,
567        )
568        return subfield_conflicts(conflicts, response_name, node1, node2)
569
570    return None  # no conflict
571
572
573def same_arguments(
574    arguments1: Collection[ArgumentNode], arguments2: Collection[ArgumentNode]
575) -> bool:
576    if len(arguments1) != len(arguments2):
577        return False
578    for argument1 in arguments1:
579        for argument2 in arguments2:
580            if argument2.name.value == argument1.name.value:
581                if not same_value(argument1.value, argument2.value):
582                    return False
583                break
584        else:
585            return False
586    return True
587
588
589def same_value(value1: ValueNode, value2: ValueNode) -> bool:
590    return print_ast(value1) == print_ast(value2)
591
592
593def do_types_conflict(type1: GraphQLOutputType, type2: GraphQLOutputType) -> bool:
594    """Check whether two types conflict
595
596    Two types conflict if both types could not apply to a value simultaneously.
597    Composite types are ignored as their individual field types will be compared later
598    recursively. However List and Non-Null types must match.
599    """
600    if is_list_type(type1):
601        return (
602            do_types_conflict(
603                cast(GraphQLList, type1).of_type, cast(GraphQLList, type2).of_type
604            )
605            if is_list_type(type2)
606            else True
607        )
608    if is_list_type(type2):
609        return True
610    if is_non_null_type(type1):
611        return (
612            do_types_conflict(
613                cast(GraphQLNonNull, type1).of_type, cast(GraphQLNonNull, type2).of_type
614            )
615            if is_non_null_type(type2)
616            else True
617        )
618    if is_non_null_type(type2):
619        return True
620    if is_leaf_type(type1) or is_leaf_type(type2):
621        return type1 is not type2
622    return False
623
624
625def get_fields_and_fragment_names(
626    context: ValidationContext,
627    cached_fields_and_fragment_names: Dict,
628    parent_type: Optional[GraphQLNamedType],
629    selection_set: SelectionSetNode,
630) -> Tuple[NodeAndDefCollection, List[str]]:
631    """Get fields and referenced fragment names
632
633    Given a selection set, return the collection of fields (a mapping of response name
634    to field nodes and definitions) as well as a list of fragment names referenced via
635    fragment spreads.
636    """
637    cached = cached_fields_and_fragment_names.get(selection_set)
638    if not cached:
639        node_and_defs: NodeAndDefCollection = {}
640        fragment_names: Dict[str, bool] = {}
641        collect_fields_and_fragment_names(
642            context, parent_type, selection_set, node_and_defs, fragment_names
643        )
644        cached = (node_and_defs, list(fragment_names))
645        cached_fields_and_fragment_names[selection_set] = cached
646    return cached
647
648
649def get_referenced_fields_and_fragment_names(
650    context: ValidationContext,
651    cached_fields_and_fragment_names: Dict,
652    fragment: FragmentDefinitionNode,
653) -> Tuple[NodeAndDefCollection, List[str]]:
654    """Get referenced fields and nested fragment names
655
656    Given a reference to a fragment, return the represented collection of fields as well
657    as a list of nested fragment names referenced via fragment spreads.
658    """
659    # Short-circuit building a type from the node if possible.
660    cached = cached_fields_and_fragment_names.get(fragment.selection_set)
661    if cached:
662        return cached
663
664    fragment_type = type_from_ast(context.schema, fragment.type_condition)
665    return get_fields_and_fragment_names(
666        context, cached_fields_and_fragment_names, fragment_type, fragment.selection_set
667    )
668
669
670def collect_fields_and_fragment_names(
671    context: ValidationContext,
672    parent_type: Optional[GraphQLNamedType],
673    selection_set: SelectionSetNode,
674    node_and_defs: NodeAndDefCollection,
675    fragment_names: Dict[str, bool],
676) -> None:
677    for selection in selection_set.selections:
678        if isinstance(selection, FieldNode):
679            field_name = selection.name.value
680            field_def = (
681                parent_type.fields.get(field_name)  # type: ignore
682                if is_object_type(parent_type) or is_interface_type(parent_type)
683                else None
684            )
685            response_name = selection.alias.value if selection.alias else field_name
686            if not node_and_defs.get(response_name):
687                node_and_defs[response_name] = []
688            node_and_defs[response_name].append(
689                cast(NodeAndDef, (parent_type, selection, field_def))
690            )
691        elif isinstance(selection, FragmentSpreadNode):
692            fragment_names[selection.name.value] = True
693        elif isinstance(selection, InlineFragmentNode):  # pragma: no cover else
694            type_condition = selection.type_condition
695            inline_fragment_type = (
696                type_from_ast(context.schema, type_condition)
697                if type_condition
698                else parent_type
699            )
700            collect_fields_and_fragment_names(
701                context,
702                inline_fragment_type,
703                selection.selection_set,
704                node_and_defs,
705                fragment_names,
706            )
707
708
709def subfield_conflicts(
710    conflicts: List[Conflict], response_name: str, node1: FieldNode, node2: FieldNode
711) -> Optional[Conflict]:
712    """Check whether there are conflicts between sub-fields.
713
714    Given a series of Conflicts which occurred between two sub-fields, generate a single
715    Conflict.
716    """
717    if conflicts:
718        return (
719            (response_name, [conflict[0] for conflict in conflicts]),
720            list(chain([node1], *[conflict[1] for conflict in conflicts])),
721            list(chain([node2], *[conflict[2] for conflict in conflicts])),
722        )
723    return None  # no conflict
724
725
726class PairSet:
727    """Pair set
728
729    A way to keep track of pairs of things when the ordering of the pair does not
730    matter. We do this by maintaining a sort of double adjacency sets.
731    """
732
733    __slots__ = ("_data",)
734
735    def __init__(self) -> None:
736        self._data: Dict[str, Dict[str, bool]] = {}
737
738    def has(self, a: str, b: str, are_mutually_exclusive: bool) -> bool:
739        first = self._data.get(a)
740        result = first and first.get(b)
741        if result is None:
742            return False
743        # `are_mutually_exclusive` being False is a superset of being True, hence if we
744        # want to know if this PairSet "has" these two with no exclusivity, we have to
745        # ensure it was added as such.
746        if not are_mutually_exclusive:
747            return not result
748        return True
749
750    def add(self, a: str, b: str, are_mutually_exclusive: bool) -> "PairSet":
751        self._pair_set_add(a, b, are_mutually_exclusive)
752        self._pair_set_add(b, a, are_mutually_exclusive)
753        return self
754
755    def _pair_set_add(self, a: str, b: str, are_mutually_exclusive: bool) -> None:
756        a_map = self._data.get(a)
757        if not a_map:
758            self._data[a] = a_map = {}
759        a_map[b] = are_mutually_exclusive
760