1import warnings
2from django.core.exceptions import FieldError
3from django.db.models import Lookup, Transform, IntegerField
4from django.db.models.lookups import EndsWith, IEndsWith, StartsWith, IStartsWith, Regex, IRegex
5import ipaddress
6from netfields.fields import InetAddressField, CidrAddressField
7
8
9class InvalidLookup(Lookup):
10    """
11    Emulate Django 1.9 error for unsupported lookups
12    """
13    def as_sql(self, qn, connection):
14        raise FieldError("Unsupported lookup '%s'" % self.lookup_name)
15
16
17class InvalidSearchLookup(Lookup):
18    """
19    Emulate Django 1.9 error for unsupported search lookup
20    """
21    lookup_name = 'search'
22
23    def as_sql(self, qn, connection):
24        raise NotImplementedError("Full-text search is not implemented for this database backend")
25
26
27class NetFieldDecoratorMixin(object):
28    def process_lhs(self, qn, connection, lhs=None):
29        lhs = lhs or self.lhs
30        lhs_string, lhs_params = qn.compile(lhs)
31        if isinstance(lhs.source if hasattr(lhs, 'source') else lhs.output_field, InetAddressField):
32            lhs_string = 'HOST(%s)' % lhs_string
33        elif isinstance(lhs.source if hasattr(lhs, 'source') else lhs.output_field, CidrAddressField):
34            lhs_string = 'TEXT(%s)' % lhs_string
35        return lhs_string, lhs_params
36
37
38class EndsWith(NetFieldDecoratorMixin, EndsWith):
39    pass
40
41
42class IEndsWith(NetFieldDecoratorMixin, IEndsWith):
43    pass
44
45
46class StartsWith(NetFieldDecoratorMixin, StartsWith):
47    pass
48
49
50class IStartsWith(NetFieldDecoratorMixin, IStartsWith):
51    pass
52
53
54class Regex(NetFieldDecoratorMixin, Regex):
55    pass
56
57
58class IRegex(NetFieldDecoratorMixin, IRegex):
59    pass
60
61
62class NetworkLookup(object):
63    def get_prep_lookup(self):
64        if hasattr(self.rhs, 'resolve_expression'):
65            return self.rhs
66        if isinstance(self.rhs, ipaddress._BaseNetwork):
67            return str(self.rhs)
68        return str(ipaddress.ip_network(self.rhs))
69
70
71class AddressLookup(object):
72    def get_prep_lookup(self):
73        if hasattr(self.rhs, 'resolve_expression'):
74            return self.rhs
75        if isinstance(self.rhs, ipaddress._BaseAddress):
76            return str(self.rhs)
77        return str(ipaddress.ip_interface(self.rhs))
78
79
80class NetContains(AddressLookup, Lookup):
81    lookup_name = 'net_contains'
82
83    def as_sql(self, qn, connection):
84        lhs, lhs_params = self.process_lhs(qn, connection)
85        rhs, rhs_params = self.process_rhs(qn, connection)
86        params = lhs_params + rhs_params
87        return '%s >> %s' % (lhs, rhs), params
88
89
90class NetContained(NetworkLookup, Lookup):
91    lookup_name = 'net_contained'
92
93    def as_sql(self, qn, connection):
94        lhs, lhs_params = self.process_lhs(qn, connection)
95        rhs, rhs_params = self.process_rhs(qn, connection)
96        params = lhs_params + rhs_params
97        return '%s << %s' % (lhs, rhs), params
98
99
100class NetContainsOrEquals(AddressLookup, Lookup):
101    lookup_name = 'net_contains_or_equals'
102
103    def as_sql(self, qn, connection):
104        lhs, lhs_params = self.process_lhs(qn, connection)
105        rhs, rhs_params = self.process_rhs(qn, connection)
106        params = lhs_params + rhs_params
107        return '%s >>= %s' % (lhs, rhs), params
108
109
110class NetContainedOrEqual(NetworkLookup, Lookup):
111    lookup_name = 'net_contained_or_equal'
112
113    def as_sql(self, qn, connection):
114        lhs, lhs_params = self.process_lhs(qn, connection)
115        rhs, rhs_params = self.process_rhs(qn, connection)
116        params = lhs_params + rhs_params
117        return '%s <<= %s' % (lhs, rhs), params
118
119
120class NetOverlaps(NetworkLookup, Lookup):
121    lookup_name = 'net_overlaps'
122
123    def as_sql(self, qn, connection):
124        lhs, lhs_params = self.process_lhs(qn, connection)
125        rhs, rhs_params = self.process_rhs(qn, connection)
126        params = lhs_params + rhs_params
127        return '%s && %s' % (lhs, rhs), params
128
129
130class HostMatches(AddressLookup, Lookup):
131    lookup_name = 'host'
132
133    def as_sql(self, qn, connection):
134        lhs, lhs_params = self.process_lhs(qn, connection)
135        rhs, rhs_params = self.process_rhs(qn, connection)
136        params = lhs_params + rhs_params
137        return 'HOST(%s) = HOST(%s)' % (lhs, rhs), params
138
139
140class Family(Transform):
141    lookup_name = 'family'
142
143    def as_sql(self, compiler, connection):
144        lhs, params = compiler.compile(self.lhs)
145        return "family(%s)" % lhs, params
146
147    @property
148    def output_field(self):
149        return IntegerField()
150
151
152class _PrefixlenMixin(object):
153    format_string = None
154
155    def as_sql(self, qn, connection):
156        warnings.warn(
157            'min_prefixlen and max_prefixlen will be depreciated in the future; '
158            'use prefixlen__gte and prefixlen__lte respectively',
159            DeprecationWarning
160        )
161        assert self.format_string is not None, "Prefixlen lookups must specify a format_string"
162        lhs, lhs_params = self.process_lhs(qn, connection)
163        rhs, rhs_params = self.process_rhs(qn, connection)
164        params = lhs_params + rhs_params
165        return self.format_string % (lhs, rhs), params
166
167    def process_lhs(self, qn, connection, lhs=None):
168        lhs = lhs or self.lhs
169        lhs_string, lhs_params = qn.compile(lhs)
170        lhs_string = 'MASKLEN(%s)' % lhs_string
171        return lhs_string, lhs_params
172
173    def get_prep_lookup(self):
174        return str(int(self.rhs))
175
176
177class MaxPrefixlen(_PrefixlenMixin, Lookup):
178    lookup_name = 'max_prefixlen'
179    format_string = '%s <= %s'
180
181
182class MinPrefixlen(_PrefixlenMixin, Lookup):
183    lookup_name = 'min_prefixlen'
184    format_string = '%s >= %s'
185
186
187class Prefixlen(Transform):
188    lookup_name = 'prefixlen'
189
190    def as_sql(self, compiler, connection):
191        lhs, params = compiler.compile(self.lhs)
192        return "masklen(%s)" % lhs, params
193
194    @property
195    def output_field(self):
196        return IntegerField()
197