1import os
2import stat
3import sys
4from collections import OrderedDict
5from contextlib import contextmanager
6
7import pytest
8
9import dask.config
10from dask.config import (
11    canonical_name,
12    collect,
13    collect_env,
14    collect_yaml,
15    config,
16    deserialize,
17    ensure_file,
18    expand_environment_variables,
19    get,
20    merge,
21    refresh,
22    rename,
23    serialize,
24    set,
25    update,
26    update_defaults,
27)
28from dask.utils import tmpfile
29
30yaml = pytest.importorskip("yaml")
31
32
33def test_canonical_name():
34    c = {"foo-bar": 1, "fizz_buzz": 2}
35    assert canonical_name("foo-bar", c) == "foo-bar"
36    assert canonical_name("foo_bar", c) == "foo-bar"
37    assert canonical_name("fizz-buzz", c) == "fizz_buzz"
38    assert canonical_name("fizz_buzz", c) == "fizz_buzz"
39    assert canonical_name("new-key", c) == "new-key"
40    assert canonical_name("new_key", c) == "new_key"
41
42
43def test_update():
44    a = {"x": 1, "y": {"a": 1}}
45    b = {"x": 2, "z": 3, "y": OrderedDict({"b": 2})}
46    update(b, a)
47    assert b == {"x": 1, "y": {"a": 1, "b": 2}, "z": 3}
48
49    a = {"x": 1, "y": {"a": 1}}
50    b = {"x": 2, "z": 3, "y": {"a": 3, "b": 2}}
51    update(b, a, priority="old")
52    assert b == {"x": 2, "y": {"a": 3, "b": 2}, "z": 3}
53
54
55def test_merge():
56    a = {"x": 1, "y": {"a": 1}}
57    b = {"x": 2, "z": 3, "y": {"b": 2}}
58
59    expected = {"x": 2, "y": {"a": 1, "b": 2}, "z": 3}
60
61    c = merge(a, b)
62    assert c == expected
63
64
65def test_collect_yaml_paths():
66    a = {"x": 1, "y": {"a": 1}}
67    b = {"x": 2, "z": 3, "y": {"b": 2}}
68
69    expected = {"x": 2, "y": {"a": 1, "b": 2}, "z": 3}
70
71    with tmpfile(extension="yaml") as fn1:
72        with tmpfile(extension="yaml") as fn2:
73            with open(fn1, "w") as f:
74                yaml.dump(a, f)
75            with open(fn2, "w") as f:
76                yaml.dump(b, f)
77
78            config = merge(*collect_yaml(paths=[fn1, fn2]))
79            assert config == expected
80
81
82def test_collect_yaml_dir():
83    a = {"x": 1, "y": {"a": 1}}
84    b = {"x": 2, "z": 3, "y": {"b": 2}}
85
86    expected = {"x": 2, "y": {"a": 1, "b": 2}, "z": 3}
87
88    with tmpfile() as dirname:
89        os.mkdir(dirname)
90        with open(os.path.join(dirname, "a.yaml"), mode="w") as f:
91            yaml.dump(a, f)
92        with open(os.path.join(dirname, "b.yaml"), mode="w") as f:
93            yaml.dump(b, f)
94
95        config = merge(*collect_yaml(paths=[dirname]))
96        assert config == expected
97
98
99@contextmanager
100def no_read_permissions(path):
101    perm_orig = stat.S_IMODE(os.stat(path).st_mode)
102    perm_new = perm_orig ^ stat.S_IREAD
103    try:
104        os.chmod(path, perm_new)
105        yield
106    finally:
107        os.chmod(path, perm_orig)
108
109
110@pytest.mark.skipif(
111    sys.platform == "win32", reason="Can't make writeonly file on windows"
112)
113@pytest.mark.parametrize("kind", ["directory", "file"])
114def test_collect_yaml_permission_errors(tmpdir, kind):
115    a = {"x": 1, "y": 2}
116    b = {"y": 3, "z": 4}
117
118    dir_path = str(tmpdir)
119    a_path = os.path.join(dir_path, "a.yaml")
120    b_path = os.path.join(dir_path, "b.yaml")
121
122    with open(a_path, mode="w") as f:
123        yaml.dump(a, f)
124    with open(b_path, mode="w") as f:
125        yaml.dump(b, f)
126
127    if kind == "directory":
128        cant_read = dir_path
129        expected = {}
130    else:
131        cant_read = a_path
132        expected = b
133
134    with no_read_permissions(cant_read):
135        config = merge(*collect_yaml(paths=[dir_path]))
136        assert config == expected
137
138
139def test_env():
140    env = {
141        "DASK_A_B": "123",
142        "DASK_C": "True",
143        "DASK_D": "hello",
144        "DASK_E__X": "123",
145        "DASK_E__Y": "456",
146        "DASK_F": '[1, 2, "3"]',
147        "DASK_G": "/not/parsable/as/literal",
148        "FOO": "not included",
149    }
150
151    expected = {
152        "a_b": 123,
153        "c": True,
154        "d": "hello",
155        "e": {"x": 123, "y": 456},
156        "f": [1, 2, "3"],
157        "g": "/not/parsable/as/literal",
158    }
159
160    res = collect_env(env)
161    assert res == expected
162
163
164def test_collect():
165    a = {"x": 1, "y": {"a": 1}}
166    b = {"x": 2, "z": 3, "y": {"b": 2}}
167    env = {"DASK_W": 4}
168
169    expected = {"w": 4, "x": 2, "y": {"a": 1, "b": 2}, "z": 3}
170
171    with tmpfile(extension="yaml") as fn1:
172        with tmpfile(extension="yaml") as fn2:
173            with open(fn1, "w") as f:
174                yaml.dump(a, f)
175            with open(fn2, "w") as f:
176                yaml.dump(b, f)
177
178            config = collect([fn1, fn2], env=env)
179            assert config == expected
180
181
182def test_collect_env_none(monkeypatch):
183    monkeypatch.setenv("DASK_FOO", "bar")
184    config = collect([])
185    assert config == {"foo": "bar"}
186
187
188def test_get():
189    d = {"x": 1, "y": {"a": 2}}
190
191    assert get("x", config=d) == 1
192    assert get("y.a", config=d) == 2
193    assert get("y.b", 123, config=d) == 123
194    with pytest.raises(KeyError):
195        get("y.b", config=d)
196
197
198def test_ensure_file(tmpdir):
199    a = {"x": 1, "y": {"a": 1}}
200    b = {"x": 123}
201
202    source = os.path.join(str(tmpdir), "source.yaml")
203    dest = os.path.join(str(tmpdir), "dest")
204    destination = os.path.join(dest, "source.yaml")
205
206    with open(source, "w") as f:
207        yaml.dump(a, f)
208
209    ensure_file(source=source, destination=dest, comment=False)
210
211    with open(destination) as f:
212        result = yaml.safe_load(f)
213    assert result == a
214
215    # don't overwrite old config files
216    with open(source, "w") as f:
217        yaml.dump(b, f)
218
219    ensure_file(source=source, destination=dest, comment=False)
220
221    with open(destination) as f:
222        result = yaml.safe_load(f)
223    assert result == a
224
225    os.remove(destination)
226
227    # Write again, now with comments
228    ensure_file(source=source, destination=dest, comment=True)
229
230    with open(destination) as f:
231        text = f.read()
232    assert "123" in text
233
234    with open(destination) as f:
235        result = yaml.safe_load(f)
236    assert not result
237
238
239def test_set():
240    with set(abc=123):
241        assert config["abc"] == 123
242        with set(abc=456):
243            assert config["abc"] == 456
244        assert config["abc"] == 123
245
246    assert "abc" not in config
247
248    with set({"abc": 123}):
249        assert config["abc"] == 123
250    assert "abc" not in config
251
252    with set({"abc.x": 1, "abc.y": 2, "abc.z.a": 3}):
253        assert config["abc"] == {"x": 1, "y": 2, "z": {"a": 3}}
254    assert "abc" not in config
255
256    d = {}
257    set({"abc.x": 123}, config=d)
258    assert d["abc"]["x"] == 123
259
260
261def test_set_kwargs():
262    with set(foo__bar=1, foo__baz=2):
263        assert config["foo"] == {"bar": 1, "baz": 2}
264    assert "foo" not in config
265
266    # Mix kwargs and dict, kwargs override
267    with set({"foo.bar": 1, "foo.baz": 2}, foo__buzz=3, foo__bar=4):
268        assert config["foo"] == {"bar": 4, "baz": 2, "buzz": 3}
269    assert "foo" not in config
270
271    # Mix kwargs and nested dict, kwargs override
272    with set({"foo": {"bar": 1, "baz": 2}}, foo__buzz=3, foo__bar=4):
273        assert config["foo"] == {"bar": 4, "baz": 2, "buzz": 3}
274    assert "foo" not in config
275
276
277def test_set_nested():
278    with set({"abc": {"x": 123}}):
279        assert config["abc"] == {"x": 123}
280        with set({"abc.y": 456}):
281            assert config["abc"] == {"x": 123, "y": 456}
282        assert config["abc"] == {"x": 123}
283    assert "abc" not in config
284
285
286def test_set_hard_to_copyables():
287    import threading
288
289    with set(x=threading.Lock()):
290        with set(y=1):
291            pass
292
293
294@pytest.mark.parametrize("mkdir", [True, False])
295def test_ensure_file_directory(mkdir, tmpdir):
296    a = {"x": 1, "y": {"a": 1}}
297
298    source = os.path.join(str(tmpdir), "source.yaml")
299    dest = os.path.join(str(tmpdir), "dest")
300
301    with open(source, "w") as f:
302        yaml.dump(a, f)
303
304    if mkdir:
305        os.mkdir(dest)
306
307    ensure_file(source=source, destination=dest)
308
309    assert os.path.isdir(dest)
310    assert os.path.exists(os.path.join(dest, "source.yaml"))
311
312
313def test_ensure_file_defaults_to_DASK_CONFIG_directory(tmpdir):
314    a = {"x": 1, "y": {"a": 1}}
315    source = os.path.join(str(tmpdir), "source.yaml")
316    with open(source, "w") as f:
317        yaml.dump(a, f)
318
319    destination = os.path.join(str(tmpdir), "dask")
320    PATH = dask.config.PATH
321    try:
322        dask.config.PATH = destination
323        ensure_file(source=source)
324    finally:
325        dask.config.PATH = PATH
326
327    assert os.path.isdir(destination)
328    [fn] = os.listdir(destination)
329    assert os.path.split(fn)[1] == os.path.split(source)[1]
330
331
332def test_rename():
333    aliases = {"foo_bar": "foo.bar"}
334    config = {"foo-bar": 123}
335    rename(aliases, config=config)
336    assert config == {"foo": {"bar": 123}}
337
338
339def test_refresh():
340    defaults = []
341    config = {}
342
343    update_defaults({"a": 1}, config=config, defaults=defaults)
344    assert config == {"a": 1}
345
346    refresh(paths=[], env={"DASK_B": "2"}, config=config, defaults=defaults)
347    assert config == {"a": 1, "b": 2}
348
349    refresh(paths=[], env={"DASK_C": "3"}, config=config, defaults=defaults)
350    assert config == {"a": 1, "c": 3}
351
352
353@pytest.mark.parametrize(
354    "inp,out",
355    [
356        ("1", "1"),
357        (1, 1),
358        ("$FOO", "foo"),
359        ([1, "$FOO"], [1, "foo"]),
360        ((1, "$FOO"), (1, "foo")),
361        ({1, "$FOO"}, {1, "foo"}),
362        ({"a": "$FOO"}, {"a": "foo"}),
363        ({"a": "A", "b": [1, "2", "$FOO"]}, {"a": "A", "b": [1, "2", "foo"]}),
364    ],
365)
366def test_expand_environment_variables(monkeypatch, inp, out):
367    monkeypatch.setenv("FOO", "foo")
368    assert expand_environment_variables(inp) == out
369
370
371def test_env_var_canonical_name(monkeypatch):
372    value = 3
373    monkeypatch.setenv("DASK_A_B", str(value))
374    d = {}
375    dask.config.refresh(config=d)
376    assert get("a_b", config=d) == value
377    assert get("a-b", config=d) == value
378
379
380def test_get_set_canonical_name():
381    c = {"x-y": {"a_b": 123}}
382
383    keys = ["x_y.a_b", "x-y.a-b", "x_y.a-b"]
384    for k in keys:
385        assert dask.config.get(k, config=c) == 123
386
387    with dask.config.set({"x_y": {"a-b": 456}}, config=c):
388        for k in keys:
389            assert dask.config.get(k, config=c) == 456
390
391    # No change to new keys in sub dicts
392    with dask.config.set({"x_y": {"a-b": {"c_d": 1}, "e-f": 2}}, config=c):
393        assert dask.config.get("x_y.a-b", config=c) == {"c_d": 1}
394        assert dask.config.get("x_y.e_f", config=c) == 2
395
396
397@pytest.mark.parametrize("key", ["custom_key", "custom-key"])
398def test_get_set_roundtrip(key):
399    value = 123
400    with dask.config.set({key: value}):
401        assert dask.config.get("custom_key") == value
402        assert dask.config.get("custom-key") == value
403
404
405def test_merge_None_to_dict():
406    assert dask.config.merge({"a": None, "c": 0}, {"a": {"b": 1}}) == {
407        "a": {"b": 1},
408        "c": 0,
409    }
410
411
412def test_core_file():
413    assert "temporary-directory" in dask.config.config
414    assert "dataframe" in dask.config.config
415    assert "shuffle-compression" in dask.config.get("dataframe")
416
417
418def test_schema():
419    jsonschema = pytest.importorskip("jsonschema")
420
421    config_fn = os.path.join(os.path.dirname(__file__), "..", "dask.yaml")
422    schema_fn = os.path.join(os.path.dirname(__file__), "..", "dask-schema.yaml")
423
424    with open(config_fn) as f:
425        config = yaml.safe_load(f)
426
427    with open(schema_fn) as f:
428        schema = yaml.safe_load(f)
429
430    jsonschema.validate(config, schema)
431
432
433def test_schema_is_complete():
434    config_fn = os.path.join(os.path.dirname(__file__), "..", "dask.yaml")
435    schema_fn = os.path.join(os.path.dirname(__file__), "..", "dask-schema.yaml")
436
437    with open(config_fn) as f:
438        config = yaml.safe_load(f)
439
440    with open(schema_fn) as f:
441        schema = yaml.safe_load(f)
442
443    def test_matches(c, s):
444        for k, v in c.items():
445            if list(c) != list(s["properties"]):
446                raise ValueError(
447                    "\nThe dask.yaml and dask-schema.yaml files are not in sync.\n"
448                    "This usually happens when we add a new configuration value,\n"
449                    "but don't add the schema of that value to the dask-schema.yaml file\n"
450                    "Please modify these files to include the missing values: \n\n"
451                    "    dask.yaml:        {}\n"
452                    "    dask-schema.yaml: {}\n\n"
453                    "Examples in these files should be a good start, \n"
454                    "even if you are not familiar with the jsonschema spec".format(
455                        sorted(c), sorted(s["properties"])
456                    )
457                )
458            if isinstance(v, dict):
459                test_matches(c[k], s["properties"][k])
460
461    test_matches(config, schema)
462
463
464def test_deprecations():
465    with pytest.warns(Warning) as info:
466        with dask.config.set(fuse_ave_width=123):
467            assert dask.config.get("optimization.fuse.ave-width") == 123
468
469    assert "optimization.fuse.ave-width" in str(info[0].message)
470
471
472def test_get_override_with():
473    with dask.config.set({"foo": "bar"}):
474        # If override_with is None get the config key
475        assert dask.config.get("foo") == "bar"
476        assert dask.config.get("foo", override_with=None) == "bar"
477
478        # Otherwise pass the default straight through
479        assert dask.config.get("foo", override_with="baz") == "baz"
480        assert dask.config.get("foo", override_with=False) is False
481        assert dask.config.get("foo", override_with=True) is True
482        assert dask.config.get("foo", override_with=123) == 123
483        assert dask.config.get("foo", override_with={"hello": "world"}) == {
484            "hello": "world"
485        }
486        assert dask.config.get("foo", override_with=["one"]) == ["one"]
487
488
489def test_config_serialization():
490    # Use context manager without changing the value to ensure test side effects are restored
491    with dask.config.set({"array.svg.size": dask.config.get("array.svg.size")}):
492
493        # Take a round trip through the serialization
494        serialized = serialize({"array": {"svg": {"size": 150}}})
495        config = deserialize(serialized)
496
497        dask.config.update(dask.config.global_config, config)
498        assert dask.config.get("array.svg.size") == 150
499
500
501def test_config_inheritance():
502    config = collect_env(
503        {"DASK_INTERNAL_INHERIT_CONFIG": serialize({"array": {"svg": {"size": 150}}})}
504    )
505    assert dask.config.get("array.svg.size", config=config) == 150
506