1import io
2import pickle
3import tempfile
4import typing as t
5from contextlib import contextmanager
6from copy import copy
7from copy import deepcopy
8
9import pytest
10
11from werkzeug import datastructures as ds
12from werkzeug import http
13from werkzeug.exceptions import BadRequestKeyError
14
15
16class TestNativeItermethods:
17    def test_basic(self):
18        class StupidDict:
19            def keys(self, multi=1):
20                return iter(["a", "b", "c"] * multi)
21
22            def values(self, multi=1):
23                return iter([1, 2, 3] * multi)
24
25            def items(self, multi=1):
26                return iter(
27                    zip(iter(self.keys(multi=multi)), iter(self.values(multi=multi)))
28                )
29
30        d = StupidDict()
31        expected_keys = ["a", "b", "c"]
32        expected_values = [1, 2, 3]
33        expected_items = list(zip(expected_keys, expected_values))
34
35        assert list(d.keys()) == expected_keys
36        assert list(d.values()) == expected_values
37        assert list(d.items()) == expected_items
38
39        assert list(d.keys(2)) == expected_keys * 2
40        assert list(d.values(2)) == expected_values * 2
41        assert list(d.items(2)) == expected_items * 2
42
43
44class _MutableMultiDictTests:
45    storage_class: t.Type["ds.MultiDict"]
46
47    def test_pickle(self):
48        cls = self.storage_class
49
50        def create_instance(module=None):
51            if module is None:
52                d = cls()
53            else:
54                old = cls.__module__
55                cls.__module__ = module
56                d = cls()
57                cls.__module__ = old
58            d.setlist(b"foo", [1, 2, 3, 4])
59            d.setlist(b"bar", b"foo bar baz".split())
60            return d
61
62        for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
63            d = create_instance()
64            s = pickle.dumps(d, protocol)
65            ud = pickle.loads(s)
66            assert type(ud) == type(d)
67            assert ud == d
68            alternative = pickle.dumps(create_instance("werkzeug"), protocol)
69            assert pickle.loads(alternative) == d
70            ud[b"newkey"] = b"bla"
71            assert ud != d
72
73    def test_multidict_dict_interop(self):
74        # https://github.com/pallets/werkzeug/pull/2043
75        md = self.storage_class([("a", 1), ("a", 2)])
76        assert dict(md)["a"] != [1, 2]
77        assert dict(md)["a"] == 1
78        assert dict(md) == {**md} == {"a": 1}
79
80    def test_basic_interface(self):
81        md = self.storage_class()
82        assert isinstance(md, dict)
83
84        mapping = [
85            ("a", 1),
86            ("b", 2),
87            ("a", 2),
88            ("d", 3),
89            ("a", 1),
90            ("a", 3),
91            ("d", 4),
92            ("c", 3),
93        ]
94        md = self.storage_class(mapping)
95
96        # simple getitem gives the first value
97        assert md["a"] == 1
98        assert md["c"] == 3
99        with pytest.raises(KeyError):
100            md["e"]
101        assert md.get("a") == 1
102
103        # list getitem
104        assert md.getlist("a") == [1, 2, 1, 3]
105        assert md.getlist("d") == [3, 4]
106        # do not raise if key not found
107        assert md.getlist("x") == []
108
109        # simple setitem overwrites all values
110        md["a"] = 42
111        assert md.getlist("a") == [42]
112
113        # list setitem
114        md.setlist("a", [1, 2, 3])
115        assert md["a"] == 1
116        assert md.getlist("a") == [1, 2, 3]
117
118        # verify that it does not change original lists
119        l1 = [1, 2, 3]
120        md.setlist("a", l1)
121        del l1[:]
122        assert md["a"] == 1
123
124        # setdefault, setlistdefault
125        assert md.setdefault("u", 23) == 23
126        assert md.getlist("u") == [23]
127        del md["u"]
128
129        md.setlist("u", [-1, -2])
130
131        # delitem
132        del md["u"]
133        with pytest.raises(KeyError):
134            md["u"]
135        del md["d"]
136        assert md.getlist("d") == []
137
138        # keys, values, items, lists
139        assert list(sorted(md.keys())) == ["a", "b", "c"]
140        assert list(sorted(md.keys())) == ["a", "b", "c"]
141
142        assert list(sorted(md.values())) == [1, 2, 3]
143        assert list(sorted(md.values())) == [1, 2, 3]
144
145        assert list(sorted(md.items())) == [("a", 1), ("b", 2), ("c", 3)]
146        assert list(sorted(md.items(multi=True))) == [
147            ("a", 1),
148            ("a", 2),
149            ("a", 3),
150            ("b", 2),
151            ("c", 3),
152        ]
153        assert list(sorted(md.items())) == [("a", 1), ("b", 2), ("c", 3)]
154        assert list(sorted(md.items(multi=True))) == [
155            ("a", 1),
156            ("a", 2),
157            ("a", 3),
158            ("b", 2),
159            ("c", 3),
160        ]
161
162        assert list(sorted(md.lists())) == [("a", [1, 2, 3]), ("b", [2]), ("c", [3])]
163        assert list(sorted(md.lists())) == [("a", [1, 2, 3]), ("b", [2]), ("c", [3])]
164
165        # copy method
166        c = md.copy()
167        assert c["a"] == 1
168        assert c.getlist("a") == [1, 2, 3]
169
170        # copy method 2
171        c = copy(md)
172        assert c["a"] == 1
173        assert c.getlist("a") == [1, 2, 3]
174
175        # deepcopy method
176        c = md.deepcopy()
177        assert c["a"] == 1
178        assert c.getlist("a") == [1, 2, 3]
179
180        # deepcopy method 2
181        c = deepcopy(md)
182        assert c["a"] == 1
183        assert c.getlist("a") == [1, 2, 3]
184
185        # update with a multidict
186        od = self.storage_class([("a", 4), ("a", 5), ("y", 0)])
187        md.update(od)
188        assert md.getlist("a") == [1, 2, 3, 4, 5]
189        assert md.getlist("y") == [0]
190
191        # update with a regular dict
192        md = c
193        od = {"a": 4, "y": 0}
194        md.update(od)
195        assert md.getlist("a") == [1, 2, 3, 4]
196        assert md.getlist("y") == [0]
197
198        # pop, poplist, popitem, popitemlist
199        assert md.pop("y") == 0
200        assert "y" not in md
201        assert md.poplist("a") == [1, 2, 3, 4]
202        assert "a" not in md
203        assert md.poplist("missing") == []
204
205        # remaining: b=2, c=3
206        popped = md.popitem()
207        assert popped in [("b", 2), ("c", 3)]
208        popped = md.popitemlist()
209        assert popped in [("b", [2]), ("c", [3])]
210
211        # type conversion
212        md = self.storage_class({"a": "4", "b": ["2", "3"]})
213        assert md.get("a", type=int) == 4
214        assert md.getlist("b", type=int) == [2, 3]
215
216        # repr
217        md = self.storage_class([("a", 1), ("a", 2), ("b", 3)])
218        assert "('a', 1)" in repr(md)
219        assert "('a', 2)" in repr(md)
220        assert "('b', 3)" in repr(md)
221
222        # add and getlist
223        md.add("c", "42")
224        md.add("c", "23")
225        assert md.getlist("c") == ["42", "23"]
226        md.add("c", "blah")
227        assert md.getlist("c", type=int) == [42, 23]
228
229        # setdefault
230        md = self.storage_class()
231        md.setdefault("x", []).append(42)
232        md.setdefault("x", []).append(23)
233        assert md["x"] == [42, 23]
234
235        # to dict
236        md = self.storage_class()
237        md["foo"] = 42
238        md.add("bar", 1)
239        md.add("bar", 2)
240        assert md.to_dict() == {"foo": 42, "bar": 1}
241        assert md.to_dict(flat=False) == {"foo": [42], "bar": [1, 2]}
242
243        # popitem from empty dict
244        with pytest.raises(KeyError):
245            self.storage_class().popitem()
246
247        with pytest.raises(KeyError):
248            self.storage_class().popitemlist()
249
250        # key errors are of a special type
251        with pytest.raises(BadRequestKeyError):
252            self.storage_class()[42]
253
254        # setlist works
255        md = self.storage_class()
256        md["foo"] = 42
257        md.setlist("foo", [1, 2])
258        assert md.getlist("foo") == [1, 2]
259
260
261class _ImmutableDictTests:
262    storage_class: t.Type[dict]
263
264    def test_follows_dict_interface(self):
265        cls = self.storage_class
266
267        data = {"foo": 1, "bar": 2, "baz": 3}
268        d = cls(data)
269
270        assert d["foo"] == 1
271        assert d["bar"] == 2
272        assert d["baz"] == 3
273        assert sorted(d.keys()) == ["bar", "baz", "foo"]
274        assert "foo" in d
275        assert "foox" not in d
276        assert len(d) == 3
277
278    def test_copies_are_mutable(self):
279        cls = self.storage_class
280        immutable = cls({"a": 1})
281        with pytest.raises(TypeError):
282            immutable.pop("a")
283
284        mutable = immutable.copy()
285        mutable.pop("a")
286        assert "a" in immutable
287        assert mutable is not immutable
288        assert copy(immutable) is immutable
289
290    def test_dict_is_hashable(self):
291        cls = self.storage_class
292        immutable = cls({"a": 1, "b": 2})
293        immutable2 = cls({"a": 2, "b": 2})
294        x = {immutable}
295        assert immutable in x
296        assert immutable2 not in x
297        x.discard(immutable)
298        assert immutable not in x
299        assert immutable2 not in x
300        x.add(immutable2)
301        assert immutable not in x
302        assert immutable2 in x
303        x.add(immutable)
304        assert immutable in x
305        assert immutable2 in x
306
307
308class TestImmutableTypeConversionDict(_ImmutableDictTests):
309    storage_class = ds.ImmutableTypeConversionDict
310
311
312class TestImmutableMultiDict(_ImmutableDictTests):
313    storage_class = ds.ImmutableMultiDict
314
315    def test_multidict_is_hashable(self):
316        cls = self.storage_class
317        immutable = cls({"a": [1, 2], "b": 2})
318        immutable2 = cls({"a": [1], "b": 2})
319        x = {immutable}
320        assert immutable in x
321        assert immutable2 not in x
322        x.discard(immutable)
323        assert immutable not in x
324        assert immutable2 not in x
325        x.add(immutable2)
326        assert immutable not in x
327        assert immutable2 in x
328        x.add(immutable)
329        assert immutable in x
330        assert immutable2 in x
331
332
333class TestImmutableDict(_ImmutableDictTests):
334    storage_class = ds.ImmutableDict
335
336
337class TestImmutableOrderedMultiDict(_ImmutableDictTests):
338    storage_class = ds.ImmutableOrderedMultiDict
339
340    def test_ordered_multidict_is_hashable(self):
341        a = self.storage_class([("a", 1), ("b", 1), ("a", 2)])
342        b = self.storage_class([("a", 1), ("a", 2), ("b", 1)])
343        assert hash(a) != hash(b)
344
345
346class TestMultiDict(_MutableMultiDictTests):
347    storage_class = ds.MultiDict
348
349    def test_multidict_pop(self):
350        def make_d():
351            return self.storage_class({"foo": [1, 2, 3, 4]})
352
353        d = make_d()
354        assert d.pop("foo") == 1
355        assert not d
356        d = make_d()
357        assert d.pop("foo", 32) == 1
358        assert not d
359        d = make_d()
360        assert d.pop("foos", 32) == 32
361        assert d
362
363        with pytest.raises(KeyError):
364            d.pop("foos")
365
366    def test_multidict_pop_raise_badrequestkeyerror_for_empty_list_value(self):
367        mapping = [("a", "b"), ("a", "c")]
368        md = self.storage_class(mapping)
369
370        md.setlistdefault("empty", [])
371
372        with pytest.raises(KeyError):
373            md.pop("empty")
374
375    def test_multidict_popitem_raise_badrequestkeyerror_for_empty_list_value(self):
376        mapping = []
377        md = self.storage_class(mapping)
378
379        md.setlistdefault("empty", [])
380
381        with pytest.raises(BadRequestKeyError):
382            md.popitem()
383
384    def test_setlistdefault(self):
385        md = self.storage_class()
386        assert md.setlistdefault("u", [-1, -2]) == [-1, -2]
387        assert md.getlist("u") == [-1, -2]
388        assert md["u"] == -1
389
390    def test_iter_interfaces(self):
391        mapping = [
392            ("a", 1),
393            ("b", 2),
394            ("a", 2),
395            ("d", 3),
396            ("a", 1),
397            ("a", 3),
398            ("d", 4),
399            ("c", 3),
400        ]
401        md = self.storage_class(mapping)
402        assert list(zip(md.keys(), md.listvalues())) == list(md.lists())
403        assert list(zip(md, md.listvalues())) == list(md.lists())
404        assert list(zip(md.keys(), md.listvalues())) == list(md.lists())
405
406    def test_getitem_raise_badrequestkeyerror_for_empty_list_value(self):
407        mapping = [("a", "b"), ("a", "c")]
408        md = self.storage_class(mapping)
409
410        md.setlistdefault("empty", [])
411
412        with pytest.raises(KeyError):
413            md["empty"]
414
415
416class TestOrderedMultiDict(_MutableMultiDictTests):
417    storage_class = ds.OrderedMultiDict
418
419    def test_ordered_interface(self):
420        cls = self.storage_class
421
422        d = cls()
423        assert not d
424        d.add("foo", "bar")
425        assert len(d) == 1
426        d.add("foo", "baz")
427        assert len(d) == 1
428        assert list(d.items()) == [("foo", "bar")]
429        assert list(d) == ["foo"]
430        assert list(d.items(multi=True)) == [("foo", "bar"), ("foo", "baz")]
431        del d["foo"]
432        assert not d
433        assert len(d) == 0
434        assert list(d) == []
435
436        d.update([("foo", 1), ("foo", 2), ("bar", 42)])
437        d.add("foo", 3)
438        assert d.getlist("foo") == [1, 2, 3]
439        assert d.getlist("bar") == [42]
440        assert list(d.items()) == [("foo", 1), ("bar", 42)]
441
442        expected = ["foo", "bar"]
443
444        assert list(d.keys()) == expected
445        assert list(d) == expected
446        assert list(d.keys()) == expected
447
448        assert list(d.items(multi=True)) == [
449            ("foo", 1),
450            ("foo", 2),
451            ("bar", 42),
452            ("foo", 3),
453        ]
454        assert len(d) == 2
455
456        assert d.pop("foo") == 1
457        assert d.pop("blafasel", None) is None
458        assert d.pop("blafasel", 42) == 42
459        assert len(d) == 1
460        assert d.poplist("bar") == [42]
461        assert not d
462
463        assert d.get("missingkey") is None
464
465        d.add("foo", 42)
466        d.add("foo", 23)
467        d.add("bar", 2)
468        d.add("foo", 42)
469        assert d == ds.MultiDict(d)
470        id = self.storage_class(d)
471        assert d == id
472        d.add("foo", 2)
473        assert d != id
474
475        d.update({"blah": [1, 2, 3]})
476        assert d["blah"] == 1
477        assert d.getlist("blah") == [1, 2, 3]
478
479        # setlist works
480        d = self.storage_class()
481        d["foo"] = 42
482        d.setlist("foo", [1, 2])
483        assert d.getlist("foo") == [1, 2]
484        with pytest.raises(BadRequestKeyError):
485            d.pop("missing")
486
487        with pytest.raises(BadRequestKeyError):
488            d["missing"]
489
490        # popping
491        d = self.storage_class()
492        d.add("foo", 23)
493        d.add("foo", 42)
494        d.add("foo", 1)
495        assert d.popitem() == ("foo", 23)
496        with pytest.raises(BadRequestKeyError):
497            d.popitem()
498        assert not d
499
500        d.add("foo", 23)
501        d.add("foo", 42)
502        d.add("foo", 1)
503        assert d.popitemlist() == ("foo", [23, 42, 1])
504
505        with pytest.raises(BadRequestKeyError):
506            d.popitemlist()
507
508        # Unhashable
509        d = self.storage_class()
510        d.add("foo", 23)
511        pytest.raises(TypeError, hash, d)
512
513    def test_iterables(self):
514        a = ds.MultiDict((("key_a", "value_a"),))
515        b = ds.MultiDict((("key_b", "value_b"),))
516        ab = ds.CombinedMultiDict((a, b))
517
518        assert sorted(ab.lists()) == [("key_a", ["value_a"]), ("key_b", ["value_b"])]
519        assert sorted(ab.listvalues()) == [["value_a"], ["value_b"]]
520        assert sorted(ab.keys()) == ["key_a", "key_b"]
521
522        assert sorted(ab.lists()) == [("key_a", ["value_a"]), ("key_b", ["value_b"])]
523        assert sorted(ab.listvalues()) == [["value_a"], ["value_b"]]
524        assert sorted(ab.keys()) == ["key_a", "key_b"]
525
526    def test_get_description(self):
527        data = ds.OrderedMultiDict()
528
529        with pytest.raises(BadRequestKeyError) as exc_info:
530            data["baz"]
531
532        assert "baz" not in exc_info.value.get_description()
533        exc_info.value.show_exception = True
534        assert "baz" in exc_info.value.get_description()
535
536        with pytest.raises(BadRequestKeyError) as exc_info:
537            data.pop("baz")
538
539        exc_info.value.show_exception = True
540        assert "baz" in exc_info.value.get_description()
541        exc_info.value.args = ()
542        assert "baz" not in exc_info.value.get_description()
543
544
545class TestTypeConversionDict:
546    storage_class = ds.TypeConversionDict
547
548    def test_value_conversion(self):
549        d = self.storage_class(foo="1")
550        assert d.get("foo", type=int) == 1
551
552    def test_return_default_when_conversion_is_not_possible(self):
553        d = self.storage_class(foo="bar")
554        assert d.get("foo", default=-1, type=int) == -1
555
556    def test_propagate_exceptions_in_conversion(self):
557        d = self.storage_class(foo="bar")
558        switch = {"a": 1}
559        with pytest.raises(KeyError):
560            d.get("foo", type=lambda x: switch[x])
561
562
563class TestCombinedMultiDict:
564    storage_class = ds.CombinedMultiDict
565
566    def test_basic_interface(self):
567        d1 = ds.MultiDict([("foo", "1")])
568        d2 = ds.MultiDict([("bar", "2"), ("bar", "3")])
569        d = self.storage_class([d1, d2])
570
571        # lookup
572        assert d["foo"] == "1"
573        assert d["bar"] == "2"
574        assert d.getlist("bar") == ["2", "3"]
575
576        assert sorted(d.items()) == [("bar", "2"), ("foo", "1")]
577        assert sorted(d.items(multi=True)) == [("bar", "2"), ("bar", "3"), ("foo", "1")]
578        assert "missingkey" not in d
579        assert "foo" in d
580
581        # type lookup
582        assert d.get("foo", type=int) == 1
583        assert d.getlist("bar", type=int) == [2, 3]
584
585        # get key errors for missing stuff
586        with pytest.raises(KeyError):
587            d["missing"]
588
589        # make sure that they are immutable
590        with pytest.raises(TypeError):
591            d["foo"] = "blub"
592
593        # copies are mutable
594        d = d.copy()
595        d["foo"] = "blub"
596
597        # make sure lists merges
598        md1 = ds.MultiDict((("foo", "bar"), ("foo", "baz")))
599        md2 = ds.MultiDict((("foo", "blafasel"),))
600        x = self.storage_class((md1, md2))
601        assert list(x.lists()) == [("foo", ["bar", "baz", "blafasel"])]
602
603        # make sure dicts are created properly
604        assert x.to_dict() == {"foo": "bar"}
605        assert x.to_dict(flat=False) == {"foo": ["bar", "baz", "blafasel"]}
606
607    def test_length(self):
608        d1 = ds.MultiDict([("foo", "1")])
609        d2 = ds.MultiDict([("bar", "2")])
610        assert len(d1) == len(d2) == 1
611        d = self.storage_class([d1, d2])
612        assert len(d) == 2
613        d1.clear()
614        assert len(d1) == 0
615        assert len(d) == 1
616
617
618class TestHeaders:
619    storage_class = ds.Headers
620
621    def test_basic_interface(self):
622        headers = self.storage_class()
623        headers.add("Content-Type", "text/plain")
624        headers.add("X-Foo", "bar")
625        assert "x-Foo" in headers
626        assert "Content-type" in headers
627
628        headers["Content-Type"] = "foo/bar"
629        assert headers["Content-Type"] == "foo/bar"
630        assert len(headers.getlist("Content-Type")) == 1
631
632        # list conversion
633        assert headers.to_wsgi_list() == [("Content-Type", "foo/bar"), ("X-Foo", "bar")]
634        assert str(headers) == "Content-Type: foo/bar\r\nX-Foo: bar\r\n\r\n"
635        assert str(self.storage_class()) == "\r\n"
636
637        # extended add
638        headers.add("Content-Disposition", "attachment", filename="foo")
639        assert headers["Content-Disposition"] == "attachment; filename=foo"
640
641        headers.add("x", "y", z='"')
642        assert headers["x"] == r'y; z="\""'
643
644    def test_defaults_and_conversion(self):
645        # defaults
646        headers = self.storage_class(
647            [
648                ("Content-Type", "text/plain"),
649                ("X-Foo", "bar"),
650                ("X-Bar", "1"),
651                ("X-Bar", "2"),
652            ]
653        )
654        assert headers.getlist("x-bar") == ["1", "2"]
655        assert headers.get("x-Bar") == "1"
656        assert headers.get("Content-Type") == "text/plain"
657
658        assert headers.setdefault("X-Foo", "nope") == "bar"
659        assert headers.setdefault("X-Bar", "nope") == "1"
660        assert headers.setdefault("X-Baz", "quux") == "quux"
661        assert headers.setdefault("X-Baz", "nope") == "quux"
662        headers.pop("X-Baz")
663
664        # type conversion
665        assert headers.get("x-bar", type=int) == 1
666        assert headers.getlist("x-bar", type=int) == [1, 2]
667
668        # list like operations
669        assert headers[0] == ("Content-Type", "text/plain")
670        assert headers[:1] == self.storage_class([("Content-Type", "text/plain")])
671        del headers[:2]
672        del headers[-1]
673        assert headers == self.storage_class([("X-Bar", "1")])
674
675    def test_copying(self):
676        a = self.storage_class([("foo", "bar")])
677        b = a.copy()
678        a.add("foo", "baz")
679        assert a.getlist("foo") == ["bar", "baz"]
680        assert b.getlist("foo") == ["bar"]
681
682    def test_popping(self):
683        headers = self.storage_class([("a", 1)])
684        assert headers.pop("a") == 1
685        assert headers.pop("b", 2) == 2
686
687        with pytest.raises(KeyError):
688            headers.pop("c")
689
690    def test_set_arguments(self):
691        a = self.storage_class()
692        a.set("Content-Disposition", "useless")
693        a.set("Content-Disposition", "attachment", filename="foo")
694        assert a["Content-Disposition"] == "attachment; filename=foo"
695
696    def test_reject_newlines(self):
697        h = self.storage_class()
698
699        for variation in "foo\nbar", "foo\r\nbar", "foo\rbar":
700            with pytest.raises(ValueError):
701                h["foo"] = variation
702            with pytest.raises(ValueError):
703                h.add("foo", variation)
704            with pytest.raises(ValueError):
705                h.add("foo", "test", option=variation)
706            with pytest.raises(ValueError):
707                h.set("foo", variation)
708            with pytest.raises(ValueError):
709                h.set("foo", "test", option=variation)
710
711    def test_slicing(self):
712        # there's nothing wrong with these being native strings
713        # Headers doesn't care about the data types
714        h = self.storage_class()
715        h.set("X-Foo-Poo", "bleh")
716        h.set("Content-Type", "application/whocares")
717        h.set("X-Forwarded-For", "192.168.0.123")
718        h[:] = [(k, v) for k, v in h if k.startswith("X-")]
719        assert list(h) == [("X-Foo-Poo", "bleh"), ("X-Forwarded-For", "192.168.0.123")]
720
721    def test_bytes_operations(self):
722        h = self.storage_class()
723        h.set("X-Foo-Poo", "bleh")
724        h.set("X-Whoops", b"\xff")
725        h.set(b"X-Bytes", b"something")
726
727        assert h.get("x-foo-poo", as_bytes=True) == b"bleh"
728        assert h.get("x-whoops", as_bytes=True) == b"\xff"
729        assert h.get("x-bytes") == "something"
730
731    def test_extend(self):
732        h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")])
733        h.extend(ds.Headers([("a", "3"), ("a", "4")]))
734        assert h.getlist("a") == ["0", "3", "4"]
735        h.extend(b=["5", "6"])
736        assert h.getlist("b") == ["1", "5", "6"]
737        h.extend({"c": "7", "d": ["8", "9"]}, c="10")
738        assert h.getlist("c") == ["2", "7", "10"]
739        assert h.getlist("d") == ["8", "9"]
740
741        with pytest.raises(TypeError):
742            h.extend({"x": "x"}, {"x": "x"})
743
744    def test_update(self):
745        h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")])
746        h.update(ds.Headers([("a", "3"), ("a", "4")]))
747        assert h.getlist("a") == ["3", "4"]
748        h.update(b=["5", "6"])
749        assert h.getlist("b") == ["5", "6"]
750        h.update({"c": "7", "d": ["8", "9"]})
751        assert h.getlist("c") == ["7"]
752        assert h.getlist("d") == ["8", "9"]
753        h.update({"c": "10"}, c="11")
754        assert h.getlist("c") == ["11"]
755
756        with pytest.raises(TypeError):
757            h.extend({"x": "x"}, {"x": "x"})
758
759    def test_setlist(self):
760        h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")])
761        h.setlist("b", ["3", "4"])
762        assert h[1] == ("b", "3")
763        assert h[-1] == ("b", "4")
764        h.setlist("b", [])
765        assert "b" not in h
766        h.setlist("d", ["5"])
767        assert h["d"] == "5"
768
769    def test_setlistdefault(self):
770        h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")])
771        assert h.setlistdefault("a", ["3"]) == ["0"]
772        assert h.setlistdefault("d", ["4", "5"]) == ["4", "5"]
773
774    def test_to_wsgi_list(self):
775        h = self.storage_class()
776        h.set("Key", "Value")
777        for key, value in h.to_wsgi_list():
778            assert key == "Key"
779            assert value == "Value"
780
781    def test_to_wsgi_list_bytes(self):
782        h = self.storage_class()
783        h.set(b"Key", b"Value")
784        for key, value in h.to_wsgi_list():
785            assert key == "Key"
786            assert value == "Value"
787
788    def test_equality(self):
789        # test equality, given keys are case insensitive
790        h1 = self.storage_class()
791        h1.add("X-Foo", "foo")
792        h1.add("X-Bar", "bah")
793        h1.add("X-Bar", "humbug")
794
795        h2 = self.storage_class()
796        h2.add("x-foo", "foo")
797        h2.add("x-bar", "bah")
798        h2.add("x-bar", "humbug")
799
800        assert h1 == h2
801
802
803class TestEnvironHeaders:
804    storage_class = ds.EnvironHeaders
805
806    def test_basic_interface(self):
807        # this happens in multiple WSGI servers because they
808        # use a vary naive way to convert the headers;
809        broken_env = {
810            "HTTP_CONTENT_TYPE": "text/html",
811            "CONTENT_TYPE": "text/html",
812            "HTTP_CONTENT_LENGTH": "0",
813            "CONTENT_LENGTH": "0",
814            "HTTP_ACCEPT": "*",
815            "wsgi.version": (1, 0),
816        }
817        headers = self.storage_class(broken_env)
818        assert headers
819        assert len(headers) == 3
820        assert sorted(headers) == [
821            ("Accept", "*"),
822            ("Content-Length", "0"),
823            ("Content-Type", "text/html"),
824        ]
825        assert not self.storage_class({"wsgi.version": (1, 0)})
826        assert len(self.storage_class({"wsgi.version": (1, 0)})) == 0
827        assert 42 not in headers
828
829    def test_skip_empty_special_vars(self):
830        env = {"HTTP_X_FOO": "42", "CONTENT_TYPE": "", "CONTENT_LENGTH": ""}
831        headers = self.storage_class(env)
832        assert dict(headers) == {"X-Foo": "42"}
833
834        env = {"HTTP_X_FOO": "42", "CONTENT_TYPE": "", "CONTENT_LENGTH": "0"}
835        headers = self.storage_class(env)
836        assert dict(headers) == {"X-Foo": "42", "Content-Length": "0"}
837
838    def test_return_type_is_str(self):
839        headers = self.storage_class({"HTTP_FOO": "\xe2\x9c\x93"})
840        assert headers["Foo"] == "\xe2\x9c\x93"
841        assert next(iter(headers)) == ("Foo", "\xe2\x9c\x93")
842
843    def test_bytes_operations(self):
844        foo_val = "\xff"
845        h = self.storage_class({"HTTP_X_FOO": foo_val})
846
847        assert h.get("x-foo", as_bytes=True) == b"\xff"
848        assert h.get("x-foo") == "\xff"
849
850
851class TestHeaderSet:
852    storage_class = ds.HeaderSet
853
854    def test_basic_interface(self):
855        hs = self.storage_class()
856        hs.add("foo")
857        hs.add("bar")
858        assert "Bar" in hs
859        assert hs.find("foo") == 0
860        assert hs.find("BAR") == 1
861        assert hs.find("baz") < 0
862        hs.discard("missing")
863        hs.discard("foo")
864        assert hs.find("foo") < 0
865        assert hs.find("bar") == 0
866
867        with pytest.raises(IndexError):
868            hs.index("missing")
869
870        assert hs.index("bar") == 0
871        assert hs
872        hs.clear()
873        assert not hs
874
875
876class TestImmutableList:
877    storage_class = ds.ImmutableList
878
879    def test_list_hashable(self):
880        data = (1, 2, 3, 4)
881        store = self.storage_class(data)
882        assert hash(data) == hash(store)
883        assert data != store
884
885
886def make_call_asserter(func=None):
887    """Utility to assert a certain number of function calls.
888
889    :param func: Additional callback for each function call.
890
891    .. code-block:: python
892        assert_calls, func = make_call_asserter()
893        with assert_calls(2):
894            func()
895            func()
896    """
897    calls = [0]
898
899    @contextmanager
900    def asserter(count, msg=None):
901        calls[0] = 0
902        yield
903        assert calls[0] == count
904
905    def wrapped(*args, **kwargs):
906        calls[0] += 1
907        if func is not None:
908            return func(*args, **kwargs)
909
910    return asserter, wrapped
911
912
913class TestCallbackDict:
914    storage_class = ds.CallbackDict
915
916    def test_callback_dict_reads(self):
917        assert_calls, func = make_call_asserter()
918        initial = {"a": "foo", "b": "bar"}
919        dct = self.storage_class(initial=initial, on_update=func)
920        with assert_calls(0, "callback triggered by read-only method"):
921            # read-only methods
922            dct["a"]
923            dct.get("a")
924            pytest.raises(KeyError, lambda: dct["x"])
925            assert "a" in dct
926            list(iter(dct))
927            dct.copy()
928        with assert_calls(0, "callback triggered without modification"):
929            # methods that may write but don't
930            dct.pop("z", None)
931            dct.setdefault("a")
932
933    def test_callback_dict_writes(self):
934        assert_calls, func = make_call_asserter()
935        initial = {"a": "foo", "b": "bar"}
936        dct = self.storage_class(initial=initial, on_update=func)
937        with assert_calls(8, "callback not triggered by write method"):
938            # always-write methods
939            dct["z"] = 123
940            dct["z"] = 123  # must trigger again
941            del dct["z"]
942            dct.pop("b", None)
943            dct.setdefault("x")
944            dct.popitem()
945            dct.update([])
946            dct.clear()
947        with assert_calls(0, "callback triggered by failed del"):
948            pytest.raises(KeyError, lambda: dct.__delitem__("x"))
949        with assert_calls(0, "callback triggered by failed pop"):
950            pytest.raises(KeyError, lambda: dct.pop("x"))
951
952
953class TestCacheControl:
954    def test_repr(self):
955        cc = ds.RequestCacheControl([("max-age", "0"), ("private", "True")])
956        assert repr(cc) == "<RequestCacheControl max-age='0' private='True'>"
957
958    def test_set_none(self):
959        cc = ds.ResponseCacheControl([("max-age", "0")])
960        assert cc.no_cache is None
961        cc.no_cache = None
962        assert cc.no_cache is None
963
964
965class TestContentSecurityPolicy:
966    def test_construct(self):
967        csp = ds.ContentSecurityPolicy([("font-src", "'self'"), ("media-src", "*")])
968        assert csp.font_src == "'self'"
969        assert csp.media_src == "*"
970        policies = [policy.strip() for policy in csp.to_header().split(";")]
971        assert "font-src 'self'" in policies
972        assert "media-src *" in policies
973
974    def test_properties(self):
975        csp = ds.ContentSecurityPolicy()
976        csp.default_src = "* 'self' quart.com"
977        csp.img_src = "'none'"
978        policies = [policy.strip() for policy in csp.to_header().split(";")]
979        assert "default-src * 'self' quart.com" in policies
980        assert "img-src 'none'" in policies
981
982
983class TestAccept:
984    storage_class = ds.Accept
985
986    def test_accept_basic(self):
987        accept = self.storage_class(
988            [("tinker", 0), ("tailor", 0.333), ("soldier", 0.667), ("sailor", 1)]
989        )
990        # check __getitem__ on indices
991        assert accept[3] == ("tinker", 0)
992        assert accept[2] == ("tailor", 0.333)
993        assert accept[1] == ("soldier", 0.667)
994        assert accept[0], ("sailor", 1)
995        # check __getitem__ on string
996        assert accept["tinker"] == 0
997        assert accept["tailor"] == 0.333
998        assert accept["soldier"] == 0.667
999        assert accept["sailor"] == 1
1000        assert accept["spy"] == 0
1001        # check quality method
1002        assert accept.quality("tinker") == 0
1003        assert accept.quality("tailor") == 0.333
1004        assert accept.quality("soldier") == 0.667
1005        assert accept.quality("sailor") == 1
1006        assert accept.quality("spy") == 0
1007        # check __contains__
1008        assert "sailor" in accept
1009        assert "spy" not in accept
1010        # check index method
1011        assert accept.index("tinker") == 3
1012        assert accept.index("tailor") == 2
1013        assert accept.index("soldier") == 1
1014        assert accept.index("sailor") == 0
1015        with pytest.raises(ValueError):
1016            accept.index("spy")
1017        # check find method
1018        assert accept.find("tinker") == 3
1019        assert accept.find("tailor") == 2
1020        assert accept.find("soldier") == 1
1021        assert accept.find("sailor") == 0
1022        assert accept.find("spy") == -1
1023        # check to_header method
1024        assert accept.to_header() == "sailor,soldier;q=0.667,tailor;q=0.333,tinker;q=0"
1025        # check best_match method
1026        assert (
1027            accept.best_match(["tinker", "tailor", "soldier", "sailor"], default=None)
1028            == "sailor"
1029        )
1030        assert (
1031            accept.best_match(["tinker", "tailor", "soldier"], default=None)
1032            == "soldier"
1033        )
1034        assert accept.best_match(["tinker", "tailor"], default=None) == "tailor"
1035        assert accept.best_match(["tinker"], default=None) is None
1036        assert accept.best_match(["tinker"], default="x") == "x"
1037
1038    def test_accept_wildcard(self):
1039        accept = self.storage_class([("*", 0), ("asterisk", 1)])
1040        assert "*" in accept
1041        assert accept.best_match(["asterisk", "star"], default=None) == "asterisk"
1042        assert accept.best_match(["star"], default=None) is None
1043
1044    def test_accept_keep_order(self):
1045        accept = self.storage_class([("*", 1)])
1046        assert accept.best_match(["alice", "bob"]) == "alice"
1047        assert accept.best_match(["bob", "alice"]) == "bob"
1048        accept = self.storage_class([("alice", 1), ("bob", 1)])
1049        assert accept.best_match(["alice", "bob"]) == "alice"
1050        assert accept.best_match(["bob", "alice"]) == "bob"
1051
1052    def test_accept_wildcard_specificity(self):
1053        accept = self.storage_class([("asterisk", 0), ("star", 0.5), ("*", 1)])
1054        assert accept.best_match(["star", "asterisk"], default=None) == "star"
1055        assert accept.best_match(["asterisk", "star"], default=None) == "star"
1056        assert accept.best_match(["asterisk", "times"], default=None) == "times"
1057        assert accept.best_match(["asterisk"], default=None) is None
1058
1059    def test_accept_equal_quality(self):
1060        accept = self.storage_class([("a", 1), ("b", 1)])
1061        assert accept.best == "a"
1062
1063
1064class TestMIMEAccept:
1065    @pytest.mark.parametrize(
1066        ("values", "matches", "default", "expect"),
1067        [
1068            ([("text/*", 1)], ["text/html"], None, "text/html"),
1069            ([("text/*", 1)], ["image/png"], "text/plain", "text/plain"),
1070            ([("text/*", 1)], ["image/png"], None, None),
1071            (
1072                [("*/*", 1), ("text/html", 1)],
1073                ["image/png", "text/html"],
1074                None,
1075                "text/html",
1076            ),
1077            (
1078                [("*/*", 1), ("text/html", 1)],
1079                ["image/png", "text/plain"],
1080                None,
1081                "image/png",
1082            ),
1083            (
1084                [("*/*", 1), ("text/html", 1), ("image/*", 1)],
1085                ["image/png", "text/html"],
1086                None,
1087                "text/html",
1088            ),
1089            (
1090                [("*/*", 1), ("text/html", 1), ("image/*", 1)],
1091                ["text/plain", "image/png"],
1092                None,
1093                "image/png",
1094            ),
1095            (
1096                [("text/html", 1), ("text/html; level=1", 1)],
1097                ["text/html;level=1"],
1098                None,
1099                "text/html;level=1",
1100            ),
1101        ],
1102    )
1103    def test_mime_accept(self, values, matches, default, expect):
1104        accept = ds.MIMEAccept(values)
1105        match = accept.best_match(matches, default=default)
1106        assert match == expect
1107
1108
1109class TestLanguageAccept:
1110    @pytest.mark.parametrize(
1111        ("values", "matches", "default", "expect"),
1112        (
1113            ([("en-us", 1)], ["en"], None, "en"),
1114            ([("en", 1)], ["en_US"], None, "en_US"),
1115            ([("en-GB", 1)], ["en-US"], None, None),
1116            ([("de_AT", 1), ("de", 0.9)], ["en"], None, None),
1117            ([("de_AT", 1), ("de", 0.9), ("en-US", 0.8)], ["de", "en"], None, "de"),
1118            ([("de_AT", 0.9), ("en-US", 1)], ["en"], None, "en"),
1119            ([("en-us", 1)], ["en-us"], None, "en-us"),
1120            ([("en-us", 1)], ["en-us", "en"], None, "en-us"),
1121            ([("en-GB", 1)], ["en-US", "en"], "en-US", "en"),
1122            ([("de_AT", 1)], ["en-US", "en"], "en-US", "en-US"),
1123            ([("aus-EN", 1)], ["aus"], None, "aus"),
1124            ([("aus", 1)], ["aus-EN"], None, "aus-EN"),
1125        ),
1126    )
1127    def test_best_match_fallback(self, values, matches, default, expect):
1128        accept = ds.LanguageAccept(values)
1129        best = accept.best_match(matches, default=default)
1130        assert best == expect
1131
1132
1133class TestFileStorage:
1134    storage_class = ds.FileStorage
1135
1136    def test_mimetype_always_lowercase(self):
1137        file_storage = self.storage_class(content_type="APPLICATION/JSON")
1138        assert file_storage.mimetype == "application/json"
1139
1140    @pytest.mark.parametrize("data", [io.StringIO("one\ntwo"), io.BytesIO(b"one\ntwo")])
1141    def test_bytes_proper_sentinel(self, data):
1142        # iterate over new lines and don't enter an infinite loop
1143        storage = self.storage_class(data)
1144        idx = -1
1145
1146        for idx, _line in enumerate(storage):
1147            assert idx < 2
1148
1149        assert idx == 1
1150
1151    @pytest.mark.parametrize("stream", (tempfile.SpooledTemporaryFile, io.BytesIO))
1152    def test_proxy_can_access_stream_attrs(self, stream):
1153        """``SpooledTemporaryFile`` doesn't implement some of
1154        ``IOBase``. Ensure that ``FileStorage`` can still access the
1155        attributes from the backing file object.
1156
1157        https://github.com/pallets/werkzeug/issues/1344
1158        https://github.com/python/cpython/pull/3249
1159        """
1160        file_storage = self.storage_class(stream=stream())
1161
1162        for name in ("fileno", "writable", "readable", "seekable"):
1163            assert hasattr(file_storage, name)
1164
1165    @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
1166    def test_save_to_pathlib_dst(self, tmp_path):
1167        src = tmp_path / "src.txt"
1168        src.write_text("test")
1169        storage = self.storage_class(src.open("rb"))
1170        dst = tmp_path / "dst.txt"
1171        storage.save(dst)
1172        assert dst.read_text() == "test"
1173
1174    def test_save_to_bytes_io(self):
1175        storage = self.storage_class(io.BytesIO(b"one\ntwo"))
1176        dst = io.BytesIO()
1177        storage.save(dst)
1178        assert dst.getvalue() == b"one\ntwo"
1179
1180    def test_save_to_file(self, tmp_path):
1181        path = tmp_path / "file.data"
1182        storage = self.storage_class(io.BytesIO(b"one\ntwo"))
1183        with path.open("wb") as dst:
1184            storage.save(dst)
1185        with path.open("rb") as src:
1186            assert src.read() == b"one\ntwo"
1187
1188
1189@pytest.mark.parametrize("ranges", ([(0, 1), (-5, None)], [(5, None)]))
1190def test_range_to_header(ranges):
1191    header = ds.Range("byes", ranges).to_header()
1192    r = http.parse_range_header(header)
1193    assert r.ranges == ranges
1194
1195
1196@pytest.mark.parametrize(
1197    "ranges", ([(0, 0)], [(None, 1)], [(1, 0)], [(0, 1), (-5, 10)])
1198)
1199def test_range_validates_ranges(ranges):
1200    with pytest.raises(ValueError):
1201        ds.Range("bytes", ranges)
1202