1from itertools import product
2from funcy import group_by, join_with, lcat, lmap
3
4from django.db.models import Subquery
5from django.db.models.query import QuerySet
6from django.db.models.sql import OR
7from django.db.models.sql.query import Query, ExtraWhere
8from django.db.models.sql.where import NothingNode, SubqueryConstraint
9from django.db.models.lookups import Lookup, Exact, In, IsNull
10from django.db.models.expressions import BaseExpression, Exists
11
12from .conf import settings
13from .invalidation import serializable_fields
14
15
16def dnfs(qs):
17    """
18    Converts query condition tree into a DNF of eq conds.
19    Separately for each alias.
20
21    Any negations, conditions with lookups other than __exact or __in,
22    conditions on joined models and subrequests are ignored.
23    __in is converted into = or = or = ...
24    """
25    SOME = object()
26    SOME_TREE = [[(None, None, SOME, True)]]
27
28    def negate(term):
29        return (term[0], term[1], term[2], not term[3])
30
31    def _dnf(where):
32        """
33        Constructs DNF of where tree consisting of terms in form:
34            (alias, attribute, value, negation)
35        meaning `alias.attribute = value`
36         or `not alias.attribute = value` if negation is False
37
38        Any conditions other then eq are dropped.
39        """
40        if isinstance(where, Lookup):
41            # If where.lhs don't refer to a field then don't bother
42            if not hasattr(where.lhs, 'target'):
43                return SOME_TREE
44            # Don't bother with complex right hand side either
45            if isinstance(where.rhs, (QuerySet, Query, BaseExpression)):
46                return SOME_TREE
47            # Skip conditions on non-serialized fields
48            if where.lhs.target not in serializable_fields(where.lhs.target.model):
49                return SOME_TREE
50
51            attname = where.lhs.target.attname
52            if isinstance(where, Exact):
53                return [[(where.lhs.alias, attname, where.rhs, True)]]
54            elif isinstance(where, IsNull):
55                return [[(where.lhs.alias, attname, None, where.rhs)]]
56            elif isinstance(where, In) and len(where.rhs) < settings.CACHEOPS_LONG_DISJUNCTION:
57                return [[(where.lhs.alias, attname, v, True)] for v in where.rhs]
58            else:
59                return SOME_TREE
60        elif isinstance(where, NothingNode):
61            return []
62        elif isinstance(where, (ExtraWhere, SubqueryConstraint, Exists)):
63            return SOME_TREE
64        elif len(where) == 0:
65            return [[]]
66        else:
67            chilren_dnfs = lmap(_dnf, where.children)
68
69            if len(chilren_dnfs) == 0:
70                return [[]]
71            elif len(chilren_dnfs) == 1:
72                result = chilren_dnfs[0]
73            else:
74                # Just unite children joined with OR
75                if where.connector == OR:
76                    result = lcat(chilren_dnfs)
77                # Use Cartesian product to AND children
78                else:
79                    result = lmap(lcat, product(*chilren_dnfs))
80
81            # Negating and expanding brackets
82            if where.negated:
83                result = [lmap(negate, p) for p in product(*result)]
84
85            return result
86
87    def clean_conj(conj, for_alias):
88        conds = {}
89        for alias, attname, value, negation in conj:
90            # "SOME" conds, negated conds and conds for other aliases should be stripped
91            if value is not SOME and negation and alias == for_alias:
92                # Conjs with fields eq 2 different values will never cause invalidation
93                if attname in conds and conds[attname] != value:
94                    return None
95                conds[attname] = value
96        return conds
97
98    def clean_dnf(tree, aliases):
99        cleaned = [clean_conj(conj, alias) for conj in tree for alias in aliases]
100        # Remove deleted conjunctions
101        cleaned = [conj for conj in cleaned if conj is not None]
102        # Any empty conjunction eats up the rest
103        # NOTE: a more elaborate DNF reduction is not really needed,
104        #       just keep your querysets sane.
105        if not all(cleaned):
106            return [{}]
107        return cleaned
108
109    def query_dnf(query):
110        def table_for(alias):
111            if alias == main_alias:
112                return alias
113            return query.alias_map[alias].table_name
114
115        dnf = _dnf(query.where)
116
117        # NOTE: we exclude content_type as it never changes and will hold dead invalidation info
118        main_alias = query.model._meta.db_table
119        aliases = {alias for alias, join in query.alias_map.items()
120                   if query.alias_refcount[alias]} \
121                | {main_alias} - {'django_content_type'}
122        tables = group_by(table_for, aliases)
123        return {table: clean_dnf(dnf, table_aliases) for table, table_aliases in tables.items()}
124
125    if qs.query.combined_queries:
126        dnfs_ = join_with(lcat, (query_dnf(q) for q in qs.query.combined_queries))
127    else:
128        dnfs_ = query_dnf(qs.query)
129
130    # Add any subqueries used for annotation
131    if qs.query.annotations:
132        subqueries = (
133            # Django 3.0+ sets Subquery.query
134            query_dnf(getattr(q, 'query', None) or getattr(q, 'queryset').query)
135            for q in qs.query.annotations.values()
136            if type(q) is Subquery
137        )
138        dnfs_.update(join_with(lcat, subqueries))
139
140    return dnfs_
141