1import os 2 3import pytest 4 5 6@pytest.fixture 7def example(): 8 def _example(name): 9 with open( 10 os.path.join(os.path.dirname(__file__), "examples", name + ".toml"), 11 encoding="utf-8", 12 ) as f: 13 return f.read() 14 15 return _example 16 17 18@pytest.fixture 19def json_example(): 20 def _example(name): 21 with open( 22 os.path.join(os.path.dirname(__file__), "examples", "json", name + ".json"), 23 encoding="utf-8", 24 ) as f: 25 return f.read() 26 27 return _example 28 29 30@pytest.fixture 31def invalid_example(): 32 def _example(name): 33 with open( 34 os.path.join( 35 os.path.dirname(__file__), "examples", "invalid", name + ".toml" 36 ), 37 encoding="utf-8", 38 ) as f: 39 return f.read() 40 41 return _example 42 43 44TEST_DIR = os.path.join(os.path.dirname(__file__), "toml-test", "tests") 45IGNORED_TESTS = { 46 "invalid": [ 47 "array-mixed-types-strings-and-ints.toml", 48 "array-mixed-types-arrays-and-ints.toml", 49 "array-mixed-types-ints-and-floats.toml", 50 ] 51} 52 53 54def get_tomltest_cases(): 55 dirs = sorted( 56 f for f in os.listdir(TEST_DIR) if os.path.isdir(os.path.join(TEST_DIR, f)) 57 ) 58 assert dirs == ["invalid", "invalid-encoder", "valid"] 59 rv = {} 60 for d in dirs: 61 rv[d] = {} 62 ignored = IGNORED_TESTS.get(d, []) 63 files = os.listdir(os.path.join(TEST_DIR, d)) 64 for f in files: 65 if f in ignored: 66 continue 67 68 bn, ext = f.rsplit(".", 1) 69 if bn not in rv[d]: 70 rv[d][bn] = {} 71 with open(os.path.join(TEST_DIR, d, f), encoding="utf-8") as inp: 72 rv[d][bn][ext] = inp.read() 73 return rv 74 75 76def pytest_generate_tests(metafunc): 77 test_list = get_tomltest_cases() 78 if "valid_case" in metafunc.fixturenames: 79 metafunc.parametrize( 80 "valid_case", 81 test_list["valid"].values(), 82 ids=list(test_list["valid"].keys()), 83 ) 84 elif "invalid_decode_case" in metafunc.fixturenames: 85 metafunc.parametrize( 86 "invalid_decode_case", 87 test_list["invalid"].values(), 88 ids=list(test_list["invalid"].keys()), 89 ) 90 elif "invalid_encode_case" in metafunc.fixturenames: 91 metafunc.parametrize( 92 "invalid_encode_case", 93 test_list["invalid-encoder"].values(), 94 ids=list(test_list["invalid-encoder"].keys()), 95 ) 96