1import pytest
2
3from mitmproxy.coretypes import multidict
4
5
6class _TMulti:
7    @staticmethod
8    def _kconv(key):
9        return key.lower()
10
11
12class TMultiDict(_TMulti, multidict.MultiDict):
13    pass
14
15
16class TestMultiDict:
17    @staticmethod
18    def _multi():
19        return TMultiDict((
20            ("foo", "bar"),
21            ("bar", "baz"),
22            ("Bar", "bam")
23        ))
24
25    def test_init(self):
26        md = TMultiDict()
27        assert len(md) == 0
28
29        md = TMultiDict([("foo", "bar")])
30        assert len(md) == 1
31        assert md.fields == (("foo", "bar"),)
32
33    def test_repr(self):
34        assert repr(self._multi()) == (
35            "TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]"
36        )
37
38    def test_getitem(self):
39        md = TMultiDict([("foo", "bar")])
40        assert "foo" in md
41        assert "Foo" in md
42        assert md["foo"] == "bar"
43
44        with pytest.raises(KeyError):
45            assert md["bar"]
46
47        md_multi = TMultiDict(
48            [("foo", "a"), ("foo", "b")]
49        )
50        assert md_multi["foo"] == "a"
51
52    def test_setitem(self):
53        md = TMultiDict()
54        md["foo"] = "bar"
55        assert md.fields == (("foo", "bar"),)
56
57        md["foo"] = "baz"
58        assert md.fields == (("foo", "baz"),)
59
60        md["bar"] = "bam"
61        assert md.fields == (("foo", "baz"), ("bar", "bam"))
62
63    def test_delitem(self):
64        md = self._multi()
65        del md["foo"]
66        assert "foo" not in md
67        assert "bar" in md
68
69        with pytest.raises(KeyError):
70            del md["foo"]
71
72        del md["bar"]
73        assert md.fields == ()
74
75    def test_iter(self):
76        md = self._multi()
77        assert list(md.__iter__()) == ["foo", "bar"]
78
79    def test_len(self):
80        md = TMultiDict()
81        assert len(md) == 0
82
83        md = self._multi()
84        assert len(md) == 2
85
86    def test_eq(self):
87        assert TMultiDict() == TMultiDict()
88        assert not (TMultiDict() == 42)
89
90        md1 = self._multi()
91        md2 = self._multi()
92        assert md1 == md2
93        md1.fields = md1.fields[1:] + md1.fields[:1]
94        assert not (md1 == md2)
95
96    def test_hash(self):
97        """
98        If a class defines mutable objects and implements an __eq__() method,
99        it should not implement __hash__(), since the implementation of hashable
100        collections requires that a key's hash value is immutable.
101        """
102        with pytest.raises(TypeError):
103            assert hash(TMultiDict())
104
105    def test_get_all(self):
106        md = self._multi()
107        assert md.get_all("foo") == ["bar"]
108        assert md.get_all("bar") == ["baz", "bam"]
109        assert md.get_all("baz") == []
110
111    def test_set_all(self):
112        md = TMultiDict()
113        md.set_all("foo", ["bar", "baz"])
114        assert md.fields == (("foo", "bar"), ("foo", "baz"))
115
116        md = TMultiDict((
117            ("a", "b"),
118            ("x", "x"),
119            ("c", "d"),
120            ("X", "X"),
121            ("e", "f"),
122        ))
123        md.set_all("x", ["1", "2", "3"])
124        assert md.fields == (
125            ("a", "b"),
126            ("x", "1"),
127            ("c", "d"),
128            ("X", "2"),
129            ("e", "f"),
130            ("x", "3"),
131        )
132        md.set_all("x", ["4"])
133        assert md.fields == (
134            ("a", "b"),
135            ("x", "4"),
136            ("c", "d"),
137            ("e", "f"),
138        )
139
140    def test_add(self):
141        md = self._multi()
142        md.add("foo", "foo")
143        assert md.fields == (
144            ("foo", "bar"),
145            ("bar", "baz"),
146            ("Bar", "bam"),
147            ("foo", "foo")
148        )
149
150    def test_insert(self):
151        md = TMultiDict([("b", "b")])
152        md.insert(0, "a", "a")
153        md.insert(2, "c", "c")
154        assert md.fields == (("a", "a"), ("b", "b"), ("c", "c"))
155
156    def test_keys(self):
157        md = self._multi()
158        assert list(md.keys()) == ["foo", "bar"]
159        assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"]
160
161    def test_values(self):
162        md = self._multi()
163        assert list(md.values()) == ["bar", "baz"]
164        assert list(md.values(multi=True)) == ["bar", "baz", "bam"]
165
166    def test_items(self):
167        md = self._multi()
168        assert list(md.items()) == [("foo", "bar"), ("bar", "baz")]
169        assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")]
170
171    def test_state(self):
172        md = self._multi()
173        assert len(md.get_state()) == 3
174        assert md == TMultiDict.from_state(md.get_state())
175
176        md2 = TMultiDict()
177        assert md != md2
178        md2.set_state(md.get_state())
179        assert md == md2
180
181
182class TParent:
183    def __init__(self):
184        self.vals = tuple()
185
186    def setter(self, vals):
187        self.vals = vals
188
189    def getter(self):
190        return self.vals
191
192
193class TestMultiDictView:
194    def test_modify(self):
195        p = TParent()
196        tv = multidict.MultiDictView(p.getter, p.setter)
197        assert len(tv) == 0
198        tv["a"] = "b"
199        assert p.vals == (("a", "b"),)
200        tv["c"] = "b"
201        assert p.vals == (("a", "b"), ("c", "b"))
202        assert tv["a"] == "b"
203
204    def test_copy(self):
205        p = TParent()
206        tv = multidict.MultiDictView(p.getter, p.setter)
207        c = tv.copy()
208        assert isinstance(c, multidict.MultiDict)
209        assert tv.items() == c.items()
210        c["foo"] = "bar"
211        assert tv.items() != c.items()
212