1#!/usr/bin/python
2# -*- coding: utf-8 -*-
3
4# Copyright: (c) 2012, Jeroen Hoekx <jeroen@hoekx.be>
5# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
6
7from __future__ import absolute_import, division, print_function
8__metaclass__ = type
9
10
11DOCUMENTATION = r'''
12---
13module: wait_for
14short_description: Waits for a condition before continuing
15description:
16     - You can wait for a set amount of time C(timeout), this is the default if nothing is specified or just C(timeout) is specified.
17       This does not produce an error.
18     - Waiting for a port to become available is useful for when services are not immediately available after their init scripts return
19       which is true of certain Java application servers.
20     - It is also useful when starting guests with the M(community.libvirt.virt) module and needing to pause until they are ready.
21     - This module can also be used to wait for a regex match a string to be present in a file.
22     - In Ansible 1.6 and later, this module can also be used to wait for a file to be available or
23       absent on the filesystem.
24     - In Ansible 1.8 and later, this module can also be used to wait for active connections to be closed before continuing, useful if a node
25       is being rotated out of a load balancer pool.
26     - For Windows targets, use the M(ansible.windows.win_wait_for) module instead.
27version_added: "0.7"
28options:
29  host:
30    description:
31      - A resolvable hostname or IP address to wait for.
32    type: str
33    default: 127.0.0.1
34  timeout:
35    description:
36      - Maximum number of seconds to wait for, when used with another condition it will force an error.
37      - When used without other conditions it is equivalent of just sleeping.
38    type: int
39    default: 300
40  connect_timeout:
41    description:
42      - Maximum number of seconds to wait for a connection to happen before closing and retrying.
43    type: int
44    default: 5
45  delay:
46    description:
47      - Number of seconds to wait before starting to poll.
48    type: int
49    default: 0
50  port:
51    description:
52      - Port number to poll.
53      - C(path) and C(port) are mutually exclusive parameters.
54    type: int
55  active_connection_states:
56    description:
57      - The list of TCP connection states which are counted as active connections.
58    type: list
59    default: [ ESTABLISHED, FIN_WAIT1, FIN_WAIT2, SYN_RECV, SYN_SENT, TIME_WAIT ]
60    version_added: "2.3"
61  state:
62    description:
63      - Either C(present), C(started), or C(stopped), C(absent), or C(drained).
64      - When checking a port C(started) will ensure the port is open, C(stopped) will check that it is closed, C(drained) will check for active connections.
65      - When checking for a file or a search string C(present) or C(started) will ensure that the file or string is present before continuing,
66        C(absent) will check that file is absent or removed.
67    type: str
68    choices: [ absent, drained, present, started, stopped ]
69    default: started
70  path:
71    description:
72      - Path to a file on the filesystem that must exist before continuing.
73      - C(path) and C(port) are mutually exclusive parameters.
74    type: path
75    version_added: "1.4"
76  search_regex:
77    description:
78      - Can be used to match a string in either a file or a socket connection.
79      - Defaults to a multiline regex.
80    type: str
81    version_added: "1.4"
82  exclude_hosts:
83    description:
84      - List of hosts or IPs to ignore when looking for active TCP connections for C(drained) state.
85    type: list
86    version_added: "1.8"
87  sleep:
88    description:
89      - Number of seconds to sleep between checks.
90      - Before Ansible 2.3 this was hardcoded to 1 second.
91    type: int
92    default: 1
93    version_added: "2.3"
94  msg:
95    description:
96      - This overrides the normal error message from a failure to meet the required conditions.
97    type: str
98    version_added: "2.4"
99notes:
100  - The ability to use search_regex with a port connection was added in Ansible 1.7.
101  - Prior to Ansible 2.4, testing for the absence of a directory or UNIX socket did not work correctly.
102  - Prior to Ansible 2.4, testing for the presence of a file did not work correctly if the remote user did not have read access to that file.
103  - Under some circumstances when using mandatory access control, a path may always be treated as being absent even if it exists, but
104    can't be modified or created by the remote user either.
105  - When waiting for a path, symbolic links will be followed.  Many other modules that manipulate files do not follow symbolic links,
106    so operations on the path using other modules may not work exactly as expected.
107seealso:
108- module: ansible.builtin.wait_for_connection
109- module: ansible.windows.win_wait_for
110- module: community.windows.win_wait_for_process
111author:
112    - Jeroen Hoekx (@jhoekx)
113    - John Jarvis (@jarv)
114    - Andrii Radyk (@AnderEnder)
115'''
116
117EXAMPLES = r'''
118- name: Sleep for 300 seconds and continue with play
119  wait_for:
120    timeout: 300
121  delegate_to: localhost
122
123- name: Wait for port 8000 to become open on the host, don't start checking for 10 seconds
124  wait_for:
125    port: 8000
126    delay: 10
127
128- name: Waits for port 8000 of any IP to close active connections, don't start checking for 10 seconds
129  wait_for:
130    host: 0.0.0.0
131    port: 8000
132    delay: 10
133    state: drained
134
135- name: Wait for port 8000 of any IP to close active connections, ignoring connections for specified hosts
136  wait_for:
137    host: 0.0.0.0
138    port: 8000
139    state: drained
140    exclude_hosts: 10.2.1.2,10.2.1.3
141
142- name: Wait until the file /tmp/foo is present before continuing
143  wait_for:
144    path: /tmp/foo
145
146- name: Wait until the string "completed" is in the file /tmp/foo before continuing
147  wait_for:
148    path: /tmp/foo
149    search_regex: completed
150
151- name: Wait until regex pattern matches in the file /tmp/foo and print the matched group
152  wait_for:
153    path: /tmp/foo
154    search_regex: completed (?P<task>\w+)
155  register: waitfor
156- debug:
157    msg: Completed {{ waitfor['match_groupdict']['task'] }}
158
159- name: Wait until the lock file is removed
160  wait_for:
161    path: /var/lock/file.lock
162    state: absent
163
164- name: Wait until the process is finished and pid was destroyed
165  wait_for:
166    path: /proc/3466/status
167    state: absent
168
169- name: Output customized message when failed
170  wait_for:
171    path: /tmp/foo
172    state: present
173    msg: Timeout to find file /tmp/foo
174
175# Do not assume the inventory_hostname is resolvable and delay 10 seconds at start
176- name: Wait 300 seconds for port 22 to become open and contain "OpenSSH"
177  wait_for:
178    port: 22
179    host: '{{ (ansible_ssh_host|default(ansible_host))|default(inventory_hostname) }}'
180    search_regex: OpenSSH
181    delay: 10
182  connection: local
183
184# Same as above but you normally have ansible_connection set in inventory, which overrides 'connection'
185- name: Wait 300 seconds for port 22 to become open and contain "OpenSSH"
186  wait_for:
187    port: 22
188    host: '{{ (ansible_ssh_host|default(ansible_host))|default(inventory_hostname) }}'
189    search_regex: OpenSSH
190    delay: 10
191  vars:
192    ansible_connection: local
193'''
194
195RETURN = r'''
196elapsed:
197  description: The number of seconds that elapsed while waiting
198  returned: always
199  type: int
200  sample: 23
201match_groups:
202  description: Tuple containing all the subgroups of the match as returned by U(https://docs.python.org/2/library/re.html#re.MatchObject.groups)
203  returned: always
204  type: list
205  sample: ['match 1', 'match 2']
206match_groupdict:
207  description: Dictionary containing all the named subgroups of the match, keyed by the subgroup name,
208    as returned by U(https://docs.python.org/2/library/re.html#re.MatchObject.groupdict)
209  returned: always
210  type: dict
211  sample:
212    {
213      'group': 'match'
214    }
215'''
216
217import binascii
218import datetime
219import errno
220import math
221import os
222import re
223import select
224import socket
225import time
226import traceback
227
228from ansible.module_utils.basic import AnsibleModule, missing_required_lib
229from ansible.module_utils.common.sys_info import get_platform_subclass
230from ansible.module_utils._text import to_native
231
232
233HAS_PSUTIL = False
234PSUTIL_IMP_ERR = None
235try:
236    import psutil
237    HAS_PSUTIL = True
238    # just because we can import it on Linux doesn't mean we will use it
239except ImportError:
240    PSUTIL_IMP_ERR = traceback.format_exc()
241
242
243class TCPConnectionInfo(object):
244    """
245    This is a generic TCP Connection Info strategy class that relies
246    on the psutil module, which is not ideal for targets, but necessary
247    for cross platform support.
248
249    A subclass may wish to override some or all of these methods.
250      - _get_exclude_ips()
251      - get_active_connections()
252
253    All subclasses MUST define platform and distribution (which may be None).
254    """
255    platform = 'Generic'
256    distribution = None
257
258    match_all_ips = {
259        socket.AF_INET: '0.0.0.0',
260        socket.AF_INET6: '::',
261    }
262    ipv4_mapped_ipv6_address = {
263        'prefix': '::ffff',
264        'match_all': '::ffff:0.0.0.0'
265    }
266
267    def __new__(cls, *args, **kwargs):
268        new_cls = get_platform_subclass(TCPConnectionInfo)
269        return super(cls, new_cls).__new__(new_cls)
270
271    def __init__(self, module):
272        self.module = module
273        self.ips = _convert_host_to_ip(module.params['host'])
274        self.port = int(self.module.params['port'])
275        self.exclude_ips = self._get_exclude_ips()
276        if not HAS_PSUTIL:
277            module.fail_json(msg=missing_required_lib('psutil'), exception=PSUTIL_IMP_ERR)
278
279    def _get_exclude_ips(self):
280        exclude_hosts = self.module.params['exclude_hosts']
281        exclude_ips = []
282        if exclude_hosts is not None:
283            for host in exclude_hosts:
284                exclude_ips.extend(_convert_host_to_ip(host))
285        return exclude_ips
286
287    def get_active_connections_count(self):
288        active_connections = 0
289        for p in psutil.process_iter():
290            try:
291                if hasattr(p, 'get_connections'):
292                    connections = p.get_connections(kind='inet')
293                else:
294                    connections = p.connections(kind='inet')
295            except psutil.Error:
296                # Process is Zombie or other error state
297                continue
298            for conn in connections:
299                if conn.status not in self.module.params['active_connection_states']:
300                    continue
301                if hasattr(conn, 'local_address'):
302                    (local_ip, local_port) = conn.local_address
303                else:
304                    (local_ip, local_port) = conn.laddr
305                if self.port != local_port:
306                    continue
307                if hasattr(conn, 'remote_address'):
308                    (remote_ip, remote_port) = conn.remote_address
309                else:
310                    (remote_ip, remote_port) = conn.raddr
311                if (conn.family, remote_ip) in self.exclude_ips:
312                    continue
313                if any((
314                    (conn.family, local_ip) in self.ips,
315                    (conn.family, self.match_all_ips[conn.family]) in self.ips,
316                    local_ip.startswith(self.ipv4_mapped_ipv6_address['prefix']) and
317                        (conn.family, self.ipv4_mapped_ipv6_address['match_all']) in self.ips,
318                )):
319                    active_connections += 1
320        return active_connections
321
322
323# ===========================================
324# Subclass: Linux
325
326class LinuxTCPConnectionInfo(TCPConnectionInfo):
327    """
328    This is a TCP Connection Info evaluation strategy class
329    that utilizes information from Linux's procfs. While less universal,
330    does allow Linux targets to not require an additional library.
331    """
332    platform = 'Linux'
333    distribution = None
334
335    source_file = {
336        socket.AF_INET: '/proc/net/tcp',
337        socket.AF_INET6: '/proc/net/tcp6'
338    }
339    match_all_ips = {
340        socket.AF_INET: '00000000',
341        socket.AF_INET6: '00000000000000000000000000000000',
342    }
343    ipv4_mapped_ipv6_address = {
344        'prefix': '0000000000000000FFFF0000',
345        'match_all': '0000000000000000FFFF000000000000'
346    }
347    local_address_field = 1
348    remote_address_field = 2
349    connection_state_field = 3
350
351    def __init__(self, module):
352        self.module = module
353        self.ips = _convert_host_to_hex(module.params['host'])
354        self.port = "%0.4X" % int(module.params['port'])
355        self.exclude_ips = self._get_exclude_ips()
356
357    def _get_exclude_ips(self):
358        exclude_hosts = self.module.params['exclude_hosts']
359        exclude_ips = []
360        if exclude_hosts is not None:
361            for host in exclude_hosts:
362                exclude_ips.extend(_convert_host_to_hex(host))
363        return exclude_ips
364
365    def get_active_connections_count(self):
366        active_connections = 0
367        for family in self.source_file.keys():
368            if not os.path.isfile(self.source_file[family]):
369                continue
370            try:
371                f = open(self.source_file[family])
372                for tcp_connection in f.readlines():
373                    tcp_connection = tcp_connection.strip().split()
374                    if tcp_connection[self.local_address_field] == 'local_address':
375                        continue
376                    if (tcp_connection[self.connection_state_field] not in
377                            [get_connection_state_id(_connection_state) for _connection_state in self.module.params['active_connection_states']]):
378                        continue
379                    (local_ip, local_port) = tcp_connection[self.local_address_field].split(':')
380                    if self.port != local_port:
381                        continue
382                    (remote_ip, remote_port) = tcp_connection[self.remote_address_field].split(':')
383                    if (family, remote_ip) in self.exclude_ips:
384                        continue
385                    if any((
386                        (family, local_ip) in self.ips,
387                        (family, self.match_all_ips[family]) in self.ips,
388                        local_ip.startswith(self.ipv4_mapped_ipv6_address['prefix']) and
389                            (family, self.ipv4_mapped_ipv6_address['match_all']) in self.ips,
390                    )):
391                        active_connections += 1
392            except IOError as e:
393                pass
394            finally:
395                f.close()
396
397        return active_connections
398
399
400def _convert_host_to_ip(host):
401    """
402    Perform forward DNS resolution on host, IP will give the same IP
403
404    Args:
405        host: String with either hostname, IPv4, or IPv6 address
406
407    Returns:
408        List of tuples containing address family and IP
409    """
410    addrinfo = socket.getaddrinfo(host, 80, 0, 0, socket.SOL_TCP)
411    ips = []
412    for family, socktype, proto, canonname, sockaddr in addrinfo:
413        ip = sockaddr[0]
414        ips.append((family, ip))
415        if family == socket.AF_INET:
416            ips.append((socket.AF_INET6, "::ffff:" + ip))
417    return ips
418
419
420def _convert_host_to_hex(host):
421    """
422    Convert the provided host to the format in /proc/net/tcp*
423
424    /proc/net/tcp uses little-endian four byte hex for ipv4
425    /proc/net/tcp6 uses little-endian per 4B word for ipv6
426
427    Args:
428        host: String with either hostname, IPv4, or IPv6 address
429
430    Returns:
431        List of tuples containing address family and the
432        little-endian converted host
433    """
434    ips = []
435    if host is not None:
436        for family, ip in _convert_host_to_ip(host):
437            hexip_nf = binascii.b2a_hex(socket.inet_pton(family, ip))
438            hexip_hf = ""
439            for i in range(0, len(hexip_nf), 8):
440                ipgroup_nf = hexip_nf[i:i + 8]
441                ipgroup_hf = socket.ntohl(int(ipgroup_nf, base=16))
442                hexip_hf = "%s%08X" % (hexip_hf, ipgroup_hf)
443            ips.append((family, hexip_hf))
444    return ips
445
446
447def _timedelta_total_seconds(timedelta):
448    return (
449        timedelta.microseconds + 0.0 +
450        (timedelta.seconds + timedelta.days * 24 * 3600) * 10 ** 6) / 10 ** 6
451
452
453def get_connection_state_id(state):
454    connection_state_id = {
455        'ESTABLISHED': '01',
456        'SYN_SENT': '02',
457        'SYN_RECV': '03',
458        'FIN_WAIT1': '04',
459        'FIN_WAIT2': '05',
460        'TIME_WAIT': '06',
461    }
462    return connection_state_id[state]
463
464
465def main():
466
467    module = AnsibleModule(
468        argument_spec=dict(
469            host=dict(type='str', default='127.0.0.1'),
470            timeout=dict(type='int', default=300),
471            connect_timeout=dict(type='int', default=5),
472            delay=dict(type='int', default=0),
473            port=dict(type='int'),
474            active_connection_states=dict(type='list', default=['ESTABLISHED', 'FIN_WAIT1', 'FIN_WAIT2', 'SYN_RECV', 'SYN_SENT', 'TIME_WAIT']),
475            path=dict(type='path'),
476            search_regex=dict(type='str'),
477            state=dict(type='str', default='started', choices=['absent', 'drained', 'present', 'started', 'stopped']),
478            exclude_hosts=dict(type='list'),
479            sleep=dict(type='int', default=1),
480            msg=dict(type='str'),
481        ),
482    )
483
484    host = module.params['host']
485    timeout = module.params['timeout']
486    connect_timeout = module.params['connect_timeout']
487    delay = module.params['delay']
488    port = module.params['port']
489    state = module.params['state']
490    path = module.params['path']
491    search_regex = module.params['search_regex']
492    msg = module.params['msg']
493
494    if search_regex is not None:
495        compiled_search_re = re.compile(search_regex, re.MULTILINE)
496    else:
497        compiled_search_re = None
498
499    match_groupdict = {}
500    match_groups = ()
501
502    if port and path:
503        module.fail_json(msg="port and path parameter can not both be passed to wait_for", elapsed=0)
504    if path and state == 'stopped':
505        module.fail_json(msg="state=stopped should only be used for checking a port in the wait_for module", elapsed=0)
506    if path and state == 'drained':
507        module.fail_json(msg="state=drained should only be used for checking a port in the wait_for module", elapsed=0)
508    if module.params['exclude_hosts'] is not None and state != 'drained':
509        module.fail_json(msg="exclude_hosts should only be with state=drained", elapsed=0)
510    for _connection_state in module.params['active_connection_states']:
511        try:
512            get_connection_state_id(_connection_state)
513        except Exception:
514            module.fail_json(msg="unknown active_connection_state (%s) defined" % _connection_state, elapsed=0)
515
516    start = datetime.datetime.utcnow()
517
518    if delay:
519        time.sleep(delay)
520
521    if not port and not path and state != 'drained':
522        time.sleep(timeout)
523    elif state in ['absent', 'stopped']:
524        # first wait for the stop condition
525        end = start + datetime.timedelta(seconds=timeout)
526
527        while datetime.datetime.utcnow() < end:
528            if path:
529                try:
530                    if not os.access(path, os.F_OK):
531                        break
532                except IOError:
533                    break
534            elif port:
535                try:
536                    s = socket.create_connection((host, port), connect_timeout)
537                    s.shutdown(socket.SHUT_RDWR)
538                    s.close()
539                except Exception:
540                    break
541            # Conditions not yet met, wait and try again
542            time.sleep(module.params['sleep'])
543        else:
544            elapsed = datetime.datetime.utcnow() - start
545            if port:
546                module.fail_json(msg=msg or "Timeout when waiting for %s:%s to stop." % (host, port), elapsed=elapsed.seconds)
547            elif path:
548                module.fail_json(msg=msg or "Timeout when waiting for %s to be absent." % (path), elapsed=elapsed.seconds)
549
550    elif state in ['started', 'present']:
551        # wait for start condition
552        end = start + datetime.timedelta(seconds=timeout)
553        while datetime.datetime.utcnow() < end:
554            if path:
555                try:
556                    os.stat(path)
557                except OSError as e:
558                    # If anything except file not present, throw an error
559                    if e.errno != 2:
560                        elapsed = datetime.datetime.utcnow() - start
561                        module.fail_json(msg=msg or "Failed to stat %s, %s" % (path, e.strerror), elapsed=elapsed.seconds)
562                    # file doesn't exist yet, so continue
563                else:
564                    # File exists.  Are there additional things to check?
565                    if not compiled_search_re:
566                        # nope, succeed!
567                        break
568                    try:
569                        f = open(path)
570                        try:
571                            search = re.search(compiled_search_re, f.read())
572                            if search:
573                                if search.groupdict():
574                                    match_groupdict = search.groupdict()
575                                if search.groups():
576                                    match_groups = search.groups()
577
578                                break
579                        finally:
580                            f.close()
581                    except IOError:
582                        pass
583            elif port:
584                alt_connect_timeout = math.ceil(_timedelta_total_seconds(end - datetime.datetime.utcnow()))
585                try:
586                    s = socket.create_connection((host, port), min(connect_timeout, alt_connect_timeout))
587                except Exception:
588                    # Failed to connect by connect_timeout. wait and try again
589                    pass
590                else:
591                    # Connected -- are there additional conditions?
592                    if compiled_search_re:
593                        data = ''
594                        matched = False
595                        while datetime.datetime.utcnow() < end:
596                            max_timeout = math.ceil(_timedelta_total_seconds(end - datetime.datetime.utcnow()))
597                            (readable, w, e) = select.select([s], [], [], max_timeout)
598                            if not readable:
599                                # No new data.  Probably means our timeout
600                                # expired
601                                continue
602                            response = s.recv(1024)
603                            if not response:
604                                # Server shutdown
605                                break
606                            data += to_native(response, errors='surrogate_or_strict')
607                            if re.search(compiled_search_re, data):
608                                matched = True
609                                break
610
611                        # Shutdown the client socket
612                        try:
613                            s.shutdown(socket.SHUT_RDWR)
614                        except socket.error as e:
615                            if e.errno != errno.ENOTCONN:
616                                raise
617                        # else, the server broke the connection on its end, assume it's not ready
618                        else:
619                            s.close()
620                        if matched:
621                            # Found our string, success!
622                            break
623                    else:
624                        # Connection established, success!
625                        try:
626                            s.shutdown(socket.SHUT_RDWR)
627                        except socket.error as e:
628                            if e.errno != errno.ENOTCONN:
629                                raise
630                        # else, the server broke the connection on its end, assume it's not ready
631                        else:
632                            s.close()
633                        break
634
635            # Conditions not yet met, wait and try again
636            time.sleep(module.params['sleep'])
637
638        else:   # while-else
639            # Timeout expired
640            elapsed = datetime.datetime.utcnow() - start
641            if port:
642                if search_regex:
643                    module.fail_json(msg=msg or "Timeout when waiting for search string %s in %s:%s" % (search_regex, host, port), elapsed=elapsed.seconds)
644                else:
645                    module.fail_json(msg=msg or "Timeout when waiting for %s:%s" % (host, port), elapsed=elapsed.seconds)
646            elif path:
647                if search_regex:
648                    module.fail_json(msg=msg or "Timeout when waiting for search string %s in %s" % (search_regex, path), elapsed=elapsed.seconds)
649                else:
650                    module.fail_json(msg=msg or "Timeout when waiting for file %s" % (path), elapsed=elapsed.seconds)
651
652    elif state == 'drained':
653        # wait until all active connections are gone
654        end = start + datetime.timedelta(seconds=timeout)
655        tcpconns = TCPConnectionInfo(module)
656        while datetime.datetime.utcnow() < end:
657            if tcpconns.get_active_connections_count() == 0:
658                break
659
660            # Conditions not yet met, wait and try again
661            time.sleep(module.params['sleep'])
662        else:
663            elapsed = datetime.datetime.utcnow() - start
664            module.fail_json(msg=msg or "Timeout when waiting for %s:%s to drain" % (host, port), elapsed=elapsed.seconds)
665
666    elapsed = datetime.datetime.utcnow() - start
667    module.exit_json(state=state, port=port, search_regex=search_regex, match_groups=match_groups, match_groupdict=match_groupdict, path=path,
668                     elapsed=elapsed.seconds)
669
670
671if __name__ == '__main__':
672    main()
673