1#
2# (c) 2018 Red Hat, Inc.
3#
4# This file is part of Ansible
5#
6# Ansible is free software: you can redistribute it and/or modify
7# it under the terms of the GNU General Public License as published by
8# the Free Software Foundation, either version 3 of the License, or
9# (at your option) any later version.
10#
11# Ansible is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
18#
19from __future__ import absolute_import, division, print_function
20
21__metaclass__ = type
22import json
23
24from copy import deepcopy
25from contextlib import contextmanager
26
27try:
28    from lxml.etree import fromstring, tostring
29except ImportError:
30    from xml.etree.ElementTree import fromstring, tostring
31
32from ansible.module_utils._text import to_text, to_bytes
33from ansible.module_utils.connection import Connection, ConnectionError
34from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.netconf import (
35    NetconfConnection,
36)
37
38
39IGNORE_XML_ATTRIBUTE = ()
40
41
42def get_connection(module):
43    if hasattr(module, "_netconf_connection"):
44        return module._netconf_connection
45
46    capabilities = get_capabilities(module)
47    network_api = capabilities.get("network_api")
48    if network_api == "netconf":
49        module._netconf_connection = NetconfConnection(module._socket_path)
50    else:
51        module.fail_json(msg="Invalid connection type %s" % network_api)
52
53    return module._netconf_connection
54
55
56def get_capabilities(module):
57    if hasattr(module, "_netconf_capabilities"):
58        return module._netconf_capabilities
59
60    capabilities = Connection(module._socket_path).get_capabilities()
61    module._netconf_capabilities = json.loads(capabilities)
62    return module._netconf_capabilities
63
64
65def lock_configuration(module, target=None):
66    conn = get_connection(module)
67    return conn.lock(target=target)
68
69
70def unlock_configuration(module, target=None):
71    conn = get_connection(module)
72    return conn.unlock(target=target)
73
74
75@contextmanager
76def locked_config(module, target=None):
77    try:
78        lock_configuration(module, target=target)
79        yield
80    finally:
81        unlock_configuration(module, target=target)
82
83
84def get_config(module, source, filter=None, lock=False):
85    conn = get_connection(module)
86    try:
87        locked = False
88        if lock:
89            conn.lock(target=source)
90            locked = True
91        response = conn.get_config(source=source, filter=filter)
92
93    except ConnectionError as e:
94        module.fail_json(
95            msg=to_text(e, errors="surrogate_then_replace").strip()
96        )
97
98    finally:
99        if locked:
100            conn.unlock(target=source)
101
102    return response
103
104
105def get(module, filter, lock=False):
106    conn = get_connection(module)
107    try:
108        locked = False
109        if lock:
110            conn.lock(target="running")
111            locked = True
112
113        response = conn.get(filter=filter)
114
115    except ConnectionError as e:
116        module.fail_json(
117            msg=to_text(e, errors="surrogate_then_replace").strip()
118        )
119
120    finally:
121        if locked:
122            conn.unlock(target="running")
123
124    return response
125
126
127def dispatch(module, request):
128    conn = get_connection(module)
129    try:
130        response = conn.dispatch(request)
131    except ConnectionError as e:
132        module.fail_json(
133            msg=to_text(e, errors="surrogate_then_replace").strip()
134        )
135
136    return response
137
138
139def sanitize_xml(data):
140    tree = fromstring(
141        to_bytes(deepcopy(data), errors="surrogate_then_replace")
142    )
143    for element in tree.iter():
144        # remove attributes
145        attribute = element.attrib
146        if attribute:
147            for key in list(attribute):
148                if key not in IGNORE_XML_ATTRIBUTE:
149                    attribute.pop(key)
150    return to_text(tostring(tree), errors="surrogate_then_replace").strip()
151