1"""
2Functions for manipulating or otherwise processing strings
3"""
4
5
6import base64
7import difflib
8import errno
9import fnmatch
10import logging
11import os
12import re
13import shlex
14import time
15import unicodedata
16
17from salt.utils.decorators.jinja import jinja_filter
18
19log = logging.getLogger(__name__)
20
21
22@jinja_filter("to_bytes")
23def to_bytes(s, encoding=None, errors="strict"):
24    """
25    Given bytes, bytearray, str, or unicode (python 2), return bytes (str for
26    python 2)
27    """
28    if encoding is None:
29        # Try utf-8 first, and fall back to detected encoding
30        encoding = ("utf-8", __salt_system_encoding__)
31    if not isinstance(encoding, (tuple, list)):
32        encoding = (encoding,)
33
34    if not encoding:
35        raise ValueError("encoding cannot be empty")
36
37    exc = None
38    if isinstance(s, bytes):
39        return s
40    if isinstance(s, bytearray):
41        return bytes(s)
42    if isinstance(s, str):
43        for enc in encoding:
44            try:
45                return s.encode(enc, errors)
46            except UnicodeEncodeError as err:
47                exc = err
48                continue
49        # The only way we get this far is if a UnicodeEncodeError was
50        # raised, otherwise we would have already returned (or raised some
51        # other exception).
52        raise exc  # pylint: disable=raising-bad-type
53    raise TypeError("expected str, bytes, or bytearray not {}".format(type(s)))
54
55
56def to_str(s, encoding=None, errors="strict", normalize=False):
57    """
58    Given str, bytes, bytearray, or unicode (py2), return str
59    """
60
61    def _normalize(s):
62        try:
63            return unicodedata.normalize("NFC", s) if normalize else s
64        except TypeError:
65            return s
66
67    if encoding is None:
68        # Try utf-8 first, and fall back to detected encoding
69        encoding = ("utf-8", __salt_system_encoding__)
70    if not isinstance(encoding, (tuple, list)):
71        encoding = (encoding,)
72
73    if not encoding:
74        raise ValueError("encoding cannot be empty")
75
76    if isinstance(s, str):
77        return _normalize(s)
78
79    exc = None
80    if isinstance(s, (bytes, bytearray)):
81        for enc in encoding:
82            try:
83                return _normalize(s.decode(enc, errors))
84            except UnicodeDecodeError as err:
85                exc = err
86                continue
87        # The only way we get this far is if a UnicodeDecodeError was
88        # raised, otherwise we would have already returned (or raised some
89        # other exception).
90        raise exc  # pylint: disable=raising-bad-type
91    raise TypeError("expected str, bytes, or bytearray not {}".format(type(s)))
92
93
94def to_unicode(s, encoding=None, errors="strict", normalize=False):
95    """
96    Given str or unicode, return unicode (str for python 3)
97    """
98
99    def _normalize(s):
100        return unicodedata.normalize("NFC", s) if normalize else s
101
102    if encoding is None:
103        # Try utf-8 first, and fall back to detected encoding
104        encoding = ("utf-8", __salt_system_encoding__)
105    if not isinstance(encoding, (tuple, list)):
106        encoding = (encoding,)
107
108    if not encoding:
109        raise ValueError("encoding cannot be empty")
110
111    if isinstance(s, str):
112        return _normalize(s)
113    elif isinstance(s, (bytes, bytearray)):
114        return _normalize(to_str(s, encoding, errors))
115    raise TypeError("expected str, bytes, or bytearray not {}".format(type(s)))
116
117
118@jinja_filter("str_to_num")
119@jinja_filter("to_num")
120def to_num(text):
121    """
122    Convert a string to a number.
123    Returns an integer if the string represents an integer, a floating
124    point number if the string is a real number, or the string unchanged
125    otherwise.
126    """
127    try:
128        return int(text)
129    except ValueError:
130        try:
131            return float(text)
132        except ValueError:
133            return text
134
135
136def to_none(text):
137    """
138    Convert a string to None if the string is empty or contains only spaces.
139    """
140    if str(text).strip():
141        return text
142    return None
143
144
145def is_quoted(value):
146    """
147    Return a single or double quote, if a string is wrapped in extra quotes.
148    Otherwise return an empty string.
149    """
150    ret = ""
151    if (
152        isinstance(value, str)
153        and value[0] == value[-1]
154        and value.startswith(("'", '"'))
155    ):
156        ret = value[0]
157    return ret
158
159
160def dequote(value):
161    """
162    Remove extra quotes around a string.
163    """
164    if is_quoted(value):
165        return value[1:-1]
166    return value
167
168
169@jinja_filter("is_hex")
170def is_hex(value):
171    """
172    Returns True if value is a hexadecimal string, otherwise returns False
173    """
174    try:
175        int(value, 16)
176        return True
177    except (TypeError, ValueError):
178        return False
179
180
181def is_binary(data):
182    """
183    Detects if the passed string of data is binary or text
184    """
185    if not data or not isinstance(data, ((str,), bytes)):
186        return False
187
188    if isinstance(data, bytes):
189        if b"\0" in data:
190            return True
191    elif "\0" in data:
192        return True
193
194    text_characters = "".join([chr(x) for x in range(32, 127)] + list("\n\r\t\b"))
195    # Get the non-text characters (map each character to itself then use the
196    # 'remove' option to get rid of the text characters.)
197    if isinstance(data, bytes):
198        import salt.utils.data
199
200        nontext = data.translate(None, salt.utils.data.encode(text_characters))
201    else:
202        trans = "".maketrans("", "", text_characters)
203        nontext = data.translate(trans)
204
205    # If more than 30% non-text characters, then
206    # this is considered binary data
207    if float(len(nontext)) / len(data) > 0.30:
208        return True
209    return False
210
211
212@jinja_filter("random_str")
213def random(size=32):
214    key = os.urandom(size)
215    return to_unicode(base64.b64encode(key).replace(b"\n", b"")[:size])
216
217
218@jinja_filter("contains_whitespace")
219def contains_whitespace(text):
220    """
221    Returns True if there are any whitespace characters in the string
222    """
223    return any(x.isspace() for x in text)
224
225
226def human_to_bytes(size):
227    """
228    Given a human-readable byte string (e.g. 2G, 30M),
229    return the number of bytes.  Will return 0 if the argument has
230    unexpected form.
231
232    .. versionadded:: 2018.3.0
233    """
234    sbytes = size[:-1]
235    unit = size[-1]
236    if sbytes.isdigit():
237        sbytes = int(sbytes)
238        if unit == "P":
239            sbytes *= 1125899906842624
240        elif unit == "T":
241            sbytes *= 1099511627776
242        elif unit == "G":
243            sbytes *= 1073741824
244        elif unit == "M":
245            sbytes *= 1048576
246        else:
247            sbytes = 0
248    else:
249        sbytes = 0
250    return sbytes
251
252
253def build_whitespace_split_regex(text):
254    '''
255    Create a regular expression at runtime which should match ignoring the
256    addition or deletion of white space or line breaks, unless between commas
257
258    Example:
259
260    .. code-block:: python
261
262        >>> import re
263        >>> import salt.utils.stringutils
264        >>> regex = salt.utils.stringutils.build_whitespace_split_regex(
265        ...     """if [ -z "$debian_chroot" ] && [ -r /etc/debian_chroot ]; then"""
266        ... )
267
268        >>> regex
269        '(?:[\\s]+)?if(?:[\\s]+)?\\[(?:[\\s]+)?\\-z(?:[\\s]+)?\\"\\$debian'
270        '\\_chroot\\"(?:[\\s]+)?\\](?:[\\s]+)?\\&\\&(?:[\\s]+)?\\[(?:[\\s]+)?'
271        '\\-r(?:[\\s]+)?\\/etc\\/debian\\_chroot(?:[\\s]+)?\\]\\;(?:[\\s]+)?'
272        'then(?:[\\s]+)?'
273        >>> re.search(
274        ...     regex,
275        ...     """if [ -z "$debian_chroot" ] && [ -r /etc/debian_chroot ]; then"""
276        ... )
277
278        <_sre.SRE_Match object at 0xb70639c0>
279        >>>
280
281    '''
282
283    def __build_parts(text):
284        lexer = shlex.shlex(text)
285        lexer.whitespace_split = True
286        lexer.commenters = ""
287        if r"'\"" in text:
288            lexer.quotes = ""
289        elif "'" in text:
290            lexer.quotes = '"'
291        elif '"' in text:
292            lexer.quotes = "'"
293        return list(lexer)
294
295    regex = r""
296    for line in text.splitlines():
297        parts = [re.escape(s) for s in __build_parts(line)]
298        regex += r"(?:[\s]+)?{}(?:[\s]+)?".format(r"(?:[\s]+)?".join(parts))
299    return r"(?m)^{}$".format(regex)
300
301
302def expr_match(line, expr):
303    """
304    Checks whether or not the passed value matches the specified expression.
305    Tries to match expr first as a glob using fnmatch.fnmatch(), and then tries
306    to match expr as a regular expression. Originally designed to match minion
307    IDs for whitelists/blacklists.
308
309    Note that this also does exact matches, as fnmatch.fnmatch() will return
310    ``True`` when no glob characters are used and the string is an exact match:
311
312    .. code-block:: python
313
314        >>> fnmatch.fnmatch('foo', 'foo')
315        True
316    """
317    try:
318        if fnmatch.fnmatch(line, expr):
319            return True
320        try:
321            if re.match(r"\A{}\Z".format(expr), line):
322                return True
323        except re.error:
324            pass
325    except TypeError:
326        log.exception("Value %r or expression %r is not a string", line, expr)
327    return False
328
329
330@jinja_filter("check_whitelist_blacklist")
331def check_whitelist_blacklist(value, whitelist=None, blacklist=None):
332    """
333    Check a whitelist and/or blacklist to see if the value matches it.
334
335    value
336        The item to check the whitelist and/or blacklist against.
337
338    whitelist
339        The list of items that are white-listed. If ``value`` is found
340        in the whitelist, then the function returns ``True``. Otherwise,
341        it returns ``False``.
342
343    blacklist
344        The list of items that are black-listed. If ``value`` is found
345        in the blacklist, then the function returns ``False``. Otherwise,
346        it returns ``True``.
347
348    If both a whitelist and a blacklist are provided, value membership
349    in the blacklist will be examined first. If the value is not found
350    in the blacklist, then the whitelist is checked. If the value isn't
351    found in the whitelist, the function returns ``False``.
352    """
353    # Normalize the input so that we have a list
354    if blacklist:
355        if isinstance(blacklist, str):
356            blacklist = [blacklist]
357        if not hasattr(blacklist, "__iter__"):
358            raise TypeError(
359                "Expecting iterable blacklist, but got {} ({})".format(
360                    type(blacklist).__name__, blacklist
361                )
362            )
363    else:
364        blacklist = []
365
366    if whitelist:
367        if isinstance(whitelist, str):
368            whitelist = [whitelist]
369        if not hasattr(whitelist, "__iter__"):
370            raise TypeError(
371                "Expecting iterable whitelist, but got {} ({})".format(
372                    type(whitelist).__name__, whitelist
373                )
374            )
375    else:
376        whitelist = []
377
378    _blacklist_match = any(expr_match(value, expr) for expr in blacklist)
379    _whitelist_match = any(expr_match(value, expr) for expr in whitelist)
380
381    if blacklist and not whitelist:
382        # Blacklist but no whitelist
383        return not _blacklist_match
384    elif whitelist and not blacklist:
385        # Whitelist but no blacklist
386        return _whitelist_match
387    elif blacklist and whitelist:
388        # Both whitelist and blacklist
389        return not _blacklist_match and _whitelist_match
390    else:
391        # No blacklist or whitelist passed
392        return True
393
394
395def check_include_exclude(path_str, include_pat=None, exclude_pat=None):
396    """
397    Check for glob or regexp patterns for include_pat and exclude_pat in the
398    'path_str' string and return True/False conditions as follows.
399      - Default: return 'True' if no include_pat or exclude_pat patterns are
400        supplied
401      - If only include_pat or exclude_pat is supplied: return 'True' if string
402        passes the include_pat test or fails exclude_pat test respectively
403      - If both include_pat and exclude_pat are supplied: return 'True' if
404        include_pat matches AND exclude_pat does not match
405    """
406
407    def _pat_check(path_str, check_pat):
408        if re.match("E@", check_pat):
409            return True if re.search(check_pat[2:], path_str) else False
410        else:
411            return True if fnmatch.fnmatch(path_str, check_pat) else False
412
413    ret = True  # -- default true
414    # Before pattern match, check if it is regexp (E@'') or glob(default)
415    if include_pat:
416        if isinstance(include_pat, list):
417            for include_line in include_pat:
418                retchk_include = _pat_check(path_str, include_line)
419                if retchk_include:
420                    break
421        else:
422            retchk_include = _pat_check(path_str, include_pat)
423
424    if exclude_pat:
425        if isinstance(exclude_pat, list):
426            for exclude_line in exclude_pat:
427                retchk_exclude = not _pat_check(path_str, exclude_line)
428                if not retchk_exclude:
429                    break
430        else:
431            retchk_exclude = not _pat_check(path_str, exclude_pat)
432
433    # Now apply include/exclude conditions
434    if include_pat and not exclude_pat:
435        ret = retchk_include
436    elif exclude_pat and not include_pat:
437        ret = retchk_exclude
438    elif include_pat and exclude_pat:
439        ret = retchk_include and retchk_exclude
440    else:
441        ret = True
442
443    return ret
444
445
446def print_cli(msg, retries=10, step=0.01):
447    """
448    Wrapper around print() that suppresses tracebacks on broken pipes (i.e.
449    when salt output is piped to less and less is stopped prematurely).
450    """
451    while retries:
452        try:
453            try:
454                print(msg)
455            except UnicodeEncodeError:
456                print(msg.encode("utf-8"))
457        except OSError as exc:
458            err = "{}".format(exc)
459            if exc.errno != errno.EPIPE:
460                if (
461                    "temporarily unavailable" in err or exc.errno in (errno.EAGAIN,)
462                ) and retries:
463                    time.sleep(step)
464                    retries -= 1
465                    continue
466                else:
467                    raise
468        break
469
470
471def get_context(template, line, num_lines=5, marker=None):
472    """
473    Returns debugging context around a line in a given string
474
475    Returns:: string
476    """
477    template_lines = template.splitlines()
478    num_template_lines = len(template_lines)
479
480    # In test mode, a single line template would return a crazy line number like,
481    # 357. Do this sanity check and if the given line is obviously wrong, just
482    # return the entire template
483    if line > num_template_lines:
484        return template
485
486    context_start = max(0, line - num_lines - 1)  # subt 1 for 0-based indexing
487    context_end = min(num_template_lines, line + num_lines)
488    error_line_in_context = line - context_start - 1  # subtr 1 for 0-based idx
489
490    buf = []
491    if context_start > 0:
492        buf.append("[...]")
493        error_line_in_context += 1
494
495    buf.extend(template_lines[context_start:context_end])
496
497    if context_end < num_template_lines:
498        buf.append("[...]")
499
500    if marker:
501        buf[error_line_in_context] += marker
502
503    return "---\n{}\n---".format("\n".join(buf))
504
505
506def get_diff(a, b, *args, **kwargs):
507    """
508    Perform diff on two iterables containing lines from two files, and return
509    the diff as as string. Lines are normalized to str types to avoid issues
510    with unicode on PY2.
511    """
512    encoding = ("utf-8", "latin-1", __salt_system_encoding__)
513    # Late import to avoid circular import
514    import salt.utils.data
515
516    return "".join(
517        difflib.unified_diff(
518            salt.utils.data.decode_list(a, encoding=encoding),
519            salt.utils.data.decode_list(b, encoding=encoding),
520            *args,
521            **kwargs
522        )
523    )
524
525
526@jinja_filter("to_snake_case")
527def camel_to_snake_case(camel_input):
528    """
529    Converts camelCase (or CamelCase) to snake_case.
530    From https://codereview.stackexchange.com/questions/185966/functions-to-convert-camelcase-strings-to-snake-case
531
532    :param str camel_input: The camelcase or CamelCase string to convert to snake_case
533
534    :return str
535    """
536    res = camel_input[0].lower()
537    for i, letter in enumerate(camel_input[1:], 1):
538        if letter.isupper():
539            if camel_input[i - 1].islower() or (
540                i != len(camel_input) - 1 and camel_input[i + 1].islower()
541            ):
542                res += "_"
543        res += letter.lower()
544    return res
545
546
547@jinja_filter("to_camelcase")
548def snake_to_camel_case(snake_input, uppercamel=False):
549    """
550    Converts snake_case to camelCase (or CamelCase if uppercamel is ``True``).
551    Inspired by https://codereview.stackexchange.com/questions/85311/transform-snake-case-to-camelcase
552
553    :param str snake_input: The input snake_case string to convert to camelCase
554    :param bool uppercamel: Whether or not to convert to CamelCase instead
555
556    :return str
557    """
558    words = snake_input.split("_")
559    if uppercamel:
560        words[0] = words[0].capitalize()
561    return words[0] + "".join(word.capitalize() for word in words[1:])
562