1import os
2
3import pytest
4
5
6@pytest.fixture
7def example():
8    def _example(name):
9        with open(
10            os.path.join(os.path.dirname(__file__), "examples", name + ".toml"),
11            encoding="utf-8",
12        ) as f:
13            return f.read()
14
15    return _example
16
17
18@pytest.fixture
19def json_example():
20    def _example(name):
21        with open(
22            os.path.join(os.path.dirname(__file__), "examples", "json", name + ".json"),
23            encoding="utf-8",
24        ) as f:
25            return f.read()
26
27    return _example
28
29
30@pytest.fixture
31def invalid_example():
32    def _example(name):
33        with open(
34            os.path.join(
35                os.path.dirname(__file__), "examples", "invalid", name + ".toml"
36            ),
37            encoding="utf-8",
38        ) as f:
39            return f.read()
40
41    return _example
42
43
44TEST_DIR = os.path.join(os.path.dirname(__file__), "toml-test", "tests")
45IGNORED_TESTS = {
46    "invalid": [
47        "array-mixed-types-strings-and-ints.toml",
48        "array-mixed-types-arrays-and-ints.toml",
49        "array-mixed-types-ints-and-floats.toml",
50    ]
51}
52
53
54def get_tomltest_cases():
55    dirs = sorted(
56        f for f in os.listdir(TEST_DIR) if os.path.isdir(os.path.join(TEST_DIR, f))
57    )
58    assert dirs == ["invalid", "invalid-encoder", "valid"]
59    rv = {}
60    for d in dirs:
61        rv[d] = {}
62        ignored = IGNORED_TESTS.get(d, [])
63        files = os.listdir(os.path.join(TEST_DIR, d))
64        for f in files:
65            if f in ignored:
66                continue
67
68            bn, ext = f.rsplit(".", 1)
69            if bn not in rv[d]:
70                rv[d][bn] = {}
71            with open(os.path.join(TEST_DIR, d, f), encoding="utf-8") as inp:
72                rv[d][bn][ext] = inp.read()
73    return rv
74
75
76def pytest_generate_tests(metafunc):
77    test_list = get_tomltest_cases()
78    if "valid_case" in metafunc.fixturenames:
79        metafunc.parametrize(
80            "valid_case",
81            test_list["valid"].values(),
82            ids=list(test_list["valid"].keys()),
83        )
84    elif "invalid_decode_case" in metafunc.fixturenames:
85        metafunc.parametrize(
86            "invalid_decode_case",
87            test_list["invalid"].values(),
88            ids=list(test_list["invalid"].keys()),
89        )
90    elif "invalid_encode_case" in metafunc.fixturenames:
91        metafunc.parametrize(
92            "invalid_encode_case",
93            test_list["invalid-encoder"].values(),
94            ids=list(test_list["invalid-encoder"].keys()),
95        )
96