1#-----------------------------------------------------------------------------
2#   Copyright (c) 2008 by David P. D. Moss. All rights reserved.
3#
4#   Released under the BSD license. See the LICENSE file for details.
5#-----------------------------------------------------------------------------
6"""Set based operations for IP addresses and subnets."""
7
8import itertools as _itertools
9
10from netaddr.ip import (IPNetwork, IPAddress, IPRange, cidr_merge,
11    cidr_exclude, iprange_to_cidrs)
12
13from netaddr.compat import _sys_maxint, _dict_keys, _int_type
14
15
16def _subtract(supernet, subnets, subnet_idx, ranges):
17    """Calculate IPSet([supernet]) - IPSet(subnets).
18
19    Assumptions: subnets is sorted, subnet_idx points to the first
20    element in subnets that is a subnet of supernet.
21
22    Results are appended to the ranges parameter as tuples of in format
23    (version, first, last). Return value is the first subnet_idx that
24    does not point to a subnet of supernet (or len(subnets) if all
25    subsequents items are a subnet of supernet).
26    """
27    version = supernet._module.version
28    subnet = subnets[subnet_idx]
29    if subnet.first > supernet.first:
30        ranges.append((version, supernet.first, subnet.first - 1))
31
32    subnet_idx += 1
33    prev_subnet = subnet
34    while subnet_idx < len(subnets):
35        cur_subnet = subnets[subnet_idx]
36
37        if cur_subnet not in supernet:
38            break
39        if prev_subnet.last + 1 == cur_subnet.first:
40            # two adjacent, non-mergable IPNetworks
41            pass
42        else:
43            ranges.append((version, prev_subnet.last + 1, cur_subnet.first - 1))
44
45        subnet_idx += 1
46        prev_subnet = cur_subnet
47
48    first = prev_subnet.last + 1
49    last = supernet.last
50    if first <= last:
51        ranges.append((version, first, last))
52
53    return subnet_idx
54
55
56def _iter_merged_ranges(sorted_ranges):
57    """Iterate over sorted_ranges, merging where possible
58
59    Sorted ranges must be a sorted iterable of (version, first, last) tuples.
60    Merging occurs for pairs like [(4, 10, 42), (4, 43, 100)] which is merged
61    into (4, 10, 100), and leads to return value
62    ( IPAddress(10, 4), IPAddress(100, 4) ), which is suitable input for the
63    iprange_to_cidrs function.
64    """
65    if not sorted_ranges:
66        return
67
68    current_version, current_start, current_stop = sorted_ranges[0]
69
70    for next_version, next_start, next_stop in sorted_ranges[1:]:
71        if next_start == current_stop + 1 and next_version == current_version:
72            # Can be merged.
73            current_stop = next_stop
74            continue
75        # Cannot be merged.
76        yield (IPAddress(current_start, current_version),
77               IPAddress(current_stop, current_version))
78        current_start = next_start
79        current_stop = next_stop
80        current_version = next_version
81    yield (IPAddress(current_start, current_version),
82           IPAddress(current_stop, current_version))
83
84
85class IPSet(object):
86    """
87    Represents an unordered collection (set) of unique IP addresses and
88    subnets.
89
90    """
91    __slots__ = ('_cidrs', '__weakref__')
92
93    def __init__(self, iterable=None, flags=0):
94        """
95        Constructor.
96
97        :param iterable: (optional) an iterable containing IP addresses,
98            subnets or ranges.
99
100        :param flags: decides which rules are applied to the interpretation
101            of the addr value. See the netaddr.core namespace documentation
102            for supported constant values.
103
104        """
105        if isinstance(iterable, IPNetwork):
106            self._cidrs = {iterable.cidr: True}
107        elif isinstance(iterable, IPRange):
108            self._cidrs = dict.fromkeys(
109                iprange_to_cidrs(iterable[0], iterable[-1]), True)
110        elif isinstance(iterable, IPSet):
111            self._cidrs = dict.fromkeys(iterable.iter_cidrs(), True)
112        else:
113            self._cidrs = {}
114            if iterable is not None:
115                mergeable = []
116                for addr in iterable:
117                    if isinstance(addr, _int_type):
118                        addr = IPAddress(addr, flags=flags)
119                    mergeable.append(addr)
120
121                for cidr in cidr_merge(mergeable):
122                    self._cidrs[cidr] = True
123
124    def __getstate__(self):
125        """:return: Pickled state of an ``IPSet`` object."""
126        return tuple([cidr.__getstate__() for cidr in self._cidrs])
127
128    def __setstate__(self, state):
129        """
130        :param state: data used to unpickle a pickled ``IPSet`` object.
131
132        """
133        self._cidrs = dict.fromkeys(
134            (IPNetwork((value, prefixlen), version=version)
135             for value, prefixlen, version in state),
136            True)
137
138    def _compact_single_network(self, added_network):
139        """
140        Same as compact(), but assume that added_network is the only change and
141        that this IPSet was properly compacted before added_network was added.
142        This allows to perform compaction much faster. added_network must
143        already be present in self._cidrs.
144        """
145        added_first = added_network.first
146        added_last = added_network.last
147        added_version = added_network.version
148
149        # Check for supernets and subnets of added_network.
150        if added_network._prefixlen == added_network._module.width:
151            # This is a single IP address, i.e. /32 for IPv4 or /128 for IPv6.
152            # It does not have any subnets, so we only need to check for its
153            # potential supernets.
154            for potential_supernet in added_network.supernet():
155                if potential_supernet in self._cidrs:
156                    del self._cidrs[added_network]
157                    return
158        else:
159            # IPNetworks from self._cidrs that are subnets of added_network.
160            to_remove = []
161            for cidr in self._cidrs:
162                if (cidr._module.version != added_version or cidr == added_network):
163                    # We found added_network or some network of a different version.
164                    continue
165                first = cidr.first
166                last = cidr.last
167                if first >= added_first and last <= added_last:
168                    # cidr is a subnet of added_network. Remember to remove it.
169                    to_remove.append(cidr)
170                elif first <= added_first and last >= added_last:
171                    # cidr is a supernet of added_network. Remove added_network.
172                    del self._cidrs[added_network]
173                    # This IPSet was properly compacted before. Since added_network
174                    # is removed now, it must again be properly compacted -> done.
175                    assert (not to_remove)
176                    return
177            for item in to_remove:
178                del self._cidrs[item]
179
180        # Check if added_network can be merged with another network.
181
182        # Note that merging can only happen between networks of the same
183        # prefixlen. This just leaves 2 candidates: The IPNetworks just before
184        # and just after the added_network.
185        # This can be reduced to 1 candidate: 10.0.0.0/24 and 10.0.1.0/24 can
186        # be merged into into 10.0.0.0/23. But 10.0.1.0/24 and 10.0.2.0/24
187        # cannot be merged. With only 1 candidate, we might as well make a
188        # dictionary lookup.
189        shift_width = added_network._module.width - added_network.prefixlen
190        while added_network.prefixlen != 0:
191            # figure out if the least significant bit of the network part is 0 or 1.
192            the_bit = (added_network._value >> shift_width) & 1
193            if the_bit:
194                candidate = added_network.previous()
195            else:
196                candidate = added_network.next()
197
198            if candidate not in self._cidrs:
199                # The only possible merge does not work -> merge done
200                return
201            # Remove added_network&candidate, add merged network.
202            del self._cidrs[candidate]
203            del self._cidrs[added_network]
204            added_network.prefixlen -= 1
205            # Be sure that we set the host bits to 0 when we move the prefixlen.
206            # Otherwise, adding 255.255.255.255/32 will result in a merged
207            # 255.255.255.255/24 network, but we want 255.255.255.0/24.
208            shift_width += 1
209            added_network._value = (added_network._value >> shift_width) << shift_width
210            self._cidrs[added_network] = True
211
212    def compact(self):
213        """
214        Compact internal list of `IPNetwork` objects using a CIDR merge.
215        """
216        cidrs = cidr_merge(self._cidrs)
217        self._cidrs = dict.fromkeys(cidrs, True)
218
219    def __hash__(self):
220        """
221        Raises ``TypeError`` if this method is called.
222
223        .. note:: IPSet objects are not hashable and cannot be used as \
224            dictionary keys or as members of other sets. \
225        """
226        raise TypeError('IP sets are unhashable!')
227
228    def __contains__(self, ip):
229        """
230        :param ip: An IP address or subnet.
231
232        :return: ``True`` if IP address or subnet is a member of this IP set.
233        """
234        # Iterating over self._cidrs is an O(n) operation: 1000 items in
235        # self._cidrs would mean 1000 loops. Iterating over all possible
236        # supernets loops at most 32 times for IPv4 or 128 times for IPv6,
237        # no matter how many CIDRs this object contains.
238        supernet = IPNetwork(ip)
239        if supernet in self._cidrs:
240            return True
241        while supernet._prefixlen:
242            supernet._prefixlen -= 1
243            if supernet in self._cidrs:
244                return True
245        return False
246
247    def __nonzero__(self):
248        """Return True if IPSet contains at least one IP, else False"""
249        return bool(self._cidrs)
250
251    __bool__ = __nonzero__  #   Python 3.x.
252
253    def __iter__(self):
254        """
255        :return: an iterator over the IP addresses within this IP set.
256        """
257        return _itertools.chain(*sorted(self._cidrs))
258
259    def iter_cidrs(self):
260        """
261        :return: an iterator over individual IP subnets within this IP set.
262        """
263        return sorted(self._cidrs)
264
265    def add(self, addr, flags=0):
266        """
267        Adds an IP address or subnet or IPRange to this IP set. Has no effect if
268        it is already present.
269
270        Note that where possible the IP address or subnet is merged with other
271        members of the set to form more concise CIDR blocks.
272
273        :param addr: An IP address or subnet in either string or object form, or
274            an IPRange object.
275
276        :param flags: decides which rules are applied to the interpretation
277            of the addr value. See the netaddr.core namespace documentation
278            for supported constant values.
279
280        """
281        if isinstance(addr, IPRange):
282            new_cidrs = dict.fromkeys(
283                iprange_to_cidrs(addr[0], addr[-1]), True)
284            self._cidrs.update(new_cidrs)
285            self.compact()
286            return
287        if isinstance(addr, IPNetwork):
288            # Networks like 10.1.2.3/8 need to be normalized to 10.0.0.0/8
289            addr = addr.cidr
290        elif isinstance(addr, _int_type):
291            addr = IPNetwork(IPAddress(addr, flags=flags))
292        else:
293            addr = IPNetwork(addr)
294
295        self._cidrs[addr] = True
296        self._compact_single_network(addr)
297
298    def remove(self, addr, flags=0):
299        """
300        Removes an IP address or subnet or IPRange from this IP set. Does
301        nothing if it is not already a member.
302
303        Note that this method behaves more like discard() found in regular
304        Python sets because it doesn't raise KeyError exceptions if the
305        IP address or subnet is question does not exist. It doesn't make sense
306        to fully emulate that behaviour here as IP sets contain groups of
307        individual IP addresses as individual set members using IPNetwork
308        objects.
309
310        :param addr: An IP address or subnet, or an IPRange.
311
312        :param flags: decides which rules are applied to the interpretation
313            of the addr value. See the netaddr.core namespace documentation
314            for supported constant values.
315
316        """
317        if isinstance(addr, IPRange):
318            cidrs = iprange_to_cidrs(addr[0], addr[-1])
319            for cidr in cidrs:
320                self.remove(cidr)
321            return
322
323        if isinstance(addr, _int_type):
324            addr = IPAddress(addr, flags=flags)
325        else:
326            addr = IPNetwork(addr)
327
328        #   This add() is required for address blocks provided that are larger
329        #   than blocks found within the set but have overlaps. e.g. :-
330        #
331        #   >>> IPSet(['192.0.2.0/24']).remove('192.0.2.0/23')
332        #   IPSet([])
333        #
334        self.add(addr)
335
336        remainder = None
337        matching_cidr = None
338
339        #   Search for a matching CIDR and exclude IP from it.
340        for cidr in self._cidrs:
341            if addr in cidr:
342                remainder = cidr_exclude(cidr, addr)
343                matching_cidr = cidr
344                break
345
346        #   Replace matching CIDR with remaining CIDR elements.
347        if remainder is not None:
348            del self._cidrs[matching_cidr]
349            for cidr in remainder:
350                self._cidrs[cidr] = True
351                # No call to self.compact() is needed. Removing an IPNetwork cannot
352                # create mergable networks.
353
354    def pop(self):
355        """
356        Removes and returns an arbitrary IP address or subnet from this IP
357        set.
358
359        :return: An IP address or subnet.
360        """
361        return self._cidrs.popitem()[0]
362
363    def isdisjoint(self, other):
364        """
365        :param other: an IP set.
366
367        :return: ``True`` if this IP set has no elements (IP addresses
368            or subnets) in common with other. Intersection *must* be an
369            empty set.
370        """
371        result = self.intersection(other)
372        return not result
373
374    def copy(self):
375        """:return: a shallow copy of this IP set."""
376        obj_copy = self.__class__()
377        obj_copy._cidrs.update(self._cidrs)
378        return obj_copy
379
380    def update(self, iterable, flags=0):
381        """
382        Update the contents of this IP set with the union of itself and
383        other IP set.
384
385        :param iterable: an iterable containing IP addresses, subnets or ranges.
386
387        :param flags: decides which rules are applied to the interpretation
388            of the addr value. See the netaddr.core namespace documentation
389            for supported constant values.
390
391        """
392        if isinstance(iterable, IPSet):
393            self._cidrs = dict.fromkeys(
394                (ip for ip in cidr_merge(_dict_keys(self._cidrs)
395                                         + _dict_keys(iterable._cidrs))), True)
396            return
397        elif isinstance(iterable, (IPNetwork, IPRange)):
398            self.add(iterable)
399            return
400
401        if not hasattr(iterable, '__iter__'):
402            raise TypeError('an iterable was expected!')
403        #   An iterable containing IP addresses or subnets.
404        mergeable = []
405        for addr in iterable:
406            if isinstance(addr, _int_type):
407                addr = IPAddress(addr, flags=flags)
408            mergeable.append(addr)
409
410        for cidr in cidr_merge(_dict_keys(self._cidrs) + mergeable):
411            self._cidrs[cidr] = True
412
413        self.compact()
414
415    def clear(self):
416        """Remove all IP addresses and subnets from this IP set."""
417        self._cidrs = {}
418
419    def __eq__(self, other):
420        """
421        :param other: an IP set
422
423        :return: ``True`` if this IP set is equivalent to the ``other`` IP set,
424            ``False`` otherwise.
425        """
426        try:
427            return self._cidrs == other._cidrs
428        except AttributeError:
429            return NotImplemented
430
431    def __ne__(self, other):
432        """
433        :param other: an IP set
434
435        :return: ``False`` if this IP set is equivalent to the ``other`` IP set,
436            ``True`` otherwise.
437        """
438        try:
439            return self._cidrs != other._cidrs
440        except AttributeError:
441            return NotImplemented
442
443    def __lt__(self, other):
444        """
445        :param other: an IP set
446
447        :return: ``True`` if this IP set is less than the ``other`` IP set,
448            ``False`` otherwise.
449        """
450        if not hasattr(other, '_cidrs'):
451            return NotImplemented
452
453        return self.size < other.size and self.issubset(other)
454
455    def issubset(self, other):
456        """
457        :param other: an IP set.
458
459        :return: ``True`` if every IP address and subnet in this IP set
460            is found within ``other``.
461        """
462        for cidr in self._cidrs:
463            if cidr not in other:
464                return False
465        return True
466
467    __le__ = issubset
468
469    def __gt__(self, other):
470        """
471        :param other: an IP set.
472
473        :return: ``True`` if this IP set is greater than the ``other`` IP set,
474            ``False`` otherwise.
475        """
476        if not hasattr(other, '_cidrs'):
477            return NotImplemented
478
479        return self.size > other.size and self.issuperset(other)
480
481    def issuperset(self, other):
482        """
483        :param other: an IP set.
484
485        :return: ``True`` if every IP address and subnet in other IP set
486            is found within this one.
487        """
488        if not hasattr(other, '_cidrs'):
489            return NotImplemented
490
491        for cidr in other._cidrs:
492            if cidr not in self:
493                return False
494        return True
495
496    __ge__ = issuperset
497
498    def union(self, other):
499        """
500        :param other: an IP set.
501
502        :return: the union of this IP set and another as a new IP set
503            (combines IP addresses and subnets from both sets).
504        """
505        ip_set = self.copy()
506        ip_set.update(other)
507        return ip_set
508
509    __or__ = union
510
511    def intersection(self, other):
512        """
513        :param other: an IP set.
514
515        :return: the intersection of this IP set and another as a new IP set.
516            (IP addresses and subnets common to both sets).
517        """
518        result_cidrs = {}
519
520        own_nets = sorted(self._cidrs)
521        other_nets = sorted(other._cidrs)
522        own_idx = 0
523        other_idx = 0
524        own_len = len(own_nets)
525        other_len = len(other_nets)
526        while own_idx < own_len and other_idx < other_len:
527            own_cur = own_nets[own_idx]
528            other_cur = other_nets[other_idx]
529
530            if own_cur == other_cur:
531                result_cidrs[own_cur] = True
532                own_idx += 1
533                other_idx += 1
534            elif own_cur in other_cur:
535                result_cidrs[own_cur] = True
536                own_idx += 1
537            elif other_cur in own_cur:
538                result_cidrs[other_cur] = True
539                other_idx += 1
540            else:
541                # own_cur and other_cur have nothing in common
542                if own_cur < other_cur:
543                    own_idx += 1
544                else:
545                    other_idx += 1
546
547        # We ran out of networks in own_nets or other_nets. Either way, there
548        # can be no further result_cidrs.
549        result = IPSet()
550        result._cidrs = result_cidrs
551        return result
552
553    __and__ = intersection
554
555    def symmetric_difference(self, other):
556        """
557        :param other: an IP set.
558
559        :return: the symmetric difference of this IP set and another as a new
560            IP set (all IP addresses and subnets that are in exactly one
561            of the sets).
562        """
563        # In contrast to intersection() and difference(), we cannot construct
564        # the result_cidrs easily. Some cidrs may have to be merged, e.g. for
565        # IPSet(["10.0.0.0/32"]).symmetric_difference(IPSet(["10.0.0.1/32"])).
566        result_ranges = []
567
568        own_nets = sorted(self._cidrs)
569        other_nets = sorted(other._cidrs)
570        own_idx = 0
571        other_idx = 0
572        own_len = len(own_nets)
573        other_len = len(other_nets)
574        while own_idx < own_len and other_idx < other_len:
575            own_cur = own_nets[own_idx]
576            other_cur = other_nets[other_idx]
577
578            if own_cur == other_cur:
579                own_idx += 1
580                other_idx += 1
581            elif own_cur in other_cur:
582                own_idx = _subtract(other_cur, own_nets, own_idx, result_ranges)
583                other_idx += 1
584            elif other_cur in own_cur:
585                other_idx = _subtract(own_cur, other_nets, other_idx, result_ranges)
586                own_idx += 1
587            else:
588                # own_cur and other_cur have nothing in common
589                if own_cur < other_cur:
590                    result_ranges.append((own_cur._module.version,
591                                          own_cur.first, own_cur.last))
592                    own_idx += 1
593                else:
594                    result_ranges.append((other_cur._module.version,
595                                          other_cur.first, other_cur.last))
596                    other_idx += 1
597
598        # If the above loop terminated because it processed all cidrs of
599        # "other", then any remaining cidrs in self must be part of the result.
600        while own_idx < own_len:
601            own_cur = own_nets[own_idx]
602            result_ranges.append((own_cur._module.version,
603                                  own_cur.first, own_cur.last))
604            own_idx += 1
605
606        # If the above loop terminated because it processed all cidrs of
607        # self, then any remaining cidrs in "other" must be part of the result.
608        while other_idx < other_len:
609            other_cur = other_nets[other_idx]
610            result_ranges.append((other_cur._module.version,
611                                  other_cur.first, other_cur.last))
612            other_idx += 1
613
614        result = IPSet()
615        for start, stop in _iter_merged_ranges(result_ranges):
616            cidrs = iprange_to_cidrs(start, stop)
617            for cidr in cidrs:
618                result._cidrs[cidr] = True
619        return result
620
621    __xor__ = symmetric_difference
622
623    def difference(self, other):
624        """
625        :param other: an IP set.
626
627        :return: the difference between this IP set and another as a new IP
628            set (all IP addresses and subnets that are in this IP set but
629            not found in the other.)
630        """
631        result_ranges = []
632        result_cidrs = {}
633
634        own_nets = sorted(self._cidrs)
635        other_nets = sorted(other._cidrs)
636        own_idx = 0
637        other_idx = 0
638        own_len = len(own_nets)
639        other_len = len(other_nets)
640        while own_idx < own_len and other_idx < other_len:
641            own_cur = own_nets[own_idx]
642            other_cur = other_nets[other_idx]
643
644            if own_cur == other_cur:
645                own_idx += 1
646                other_idx += 1
647            elif own_cur in other_cur:
648                own_idx += 1
649            elif other_cur in own_cur:
650                other_idx = _subtract(own_cur, other_nets, other_idx,
651                                      result_ranges)
652                own_idx += 1
653            else:
654                # own_cur and other_cur have nothing in common
655                if own_cur < other_cur:
656                    result_cidrs[own_cur] = True
657                    own_idx += 1
658                else:
659                    other_idx += 1
660
661        # If the above loop terminated because it processed all cidrs of
662        # "other", then any remaining cidrs in self must be part of the result.
663        while own_idx < own_len:
664            result_cidrs[own_nets[own_idx]] = True
665            own_idx += 1
666
667        for start, stop in _iter_merged_ranges(result_ranges):
668            for cidr in iprange_to_cidrs(start, stop):
669                result_cidrs[cidr] = True
670
671        result = IPSet()
672        result._cidrs = result_cidrs
673        return result
674
675    __sub__ = difference
676
677    def __len__(self):
678        """
679        :return: the cardinality of this IP set (i.e. sum of individual IP \
680            addresses). Raises ``IndexError`` if size > maxint (a Python \
681            limitation). Use the .size property for subnets of any size.
682        """
683        size = self.size
684        if size > _sys_maxint:
685            raise IndexError(
686                "range contains more than %d (sys.maxint) IP addresses!"
687                "Use the .size property instead." % _sys_maxint)
688        return size
689
690    @property
691    def size(self):
692        """
693        The cardinality of this IP set (based on the number of individual IP
694        addresses including those implicitly defined in subnets).
695        """
696        return sum([cidr.size for cidr in self._cidrs])
697
698    def __repr__(self):
699        """:return: Python statement to create an equivalent object"""
700        return 'IPSet(%r)' % [str(c) for c in sorted(self._cidrs)]
701
702    __str__ = __repr__
703
704    def iscontiguous(self):
705        """
706        Returns True if the members of the set form a contiguous IP
707        address range (with no gaps), False otherwise.
708
709        :return: ``True`` if the ``IPSet`` object is contiguous.
710        """
711        cidrs = self.iter_cidrs()
712        if len(cidrs) > 1:
713            previous = cidrs[0][0]
714            for cidr in cidrs:
715                if cidr[0] != previous:
716                    return False
717                previous = cidr[-1] + 1
718        return True
719
720    def iprange(self):
721        """
722        Generates an IPRange for this IPSet, if all its members
723        form a single contiguous sequence.
724
725        Raises ``ValueError`` if the set is not contiguous.
726
727        :return: An ``IPRange`` for all IPs in the IPSet.
728        """
729        if self.iscontiguous():
730            cidrs = self.iter_cidrs()
731            if not cidrs:
732                return None
733            return IPRange(cidrs[0][0], cidrs[-1][-1])
734        else:
735            raise ValueError("IPSet is not contiguous")
736
737    def iter_ipranges(self):
738        """Generate the merged IPRanges for this IPSet.
739
740        In contrast to self.iprange(), this will work even when the IPSet is
741        not contiguous. Adjacent IPRanges will be merged together, so you
742        get the minimal number of IPRanges.
743        """
744        sorted_ranges = [(cidr._module.version, cidr.first, cidr.last) for
745                         cidr in self.iter_cidrs()]
746
747        for start, stop in _iter_merged_ranges(sorted_ranges):
748            yield IPRange(start, stop)
749