1"""
2Common code used in Docker integration tests
3"""
4
5import functools
6import random
7import string
8
9from salt._compat import ipaddress
10from salt.exceptions import CommandExecutionError
11
12
13def random_name(prefix=""):
14    ret = prefix
15    for _ in range(8):
16        ret += random.choice(string.ascii_lowercase)
17    return ret
18
19
20class Network:
21    def __init__(self, name, **kwargs):
22        self.kwargs = kwargs
23        self.name = name
24        try:
25            self.net = ipaddress.ip_network(self.kwargs["subnet"])
26            self._rand_indexes = random.sample(
27                range(2, self.net.num_addresses - 1), self.net.num_addresses - 3
28            )
29            self.ip_arg = "ipv{}_address".format(self.net.version)
30        except KeyError:
31            # No explicit subnet passed
32            self.net = self.ip_arg = None
33
34    def __getitem__(self, index):
35        try:
36            return self.net[self._rand_indexes[index]].compressed
37        except (TypeError, AttributeError):
38            raise ValueError(
39                "Indexing not supported for networks without a custom subnet"
40            )
41
42    def arg_map(self, arg_name):
43        return {
44            "ipv4_address": "IPv4Address",
45            "ipv6_address": "IPv6Address",
46            "links": "Links",
47            "aliases": "Aliases",
48        }[arg_name]
49
50    @property
51    def subnet(self):
52        try:
53            return self.net.compressed
54        except AttributeError:
55            return None
56
57    @property
58    def gateway(self):
59        try:
60            return self.kwargs["gateway"]
61        except KeyError:
62            try:
63                return self.net[1].compressed
64            except (AttributeError, IndexError):
65                return None
66
67
68class with_network:
69    """
70    Generate a network for the test. Information about the network will be
71    passed to the wrapped function.
72    """
73
74    def __init__(self, **kwargs):
75        self.create = kwargs.pop("create", False)
76        self.network = Network(random_name(prefix="salt_net_"), **kwargs)
77        if self.network.net is not None:
78            if "enable_ipv6" not in kwargs:
79                kwargs["enable_ipv6"] = self.network.net.version == 6
80        self.kwargs = kwargs
81
82    def __call__(self, func):
83        self.func = func
84        return functools.wraps(func)(
85            # pylint: disable=W0108
86            lambda testcase, *args, **kwargs: self.wrap(testcase, *args, **kwargs)
87            # pylint: enable=W0108
88        )
89
90    def wrap(self, testcase, *args, **kwargs):
91        if self.create:
92            testcase.run_function(
93                "docker.create_network", [self.network.name], **self.kwargs
94            )
95        try:
96            return self.func(testcase, self.network, *args, **kwargs)
97        finally:
98            try:
99                testcase.run_function(
100                    "docker.disconnect_all_containers_from_network", [self.network.name]
101                )
102            except CommandExecutionError as exc:
103                if "404" not in exc.__str__():
104                    raise
105            else:
106                testcase.run_function("docker.remove_network", [self.network.name])
107