1# Copyright (c) 2016 Red Hat Inc
2# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
3
4# General networking tools that may be used by all modules
5
6from __future__ import (absolute_import, division, print_function)
7__metaclass__ = type
8
9import re
10from struct import pack
11from socket import inet_ntoa
12
13from ansible.module_utils.six.moves import zip
14
15
16VALID_MASKS = [2**8 - 2**i for i in range(0, 9)]
17
18
19def is_netmask(val):
20    parts = str(val).split('.')
21    if not len(parts) == 4:
22        return False
23    for part in parts:
24        try:
25            if int(part) not in VALID_MASKS:
26                raise ValueError
27        except ValueError:
28            return False
29    return True
30
31
32def is_masklen(val):
33    try:
34        return 0 <= int(val) <= 32
35    except ValueError:
36        return False
37
38
39def to_netmask(val):
40    """ converts a masklen to a netmask """
41    if not is_masklen(val):
42        raise ValueError('invalid value for masklen')
43
44    bits = 0
45    for i in range(32 - int(val), 32):
46        bits |= (1 << i)
47
48    return inet_ntoa(pack('>I', bits))
49
50
51def to_masklen(val):
52    """ converts a netmask to a masklen """
53    if not is_netmask(val):
54        raise ValueError('invalid value for netmask: %s' % val)
55
56    bits = list()
57    for x in val.split('.'):
58        octet = bin(int(x)).count('1')
59        bits.append(octet)
60
61    return sum(bits)
62
63
64def to_subnet(addr, mask, dotted_notation=False):
65    """ coverts an addr / mask pair to a subnet in cidr notation """
66    try:
67        if not is_masklen(mask):
68            raise ValueError
69        cidr = int(mask)
70        mask = to_netmask(mask)
71    except ValueError:
72        cidr = to_masklen(mask)
73
74    addr = addr.split('.')
75    mask = mask.split('.')
76
77    network = list()
78    for s_addr, s_mask in zip(addr, mask):
79        network.append(str(int(s_addr) & int(s_mask)))
80
81    if dotted_notation:
82        return '%s %s' % ('.'.join(network), to_netmask(cidr))
83    return '%s/%s' % ('.'.join(network), cidr)
84
85
86def to_ipv6_subnet(addr):
87    """ IPv6 addresses are eight groupings. The first four groupings (64 bits) comprise the subnet address. """
88
89    # https://tools.ietf.org/rfc/rfc2374.txt
90
91    # Split by :: to identify omitted zeros
92    ipv6_prefix = addr.split('::')[0]
93
94    # Get the first four groups, or as many as are found + ::
95    found_groups = []
96    for group in ipv6_prefix.split(':'):
97        found_groups.append(group)
98        if len(found_groups) == 4:
99            break
100    if len(found_groups) < 4:
101        found_groups.append('::')
102
103    # Concatenate network address parts
104    network_addr = ''
105    for group in found_groups:
106        if group != '::':
107            network_addr += str(group)
108        network_addr += str(':')
109
110    # Ensure network address ends with ::
111    if not network_addr.endswith('::'):
112        network_addr += str(':')
113    return network_addr
114
115
116def to_ipv6_network(addr):
117    """ IPv6 addresses are eight groupings. The first three groupings (48 bits) comprise the network address. """
118
119    # Split by :: to identify omitted zeros
120    ipv6_prefix = addr.split('::')[0]
121
122    # Get the first three groups, or as many as are found + ::
123    found_groups = []
124    for group in ipv6_prefix.split(':'):
125        found_groups.append(group)
126        if len(found_groups) == 3:
127            break
128    if len(found_groups) < 3:
129        found_groups.append('::')
130
131    # Concatenate network address parts
132    network_addr = ''
133    for group in found_groups:
134        if group != '::':
135            network_addr += str(group)
136        network_addr += str(':')
137
138    # Ensure network address ends with ::
139    if not network_addr.endswith('::'):
140        network_addr += str(':')
141    return network_addr
142
143
144def to_bits(val):
145    """ converts a netmask to bits """
146    bits = ''
147    for octet in val.split('.'):
148        bits += bin(int(octet))[2:].zfill(8)
149    return str
150
151
152def is_mac(mac_address):
153    """
154    Validate MAC address for given string
155    Args:
156        mac_address: string to validate as MAC address
157
158    Returns: (Boolean) True if string is valid MAC address, otherwise False
159    """
160    mac_addr_regex = re.compile('[0-9a-f]{2}([-:])[0-9a-f]{2}(\\1[0-9a-f]{2}){4}$')
161    return bool(mac_addr_regex.match(mac_address.lower()))
162