1# Copyright (c) Thomas Kluyver and contributors
2# Distributed under the terms of the MIT license; see LICENSE file.
3
4import os.path as osp
5import pytest
6import warnings
7from zipfile import ZipFile
8
9import entrypoints
10
11samples_dir = osp.join(osp.dirname(__file__), 'samples')
12
13sample_path = [
14    osp.join(samples_dir, 'packages1'),
15    osp.join(samples_dir, 'packages1', 'baz-0.3.egg'),
16    osp.join(samples_dir, 'packages2'),
17    osp.join(samples_dir, 'packages2', 'qux-0.4.egg'),
18]
19
20def test_get_group_all():
21    group = entrypoints.get_group_all('entrypoints.test1', sample_path)
22    print(group)
23    assert len(group) == 5
24    assert set(ep.name for ep in group) == {'abc', 'rew', 'opo', 'njn'}
25
26def test_get_group_named():
27    group = entrypoints.get_group_named('entrypoints.test1', sample_path)
28    print(group)
29    assert len(group) == 4
30    assert group['abc'].module_name == 'foo'
31    assert group['abc'].object_name == 'abc'
32
33def test_get_single():
34    ep = entrypoints.get_single('entrypoints.test1', 'abc', sample_path)
35    assert ep.module_name == 'foo'
36    assert ep.object_name == 'abc'
37
38    ep2 = entrypoints.get_single('entrypoints.test1', 'njn', sample_path)
39    assert ep2.module_name == 'qux.extn'
40    assert ep2.object_name == 'Njn.load'
41
42def test_dot_prefix():
43    ep = entrypoints.get_single('blogtool.parsers', '.rst', sample_path)
44    assert ep.object_name == 'SomeClass.some_classmethod'
45    assert ep.extras == ['reST']
46
47    group = entrypoints.get_group_named('blogtool.parsers', sample_path)
48    assert set(group.keys()) == {'.rst'}
49
50def test_case_sensitive():
51    group = entrypoints.get_group_named('test.case_sensitive', sample_path)
52    assert set(group.keys()) == {'Ptangle', 'ptangle'}
53
54def test_load_zip(tmpdir):
55    whl_file = str(tmpdir / 'parmesan-1.2.whl')
56    with ZipFile(whl_file, 'w') as whl:
57        whl.writestr('parmesan-1.2.dist-info/entry_points.txt',
58                     b'[entrypoints.test.inzip]\na = edam:gouda')
59        whl.writestr('gruyere-2!1b4.dev0.egg-info/entry_points.txt',
60                     b'[entrypoints.test.inzip]\nb = wensleydale:gouda')
61
62    ep = entrypoints.get_single('entrypoints.test.inzip', 'a', [str(whl_file)])
63    assert ep.module_name == 'edam'
64    assert ep.object_name == 'gouda'
65    assert ep.distro.name == 'parmesan'
66    assert ep.distro.version == '1.2'
67
68    ep2 = entrypoints.get_single('entrypoints.test.inzip', 'b', [str(whl_file)])
69    assert ep2.module_name == 'wensleydale'
70    assert ep2.object_name == 'gouda'
71    assert ep2.distro.name == 'gruyere'
72    assert ep2.distro.version == '2!1b4.dev0'
73
74def test_load():
75    ep = entrypoints.EntryPoint('get_ep', 'entrypoints', 'get_single', None)
76    obj = ep.load()
77    assert obj is entrypoints.get_single
78
79    # The object part is optional (e.g. pytest plugins use just a module ref)
80    ep = entrypoints.EntryPoint('ep_mod', 'entrypoints', None)
81    obj = ep.load()
82    assert obj is entrypoints
83
84def test_bad():
85    bad_path = [osp.join(samples_dir, 'packages3')]
86
87    with warnings.catch_warnings(record=True) as w:
88        group = entrypoints.get_group_named('entrypoints.test1', bad_path)
89
90    assert 'bad' not in group
91    assert len(w) == 1
92
93    with warnings.catch_warnings(record=True) as w2, \
94            pytest.raises(entrypoints.NoSuchEntryPoint):
95        ep = entrypoints.get_single('entrypoints.test1', 'bad')
96
97    assert len(w) == 1
98
99def test_missing():
100    with pytest.raises(entrypoints.NoSuchEntryPoint) as ec:
101        entrypoints.get_single('no.such.group', 'no_such_name', sample_path)
102
103    assert ec.value.group == 'no.such.group'
104    assert ec.value.name == 'no_such_name'
105
106def test_parse():
107    ep = entrypoints.EntryPoint.from_string(
108        'some.module:some.attr [extra1,extra2]', 'foo'
109    )
110    assert ep.module_name == 'some.module'
111    assert ep.object_name == 'some.attr'
112    assert ep.extras == ['extra1', 'extra2']
113
114def test_parse_bad():
115    with pytest.raises(entrypoints.BadEntryPoint):
116        entrypoints.EntryPoint.from_string("this won't work", 'foo')
117