1"""Tests for C-implemented GenericAlias."""
2
3import unittest
4import pickle
5import copy
6from collections import (
7    defaultdict, deque, OrderedDict, Counter, UserDict, UserList
8)
9from collections.abc import *
10from concurrent.futures import Future
11from concurrent.futures.thread import _WorkItem
12from contextlib import AbstractContextManager, AbstractAsyncContextManager
13from contextvars import ContextVar, Token
14from dataclasses import Field
15from functools import partial, partialmethod, cached_property
16from graphlib import TopologicalSorter
17from mailbox import Mailbox, _PartialFile
18try:
19    import ctypes
20except ImportError:
21    ctypes = None
22from difflib import SequenceMatcher
23from filecmp import dircmp
24from fileinput import FileInput
25from itertools import chain
26from http.cookies import Morsel
27from multiprocessing.managers import ValueProxy
28from multiprocessing.pool import ApplyResult
29try:
30    from multiprocessing.shared_memory import ShareableList
31except ImportError:
32    # multiprocessing.shared_memory is not available on e.g. Android
33    ShareableList = None
34from multiprocessing.queues import SimpleQueue as MPSimpleQueue
35from os import DirEntry
36from re import Pattern, Match
37from types import GenericAlias, MappingProxyType, AsyncGeneratorType
38from tempfile import TemporaryDirectory, SpooledTemporaryFile
39from urllib.parse import SplitResult, ParseResult
40from unittest.case import _AssertRaisesContext
41from queue import Queue, SimpleQueue
42from weakref import WeakSet, ReferenceType, ref
43import typing
44
45from typing import TypeVar
46T = TypeVar('T')
47K = TypeVar('K')
48V = TypeVar('V')
49
50class BaseTest(unittest.TestCase):
51    """Test basics."""
52    generic_types = [type, tuple, list, dict, set, frozenset, enumerate,
53                     defaultdict, deque,
54                     SequenceMatcher,
55                     dircmp,
56                     FileInput,
57                     OrderedDict, Counter, UserDict, UserList,
58                     Pattern, Match,
59                     partial, partialmethod, cached_property,
60                     TopologicalSorter,
61                     AbstractContextManager, AbstractAsyncContextManager,
62                     Awaitable, Coroutine,
63                     AsyncIterable, AsyncIterator,
64                     AsyncGenerator, Generator,
65                     Iterable, Iterator,
66                     Reversible,
67                     Container, Collection,
68                     Mailbox, _PartialFile,
69                     ContextVar, Token,
70                     Field,
71                     Set, MutableSet,
72                     Mapping, MutableMapping, MappingView,
73                     KeysView, ItemsView, ValuesView,
74                     Sequence, MutableSequence,
75                     MappingProxyType, AsyncGeneratorType,
76                     DirEntry,
77                     chain,
78                     TemporaryDirectory, SpooledTemporaryFile,
79                     Queue, SimpleQueue,
80                     _AssertRaisesContext,
81                     SplitResult, ParseResult,
82                     ValueProxy, ApplyResult,
83                     WeakSet, ReferenceType, ref,
84                     ShareableList, MPSimpleQueue,
85                     Future, _WorkItem,
86                     Morsel]
87    if ctypes is not None:
88        generic_types.extend((ctypes.Array, ctypes.LibraryLoader))
89
90    def test_subscriptable(self):
91        for t in self.generic_types:
92            if t is None:
93                continue
94            tname = t.__name__
95            with self.subTest(f"Testing {tname}"):
96                alias = t[int]
97                self.assertIs(alias.__origin__, t)
98                self.assertEqual(alias.__args__, (int,))
99                self.assertEqual(alias.__parameters__, ())
100
101    def test_unsubscriptable(self):
102        for t in int, str, float, Sized, Hashable:
103            tname = t.__name__
104            with self.subTest(f"Testing {tname}"):
105                with self.assertRaises(TypeError):
106                    t[int]
107
108    def test_instantiate(self):
109        for t in tuple, list, dict, set, frozenset, defaultdict, deque:
110            tname = t.__name__
111            with self.subTest(f"Testing {tname}"):
112                alias = t[int]
113                self.assertEqual(alias(), t())
114                if t is dict:
115                    self.assertEqual(alias(iter([('a', 1), ('b', 2)])), dict(a=1, b=2))
116                    self.assertEqual(alias(a=1, b=2), dict(a=1, b=2))
117                elif t is defaultdict:
118                    def default():
119                        return 'value'
120                    a = alias(default)
121                    d = defaultdict(default)
122                    self.assertEqual(a['test'], d['test'])
123                else:
124                    self.assertEqual(alias(iter((1, 2, 3))), t((1, 2, 3)))
125
126    def test_unbound_methods(self):
127        t = list[int]
128        a = t()
129        t.append(a, 'foo')
130        self.assertEqual(a, ['foo'])
131        x = t.__getitem__(a, 0)
132        self.assertEqual(x, 'foo')
133        self.assertEqual(t.__len__(a), 1)
134
135    def test_subclassing(self):
136        class C(list[int]):
137            pass
138        self.assertEqual(C.__bases__, (list,))
139        self.assertEqual(C.__class__, type)
140
141    def test_class_methods(self):
142        t = dict[int, None]
143        self.assertEqual(dict.fromkeys(range(2)), {0: None, 1: None})  # This works
144        self.assertEqual(t.fromkeys(range(2)), {0: None, 1: None})  # Should be equivalent
145
146    def test_no_chaining(self):
147        t = list[int]
148        with self.assertRaises(TypeError):
149            t[int]
150
151    def test_generic_subclass(self):
152        class MyList(list):
153            pass
154        t = MyList[int]
155        self.assertIs(t.__origin__, MyList)
156        self.assertEqual(t.__args__, (int,))
157        self.assertEqual(t.__parameters__, ())
158
159    def test_repr(self):
160        class MyList(list):
161            pass
162        self.assertEqual(repr(list[str]), 'list[str]')
163        self.assertEqual(repr(list[()]), 'list[()]')
164        self.assertEqual(repr(tuple[int, ...]), 'tuple[int, ...]')
165        self.assertTrue(repr(MyList[int]).endswith('.BaseTest.test_repr.<locals>.MyList[int]'))
166        self.assertEqual(repr(list[str]()), '[]')  # instances should keep their normal repr
167
168    def test_exposed_type(self):
169        import types
170        a = types.GenericAlias(list, int)
171        self.assertEqual(str(a), 'list[int]')
172        self.assertIs(a.__origin__, list)
173        self.assertEqual(a.__args__, (int,))
174        self.assertEqual(a.__parameters__, ())
175
176    def test_parameters(self):
177        from typing import List, Dict, Callable
178        D0 = dict[str, int]
179        self.assertEqual(D0.__args__, (str, int))
180        self.assertEqual(D0.__parameters__, ())
181        D1a = dict[str, V]
182        self.assertEqual(D1a.__args__, (str, V))
183        self.assertEqual(D1a.__parameters__, (V,))
184        D1b = dict[K, int]
185        self.assertEqual(D1b.__args__, (K, int))
186        self.assertEqual(D1b.__parameters__, (K,))
187        D2a = dict[K, V]
188        self.assertEqual(D2a.__args__, (K, V))
189        self.assertEqual(D2a.__parameters__, (K, V))
190        D2b = dict[T, T]
191        self.assertEqual(D2b.__args__, (T, T))
192        self.assertEqual(D2b.__parameters__, (T,))
193        L0 = list[str]
194        self.assertEqual(L0.__args__, (str,))
195        self.assertEqual(L0.__parameters__, ())
196        L1 = list[T]
197        self.assertEqual(L1.__args__, (T,))
198        self.assertEqual(L1.__parameters__, (T,))
199        L2 = list[list[T]]
200        self.assertEqual(L2.__args__, (list[T],))
201        self.assertEqual(L2.__parameters__, (T,))
202        L3 = list[List[T]]
203        self.assertEqual(L3.__args__, (List[T],))
204        self.assertEqual(L3.__parameters__, (T,))
205        L4a = list[Dict[K, V]]
206        self.assertEqual(L4a.__args__, (Dict[K, V],))
207        self.assertEqual(L4a.__parameters__, (K, V))
208        L4b = list[Dict[T, int]]
209        self.assertEqual(L4b.__args__, (Dict[T, int],))
210        self.assertEqual(L4b.__parameters__, (T,))
211        L5 = list[Callable[[K, V], K]]
212        self.assertEqual(L5.__args__, (Callable[[K, V], K],))
213        self.assertEqual(L5.__parameters__, (K, V))
214
215    def test_parameter_chaining(self):
216        from typing import List, Dict, Union, Callable
217        self.assertEqual(list[T][int], list[int])
218        self.assertEqual(dict[str, T][int], dict[str, int])
219        self.assertEqual(dict[T, int][str], dict[str, int])
220        self.assertEqual(dict[K, V][str, int], dict[str, int])
221        self.assertEqual(dict[T, T][int], dict[int, int])
222
223        self.assertEqual(list[list[T]][int], list[list[int]])
224        self.assertEqual(list[dict[T, int]][str], list[dict[str, int]])
225        self.assertEqual(list[dict[str, T]][int], list[dict[str, int]])
226        self.assertEqual(list[dict[K, V]][str, int], list[dict[str, int]])
227        self.assertEqual(dict[T, list[int]][str], dict[str, list[int]])
228
229        self.assertEqual(list[List[T]][int], list[List[int]])
230        self.assertEqual(list[Dict[K, V]][str, int], list[Dict[str, int]])
231        self.assertEqual(list[Union[K, V]][str, int], list[Union[str, int]])
232        self.assertEqual(list[Callable[[K, V], K]][str, int],
233                         list[Callable[[str, int], str]])
234        self.assertEqual(dict[T, List[int]][str], dict[str, List[int]])
235
236        with self.assertRaises(TypeError):
237            list[int][int]
238            dict[T, int][str, int]
239            dict[str, T][str, int]
240            dict[T, T][str, int]
241
242    def test_equality(self):
243        self.assertEqual(list[int], list[int])
244        self.assertEqual(dict[str, int], dict[str, int])
245        self.assertNotEqual(dict[str, int], dict[str, str])
246        self.assertNotEqual(list, list[int])
247        self.assertNotEqual(list[int], list)
248
249    def test_isinstance(self):
250        self.assertTrue(isinstance([], list))
251        with self.assertRaises(TypeError):
252            isinstance([], list[str])
253
254    def test_issubclass(self):
255        class L(list): ...
256        self.assertTrue(issubclass(L, list))
257        with self.assertRaises(TypeError):
258            issubclass(L, list[str])
259
260    def test_type_generic(self):
261        t = type[int]
262        Test = t('Test', (), {})
263        self.assertTrue(isinstance(Test, type))
264        test = Test()
265        self.assertEqual(t(test), Test)
266        self.assertEqual(t(0), int)
267
268    def test_type_subclass_generic(self):
269        class MyType(type):
270            pass
271        with self.assertRaises(TypeError):
272            MyType[int]
273
274    def test_pickle(self):
275        alias = GenericAlias(list, T)
276        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
277            s = pickle.dumps(alias, proto)
278            loaded = pickle.loads(s)
279            self.assertEqual(loaded.__origin__, alias.__origin__)
280            self.assertEqual(loaded.__args__, alias.__args__)
281            self.assertEqual(loaded.__parameters__, alias.__parameters__)
282
283    def test_copy(self):
284        class X(list):
285            def __copy__(self):
286                return self
287            def __deepcopy__(self, memo):
288                return self
289
290        for origin in list, deque, X:
291            alias = GenericAlias(origin, T)
292            copied = copy.copy(alias)
293            self.assertEqual(copied.__origin__, alias.__origin__)
294            self.assertEqual(copied.__args__, alias.__args__)
295            self.assertEqual(copied.__parameters__, alias.__parameters__)
296            copied = copy.deepcopy(alias)
297            self.assertEqual(copied.__origin__, alias.__origin__)
298            self.assertEqual(copied.__args__, alias.__args__)
299            self.assertEqual(copied.__parameters__, alias.__parameters__)
300
301    def test_union(self):
302        a = typing.Union[list[int], list[str]]
303        self.assertEqual(a.__args__, (list[int], list[str]))
304        self.assertEqual(a.__parameters__, ())
305
306    def test_union_generic(self):
307        a = typing.Union[list[T], tuple[T, ...]]
308        self.assertEqual(a.__args__, (list[T], tuple[T, ...]))
309        self.assertEqual(a.__parameters__, (T,))
310
311    def test_dir(self):
312        dir_of_gen_alias = set(dir(list[int]))
313        self.assertTrue(dir_of_gen_alias.issuperset(dir(list)))
314        for generic_alias_property in ("__origin__", "__args__", "__parameters__"):
315            self.assertIn(generic_alias_property, dir_of_gen_alias)
316
317    def test_weakref(self):
318        for t in self.generic_types:
319            if t is None:
320                continue
321            tname = t.__name__
322            with self.subTest(f"Testing {tname}"):
323                alias = t[int]
324                self.assertEqual(ref(alias)(), alias)
325
326    def test_no_kwargs(self):
327        # bpo-42576
328        with self.assertRaises(TypeError):
329            GenericAlias(bad=float)
330
331    def test_subclassing_types_genericalias(self):
332        class SubClass(GenericAlias): ...
333        alias = SubClass(list, int)
334        class Bad(GenericAlias):
335            def __new__(cls, *args, **kwargs):
336                super().__new__(cls, *args, **kwargs)
337
338        self.assertEqual(alias, list[int])
339        with self.assertRaises(TypeError):
340            Bad(list, int, bad=int)
341
342
343if __name__ == "__main__":
344    unittest.main()
345