1from . import util as test_util
2machinery = test_util.import_importlib('importlib.machinery')
3
4import os
5import re
6import sys
7import unittest
8from test import support
9from distutils.util import get_platform
10from contextlib import contextmanager
11from .util import temp_module
12
13support.import_module('winreg', required_on=['win'])
14from winreg import (
15    CreateKey, HKEY_CURRENT_USER,
16    SetValue, REG_SZ, KEY_ALL_ACCESS,
17    EnumKey, CloseKey, DeleteKey, OpenKey
18)
19
20def delete_registry_tree(root, subkey):
21    try:
22        hkey = OpenKey(root, subkey, access=KEY_ALL_ACCESS)
23    except OSError:
24        # subkey does not exist
25        return
26    while True:
27        try:
28            subsubkey = EnumKey(hkey, 0)
29        except OSError:
30            # no more subkeys
31            break
32        delete_registry_tree(hkey, subsubkey)
33    CloseKey(hkey)
34    DeleteKey(root, subkey)
35
36@contextmanager
37def setup_module(machinery, name, path=None):
38    if machinery.WindowsRegistryFinder.DEBUG_BUILD:
39        root = machinery.WindowsRegistryFinder.REGISTRY_KEY_DEBUG
40    else:
41        root = machinery.WindowsRegistryFinder.REGISTRY_KEY
42    key = root.format(fullname=name,
43                      sys_version='%d.%d' % sys.version_info[:2])
44    try:
45        with temp_module(name, "a = 1") as location:
46            subkey = CreateKey(HKEY_CURRENT_USER, key)
47            if path is None:
48                path = location + ".py"
49            SetValue(subkey, "", REG_SZ, path)
50            yield
51    finally:
52        if machinery.WindowsRegistryFinder.DEBUG_BUILD:
53            key = os.path.dirname(key)
54        delete_registry_tree(HKEY_CURRENT_USER, key)
55
56
57@unittest.skipUnless(sys.platform.startswith('win'), 'requires Windows')
58class WindowsRegistryFinderTests:
59    # The module name is process-specific, allowing for
60    # simultaneous runs of the same test on a single machine.
61    test_module = "spamham{}".format(os.getpid())
62
63    def test_find_spec_missing(self):
64        spec = self.machinery.WindowsRegistryFinder.find_spec('spam')
65        self.assertIs(spec, None)
66
67    def test_find_module_missing(self):
68        loader = self.machinery.WindowsRegistryFinder.find_module('spam')
69        self.assertIs(loader, None)
70
71    def test_module_found(self):
72        with setup_module(self.machinery, self.test_module):
73            loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module)
74            spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module)
75            self.assertIsNot(loader, None)
76            self.assertIsNot(spec, None)
77
78    def test_module_not_found(self):
79        with setup_module(self.machinery, self.test_module, path="."):
80            loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module)
81            spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module)
82            self.assertIsNone(loader)
83            self.assertIsNone(spec)
84
85(Frozen_WindowsRegistryFinderTests,
86 Source_WindowsRegistryFinderTests
87 ) = test_util.test_both(WindowsRegistryFinderTests, machinery=machinery)
88
89@unittest.skipUnless(sys.platform.startswith('win'), 'requires Windows')
90class WindowsExtensionSuffixTests:
91    def test_tagged_suffix(self):
92        suffixes = self.machinery.EXTENSION_SUFFIXES
93        expected_tag = ".cp{0.major}{0.minor}-{1}.pyd".format(sys.version_info,
94            re.sub('[^a-zA-Z0-9]', '_', get_platform()))
95        try:
96            untagged_i = suffixes.index(".pyd")
97        except ValueError:
98            untagged_i = suffixes.index("_d.pyd")
99            expected_tag = "_d" + expected_tag
100
101        self.assertIn(expected_tag, suffixes)
102
103        # Ensure the tags are in the correct order
104        tagged_i = suffixes.index(expected_tag)
105        self.assertLess(tagged_i, untagged_i)
106
107(Frozen_WindowsExtensionSuffixTests,
108 Source_WindowsExtensionSuffixTests
109 ) = test_util.test_both(WindowsExtensionSuffixTests, machinery=machinery)
110
111
112@unittest.skipUnless(sys.platform.startswith('win'), 'requires Windows')
113class WindowsBootstrapPathTests(unittest.TestCase):
114    def check_join(self, expected, *inputs):
115        from importlib._bootstrap_external import _path_join
116        actual = _path_join(*inputs)
117        if expected.casefold() == actual.casefold():
118            return
119        self.assertEqual(expected, actual)
120
121    def test_path_join(self):
122        self.check_join(r"C:\A\B", "C:\\", "A", "B")
123        self.check_join(r"C:\A\B", "D:\\", "D", "C:\\", "A", "B")
124        self.check_join(r"C:\A\B", "C:\\", "A", "C:B")
125        self.check_join(r"C:\A\B", "C:\\", "A\\B")
126        self.check_join(r"C:\A\B", r"C:\A\B")
127
128        self.check_join("D:A", r"D:", "A")
129        self.check_join("D:A", r"C:\B\C", "D:", "A")
130        self.check_join("D:A", r"C:\B\C", r"D:A")
131
132        self.check_join(r"A\B\C", "A", "B", "C")
133        self.check_join(r"A\B\C", "A", r"B\C")
134        self.check_join(r"A\B/C", "A", "B/C")
135        self.check_join(r"A\B\C", "A/", "B\\", "C")
136
137        # Dots are not normalised by this function
138        self.check_join(r"A\../C", "A", "../C")
139        self.check_join(r"A.\.\B", "A.", ".", "B")
140
141        self.check_join(r"\\Server\Share\A\B\C", r"\\Server\Share", "A", "B", "C")
142        self.check_join(r"\\Server\Share\A\B\C", r"\\Server\Share", "D", r"\A", "B", "C")
143        self.check_join(r"\\Server\Share\A\B\C", r"\\Server2\Share2", "D",
144                                                 r"\\Server\Share", "A", "B", "C")
145        self.check_join(r"\\Server\Share\A\B\C", r"\\Server", r"\Share", "A", "B", "C")
146        self.check_join(r"\\Server\Share", r"\\Server\Share")
147        self.check_join(r"\\Server\Share\\", r"\\Server\Share\\")
148
149        # Handle edge cases with empty segments
150        self.check_join("C:\\A", "C:/A", "")
151        self.check_join("C:\\", "C:/", "")
152        self.check_join("C:", "C:", "")
153        self.check_join("//Server/Share\\", "//Server/Share/", "")
154        self.check_join("//Server/Share\\", "//Server/Share", "")
155