1#!/usr/bin/env python
2# Copyright: (c) 2017, Ansible Project
3# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
4from __future__ import (absolute_import, division, print_function)
5
6__metaclass__ = type
7__requires__ = ['ansible']
8
9
10import fcntl
11import hashlib
12import os
13import signal
14import socket
15import sys
16import time
17import traceback
18import errno
19import json
20
21from contextlib import contextmanager
22
23from ansible import constants as C
24from ansible.module_utils._text import to_bytes, to_text
25from ansible.module_utils.six import PY3
26from ansible.module_utils.six.moves import cPickle, StringIO
27from ansible.module_utils.connection import Connection, ConnectionError, send_data, recv_data
28from ansible.module_utils.service import fork_process
29from ansible.parsing.ajson import AnsibleJSONEncoder, AnsibleJSONDecoder
30from ansible.playbook.play_context import PlayContext
31from ansible.plugins.loader import connection_loader
32from ansible.utils.path import unfrackpath, makedirs_safe
33from ansible.utils.display import Display
34from ansible.utils.jsonrpc import JsonRpcServer
35
36
37def read_stream(byte_stream):
38    size = int(byte_stream.readline().strip())
39
40    data = byte_stream.read(size)
41    if len(data) < size:
42        raise Exception("EOF found before data was complete")
43
44    data_hash = to_text(byte_stream.readline().strip())
45    if data_hash != hashlib.sha1(data).hexdigest():
46        raise Exception("Read {0} bytes, but data did not match checksum".format(size))
47
48    # restore escaped loose \r characters
49    data = data.replace(br'\r', b'\r')
50
51    return data
52
53
54@contextmanager
55def file_lock(lock_path):
56    """
57    Uses contextmanager to create and release a file lock based on the
58    given path. This allows us to create locks using `with file_lock()`
59    to prevent deadlocks related to failure to unlock properly.
60    """
61
62    lock_fd = os.open(lock_path, os.O_RDWR | os.O_CREAT, 0o600)
63    fcntl.lockf(lock_fd, fcntl.LOCK_EX)
64    yield
65    fcntl.lockf(lock_fd, fcntl.LOCK_UN)
66    os.close(lock_fd)
67
68
69class ConnectionProcess(object):
70    '''
71    The connection process wraps around a Connection object that manages
72    the connection to a remote device that persists over the playbook
73    '''
74    def __init__(self, fd, play_context, socket_path, original_path, task_uuid=None, ansible_playbook_pid=None):
75        self.play_context = play_context
76        self.socket_path = socket_path
77        self.original_path = original_path
78        self._task_uuid = task_uuid
79
80        self.fd = fd
81        self.exception = None
82
83        self.srv = JsonRpcServer()
84        self.sock = None
85
86        self.connection = None
87        self._ansible_playbook_pid = ansible_playbook_pid
88
89    def start(self, variables):
90        try:
91            messages = list()
92            result = {}
93
94            messages.append(('vvvv', 'control socket path is %s' % self.socket_path))
95
96            # If this is a relative path (~ gets expanded later) then plug the
97            # key's path on to the directory we originally came from, so we can
98            # find it now that our cwd is /
99            if self.play_context.private_key_file and self.play_context.private_key_file[0] not in '~/':
100                self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file)
101            self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null',
102                                                    task_uuid=self._task_uuid, ansible_playbook_pid=self._ansible_playbook_pid)
103            try:
104                self.connection.set_options(var_options=variables)
105            except ConnectionError as exc:
106                messages.append(('debug', to_text(exc)))
107                raise ConnectionError('Unable to decode JSON from response set_options. See the debug log for more information.')
108
109            self.connection._socket_path = self.socket_path
110            self.srv.register(self.connection)
111            messages.extend([('vvvv', msg) for msg in sys.stdout.getvalue().splitlines()])
112
113            self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
114            self.sock.bind(self.socket_path)
115            self.sock.listen(1)
116            messages.append(('vvvv', 'local domain socket listeners started successfully'))
117        except Exception as exc:
118            messages.extend(self.connection.pop_messages())
119            result['error'] = to_text(exc)
120            result['exception'] = traceback.format_exc()
121        finally:
122            result['messages'] = messages
123            self.fd.write(json.dumps(result, cls=AnsibleJSONEncoder))
124            self.fd.close()
125
126    def run(self):
127        try:
128            while not self.connection._conn_closed:
129                signal.signal(signal.SIGALRM, self.connect_timeout)
130                signal.signal(signal.SIGTERM, self.handler)
131                signal.alarm(self.connection.get_option('persistent_connect_timeout'))
132
133                self.exception = None
134                (s, addr) = self.sock.accept()
135                signal.alarm(0)
136                signal.signal(signal.SIGALRM, self.command_timeout)
137                while True:
138                    data = recv_data(s)
139                    if not data:
140                        break
141                    log_messages = self.connection.get_option('persistent_log_messages')
142
143                    if log_messages:
144                        display.display("jsonrpc request: %s" % data, log_only=True)
145
146                    request = json.loads(to_text(data, errors='surrogate_or_strict'))
147                    if request.get('method') == "exec_command" and not self.connection.connected:
148                        self.connection._connect()
149
150                    signal.alarm(self.connection.get_option('persistent_command_timeout'))
151
152                    resp = self.srv.handle_request(data)
153                    signal.alarm(0)
154
155                    if log_messages:
156                        display.display("jsonrpc response: %s" % resp, log_only=True)
157
158                    send_data(s, to_bytes(resp))
159
160                s.close()
161
162        except Exception as e:
163            # socket.accept() will raise EINTR if the socket.close() is called
164            if hasattr(e, 'errno'):
165                if e.errno != errno.EINTR:
166                    self.exception = traceback.format_exc()
167            else:
168                self.exception = traceback.format_exc()
169
170        finally:
171            # allow time for any exception msg send over socket to receive at other end before shutting down
172            time.sleep(0.1)
173
174            # when done, close the connection properly and cleanup the socket file so it can be recreated
175            self.shutdown()
176
177    def connect_timeout(self, signum, frame):
178        msg = 'persistent connection idle timeout triggered, timeout value is %s secs.\nSee the timeout setting options in the Network Debug and ' \
179              'Troubleshooting Guide.' % self.connection.get_option('persistent_connect_timeout')
180        display.display(msg, log_only=True)
181        raise Exception(msg)
182
183    def command_timeout(self, signum, frame):
184        msg = 'command timeout triggered, timeout value is %s secs.\nSee the timeout setting options in the Network Debug and Troubleshooting Guide.'\
185              % self.connection.get_option('persistent_command_timeout')
186        display.display(msg, log_only=True)
187        raise Exception(msg)
188
189    def handler(self, signum, frame):
190        msg = 'signal handler called with signal %s.' % signum
191        display.display(msg, log_only=True)
192        raise Exception(msg)
193
194    def shutdown(self):
195        """ Shuts down the local domain socket
196        """
197        lock_path = unfrackpath("%s/.ansible_pc_lock_%s" % os.path.split(self.socket_path))
198        if os.path.exists(self.socket_path):
199            try:
200                if self.sock:
201                    self.sock.close()
202                if self.connection:
203                    self.connection.close()
204                    if self.connection.get_option("persistent_log_messages"):
205                        for _level, message in self.connection.pop_messages():
206                            display.display(message, log_only=True)
207            except Exception:
208                pass
209            finally:
210                if os.path.exists(self.socket_path):
211                    os.remove(self.socket_path)
212                    setattr(self.connection, '_socket_path', None)
213                    setattr(self.connection, '_connected', False)
214
215        if os.path.exists(lock_path):
216            os.remove(lock_path)
217
218        display.display('shutdown complete', log_only=True)
219
220
221def main():
222    """ Called to initiate the connect to the remote device
223    """
224    rc = 0
225    result = {}
226    messages = list()
227    socket_path = None
228
229    # Need stdin as a byte stream
230    if PY3:
231        stdin = sys.stdin.buffer
232    else:
233        stdin = sys.stdin
234
235    # Note: update the below log capture code after Display.display() is refactored.
236    saved_stdout = sys.stdout
237    sys.stdout = StringIO()
238
239    try:
240        # read the play context data via stdin, which means depickling it
241        vars_data = read_stream(stdin)
242        init_data = read_stream(stdin)
243
244        if PY3:
245            pc_data = cPickle.loads(init_data, encoding='bytes')
246            variables = cPickle.loads(vars_data, encoding='bytes')
247        else:
248            pc_data = cPickle.loads(init_data)
249            variables = cPickle.loads(vars_data)
250
251        play_context = PlayContext()
252        play_context.deserialize(pc_data)
253        display.verbosity = play_context.verbosity
254
255    except Exception as e:
256        rc = 1
257        result.update({
258            'error': to_text(e),
259            'exception': traceback.format_exc()
260        })
261
262    if rc == 0:
263        ssh = connection_loader.get('ssh', class_only=True)
264        ansible_playbook_pid = sys.argv[1]
265        task_uuid = sys.argv[2]
266        cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user, play_context.connection, ansible_playbook_pid)
267        # create the persistent connection dir if need be and create the paths
268        # which we will be using later
269        tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
270        makedirs_safe(tmp_path)
271
272        socket_path = unfrackpath(cp % dict(directory=tmp_path))
273        lock_path = unfrackpath("%s/.ansible_pc_lock_%s" % os.path.split(socket_path))
274
275        with file_lock(lock_path):
276            if not os.path.exists(socket_path):
277                messages.append(('vvvv', 'local domain socket does not exist, starting it'))
278                original_path = os.getcwd()
279                r, w = os.pipe()
280                pid = fork_process()
281
282                if pid == 0:
283                    try:
284                        os.close(r)
285                        wfd = os.fdopen(w, 'w')
286                        process = ConnectionProcess(wfd, play_context, socket_path, original_path, task_uuid, ansible_playbook_pid)
287                        process.start(variables)
288                    except Exception:
289                        messages.append(('error', traceback.format_exc()))
290                        rc = 1
291
292                    if rc == 0:
293                        process.run()
294                    else:
295                        process.shutdown()
296
297                    sys.exit(rc)
298
299                else:
300                    os.close(w)
301                    rfd = os.fdopen(r, 'r')
302                    data = json.loads(rfd.read(), cls=AnsibleJSONDecoder)
303                    messages.extend(data.pop('messages'))
304                    result.update(data)
305
306            else:
307                messages.append(('vvvv', 'found existing local domain socket, using it!'))
308                conn = Connection(socket_path)
309                try:
310                    conn.set_options(var_options=variables)
311                except ConnectionError as exc:
312                    messages.append(('debug', to_text(exc)))
313                    raise ConnectionError('Unable to decode JSON from response set_options. See the debug log for more information.')
314                pc_data = to_text(init_data)
315                try:
316                    conn.update_play_context(pc_data)
317                    conn.set_check_prompt(task_uuid)
318                except Exception as exc:
319                    # Only network_cli has update_play context and set_check_prompt, so missing this is
320                    # not fatal e.g. netconf
321                    if isinstance(exc, ConnectionError) and getattr(exc, 'code', None) == -32601:
322                        pass
323                    else:
324                        result.update({
325                            'error': to_text(exc),
326                            'exception': traceback.format_exc()
327                        })
328
329    if os.path.exists(socket_path):
330        messages.extend(Connection(socket_path).pop_messages())
331    messages.append(('vvvv', sys.stdout.getvalue()))
332    result.update({
333        'messages': messages,
334        'socket_path': socket_path
335    })
336
337    sys.stdout = saved_stdout
338    if 'exception' in result:
339        rc = 1
340        sys.stderr.write(json.dumps(result, cls=AnsibleJSONEncoder))
341    else:
342        rc = 0
343        sys.stdout.write(json.dumps(result, cls=AnsibleJSONEncoder))
344
345    sys.exit(rc)
346
347
348if __name__ == '__main__':
349    display = Display()
350    main()
351