1from __future__ import absolute_import
2from __future__ import division
3from __future__ import unicode_literals
4
5import concurrent.futures
6import contextlib
7import math
8import os
9import subprocess
10import sys
11
12import six
13
14from pre_commit import parse_shebang
15from pre_commit.util import cmd_output_b
16from pre_commit.util import cmd_output_p
17
18
19def _environ_size(_env=None):
20    environ = _env if _env is not None else getattr(os, 'environb', os.environ)
21    size = 8 * len(environ)  # number of pointers in `envp`
22    for k, v in environ.items():
23        size += len(k) + len(v) + 2  # c strings in `envp`
24    return size
25
26
27def _get_platform_max_length():  # pragma: no cover (platform specific)
28    if os.name == 'posix':
29        maximum = os.sysconf(str('SC_ARG_MAX')) - 2048 - _environ_size()
30        maximum = max(min(maximum, 2 ** 17), 2 ** 12)
31        return maximum
32    elif os.name == 'nt':
33        return 2 ** 15 - 2048  # UNICODE_STRING max - headroom
34    else:
35        # posix minimum
36        return 2 ** 12
37
38
39def _command_length(*cmd):
40    full_cmd = ' '.join(cmd)
41
42    # win32 uses the amount of characters, more details at:
43    # https://github.com/pre-commit/pre-commit/pull/839
44    if sys.platform == 'win32':
45        # the python2.x apis require bytes, we encode as UTF-8
46        if six.PY2:
47            return len(full_cmd.encode('utf-8'))
48        else:
49            return len(full_cmd.encode('utf-16le')) // 2
50    else:
51        return len(full_cmd.encode(sys.getfilesystemencoding()))
52
53
54class ArgumentTooLongError(RuntimeError):
55    pass
56
57
58def partition(cmd, varargs, target_concurrency, _max_length=None):
59    _max_length = _max_length or _get_platform_max_length()
60
61    # Generally, we try to partition evenly into at least `target_concurrency`
62    # partitions, but we don't want a bunch of tiny partitions.
63    max_args = max(4, math.ceil(len(varargs) / target_concurrency))
64
65    cmd = tuple(cmd)
66    ret = []
67
68    ret_cmd = []
69    # Reversed so arguments are in order
70    varargs = list(reversed(varargs))
71
72    total_length = _command_length(*cmd) + 1
73    while varargs:
74        arg = varargs.pop()
75
76        arg_length = _command_length(arg) + 1
77        if (
78                total_length + arg_length <= _max_length and
79                len(ret_cmd) < max_args
80        ):
81            ret_cmd.append(arg)
82            total_length += arg_length
83        elif not ret_cmd:
84            raise ArgumentTooLongError(arg)
85        else:
86            # We've exceeded the length, yield a command
87            ret.append(cmd + tuple(ret_cmd))
88            ret_cmd = []
89            total_length = _command_length(*cmd) + 1
90            varargs.append(arg)
91
92    ret.append(cmd + tuple(ret_cmd))
93
94    return tuple(ret)
95
96
97@contextlib.contextmanager
98def _thread_mapper(maxsize):
99    if maxsize == 1:
100        yield map
101    else:
102        with concurrent.futures.ThreadPoolExecutor(maxsize) as ex:
103            yield ex.map
104
105
106def xargs(cmd, varargs, **kwargs):
107    """A simplified implementation of xargs.
108
109    color: Make a pty if on a platform that supports it
110    negate: Make nonzero successful and zero a failure
111    target_concurrency: Target number of partitions to run concurrently
112    """
113    color = kwargs.pop('color', False)
114    negate = kwargs.pop('negate', False)
115    target_concurrency = kwargs.pop('target_concurrency', 1)
116    max_length = kwargs.pop('_max_length', _get_platform_max_length())
117    cmd_fn = cmd_output_p if color else cmd_output_b
118    retcode = 0
119    stdout = b''
120
121    try:
122        cmd = parse_shebang.normalize_cmd(cmd)
123    except parse_shebang.ExecutableNotFoundError as e:
124        return e.to_output()[:2]
125
126    partitions = partition(cmd, varargs, target_concurrency, max_length)
127
128    def run_cmd_partition(run_cmd):
129        return cmd_fn(
130            *run_cmd, retcode=None, stderr=subprocess.STDOUT, **kwargs
131        )
132
133    threads = min(len(partitions), target_concurrency)
134    with _thread_mapper(threads) as thread_map:
135        results = thread_map(run_cmd_partition, partitions)
136
137        for proc_retcode, proc_out, _ in results:
138            if negate:
139                proc_retcode = not proc_retcode
140            retcode = max(retcode, proc_retcode)
141            stdout += proc_out
142
143    return retcode, stdout
144