1# -*- coding: utf-8 -*-
2#
3# (c) 2017, Ansible by Red Hat, inc
4#
5# This file is part of Ansible by Red Hat
6#
7# Ansible is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# Ansible is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
19#
20import json
21
22from ansible.module_utils._text import to_text
23from ansible.module_utils.connection import Connection, ConnectionError
24from ansible.module_utils.network.common.utils import to_list, EntityCollection
25
26_DEVICE_CONFIGS = {}
27_CONNECTION = None
28
29_COMMAND_SPEC = {
30    'command': dict(key=True),
31    'prompt': dict(),
32    'answer': dict()
33}
34
35
36def get_connection(module):
37    global _CONNECTION
38    if _CONNECTION:
39        return _CONNECTION
40    _CONNECTION = Connection(module._socket_path)
41    return _CONNECTION
42
43
44def to_commands(module, commands):
45    if not isinstance(commands, list):
46        raise AssertionError('argument must be of type <list>')
47
48    transform = EntityCollection(module, _COMMAND_SPEC)
49    commands = transform(commands)
50    return commands
51
52
53def run_commands(module, commands, check_rc=True):
54    connection = get_connection(module)
55
56    commands = to_commands(module, to_list(commands))
57
58    responses = list()
59
60    for cmd in commands:
61        out = connection.get(**cmd)
62        responses.append(to_text(out, errors='surrogate_then_replace'))
63
64    return responses
65
66
67def get_config(module, source='running'):
68    conn = get_connection(module)
69    out = conn.get_config(source)
70    cfg = to_text(out, errors='surrogate_then_replace').strip()
71    return cfg
72
73
74def load_config(module, config):
75    try:
76        conn = get_connection(module)
77        conn.edit_config(config)
78    except ConnectionError as exc:
79        module.fail_json(msg=to_text(exc))
80
81
82def _parse_json_output(out):
83    out_list = out.split('\n')
84    first_index = 0
85    opening_char = None
86    lines_count = len(out_list)
87    while first_index < lines_count:
88        first_line = out_list[first_index].strip()
89        if not first_line or first_line[0] not in ("[", "{"):
90            first_index += 1
91            continue
92        opening_char = first_line[0]
93        break
94    if not opening_char:
95        return "null"
96    closing_char = ']' if opening_char == '[' else '}'
97    last_index = lines_count - 1
98    found = False
99    while last_index > first_index:
100        last_line = out_list[last_index].strip()
101        if not last_line or last_line[0] != closing_char:
102            last_index -= 1
103            continue
104        found = True
105        break
106    if not found:
107        return opening_char + closing_char
108    return "".join(out_list[first_index:last_index + 1])
109
110
111def show_cmd(module, cmd, json_fmt=True, fail_on_error=True):
112    if json_fmt:
113        cmd += " | json-print"
114    conn = get_connection(module)
115    command_obj = to_commands(module, to_list(cmd))[0]
116    try:
117        out = conn.get(**command_obj)
118    except ConnectionError:
119        if fail_on_error:
120            raise
121        return None
122    if json_fmt:
123        out = _parse_json_output(out)
124        try:
125            cfg = json.loads(out)
126        except ValueError:
127            module.fail_json(
128                msg="got invalid json",
129                stderr=to_text(out, errors='surrogate_then_replace'))
130    else:
131        cfg = to_text(out, errors='surrogate_then_replace').strip()
132    return cfg
133
134
135def get_interfaces_config(module, interface_type, flags=None, json_fmt=True):
136    cmd = "show interfaces %s" % interface_type
137    if flags:
138        cmd += " %s" % flags
139    return show_cmd(module, cmd, json_fmt)
140
141
142def get_bgp_summary(module):
143    cmd = "show running-config protocol bgp"
144    return show_cmd(module, cmd, json_fmt=False, fail_on_error=False)
145
146
147def get_capabilities(module):
148    """Returns platform info of the remove device
149    """
150    if hasattr(module, '_capabilities'):
151        return module._capabilities
152
153    connection = get_connection(module)
154    try:
155        capabilities = connection.get_capabilities()
156    except ConnectionError as exc:
157        module.fail_json(msg=to_text(exc, errors='surrogate_then_replace'))
158
159    module._capabilities = json.loads(capabilities)
160    return module._capabilities
161
162
163class BaseOnyxModule(object):
164    ONYX_API_VERSION = "3.6.6000"
165
166    def __init__(self):
167        self._module = None
168        self._commands = list()
169        self._current_config = None
170        self._required_config = None
171        self._os_version = None
172
173    def init_module(self):
174        pass
175
176    def load_current_config(self):
177        pass
178
179    def get_required_config(self):
180        pass
181
182    def _get_os_version(self):
183        capabilities = get_capabilities(self._module)
184        device_info = capabilities['device_info']
185        return device_info['network_os_version']
186
187    # pylint: disable=unused-argument
188    def check_declarative_intent_params(self, result):
189        return None
190
191    def _validate_key(self, param, key):
192        validator = getattr(self, 'validate_%s' % key)
193        if callable(validator):
194            validator(param.get(key))
195
196    def validate_param_values(self, obj, param=None):
197        if param is None:
198            param = self._module.params
199        for key in obj:
200            # validate the param value (if validator func exists)
201            try:
202                self._validate_key(param, key)
203            except AttributeError:
204                pass
205
206    @classmethod
207    def get_config_attr(cls, item, arg):
208        return item.get(arg)
209
210    @classmethod
211    def get_mtu(cls, item):
212        mtu = cls.get_config_attr(item, "MTU")
213        mtu_parts = mtu.split()
214        try:
215            return int(mtu_parts[0])
216        except ValueError:
217            return None
218
219    def _validate_range(self, attr_name, min_val, max_val, value):
220        if value is None:
221            return True
222        if not min_val <= int(value) <= max_val:
223            msg = '%s must be between %s and %s' % (
224                attr_name, min_val, max_val)
225            self._module.fail_json(msg=msg)
226
227    def validate_mtu(self, value):
228        self._validate_range('mtu', 1500, 9612, value)
229
230    def generate_commands(self):
231        pass
232
233    def run(self):
234        self.init_module()
235
236        result = {'changed': False}
237
238        self.get_required_config()
239        self.load_current_config()
240
241        self.generate_commands()
242        result['commands'] = self._commands
243
244        if self._commands:
245            if not self._module.check_mode:
246                load_config(self._module, self._commands)
247            result['changed'] = True
248
249        failed_conditions = self.check_declarative_intent_params(result)
250
251        if failed_conditions:
252            msg = 'One or more conditional statements have not been satisfied'
253            self._module.fail_json(msg=msg,
254                                   failed_conditions=failed_conditions)
255
256        self._module.exit_json(**result)
257
258    @classmethod
259    def main(cls):
260        app = cls()
261        app.run()
262