1import logging
2import math
3import os
4import pathlib
5import re
6import sys
7from contextlib import contextmanager
8from functools import partial
9from hashlib import md5
10from urllib.parse import urlsplit
11
12DEFAULT_BLOCK_SIZE = 5 * 2 ** 20
13PY36 = sys.version_info < (3, 7)
14
15
16def infer_storage_options(urlpath, inherit_storage_options=None):
17    """Infer storage options from URL path and merge it with existing storage
18    options.
19
20    Parameters
21    ----------
22    urlpath: str or unicode
23        Either local absolute file path or URL (hdfs://namenode:8020/file.csv)
24    inherit_storage_options: dict (optional)
25        Its contents will get merged with the inferred information from the
26        given path
27
28    Returns
29    -------
30    Storage options dict.
31
32    Examples
33    --------
34    >>> infer_storage_options('/mnt/datasets/test.csv')  # doctest: +SKIP
35    {"protocol": "file", "path", "/mnt/datasets/test.csv"}
36    >>> infer_storage_options(
37    ...     'hdfs://username:pwd@node:123/mnt/datasets/test.csv?q=1',
38    ...     inherit_storage_options={'extra': 'value'},
39    ... )  # doctest: +SKIP
40    {"protocol": "hdfs", "username": "username", "password": "pwd",
41    "host": "node", "port": 123, "path": "/mnt/datasets/test.csv",
42    "url_query": "q=1", "extra": "value"}
43    """
44    # Handle Windows paths including disk name in this special case
45    if (
46        re.match(r"^[a-zA-Z]:[\\/]", urlpath)
47        or re.match(r"^[a-zA-Z0-9]+://", urlpath) is None
48    ):
49        return {"protocol": "file", "path": urlpath}
50
51    parsed_path = urlsplit(urlpath)
52    protocol = parsed_path.scheme or "file"
53    if parsed_path.fragment:
54        path = "#".join([parsed_path.path, parsed_path.fragment])
55    else:
56        path = parsed_path.path
57    if protocol == "file":
58        # Special case parsing file protocol URL on Windows according to:
59        # https://msdn.microsoft.com/en-us/library/jj710207.aspx
60        windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path)
61        if windows_path:
62            path = "%s:%s" % windows_path.groups()
63
64    if protocol in ["http", "https"]:
65        # for HTTP, we don't want to parse, as requests will anyway
66        return {"protocol": protocol, "path": urlpath}
67
68    options = {"protocol": protocol, "path": path}
69
70    if parsed_path.netloc:
71        # Parse `hostname` from netloc manually because `parsed_path.hostname`
72        # lowercases the hostname which is not always desirable (e.g. in S3):
73        # https://github.com/dask/dask/issues/1417
74        options["host"] = parsed_path.netloc.rsplit("@", 1)[-1].rsplit(":", 1)[0]
75
76        if protocol in ("s3", "s3a", "gcs", "gs"):
77            options["path"] = options["host"] + options["path"]
78        else:
79            options["host"] = options["host"]
80        if parsed_path.port:
81            options["port"] = parsed_path.port
82        if parsed_path.username:
83            options["username"] = parsed_path.username
84        if parsed_path.password:
85            options["password"] = parsed_path.password
86
87    if parsed_path.query:
88        options["url_query"] = parsed_path.query
89    if parsed_path.fragment:
90        options["url_fragment"] = parsed_path.fragment
91
92    if inherit_storage_options:
93        update_storage_options(options, inherit_storage_options)
94
95    return options
96
97
98def update_storage_options(options, inherited=None):
99    if not inherited:
100        inherited = {}
101    collisions = set(options) & set(inherited)
102    if collisions:
103        for collision in collisions:
104            if options.get(collision) != inherited.get(collision):
105                raise KeyError(
106                    "Collision between inferred and specified storage "
107                    "option:\n%s" % collision
108                )
109    options.update(inherited)
110
111
112# Compression extensions registered via fsspec.compression.register_compression
113compressions = {}
114
115
116def infer_compression(filename):
117    """Infer compression, if available, from filename.
118
119    Infer a named compression type, if registered and available, from filename
120    extension. This includes builtin (gz, bz2, zip) compressions, as well as
121    optional compressions. See fsspec.compression.register_compression.
122    """
123    extension = os.path.splitext(filename)[-1].strip(".").lower()
124    if extension in compressions:
125        return compressions[extension]
126
127
128def build_name_function(max_int):
129    """Returns a function that receives a single integer
130    and returns it as a string padded by enough zero characters
131    to align with maximum possible integer
132
133    >>> name_f = build_name_function(57)
134
135    >>> name_f(7)
136    '07'
137    >>> name_f(31)
138    '31'
139    >>> build_name_function(1000)(42)
140    '0042'
141    >>> build_name_function(999)(42)
142    '042'
143    >>> build_name_function(0)(0)
144    '0'
145    """
146    # handle corner cases max_int is 0 or exact power of 10
147    max_int += 1e-8
148
149    pad_length = int(math.ceil(math.log10(max_int)))
150
151    def name_function(i):
152        return str(i).zfill(pad_length)
153
154    return name_function
155
156
157def seek_delimiter(file, delimiter, blocksize):
158    r"""Seek current file to file start, file end, or byte after delimiter seq.
159
160    Seeks file to next chunk delimiter, where chunks are defined on file start,
161    a delimiting sequence, and file end. Use file.tell() to see location afterwards.
162    Note that file start is a valid split, so must be at offset > 0 to seek for
163    delimiter.
164
165    Parameters
166    ----------
167    file: a file
168    delimiter: bytes
169        a delimiter like ``b'\n'`` or message sentinel, matching file .read() type
170    blocksize: int
171        Number of bytes to read from the file at once.
172
173
174    Returns
175    -------
176    Returns True if a delimiter was found, False if at file start or end.
177
178    """
179
180    if file.tell() == 0:
181        # beginning-of-file, return without seek
182        return False
183
184    # Interface is for binary IO, with delimiter as bytes, but initialize last
185    # with result of file.read to preserve compatibility with text IO.
186    last = None
187    while True:
188        current = file.read(blocksize)
189        if not current:
190            # end-of-file without delimiter
191            return False
192        full = last + current if last else current
193        try:
194            if delimiter in full:
195                i = full.index(delimiter)
196                file.seek(file.tell() - (len(full) - i) + len(delimiter))
197                return True
198            elif len(current) < blocksize:
199                # end-of-file without delimiter
200                return False
201        except (OSError, ValueError):
202            pass
203        last = full[-len(delimiter) :]
204
205
206def read_block(f, offset, length, delimiter=None, split_before=False):
207    """Read a block of bytes from a file
208
209    Parameters
210    ----------
211    f: File
212        Open file
213    offset: int
214        Byte offset to start read
215    length: int
216        Number of bytes to read, read through end of file if None
217    delimiter: bytes (optional)
218        Ensure reading starts and stops at delimiter bytestring
219    split_before: bool (optional)
220        Start/stop read *before* delimiter bytestring.
221
222
223    If using the ``delimiter=`` keyword argument we ensure that the read
224    starts and stops at delimiter boundaries that follow the locations
225    ``offset`` and ``offset + length``.  If ``offset`` is zero then we
226    start at zero, regardless of delimiter.  The bytestring returned WILL
227    include the terminating delimiter string.
228
229    Examples
230    --------
231
232    >>> from io import BytesIO  # doctest: +SKIP
233    >>> f = BytesIO(b'Alice, 100\\nBob, 200\\nCharlie, 300')  # doctest: +SKIP
234    >>> read_block(f, 0, 13)  # doctest: +SKIP
235    b'Alice, 100\\nBo'
236
237    >>> read_block(f, 0, 13, delimiter=b'\\n')  # doctest: +SKIP
238    b'Alice, 100\\nBob, 200\\n'
239
240    >>> read_block(f, 10, 10, delimiter=b'\\n')  # doctest: +SKIP
241    b'Bob, 200\\nCharlie, 300'
242    """
243    if delimiter:
244        f.seek(offset)
245        found_start_delim = seek_delimiter(f, delimiter, 2 ** 16)
246        if length is None:
247            return f.read()
248        start = f.tell()
249        length -= start - offset
250
251        f.seek(start + length)
252        found_end_delim = seek_delimiter(f, delimiter, 2 ** 16)
253        end = f.tell()
254
255        # Adjust split location to before delimiter iff seek found the
256        # delimiter sequence, not start or end of file.
257        if found_start_delim and split_before:
258            start -= len(delimiter)
259
260        if found_end_delim and split_before:
261            end -= len(delimiter)
262
263        offset = start
264        length = end - start
265
266    f.seek(offset)
267    b = f.read(length)
268    return b
269
270
271def tokenize(*args, **kwargs):
272    """Deterministic token
273
274    (modified from dask.base)
275
276    >>> tokenize([1, 2, '3'])
277    '9d71491b50023b06fc76928e6eddb952'
278
279    >>> tokenize('Hello') == tokenize('Hello')
280    True
281    """
282    if kwargs:
283        args += (kwargs,)
284    return md5(str(args).encode()).hexdigest()
285
286
287def stringify_path(filepath):
288    """Attempt to convert a path-like object to a string.
289
290    Parameters
291    ----------
292    filepath: object to be converted
293
294    Returns
295    -------
296    filepath_str: maybe a string version of the object
297
298    Notes
299    -----
300    Objects supporting the fspath protocol (Python 3.6+) are coerced
301    according to its __fspath__ method.
302
303    For backwards compatibility with older Python version, pathlib.Path
304    objects are specially coerced.
305
306    Any other object is passed through unchanged, which includes bytes,
307    strings, buffers, or anything else that's not even path-like.
308    """
309    if isinstance(filepath, str):
310        return filepath
311    elif hasattr(filepath, "__fspath__"):
312        return filepath.__fspath__()
313    elif isinstance(filepath, pathlib.Path):
314        return str(filepath)
315    elif hasattr(filepath, "path"):
316        return filepath.path
317    else:
318        return filepath
319
320
321def make_instance(cls, args, kwargs):
322    inst = cls(*args, **kwargs)
323    inst._determine_worker()
324    return inst
325
326
327def common_prefix(paths):
328    """For a list of paths, find the shortest prefix common to all"""
329    parts = [p.split("/") for p in paths]
330    lmax = min(len(p) for p in parts)
331    end = 0
332    for i in range(lmax):
333        end = all(p[i] == parts[0][i] for p in parts)
334        if not end:
335            break
336    i += end
337    return "/".join(parts[0][:i])
338
339
340def other_paths(paths, path2, is_dir=None, exists=False):
341    """In bulk file operations, construct a new file tree from a list of files
342
343    Parameters
344    ----------
345    paths: list of str
346        The input file tree
347    path2: str or list of str
348        Root to construct the new list in. If this is already a list of str, we just
349        assert it has the right number of elements.
350    is_dir: bool (optional)
351        For the special case where the input in one element, whether to regard the value
352        as the target path, or as a directory to put a file path within. If None, a
353        directory is inferred if the path ends in '/'
354    exists: bool (optional)
355        For a str destination, it is already exists (and is a dir), files should
356        end up inside.
357
358    Returns
359    -------
360    list of str
361    """
362    if isinstance(path2, str):
363        is_dir = is_dir or path2.endswith("/")
364        path2 = path2.rstrip("/")
365        if len(paths) > 1:
366            cp = common_prefix(paths)
367            if exists:
368                cp = cp.rsplit("/", 1)[0]
369            path2 = [p.replace(cp, path2, 1) for p in paths]
370        else:
371            if is_dir:
372                path2 = [path2.rstrip("/") + "/" + paths[0].rsplit("/")[-1]]
373            else:
374                path2 = [path2]
375    else:
376        assert len(paths) == len(path2)
377    return path2
378
379
380def is_exception(obj):
381    return isinstance(obj, BaseException)
382
383
384def get_protocol(url):
385    parts = re.split(r"(\:\:|\://)", url, 1)
386    if len(parts) > 1:
387        return parts[0]
388    return "file"
389
390
391def can_be_local(path):
392    """Can the given URL be used with open_local?"""
393    from fsspec import get_filesystem_class
394
395    try:
396        return getattr(get_filesystem_class(get_protocol(path)), "local_file", False)
397    except (ValueError, ImportError):
398        # not in registry or import failed
399        return False
400
401
402def get_package_version_without_import(name):
403    """For given package name, try to find the version without importing it
404
405    Import and package.__version__ is still the backup here, so an import
406    *might* happen.
407
408    Returns either the version string, or None if the package
409    or the version was not readily  found.
410    """
411    if name in sys.modules:
412        mod = sys.modules[name]
413        if hasattr(mod, "__version__"):
414            return mod.__version__
415    if sys.version_info >= (3, 8):
416        try:
417            import importlib.metadata
418
419            return importlib.metadata.distribution(name).version
420        except:  # noqa: E722
421            pass
422    else:
423        try:
424            import importlib_metadata
425
426            return importlib_metadata.distribution(name).version
427        except:  # noqa: E722
428            pass
429    try:
430        import importlib
431
432        mod = importlib.import_module(name)
433        return mod.__version__
434    except (ImportError, AttributeError):
435        return None
436
437
438def setup_logging(logger=None, logger_name=None, level="DEBUG", clear=True):
439    if logger is None and logger_name is None:
440        raise ValueError("Provide either logger object or logger name")
441    logger = logger or logging.getLogger(logger_name)
442    handle = logging.StreamHandler()
443    formatter = logging.Formatter(
444        "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s -- %(message)s"
445    )
446    handle.setFormatter(formatter)
447    if clear:
448        logger.handlers.clear()
449    logger.addHandler(handle)
450    logger.setLevel(level)
451    return logger
452
453
454def _unstrip_protocol(name, fs):
455    if isinstance(fs.protocol, str):
456        if name.startswith(fs.protocol):
457            return name
458        return fs.protocol + "://" + name
459    else:
460        if name.startswith(tuple(fs.protocol)):
461            return name
462        return fs.protocol[0] + "://" + name
463
464
465def mirror_from(origin_name, methods):
466    """Mirror attributes and methods from the given
467    origin_name attribute of the instance to the
468    decorated class"""
469
470    def origin_getter(method, self):
471        origin = getattr(self, origin_name)
472        return getattr(origin, method)
473
474    def wrapper(cls):
475        for method in methods:
476            wrapped_method = partial(origin_getter, method)
477            setattr(cls, method, property(wrapped_method))
478        return cls
479
480    return wrapper
481
482
483@contextmanager
484def nullcontext(obj):
485    yield obj
486
487
488def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=True):
489    """Merge adjacent byte-offset ranges when the inter-range
490    gap is <= `max_gap`, and when the merged byte range does not
491    exceed `max_block` (if specified). By default, this function
492    will re-order the input paths and byte ranges to ensure sorted
493    order. If the user can guarantee that the inputs are already
494    sorted, passing `sort=False` will skip the re-ordering.
495    """
496
497    # Check input
498    if not isinstance(paths, list):
499        raise TypeError
500    if not isinstance(starts, list):
501        starts = [starts] * len(paths)
502    if not isinstance(ends, list):
503        ends = [starts] * len(paths)
504    if len(starts) != len(paths) or len(ends) != len(paths):
505        raise ValueError
506
507    # Early Return
508    if len(starts) <= 1:
509        return paths, starts, ends
510
511    # Sort by paths and then ranges if `sort=True`
512    if sort:
513        paths, starts, ends = [list(v) for v in zip(*sorted(zip(paths, starts, ends)))]
514
515    if paths:
516        # Loop through the coupled `paths`, `starts`, and
517        # `ends`, and merge adjacent blocks when appropriate
518        new_paths = paths[:1]
519        new_starts = starts[:1]
520        new_ends = ends[:1]
521        for i in range(1, len(paths)):
522            if (
523                paths[i] != paths[i - 1]
524                or ((starts[i] - new_ends[-1]) > max_gap)
525                or ((max_block is not None and (ends[i] - new_starts[-1]) > max_block))
526            ):
527                # Cannot merge with previous block.
528                # Add new `paths`, `starts`, and `ends` elements
529                new_paths.append(paths[i])
530                new_starts.append(starts[i])
531                new_ends.append(ends[i])
532            else:
533                # Merge with previous block by updating the
534                # last element of `ends`
535                new_ends[-1] = ends[i]
536        return new_paths, new_starts, new_ends
537
538    # `paths` is empty. Just return input lists
539    return paths, starts, ends
540