1from datalad.support.exceptions import IncompleteResultsError, CommandError
2from filelock import Timeout, FileLock
3from multiprocessing import Process
4from pathlib import Path
5from time import sleep
6from typing import Union
7from unittest.mock import Mock
8import attr
9import datalad.api as datalad
10import datetime as dt
11import logging
12import os
13import pytest
14import random
15import shutil
16import tempfile
17
18from afni_test_utils import misc, tools
19from afni_test_utils.tools import get_current_test_name
20
21DATA_FETCH_LOCK_PATH = Path(tempfile.gettempdir()) / "afni_tests_data.lock"
22dl_lock = FileLock(DATA_FETCH_LOCK_PATH, timeout=300)
23
24
25def get_test_data_path(config_obj):
26    if hasattr(config_obj, "rootdir"):
27        return Path(config_obj.rootdir) / "afni_ci_test_data"
28    elif hasattr(config_obj, "config"):
29        return Path(config_obj.config.rootdir) / "afni_ci_test_data"
30    else:
31        raise ValueError("A pytest config object was expected")
32
33
34def get_tests_data_dir(config_obj):
35    """Get the path to the test data directory. If the test data directory
36    does not exist or is not populated, install with datalad.
37    """
38    logger = logging.getLogger("Test data setup")
39
40    tests_data_dir = get_test_data_path(config_obj)
41
42    # remote should be configured or something is badly amiss...
43    dl_dset = datalad.Dataset(str(tests_data_dir))
44    if (
45        dl_dset.is_installed()
46        and "remote.afni_ci_test_data.url" not in dl_dset.config.keys()
47    ):
48        for f in dl_dset.pathobj.glob("**/*"):
49            try:
50                f.chmod(0o700)
51            except FileNotFoundError:
52                # missing symlink, nothing to worry about
53                pass
54        logger.warn("Not sure about test data, perhaps you should try removing...")
55        raise ValueError("Not sure about test data, perhaps you should try removing...")
56        # shutil.rmtree(dl_dset.pathobj)
57
58    # datalad is required and the datalad repository is used for data.
59    if not (tests_data_dir / ".datalad").exists():
60        try:
61            global dl_lock
62            dl_lock.acquire()
63            if not (tests_data_dir / ".datalad").exists():
64                logger.warn("Installing test data")
65                datalad.install(
66                    str(tests_data_dir),
67                    "https://github.com/afni/afni_ci_test_data.git",
68                    recursive=True,
69                    on_failure="stop",
70                )
71        finally:
72            dl_lock.release()
73    # Needs to be user writeable:
74    some_files = [".git/logs/HEAD"]
75    for f in some_files:
76        data_file = tests_data_dir / f
77        if not data_file.exists():
78            raise ValueError(
79                f"{f} does not exist (parent existences: {f.parent.exists()}"
80            )
81        if not os.access(data_file, os.W_OK):
82            raise ValueError(f"{f} is not user writeable ({os.getuid()})")
83    return tests_data_dir
84
85
86def save_output_to_repo(config_obj):
87    base_comparison_dir_path = get_base_comparison_dir_path(config_obj)
88
89    update_msg = "Update data with test run on {d}".format(
90        d=dt.datetime.today().strftime("%Y-%m-%d")
91    )
92
93    result = datalad.save(update_msg, str(base_comparison_dir_path), on_failure="stop")
94
95    sample_test_output = get_test_data_path() / "sample_test_output"
96    data_message = (
97        "New sample output was saved to {sample_test_output} for "
98        "future comparisons. Consider publishing this new data to "
99        "the publicly accessible servers.. "
100    )
101    print(data_message.format(**locals()))
102
103
104def get_test_comparison_dir_path(base_comparison_dir_path, mod: Union[str or Path]):
105    """Get full path full comparison directory for a specific test"""
106    return base_comparison_dir_path / mod.name / get_current_test_name()
107
108
109def get_base_comparison_dir_path(config_obj):
110    """If the user does not provide a comparison directory a default in the
111    test data directory is used. The user can specify a directory containing
112    the output of a previous test run or the "sample" output that is created
113    by a previous test run when the "--create-sample-output" flag was provided.
114    """
115    comparison_dir = config_obj.getoption("--diff-with-sample")
116    if comparison_dir is not None:
117        return Path(comparison_dir).absolute()
118    else:
119        return get_test_data_path(config_obj) / "sample_test_output"
120
121
122def get_data_fixture(pytestconfig, request, output_dir):
123    """A function-scoped test fixture used for AFNI's testing. The fixture
124    sets up output directories as required and provides the named tuple "data"
125    to the calling function. The data object contains some fields convenient
126    for writing tests like the output directory. Finally the data fixture
127    handles test input data.files  listed in a data_paths dictionary (if
128    defined within the test module) the fixture will download them to a local
129    datalad repository as required. Paths should be listed relative to the
130    repository base-directory.
131
132    Args: request (pytest.fixture): A function level pytest request object
133        providing information about the calling test function.
134
135    Returns:
136        collections.NameTuple: A data object for conveniently handling the specification
137    """
138    test_name = get_current_test_name()
139    tests_data_dir = get_test_data_path(pytestconfig)
140
141    current_test_module = Path(request.module.__file__)
142    module_outdir = output_dir / current_test_module.stem.replace("test_", "")
143    test_logdir = module_outdir / get_current_test_name() / "captured_output"
144    if not test_logdir.exists():
145        os.makedirs(test_logdir, exist_ok=True)
146
147    # Add stream and file logging as requested
148    logger = logging.getLogger(test_name)
149    logger = tools.logger_config(
150        logger,
151        file=output_dir / "all_tests.log",
152        log_file_level=pytestconfig.getoption("--log-file-level"),
153    )
154
155    # Set module specific values:
156    try:
157        data_paths = request.module.data_paths
158    except AttributeError:
159        data_paths = {}
160    # start creating output dict, downloading test data as required
161    out_dict = {
162        k: process_path_obj(v, tests_data_dir, logger) for k, v in data_paths.items()
163    }
164
165    # This will be created as required later
166    sampdir = convert_to_sample_dir_path(test_logdir.parent)
167
168    # Get the comparison directory and check if it needs to be downloaded
169    base_comparison_dir_path = get_base_comparison_dir_path(pytestconfig)
170    comparison_dir = get_test_comparison_dir_path(
171        base_comparison_dir_path, module_outdir
172    )
173    # Define output for calling module and get data as required:
174    out_dict.update(
175        {
176            "module_outdir": module_outdir,
177            "logger": logger,
178            "current_test_module": current_test_module,
179            "outdir": module_outdir / get_current_test_name(),
180            "sampdir": sampdir,
181            "logdir": test_logdir,
182            "comparison_dir": comparison_dir,
183            "base_comparison_dir": base_comparison_dir_path,
184            "base_outdir": output_dir,
185            "tests_data_dir": tests_data_dir,
186            "test_name": test_name,
187            "rootdir": pytestconfig.rootdir,
188            "create_sample_output": pytestconfig.getoption("--create-sample-output"),
189            "save_sample_output": pytestconfig.getoption("--save-sample-output"),
190        }
191    )
192
193    DataClass = attr.make_class(
194        test_name + "_data", [k for k in out_dict.keys()], slots=True
195    )
196    return DataClass(*[v for v in out_dict.values()])
197
198
199def check_file_exists(file_path, test_data_dir):
200    full_path = test_data_dir / file_path
201    no_file_error = (
202        f"Could not find {full_path}. You have specified the path "
203        f"{file_path} for an input datafile but this path does not exist "
204        "in the test data directory that has been created in "
205        f"{test_data_dir} "
206    )
207
208    if not (full_path.exists() or full_path.is_symlink()):
209        if "sample_test_output" in full_path.parts:
210            raise ValueError(
211                "Cannot specify input data that is located in the "
212                "sample_test_output directory. "
213            )
214
215        else:
216            raise ValueError(no_file_error)
217
218
219def generate_fetch_list(path_obj, test_data_dir):
220    """Provided with path_obj, a list of pathlib.Path objects, resolves to a
221    list containing 1 or more pathlib.Path objects.
222
223    Args:
224        path_obj (TYPE): may be a list of paths, a path, or a path that
225    contains a glob pattern at the end
226        test_data_dir (TYPE): Description
227
228    Returns:
229        List: List of paths as str type (including HEAD files if BRIK is used)
230        Bool: needs_fetching, True if all data has not been downloaded
231
232    Raises:
233        TypeError: Description
234    """
235    needs_fetching = False
236    fetch_list = []
237    for p in path_obj:
238        with_partners = add_partner_files(test_data_dir, p)
239        for pp in with_partners:
240            # fetch if any file does not "exist" (is a broken symlink)
241            needs_fetching = needs_fetching or not (test_data_dir / pp).exists()
242        fetch_list += with_partners
243
244    return [str(f) for f in fetch_list], needs_fetching
245
246
247def glob_if_necessary(test_data_dir, path_obj):
248    """
249    Check that path/paths exist in test_data_dir. Paths  may be a
250    glob, so tries globbing before raising an error that it doesn't exist. Return
251    the list of paths.
252    """
253    if type(path_obj) == str:
254        path_obj = [Path(path_obj)]
255    elif isinstance(path_obj, Path):
256        path_obj = [path_obj]
257    elif iter(path_obj):
258        path_obj = [Path(p) for p in path_obj]
259    else:
260        raise TypeError(
261            "data_paths must contain paths (values that are of type str, pathlib.Path) or a "
262            "non-str iterable type containing paths. i.e. list, tuple... "
263        )
264
265    outfiles = []
266
267    for file_in in path_obj:
268
269        try:
270            # file should be found even if just as an unresolved symlink
271            check_file_exists(file_in, test_data_dir)
272            outfiles.append(file_in)
273        except ValueError as e:
274            outfiles += [f for f in (test_data_dir / file_in.parent).glob(file_in.name)]
275            if not outfiles:
276                raise e
277
278    return outfiles
279
280
281def add_partner_files(test_data_dir, path_in):
282    """
283    If the path is a brikor a head file the pair is returned for the purposes
284    of fetching the data via datalad
285    """
286    try:
287        from afnipy import afni_base as ab
288    except ImportError:
289        ab = misc.try_to_import_afni_module("afni_base")
290    files_out = [path_in]
291    brik_pats = [".HEAD", ".BRIK"]
292    if any(pat in path_in.name for pat in brik_pats):
293        parsed_obj = ab.parse_afni_name(str(test_data_dir / path_in))
294        if parsed_obj["type"] == "BRIK":
295            globbed = Path(parsed_obj["path"]).glob(parsed_obj["prefix"] + "*")
296            files_out += list(globbed)
297            files_out = list(set(files_out))
298
299    return files_out
300
301
302def process_path_obj(path_obj, test_data_dir, logger=None):
303    """
304    This function is used to process paths that have been defined in the
305    data_paths dictionary of test modules. Globs are resolved, and the data is
306    fetched using datalad. If HEAD files are provided, the corresponding BRIK
307    files are also downloaded.
308
309    Args: path_obj (str/pathlib.Path or iterable): Paths as
310        strings/pathlib.Path  or non-str iterables with elements of these
311        types can be passed as arguments for conversion to Path objects.
312        Globbing at the final element of the path is also supported and will
313        be resolved before being returned.
314
315        test_data_dir (pathlib.Path): An existing datalad repository
316        containing the test data.
317    Returns:
318
319        Iterable of Paths: Single pathlib.Path object or list of pathlib Paths
320        fetched as required.
321    """
322
323    # Resolve all globs and return a list of pathlib objects
324    path_obj = glob_if_necessary(test_data_dir, path_obj)
325    # Search for any files that might be missing eg HEAD for a BRIK
326    files_to_fetch, needs_fetching = generate_fetch_list(path_obj, test_data_dir)
327
328    # Fetching the data
329    if needs_fetching:
330        attempt_count = 0
331        while attempt_count < 2:
332            # fetch data with a global dl_lock
333            fetch_status = try_data_download(files_to_fetch, test_data_dir, logger)
334            if fetch_status:
335                break
336            else:
337                attempt_count += 1
338        else:
339            # datalad download attempts failed
340            pytest.exit(
341                f"Datalad download failed {attempt_count} times, you may "
342                "not be connected to the internet "
343            )
344        logger.info(f"Downloaded data for {test_data_dir}")
345    path_obj = [test_data_dir / p for p in path_obj]
346    if len(path_obj) == 1:
347        return path_obj[0]
348    else:
349        return path_obj
350
351
352def try_data_download(file_fetch_list, test_data_dir, logger):
353    try:
354        global dl_lock
355        dl_lock.acquire(poll_intervall=1)
356        dl_dset = datalad.Dataset(str(test_data_dir))
357        # Fetching the data
358        process_for_fetching_data = Process(
359            target=dl_dset.get, kwargs={"path": [str(p) for p in file_fetch_list]}
360        )
361
362        # attempts should be timed-out to deal with unpredictable stalls.
363        process_for_fetching_data.start()
364        # logger.debug(f"Fetching data for {test_data_dir}")
365        process_for_fetching_data.join(timeout=60)
366        if process_for_fetching_data.is_alive():
367            # terminate the process.
368            process_for_fetching_data.terminate()
369            # logger.warn(f"Data fetching timed out for {file_fetch_list}")
370            return False
371        elif process_for_fetching_data.exitcode != 0:
372            # logger.warn(f"Data fetching failed for {file_fetch_list}")
373            return False
374        else:
375            return True
376    except (
377        IncompleteResultsError,
378        ValueError,
379        CommandError,
380        TimeoutError,
381        Timeout,
382    ) as err:
383        logger.warn(
384            f"Datalad download failure ({type(err)}) for {test_data_dir}. Will try again"
385        )
386
387        return False
388
389    finally:
390        dl_lock.release()
391        sleep(random.randint(1, 10))
392
393
394def convert_to_sample_dir_path(output_dir):
395    sampdir = Path(str(output_dir).replace("output_", "sample_output_"))
396    return sampdir
397