1#----------------------------------------------------------------------------- 2# Copyright (c) 2008 by David P. D. Moss. All rights reserved. 3# 4# Released under the BSD license. See the LICENSE file for details. 5#----------------------------------------------------------------------------- 6"""Set based operations for IP addresses and subnets.""" 7 8import itertools as _itertools 9 10from netaddr.ip import (IPNetwork, IPAddress, IPRange, cidr_merge, 11 cidr_exclude, iprange_to_cidrs) 12 13from netaddr.compat import _sys_maxint, _dict_keys, _int_type 14 15 16def _subtract(supernet, subnets, subnet_idx, ranges): 17 """Calculate IPSet([supernet]) - IPSet(subnets). 18 19 Assumptions: subnets is sorted, subnet_idx points to the first 20 element in subnets that is a subnet of supernet. 21 22 Results are appended to the ranges parameter as tuples of in format 23 (version, first, last). Return value is the first subnet_idx that 24 does not point to a subnet of supernet (or len(subnets) if all 25 subsequents items are a subnet of supernet). 26 """ 27 version = supernet._module.version 28 subnet = subnets[subnet_idx] 29 if subnet.first > supernet.first: 30 ranges.append((version, supernet.first, subnet.first - 1)) 31 32 subnet_idx += 1 33 prev_subnet = subnet 34 while subnet_idx < len(subnets): 35 cur_subnet = subnets[subnet_idx] 36 37 if cur_subnet not in supernet: 38 break 39 if prev_subnet.last + 1 == cur_subnet.first: 40 # two adjacent, non-mergable IPNetworks 41 pass 42 else: 43 ranges.append((version, prev_subnet.last + 1, cur_subnet.first - 1)) 44 45 subnet_idx += 1 46 prev_subnet = cur_subnet 47 48 first = prev_subnet.last + 1 49 last = supernet.last 50 if first <= last: 51 ranges.append((version, first, last)) 52 53 return subnet_idx 54 55 56def _iter_merged_ranges(sorted_ranges): 57 """Iterate over sorted_ranges, merging where possible 58 59 Sorted ranges must be a sorted iterable of (version, first, last) tuples. 60 Merging occurs for pairs like [(4, 10, 42), (4, 43, 100)] which is merged 61 into (4, 10, 100), and leads to return value 62 ( IPAddress(10, 4), IPAddress(100, 4) ), which is suitable input for the 63 iprange_to_cidrs function. 64 """ 65 if not sorted_ranges: 66 return 67 68 current_version, current_start, current_stop = sorted_ranges[0] 69 70 for next_version, next_start, next_stop in sorted_ranges[1:]: 71 if next_start == current_stop + 1 and next_version == current_version: 72 # Can be merged. 73 current_stop = next_stop 74 continue 75 # Cannot be merged. 76 yield (IPAddress(current_start, current_version), 77 IPAddress(current_stop, current_version)) 78 current_start = next_start 79 current_stop = next_stop 80 current_version = next_version 81 yield (IPAddress(current_start, current_version), 82 IPAddress(current_stop, current_version)) 83 84 85class IPSet(object): 86 """ 87 Represents an unordered collection (set) of unique IP addresses and 88 subnets. 89 90 """ 91 __slots__ = ('_cidrs', '__weakref__') 92 93 def __init__(self, iterable=None, flags=0): 94 """ 95 Constructor. 96 97 :param iterable: (optional) an iterable containing IP addresses, 98 subnets or ranges. 99 100 :param flags: decides which rules are applied to the interpretation 101 of the addr value. See the netaddr.core namespace documentation 102 for supported constant values. 103 104 """ 105 if isinstance(iterable, IPNetwork): 106 self._cidrs = {iterable.cidr: True} 107 elif isinstance(iterable, IPRange): 108 self._cidrs = dict.fromkeys( 109 iprange_to_cidrs(iterable[0], iterable[-1]), True) 110 elif isinstance(iterable, IPSet): 111 self._cidrs = dict.fromkeys(iterable.iter_cidrs(), True) 112 else: 113 self._cidrs = {} 114 if iterable is not None: 115 mergeable = [] 116 for addr in iterable: 117 if isinstance(addr, _int_type): 118 addr = IPAddress(addr, flags=flags) 119 mergeable.append(addr) 120 121 for cidr in cidr_merge(mergeable): 122 self._cidrs[cidr] = True 123 124 def __getstate__(self): 125 """:return: Pickled state of an ``IPSet`` object.""" 126 return tuple([cidr.__getstate__() for cidr in self._cidrs]) 127 128 def __setstate__(self, state): 129 """ 130 :param state: data used to unpickle a pickled ``IPSet`` object. 131 132 """ 133 self._cidrs = dict.fromkeys( 134 (IPNetwork((value, prefixlen), version=version) 135 for value, prefixlen, version in state), 136 True) 137 138 def _compact_single_network(self, added_network): 139 """ 140 Same as compact(), but assume that added_network is the only change and 141 that this IPSet was properly compacted before added_network was added. 142 This allows to perform compaction much faster. added_network must 143 already be present in self._cidrs. 144 """ 145 added_first = added_network.first 146 added_last = added_network.last 147 added_version = added_network.version 148 149 # Check for supernets and subnets of added_network. 150 if added_network._prefixlen == added_network._module.width: 151 # This is a single IP address, i.e. /32 for IPv4 or /128 for IPv6. 152 # It does not have any subnets, so we only need to check for its 153 # potential supernets. 154 for potential_supernet in added_network.supernet(): 155 if potential_supernet in self._cidrs: 156 del self._cidrs[added_network] 157 return 158 else: 159 # IPNetworks from self._cidrs that are subnets of added_network. 160 to_remove = [] 161 for cidr in self._cidrs: 162 if (cidr._module.version != added_version or cidr == added_network): 163 # We found added_network or some network of a different version. 164 continue 165 first = cidr.first 166 last = cidr.last 167 if first >= added_first and last <= added_last: 168 # cidr is a subnet of added_network. Remember to remove it. 169 to_remove.append(cidr) 170 elif first <= added_first and last >= added_last: 171 # cidr is a supernet of added_network. Remove added_network. 172 del self._cidrs[added_network] 173 # This IPSet was properly compacted before. Since added_network 174 # is removed now, it must again be properly compacted -> done. 175 assert (not to_remove) 176 return 177 for item in to_remove: 178 del self._cidrs[item] 179 180 # Check if added_network can be merged with another network. 181 182 # Note that merging can only happen between networks of the same 183 # prefixlen. This just leaves 2 candidates: The IPNetworks just before 184 # and just after the added_network. 185 # This can be reduced to 1 candidate: 10.0.0.0/24 and 10.0.1.0/24 can 186 # be merged into into 10.0.0.0/23. But 10.0.1.0/24 and 10.0.2.0/24 187 # cannot be merged. With only 1 candidate, we might as well make a 188 # dictionary lookup. 189 shift_width = added_network._module.width - added_network.prefixlen 190 while added_network.prefixlen != 0: 191 # figure out if the least significant bit of the network part is 0 or 1. 192 the_bit = (added_network._value >> shift_width) & 1 193 if the_bit: 194 candidate = added_network.previous() 195 else: 196 candidate = added_network.next() 197 198 if candidate not in self._cidrs: 199 # The only possible merge does not work -> merge done 200 return 201 # Remove added_network&candidate, add merged network. 202 del self._cidrs[candidate] 203 del self._cidrs[added_network] 204 added_network.prefixlen -= 1 205 # Be sure that we set the host bits to 0 when we move the prefixlen. 206 # Otherwise, adding 255.255.255.255/32 will result in a merged 207 # 255.255.255.255/24 network, but we want 255.255.255.0/24. 208 shift_width += 1 209 added_network._value = (added_network._value >> shift_width) << shift_width 210 self._cidrs[added_network] = True 211 212 def compact(self): 213 """ 214 Compact internal list of `IPNetwork` objects using a CIDR merge. 215 """ 216 cidrs = cidr_merge(self._cidrs) 217 self._cidrs = dict.fromkeys(cidrs, True) 218 219 def __hash__(self): 220 """ 221 Raises ``TypeError`` if this method is called. 222 223 .. note:: IPSet objects are not hashable and cannot be used as \ 224 dictionary keys or as members of other sets. \ 225 """ 226 raise TypeError('IP sets are unhashable!') 227 228 def __contains__(self, ip): 229 """ 230 :param ip: An IP address or subnet. 231 232 :return: ``True`` if IP address or subnet is a member of this IP set. 233 """ 234 # Iterating over self._cidrs is an O(n) operation: 1000 items in 235 # self._cidrs would mean 1000 loops. Iterating over all possible 236 # supernets loops at most 32 times for IPv4 or 128 times for IPv6, 237 # no matter how many CIDRs this object contains. 238 supernet = IPNetwork(ip) 239 if supernet in self._cidrs: 240 return True 241 while supernet._prefixlen: 242 supernet._prefixlen -= 1 243 if supernet in self._cidrs: 244 return True 245 return False 246 247 def __nonzero__(self): 248 """Return True if IPSet contains at least one IP, else False""" 249 return bool(self._cidrs) 250 251 __bool__ = __nonzero__ # Python 3.x. 252 253 def __iter__(self): 254 """ 255 :return: an iterator over the IP addresses within this IP set. 256 """ 257 return _itertools.chain(*sorted(self._cidrs)) 258 259 def iter_cidrs(self): 260 """ 261 :return: an iterator over individual IP subnets within this IP set. 262 """ 263 return sorted(self._cidrs) 264 265 def add(self, addr, flags=0): 266 """ 267 Adds an IP address or subnet or IPRange to this IP set. Has no effect if 268 it is already present. 269 270 Note that where possible the IP address or subnet is merged with other 271 members of the set to form more concise CIDR blocks. 272 273 :param addr: An IP address or subnet in either string or object form, or 274 an IPRange object. 275 276 :param flags: decides which rules are applied to the interpretation 277 of the addr value. See the netaddr.core namespace documentation 278 for supported constant values. 279 280 """ 281 if isinstance(addr, IPRange): 282 new_cidrs = dict.fromkeys( 283 iprange_to_cidrs(addr[0], addr[-1]), True) 284 self._cidrs.update(new_cidrs) 285 self.compact() 286 return 287 if isinstance(addr, IPNetwork): 288 # Networks like 10.1.2.3/8 need to be normalized to 10.0.0.0/8 289 addr = addr.cidr 290 elif isinstance(addr, _int_type): 291 addr = IPNetwork(IPAddress(addr, flags=flags)) 292 else: 293 addr = IPNetwork(addr) 294 295 self._cidrs[addr] = True 296 self._compact_single_network(addr) 297 298 def remove(self, addr, flags=0): 299 """ 300 Removes an IP address or subnet or IPRange from this IP set. Does 301 nothing if it is not already a member. 302 303 Note that this method behaves more like discard() found in regular 304 Python sets because it doesn't raise KeyError exceptions if the 305 IP address or subnet is question does not exist. It doesn't make sense 306 to fully emulate that behaviour here as IP sets contain groups of 307 individual IP addresses as individual set members using IPNetwork 308 objects. 309 310 :param addr: An IP address or subnet, or an IPRange. 311 312 :param flags: decides which rules are applied to the interpretation 313 of the addr value. See the netaddr.core namespace documentation 314 for supported constant values. 315 316 """ 317 if isinstance(addr, IPRange): 318 cidrs = iprange_to_cidrs(addr[0], addr[-1]) 319 for cidr in cidrs: 320 self.remove(cidr) 321 return 322 323 if isinstance(addr, _int_type): 324 addr = IPAddress(addr, flags=flags) 325 else: 326 addr = IPNetwork(addr) 327 328 # This add() is required for address blocks provided that are larger 329 # than blocks found within the set but have overlaps. e.g. :- 330 # 331 # >>> IPSet(['192.0.2.0/24']).remove('192.0.2.0/23') 332 # IPSet([]) 333 # 334 self.add(addr) 335 336 remainder = None 337 matching_cidr = None 338 339 # Search for a matching CIDR and exclude IP from it. 340 for cidr in self._cidrs: 341 if addr in cidr: 342 remainder = cidr_exclude(cidr, addr) 343 matching_cidr = cidr 344 break 345 346 # Replace matching CIDR with remaining CIDR elements. 347 if remainder is not None: 348 del self._cidrs[matching_cidr] 349 for cidr in remainder: 350 self._cidrs[cidr] = True 351 # No call to self.compact() is needed. Removing an IPNetwork cannot 352 # create mergable networks. 353 354 def pop(self): 355 """ 356 Removes and returns an arbitrary IP address or subnet from this IP 357 set. 358 359 :return: An IP address or subnet. 360 """ 361 return self._cidrs.popitem()[0] 362 363 def isdisjoint(self, other): 364 """ 365 :param other: an IP set. 366 367 :return: ``True`` if this IP set has no elements (IP addresses 368 or subnets) in common with other. Intersection *must* be an 369 empty set. 370 """ 371 result = self.intersection(other) 372 return not result 373 374 def copy(self): 375 """:return: a shallow copy of this IP set.""" 376 obj_copy = self.__class__() 377 obj_copy._cidrs.update(self._cidrs) 378 return obj_copy 379 380 def update(self, iterable, flags=0): 381 """ 382 Update the contents of this IP set with the union of itself and 383 other IP set. 384 385 :param iterable: an iterable containing IP addresses, subnets or ranges. 386 387 :param flags: decides which rules are applied to the interpretation 388 of the addr value. See the netaddr.core namespace documentation 389 for supported constant values. 390 391 """ 392 if isinstance(iterable, IPSet): 393 self._cidrs = dict.fromkeys( 394 (ip for ip in cidr_merge(_dict_keys(self._cidrs) 395 + _dict_keys(iterable._cidrs))), True) 396 return 397 elif isinstance(iterable, (IPNetwork, IPRange)): 398 self.add(iterable) 399 return 400 401 if not hasattr(iterable, '__iter__'): 402 raise TypeError('an iterable was expected!') 403 # An iterable containing IP addresses or subnets. 404 mergeable = [] 405 for addr in iterable: 406 if isinstance(addr, _int_type): 407 addr = IPAddress(addr, flags=flags) 408 mergeable.append(addr) 409 410 for cidr in cidr_merge(_dict_keys(self._cidrs) + mergeable): 411 self._cidrs[cidr] = True 412 413 self.compact() 414 415 def clear(self): 416 """Remove all IP addresses and subnets from this IP set.""" 417 self._cidrs = {} 418 419 def __eq__(self, other): 420 """ 421 :param other: an IP set 422 423 :return: ``True`` if this IP set is equivalent to the ``other`` IP set, 424 ``False`` otherwise. 425 """ 426 try: 427 return self._cidrs == other._cidrs 428 except AttributeError: 429 return NotImplemented 430 431 def __ne__(self, other): 432 """ 433 :param other: an IP set 434 435 :return: ``False`` if this IP set is equivalent to the ``other`` IP set, 436 ``True`` otherwise. 437 """ 438 try: 439 return self._cidrs != other._cidrs 440 except AttributeError: 441 return NotImplemented 442 443 def __lt__(self, other): 444 """ 445 :param other: an IP set 446 447 :return: ``True`` if this IP set is less than the ``other`` IP set, 448 ``False`` otherwise. 449 """ 450 if not hasattr(other, '_cidrs'): 451 return NotImplemented 452 453 return self.size < other.size and self.issubset(other) 454 455 def issubset(self, other): 456 """ 457 :param other: an IP set. 458 459 :return: ``True`` if every IP address and subnet in this IP set 460 is found within ``other``. 461 """ 462 for cidr in self._cidrs: 463 if cidr not in other: 464 return False 465 return True 466 467 __le__ = issubset 468 469 def __gt__(self, other): 470 """ 471 :param other: an IP set. 472 473 :return: ``True`` if this IP set is greater than the ``other`` IP set, 474 ``False`` otherwise. 475 """ 476 if not hasattr(other, '_cidrs'): 477 return NotImplemented 478 479 return self.size > other.size and self.issuperset(other) 480 481 def issuperset(self, other): 482 """ 483 :param other: an IP set. 484 485 :return: ``True`` if every IP address and subnet in other IP set 486 is found within this one. 487 """ 488 if not hasattr(other, '_cidrs'): 489 return NotImplemented 490 491 for cidr in other._cidrs: 492 if cidr not in self: 493 return False 494 return True 495 496 __ge__ = issuperset 497 498 def union(self, other): 499 """ 500 :param other: an IP set. 501 502 :return: the union of this IP set and another as a new IP set 503 (combines IP addresses and subnets from both sets). 504 """ 505 ip_set = self.copy() 506 ip_set.update(other) 507 return ip_set 508 509 __or__ = union 510 511 def intersection(self, other): 512 """ 513 :param other: an IP set. 514 515 :return: the intersection of this IP set and another as a new IP set. 516 (IP addresses and subnets common to both sets). 517 """ 518 result_cidrs = {} 519 520 own_nets = sorted(self._cidrs) 521 other_nets = sorted(other._cidrs) 522 own_idx = 0 523 other_idx = 0 524 own_len = len(own_nets) 525 other_len = len(other_nets) 526 while own_idx < own_len and other_idx < other_len: 527 own_cur = own_nets[own_idx] 528 other_cur = other_nets[other_idx] 529 530 if own_cur == other_cur: 531 result_cidrs[own_cur] = True 532 own_idx += 1 533 other_idx += 1 534 elif own_cur in other_cur: 535 result_cidrs[own_cur] = True 536 own_idx += 1 537 elif other_cur in own_cur: 538 result_cidrs[other_cur] = True 539 other_idx += 1 540 else: 541 # own_cur and other_cur have nothing in common 542 if own_cur < other_cur: 543 own_idx += 1 544 else: 545 other_idx += 1 546 547 # We ran out of networks in own_nets or other_nets. Either way, there 548 # can be no further result_cidrs. 549 result = IPSet() 550 result._cidrs = result_cidrs 551 return result 552 553 __and__ = intersection 554 555 def symmetric_difference(self, other): 556 """ 557 :param other: an IP set. 558 559 :return: the symmetric difference of this IP set and another as a new 560 IP set (all IP addresses and subnets that are in exactly one 561 of the sets). 562 """ 563 # In contrast to intersection() and difference(), we cannot construct 564 # the result_cidrs easily. Some cidrs may have to be merged, e.g. for 565 # IPSet(["10.0.0.0/32"]).symmetric_difference(IPSet(["10.0.0.1/32"])). 566 result_ranges = [] 567 568 own_nets = sorted(self._cidrs) 569 other_nets = sorted(other._cidrs) 570 own_idx = 0 571 other_idx = 0 572 own_len = len(own_nets) 573 other_len = len(other_nets) 574 while own_idx < own_len and other_idx < other_len: 575 own_cur = own_nets[own_idx] 576 other_cur = other_nets[other_idx] 577 578 if own_cur == other_cur: 579 own_idx += 1 580 other_idx += 1 581 elif own_cur in other_cur: 582 own_idx = _subtract(other_cur, own_nets, own_idx, result_ranges) 583 other_idx += 1 584 elif other_cur in own_cur: 585 other_idx = _subtract(own_cur, other_nets, other_idx, result_ranges) 586 own_idx += 1 587 else: 588 # own_cur and other_cur have nothing in common 589 if own_cur < other_cur: 590 result_ranges.append((own_cur._module.version, 591 own_cur.first, own_cur.last)) 592 own_idx += 1 593 else: 594 result_ranges.append((other_cur._module.version, 595 other_cur.first, other_cur.last)) 596 other_idx += 1 597 598 # If the above loop terminated because it processed all cidrs of 599 # "other", then any remaining cidrs in self must be part of the result. 600 while own_idx < own_len: 601 own_cur = own_nets[own_idx] 602 result_ranges.append((own_cur._module.version, 603 own_cur.first, own_cur.last)) 604 own_idx += 1 605 606 # If the above loop terminated because it processed all cidrs of 607 # self, then any remaining cidrs in "other" must be part of the result. 608 while other_idx < other_len: 609 other_cur = other_nets[other_idx] 610 result_ranges.append((other_cur._module.version, 611 other_cur.first, other_cur.last)) 612 other_idx += 1 613 614 result = IPSet() 615 for start, stop in _iter_merged_ranges(result_ranges): 616 cidrs = iprange_to_cidrs(start, stop) 617 for cidr in cidrs: 618 result._cidrs[cidr] = True 619 return result 620 621 __xor__ = symmetric_difference 622 623 def difference(self, other): 624 """ 625 :param other: an IP set. 626 627 :return: the difference between this IP set and another as a new IP 628 set (all IP addresses and subnets that are in this IP set but 629 not found in the other.) 630 """ 631 result_ranges = [] 632 result_cidrs = {} 633 634 own_nets = sorted(self._cidrs) 635 other_nets = sorted(other._cidrs) 636 own_idx = 0 637 other_idx = 0 638 own_len = len(own_nets) 639 other_len = len(other_nets) 640 while own_idx < own_len and other_idx < other_len: 641 own_cur = own_nets[own_idx] 642 other_cur = other_nets[other_idx] 643 644 if own_cur == other_cur: 645 own_idx += 1 646 other_idx += 1 647 elif own_cur in other_cur: 648 own_idx += 1 649 elif other_cur in own_cur: 650 other_idx = _subtract(own_cur, other_nets, other_idx, 651 result_ranges) 652 own_idx += 1 653 else: 654 # own_cur and other_cur have nothing in common 655 if own_cur < other_cur: 656 result_cidrs[own_cur] = True 657 own_idx += 1 658 else: 659 other_idx += 1 660 661 # If the above loop terminated because it processed all cidrs of 662 # "other", then any remaining cidrs in self must be part of the result. 663 while own_idx < own_len: 664 result_cidrs[own_nets[own_idx]] = True 665 own_idx += 1 666 667 for start, stop in _iter_merged_ranges(result_ranges): 668 for cidr in iprange_to_cidrs(start, stop): 669 result_cidrs[cidr] = True 670 671 result = IPSet() 672 result._cidrs = result_cidrs 673 return result 674 675 __sub__ = difference 676 677 def __len__(self): 678 """ 679 :return: the cardinality of this IP set (i.e. sum of individual IP \ 680 addresses). Raises ``IndexError`` if size > maxint (a Python \ 681 limitation). Use the .size property for subnets of any size. 682 """ 683 size = self.size 684 if size > _sys_maxint: 685 raise IndexError( 686 "range contains more than %d (sys.maxint) IP addresses!" 687 "Use the .size property instead." % _sys_maxint) 688 return size 689 690 @property 691 def size(self): 692 """ 693 The cardinality of this IP set (based on the number of individual IP 694 addresses including those implicitly defined in subnets). 695 """ 696 return sum([cidr.size for cidr in self._cidrs]) 697 698 def __repr__(self): 699 """:return: Python statement to create an equivalent object""" 700 return 'IPSet(%r)' % [str(c) for c in sorted(self._cidrs)] 701 702 __str__ = __repr__ 703 704 def iscontiguous(self): 705 """ 706 Returns True if the members of the set form a contiguous IP 707 address range (with no gaps), False otherwise. 708 709 :return: ``True`` if the ``IPSet`` object is contiguous. 710 """ 711 cidrs = self.iter_cidrs() 712 if len(cidrs) > 1: 713 previous = cidrs[0][0] 714 for cidr in cidrs: 715 if cidr[0] != previous: 716 return False 717 previous = cidr[-1] + 1 718 return True 719 720 def iprange(self): 721 """ 722 Generates an IPRange for this IPSet, if all its members 723 form a single contiguous sequence. 724 725 Raises ``ValueError`` if the set is not contiguous. 726 727 :return: An ``IPRange`` for all IPs in the IPSet. 728 """ 729 if self.iscontiguous(): 730 cidrs = self.iter_cidrs() 731 if not cidrs: 732 return None 733 return IPRange(cidrs[0][0], cidrs[-1][-1]) 734 else: 735 raise ValueError("IPSet is not contiguous") 736 737 def iter_ipranges(self): 738 """Generate the merged IPRanges for this IPSet. 739 740 In contrast to self.iprange(), this will work even when the IPSet is 741 not contiguous. Adjacent IPRanges will be merged together, so you 742 get the minimal number of IPRanges. 743 """ 744 sorted_ranges = [(cidr._module.version, cidr.first, cidr.last) for 745 cidr in self.iter_cidrs()] 746 747 for start, stop in _iter_merged_ranges(sorted_ranges): 748 yield IPRange(start, stop) 749