1import pytest
2
3import xarray
4from xarray import concat, merge
5from xarray.backends.file_manager import FILE_CACHE
6from xarray.core.options import OPTIONS, _get_keep_attrs
7from xarray.tests.test_dataset import create_test_data
8
9
10def test_invalid_option_raises() -> None:
11    with pytest.raises(ValueError):
12        xarray.set_options(not_a_valid_options=True)
13
14
15def test_display_width() -> None:
16    with pytest.raises(ValueError):
17        xarray.set_options(display_width=0)
18    with pytest.raises(ValueError):
19        xarray.set_options(display_width=-10)
20    with pytest.raises(ValueError):
21        xarray.set_options(display_width=3.5)
22
23
24def test_arithmetic_join() -> None:
25    with pytest.raises(ValueError):
26        xarray.set_options(arithmetic_join="invalid")
27    with xarray.set_options(arithmetic_join="exact"):
28        assert OPTIONS["arithmetic_join"] == "exact"
29
30
31def test_enable_cftimeindex() -> None:
32    with pytest.raises(ValueError):
33        xarray.set_options(enable_cftimeindex=None)
34    with pytest.warns(FutureWarning, match="no-op"):
35        with xarray.set_options(enable_cftimeindex=True):
36            assert OPTIONS["enable_cftimeindex"]
37
38
39def test_file_cache_maxsize() -> None:
40    with pytest.raises(ValueError):
41        xarray.set_options(file_cache_maxsize=0)
42    original_size = FILE_CACHE.maxsize
43    with xarray.set_options(file_cache_maxsize=123):
44        assert FILE_CACHE.maxsize == 123
45    assert FILE_CACHE.maxsize == original_size
46
47
48def test_keep_attrs() -> None:
49    with pytest.raises(ValueError):
50        xarray.set_options(keep_attrs="invalid_str")
51    with xarray.set_options(keep_attrs=True):
52        assert OPTIONS["keep_attrs"]
53    with xarray.set_options(keep_attrs=False):
54        assert not OPTIONS["keep_attrs"]
55    with xarray.set_options(keep_attrs="default"):
56        assert _get_keep_attrs(default=True)
57        assert not _get_keep_attrs(default=False)
58
59
60def test_nested_options() -> None:
61    original = OPTIONS["display_width"]
62    with xarray.set_options(display_width=1):
63        assert OPTIONS["display_width"] == 1
64        with xarray.set_options(display_width=2):
65            assert OPTIONS["display_width"] == 2
66        assert OPTIONS["display_width"] == 1
67    assert OPTIONS["display_width"] == original
68
69
70def test_display_style() -> None:
71    original = "html"
72    assert OPTIONS["display_style"] == original
73    with pytest.raises(ValueError):
74        xarray.set_options(display_style="invalid_str")
75    with xarray.set_options(display_style="text"):
76        assert OPTIONS["display_style"] == "text"
77    assert OPTIONS["display_style"] == original
78
79
80def create_test_dataset_attrs(seed=0):
81    ds = create_test_data(seed)
82    ds.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}}
83    return ds
84
85
86def create_test_dataarray_attrs(seed=0, var="var1"):
87    da = create_test_data(seed)[var]
88    da.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}}
89    return da
90
91
92class TestAttrRetention:
93    def test_dataset_attr_retention(self) -> None:
94        # Use .mean() for all tests: a typical reduction operation
95        ds = create_test_dataset_attrs()
96        original_attrs = ds.attrs
97
98        # Test default behaviour
99        result = ds.mean()
100        assert result.attrs == {}
101        with xarray.set_options(keep_attrs="default"):
102            result = ds.mean()
103            assert result.attrs == {}
104
105        with xarray.set_options(keep_attrs=True):
106            result = ds.mean()
107            assert result.attrs == original_attrs
108
109        with xarray.set_options(keep_attrs=False):
110            result = ds.mean()
111            assert result.attrs == {}
112
113    def test_dataarray_attr_retention(self) -> None:
114        # Use .mean() for all tests: a typical reduction operation
115        da = create_test_dataarray_attrs()
116        original_attrs = da.attrs
117
118        # Test default behaviour
119        result = da.mean()
120        assert result.attrs == {}
121        with xarray.set_options(keep_attrs="default"):
122            result = da.mean()
123            assert result.attrs == {}
124
125        with xarray.set_options(keep_attrs=True):
126            result = da.mean()
127            assert result.attrs == original_attrs
128
129        with xarray.set_options(keep_attrs=False):
130            result = da.mean()
131            assert result.attrs == {}
132
133    def test_groupby_attr_retention(self) -> None:
134        da = xarray.DataArray([1, 2, 3], [("x", [1, 1, 2])])
135        da.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}}
136        original_attrs = da.attrs
137
138        # Test default behaviour
139        result = da.groupby("x").sum(keep_attrs=True)
140        assert result.attrs == original_attrs
141        with xarray.set_options(keep_attrs="default"):
142            result = da.groupby("x").sum(keep_attrs=True)
143            assert result.attrs == original_attrs
144
145        with xarray.set_options(keep_attrs=True):
146            result1 = da.groupby("x")
147            result = result1.sum()
148            assert result.attrs == original_attrs
149
150        with xarray.set_options(keep_attrs=False):
151            result = da.groupby("x").sum()
152            assert result.attrs == {}
153
154    def test_concat_attr_retention(self) -> None:
155        ds1 = create_test_dataset_attrs()
156        ds2 = create_test_dataset_attrs()
157        ds2.attrs = {"wrong": "attributes"}
158        original_attrs = ds1.attrs
159
160        # Test default behaviour of keeping the attrs of the first
161        # dataset in the supplied list
162        # global keep_attrs option current doesn't affect concat
163        result = concat([ds1, ds2], dim="dim1")
164        assert result.attrs == original_attrs
165
166    @pytest.mark.xfail
167    def test_merge_attr_retention(self) -> None:
168        da1 = create_test_dataarray_attrs(var="var1")
169        da2 = create_test_dataarray_attrs(var="var2")
170        da2.attrs = {"wrong": "attributes"}
171        original_attrs = da1.attrs
172
173        # merge currently discards attrs, and the global keep_attrs
174        # option doesn't affect this
175        result = merge([da1, da2])
176        assert result.attrs == original_attrs
177
178    def test_display_style_text(self) -> None:
179        ds = create_test_dataset_attrs()
180        with xarray.set_options(display_style="text"):
181            text = ds._repr_html_()
182            assert text.startswith("<pre>")
183            assert "&#x27;nested&#x27;" in text
184            assert "&lt;xarray.Dataset&gt;" in text
185
186    def test_display_style_html(self) -> None:
187        ds = create_test_dataset_attrs()
188        with xarray.set_options(display_style="html"):
189            html = ds._repr_html_()
190            assert html.startswith("<div>")
191            assert "&#x27;nested&#x27;" in html
192
193    def test_display_dataarray_style_text(self) -> None:
194        da = create_test_dataarray_attrs()
195        with xarray.set_options(display_style="text"):
196            text = da._repr_html_()
197            assert text.startswith("<pre>")
198            assert "&lt;xarray.DataArray &#x27;var1&#x27;" in text
199
200    def test_display_dataarray_style_html(self) -> None:
201        da = create_test_dataarray_attrs()
202        with xarray.set_options(display_style="html"):
203            html = da._repr_html_()
204            assert html.startswith("<div>")
205            assert "#x27;nested&#x27;" in html
206
207
208@pytest.mark.parametrize(
209    "set_value",
210    [("left"), ("exact")],
211)
212def test_get_options_retention(set_value):
213    """Test to check if get_options will return changes made by set_options"""
214    with xarray.set_options(arithmetic_join=set_value):
215        get_options = xarray.get_options()
216        assert get_options["arithmetic_join"] == set_value
217