1# -*- coding: utf-8 -*-
2# This code is part of Ansible, but is an independent component.
3# This particular file snippet, and this file snippet only, is BSD licensed.
4# Modules you write using this snippet, which is embedded dynamically by Ansible
5# still belong to the author of the module, and may assign their own license
6# to the complete work.
7#
8# Copyright: (c) 2018, Johannes Brunswicker <johannes.brunswicker@gmail.com>
9#
10# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
11
12from __future__ import (absolute_import, division, print_function)
13__metaclass__ = type
14
15import json
16
17from ansible.module_utils.common.text.converters import to_native
18from ansible.module_utils.basic import AnsibleModule
19from ansible.module_utils.urls import fetch_url
20
21
22class UTMModuleConfigurationError(Exception):
23
24    def __init__(self, msg, **args):
25        super(UTMModuleConfigurationError, self).__init__(self, msg)
26        self.msg = msg
27        self.module_fail_args = args
28
29    def do_fail(self, module):
30        module.fail_json(msg=self.msg, other=self.module_fail_args)
31
32
33class UTMModule(AnsibleModule):
34    """
35    This is a helper class to construct any UTM Module. This will automatically add the utm host, port, token,
36    protocol, validate_certs and state field to the module. If you want to implement your own sophos utm module
37    just initialize this UTMModule class and define the Payload fields that are needed for your module.
38    See the other modules like utm_aaa_group for example.
39    """
40
41    def __init__(self, argument_spec, bypass_checks=False, no_log=False,
42                 mutually_exclusive=None, required_together=None, required_one_of=None, add_file_common_args=False,
43                 supports_check_mode=False, required_if=None):
44        default_specs = dict(
45            headers=dict(type='dict', required=False, default={}),
46            utm_host=dict(type='str', required=True),
47            utm_port=dict(type='int', default=4444),
48            utm_token=dict(type='str', required=True, no_log=True),
49            utm_protocol=dict(type='str', required=False, default="https", choices=["https", "http"]),
50            validate_certs=dict(type='bool', required=False, default=True),
51            state=dict(default='present', choices=['present', 'absent'])
52        )
53        super(UTMModule, self).__init__(self._merge_specs(default_specs, argument_spec), bypass_checks, no_log,
54                                        mutually_exclusive, required_together, required_one_of,
55                                        add_file_common_args, supports_check_mode, required_if)
56
57    def _merge_specs(self, default_specs, custom_specs):
58        result = default_specs.copy()
59        result.update(custom_specs)
60        return result
61
62
63class UTM:
64
65    def __init__(self, module, endpoint, change_relevant_keys, info_only=False):
66        """
67        Initialize UTM Class
68        :param module: The Ansible module
69        :param endpoint: The corresponding endpoint to the module
70        :param change_relevant_keys: The keys of the object to check for changes
71        :param info_only: When implementing an info module, set this to true. Will allow access to the info method only
72        """
73        self.info_only = info_only
74        self.module = module
75        self.request_url = module.params.get('utm_protocol') + "://" + module.params.get('utm_host') + ":" + to_native(
76            module.params.get('utm_port')) + "/api/objects/" + endpoint + "/"
77
78        """
79        The change_relevant_keys will be checked for changes to determine whether the object needs to be updated
80        """
81        self.change_relevant_keys = change_relevant_keys
82        self.module.params['url_username'] = 'token'
83        self.module.params['url_password'] = module.params.get('utm_token')
84        if all(elem in self.change_relevant_keys for elem in module.params.keys()):
85            raise UTMModuleConfigurationError(
86                "The keys " + to_native(
87                    self.change_relevant_keys) + " to check are not in the modules keys:\n" + to_native(
88                    list(module.params.keys())))
89
90    def execute(self):
91        try:
92            if not self.info_only:
93                if self.module.params.get('state') == 'present':
94                    self._add()
95                elif self.module.params.get('state') == 'absent':
96                    self._remove()
97            else:
98                self._info()
99        except Exception as e:
100            self.module.fail_json(msg=to_native(e))
101
102    def _info(self):
103        """
104        returns the info for an object in utm
105        """
106        info, result = self._lookup_entry(self.module, self.request_url)
107        if info["status"] >= 400:
108            self.module.fail_json(result=json.loads(info))
109        else:
110            if result is None:
111                self.module.exit_json(changed=False)
112            else:
113                self.module.exit_json(result=result, changed=False)
114
115    def _add(self):
116        """
117        adds or updates a host object on utm
118        """
119
120        combined_headers = self._combine_headers()
121
122        is_changed = False
123        info, result = self._lookup_entry(self.module, self.request_url)
124        if info["status"] >= 400:
125            self.module.fail_json(result=json.loads(info))
126        else:
127            data_as_json_string = self.module.jsonify(self.module.params)
128            if result is None:
129                response, info = fetch_url(self.module, self.request_url, method="POST",
130                                           headers=combined_headers,
131                                           data=data_as_json_string)
132                if info["status"] >= 400:
133                    self.module.fail_json(msg=json.loads(info["body"]))
134                is_changed = True
135                result = self._clean_result(json.loads(response.read()))
136            else:
137                if self._is_object_changed(self.change_relevant_keys, self.module, result):
138                    response, info = fetch_url(self.module, self.request_url + result['_ref'], method="PUT",
139                                               headers=combined_headers,
140                                               data=data_as_json_string)
141                    if info['status'] >= 400:
142                        self.module.fail_json(msg=json.loads(info["body"]))
143                    is_changed = True
144                    result = self._clean_result(json.loads(response.read()))
145            self.module.exit_json(result=result, changed=is_changed)
146
147    def _combine_headers(self):
148        """
149        This will combine a header default with headers that come from the module declaration
150        :return: A combined headers dict
151        """
152        default_headers = {"Accept": "application/json", "Content-type": "application/json"}
153        if self.module.params.get('headers') is not None:
154            result = default_headers.copy()
155            result.update(self.module.params.get('headers'))
156        else:
157            result = default_headers
158        return result
159
160    def _remove(self):
161        """
162        removes an object from utm
163        """
164        is_changed = False
165        info, result = self._lookup_entry(self.module, self.request_url)
166        if result is not None:
167            response, info = fetch_url(self.module, self.request_url + result['_ref'], method="DELETE",
168                                       headers={"Accept": "application/json", "X-Restd-Err-Ack": "all"},
169                                       data=self.module.jsonify(self.module.params))
170            if info["status"] >= 400:
171                self.module.fail_json(msg=json.loads(info["body"]))
172            else:
173                is_changed = True
174        self.module.exit_json(changed=is_changed)
175
176    def _lookup_entry(self, module, request_url):
177        """
178        Lookup for existing entry
179        :param module:
180        :param request_url:
181        :return:
182        """
183        response, info = fetch_url(module, request_url, method="GET", headers={"Accept": "application/json"})
184        result = None
185        if response is not None:
186            results = json.loads(response.read())
187            result = next(iter(filter(lambda d: d['name'] == module.params.get('name'), results)), None)
188        return info, result
189
190    def _clean_result(self, result):
191        """
192        Will clean the result from irrelevant fields
193        :param result: The result from the query
194        :return: The modified result
195        """
196        del result['utm_host']
197        del result['utm_port']
198        del result['utm_token']
199        del result['utm_protocol']
200        del result['validate_certs']
201        del result['url_username']
202        del result['url_password']
203        del result['state']
204        return result
205
206    def _is_object_changed(self, keys, module, result):
207        """
208        Check if my object is changed
209        :param keys: The keys that will determine if an object is changed
210        :param module: The module
211        :param result: The result from the query
212        :return:
213        """
214        for key in keys:
215            if module.params.get(key) != result[key]:
216                return True
217        return False
218