1from unittest import mock
2
3import pkg_resources
4import pytest
5
6from xarray.backends import common, plugins
7
8
9class DummyBackendEntrypointArgs(common.BackendEntrypoint):
10    def open_dataset(filename_or_obj, *args):
11        pass
12
13
14class DummyBackendEntrypointKwargs(common.BackendEntrypoint):
15    def open_dataset(filename_or_obj, **kwargs):
16        pass
17
18
19class DummyBackendEntrypoint1(common.BackendEntrypoint):
20    def open_dataset(self, filename_or_obj, *, decoder):
21        pass
22
23
24class DummyBackendEntrypoint2(common.BackendEntrypoint):
25    def open_dataset(self, filename_or_obj, *, decoder):
26        pass
27
28
29@pytest.fixture
30def dummy_duplicated_entrypoints():
31    specs = [
32        "engine1 = xarray.tests.test_plugins:backend_1",
33        "engine1 = xarray.tests.test_plugins:backend_2",
34        "engine2 = xarray.tests.test_plugins:backend_1",
35        "engine2 = xarray.tests.test_plugins:backend_2",
36    ]
37    eps = [pkg_resources.EntryPoint.parse(spec) for spec in specs]
38    return eps
39
40
41@pytest.mark.filterwarnings("ignore:Found")
42def test_remove_duplicates(dummy_duplicated_entrypoints) -> None:
43    with pytest.warns(RuntimeWarning):
44        entrypoints = plugins.remove_duplicates(dummy_duplicated_entrypoints)
45    assert len(entrypoints) == 2
46
47
48def test_broken_plugin() -> None:
49    broken_backend = pkg_resources.EntryPoint.parse(
50        "broken_backend = xarray.tests.test_plugins:backend_1"
51    )
52    with pytest.warns(RuntimeWarning) as record:
53        _ = plugins.build_engines([broken_backend])
54    assert len(record) == 1
55    message = str(record[0].message)
56    assert "Engine 'broken_backend'" in message
57
58
59def test_remove_duplicates_warnings(dummy_duplicated_entrypoints) -> None:
60
61    with pytest.warns(RuntimeWarning) as record:
62        _ = plugins.remove_duplicates(dummy_duplicated_entrypoints)
63
64    assert len(record) == 2
65    message0 = str(record[0].message)
66    message1 = str(record[1].message)
67    assert "entrypoints" in message0
68    assert "entrypoints" in message1
69
70
71@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=None))
72def test_backends_dict_from_pkg() -> None:
73    specs = [
74        "engine1 = xarray.tests.test_plugins:backend_1",
75        "engine2 = xarray.tests.test_plugins:backend_2",
76    ]
77    entrypoints = [pkg_resources.EntryPoint.parse(spec) for spec in specs]
78    engines = plugins.backends_dict_from_pkg(entrypoints)
79    assert len(engines) == 2
80    assert engines.keys() == set(("engine1", "engine2"))
81
82
83def test_set_missing_parameters() -> None:
84    backend_1 = DummyBackendEntrypoint1
85    backend_2 = DummyBackendEntrypoint2
86    backend_2.open_dataset_parameters = ("filename_or_obj",)
87    engines = {"engine_1": backend_1, "engine_2": backend_2}
88    plugins.set_missing_parameters(engines)
89
90    assert len(engines) == 2
91    assert backend_1.open_dataset_parameters == ("filename_or_obj", "decoder")
92    assert backend_2.open_dataset_parameters == ("filename_or_obj",)
93
94    backend = DummyBackendEntrypointKwargs()
95    backend.open_dataset_parameters = ("filename_or_obj", "decoder")
96    plugins.set_missing_parameters({"engine": backend})
97    assert backend.open_dataset_parameters == ("filename_or_obj", "decoder")
98
99    backend_args = DummyBackendEntrypointArgs()
100    backend_args.open_dataset_parameters = ("filename_or_obj", "decoder")
101    plugins.set_missing_parameters({"engine": backend_args})
102    assert backend_args.open_dataset_parameters == ("filename_or_obj", "decoder")
103
104
105def test_set_missing_parameters_raise_error() -> None:
106
107    backend = DummyBackendEntrypointKwargs()
108    with pytest.raises(TypeError):
109        plugins.set_missing_parameters({"engine": backend})
110
111    backend_args = DummyBackendEntrypointArgs()
112    with pytest.raises(TypeError):
113        plugins.set_missing_parameters({"engine": backend_args})
114
115
116@mock.patch(
117    "pkg_resources.EntryPoint.load",
118    mock.MagicMock(return_value=DummyBackendEntrypoint1),
119)
120def test_build_engines() -> None:
121    dummy_pkg_entrypoint = pkg_resources.EntryPoint.parse(
122        "cfgrib = xarray.tests.test_plugins:backend_1"
123    )
124    backend_entrypoints = plugins.build_engines([dummy_pkg_entrypoint])
125
126    assert isinstance(backend_entrypoints["cfgrib"], DummyBackendEntrypoint1)
127    assert backend_entrypoints["cfgrib"].open_dataset_parameters == (
128        "filename_or_obj",
129        "decoder",
130    )
131
132
133@mock.patch(
134    "pkg_resources.EntryPoint.load",
135    mock.MagicMock(return_value=DummyBackendEntrypoint1),
136)
137def test_build_engines_sorted() -> None:
138    dummy_pkg_entrypoints = [
139        pkg_resources.EntryPoint.parse(
140            "dummy2 = xarray.tests.test_plugins:backend_1",
141        ),
142        pkg_resources.EntryPoint.parse(
143            "dummy1 = xarray.tests.test_plugins:backend_1",
144        ),
145    ]
146    backend_entrypoints = plugins.build_engines(dummy_pkg_entrypoints)
147    backend_entrypoints = list(backend_entrypoints)
148
149    indices = []
150    for be in plugins.STANDARD_BACKENDS_ORDER:
151        try:
152            index = backend_entrypoints.index(be)
153            backend_entrypoints.pop(index)
154            indices.append(index)
155        except ValueError:
156            pass
157
158    assert set(indices) < {0, -1}
159    assert list(backend_entrypoints) == sorted(backend_entrypoints)
160
161
162@mock.patch(
163    "xarray.backends.plugins.list_engines",
164    mock.MagicMock(return_value={"dummy": DummyBackendEntrypointArgs()}),
165)
166def test_no_matching_engine_found() -> None:
167    with pytest.raises(ValueError, match=r"did not find a match in any"):
168        plugins.guess_engine("not-valid")
169
170    with pytest.raises(ValueError, match=r"found the following matches with the input"):
171        plugins.guess_engine("foo.nc")
172
173
174@mock.patch(
175    "xarray.backends.plugins.list_engines",
176    mock.MagicMock(return_value={}),
177)
178def test_engines_not_installed() -> None:
179    with pytest.raises(ValueError, match=r"xarray is unable to open"):
180        plugins.guess_engine("not-valid")
181
182    with pytest.raises(ValueError, match=r"found the following matches with the input"):
183        plugins.guess_engine("foo.nc")
184