1import io 2import pickle 3import tempfile 4import typing as t 5from contextlib import contextmanager 6from copy import copy 7from copy import deepcopy 8 9import pytest 10 11from werkzeug import datastructures as ds 12from werkzeug import http 13from werkzeug.exceptions import BadRequestKeyError 14 15 16class TestNativeItermethods: 17 def test_basic(self): 18 class StupidDict: 19 def keys(self, multi=1): 20 return iter(["a", "b", "c"] * multi) 21 22 def values(self, multi=1): 23 return iter([1, 2, 3] * multi) 24 25 def items(self, multi=1): 26 return iter( 27 zip(iter(self.keys(multi=multi)), iter(self.values(multi=multi))) 28 ) 29 30 d = StupidDict() 31 expected_keys = ["a", "b", "c"] 32 expected_values = [1, 2, 3] 33 expected_items = list(zip(expected_keys, expected_values)) 34 35 assert list(d.keys()) == expected_keys 36 assert list(d.values()) == expected_values 37 assert list(d.items()) == expected_items 38 39 assert list(d.keys(2)) == expected_keys * 2 40 assert list(d.values(2)) == expected_values * 2 41 assert list(d.items(2)) == expected_items * 2 42 43 44class _MutableMultiDictTests: 45 storage_class: t.Type["ds.MultiDict"] 46 47 def test_pickle(self): 48 cls = self.storage_class 49 50 def create_instance(module=None): 51 if module is None: 52 d = cls() 53 else: 54 old = cls.__module__ 55 cls.__module__ = module 56 d = cls() 57 cls.__module__ = old 58 d.setlist(b"foo", [1, 2, 3, 4]) 59 d.setlist(b"bar", b"foo bar baz".split()) 60 return d 61 62 for protocol in range(pickle.HIGHEST_PROTOCOL + 1): 63 d = create_instance() 64 s = pickle.dumps(d, protocol) 65 ud = pickle.loads(s) 66 assert type(ud) == type(d) 67 assert ud == d 68 alternative = pickle.dumps(create_instance("werkzeug"), protocol) 69 assert pickle.loads(alternative) == d 70 ud[b"newkey"] = b"bla" 71 assert ud != d 72 73 def test_multidict_dict_interop(self): 74 # https://github.com/pallets/werkzeug/pull/2043 75 md = self.storage_class([("a", 1), ("a", 2)]) 76 assert dict(md)["a"] != [1, 2] 77 assert dict(md)["a"] == 1 78 assert dict(md) == {**md} == {"a": 1} 79 80 def test_basic_interface(self): 81 md = self.storage_class() 82 assert isinstance(md, dict) 83 84 mapping = [ 85 ("a", 1), 86 ("b", 2), 87 ("a", 2), 88 ("d", 3), 89 ("a", 1), 90 ("a", 3), 91 ("d", 4), 92 ("c", 3), 93 ] 94 md = self.storage_class(mapping) 95 96 # simple getitem gives the first value 97 assert md["a"] == 1 98 assert md["c"] == 3 99 with pytest.raises(KeyError): 100 md["e"] 101 assert md.get("a") == 1 102 103 # list getitem 104 assert md.getlist("a") == [1, 2, 1, 3] 105 assert md.getlist("d") == [3, 4] 106 # do not raise if key not found 107 assert md.getlist("x") == [] 108 109 # simple setitem overwrites all values 110 md["a"] = 42 111 assert md.getlist("a") == [42] 112 113 # list setitem 114 md.setlist("a", [1, 2, 3]) 115 assert md["a"] == 1 116 assert md.getlist("a") == [1, 2, 3] 117 118 # verify that it does not change original lists 119 l1 = [1, 2, 3] 120 md.setlist("a", l1) 121 del l1[:] 122 assert md["a"] == 1 123 124 # setdefault, setlistdefault 125 assert md.setdefault("u", 23) == 23 126 assert md.getlist("u") == [23] 127 del md["u"] 128 129 md.setlist("u", [-1, -2]) 130 131 # delitem 132 del md["u"] 133 with pytest.raises(KeyError): 134 md["u"] 135 del md["d"] 136 assert md.getlist("d") == [] 137 138 # keys, values, items, lists 139 assert list(sorted(md.keys())) == ["a", "b", "c"] 140 assert list(sorted(md.keys())) == ["a", "b", "c"] 141 142 assert list(sorted(md.values())) == [1, 2, 3] 143 assert list(sorted(md.values())) == [1, 2, 3] 144 145 assert list(sorted(md.items())) == [("a", 1), ("b", 2), ("c", 3)] 146 assert list(sorted(md.items(multi=True))) == [ 147 ("a", 1), 148 ("a", 2), 149 ("a", 3), 150 ("b", 2), 151 ("c", 3), 152 ] 153 assert list(sorted(md.items())) == [("a", 1), ("b", 2), ("c", 3)] 154 assert list(sorted(md.items(multi=True))) == [ 155 ("a", 1), 156 ("a", 2), 157 ("a", 3), 158 ("b", 2), 159 ("c", 3), 160 ] 161 162 assert list(sorted(md.lists())) == [("a", [1, 2, 3]), ("b", [2]), ("c", [3])] 163 assert list(sorted(md.lists())) == [("a", [1, 2, 3]), ("b", [2]), ("c", [3])] 164 165 # copy method 166 c = md.copy() 167 assert c["a"] == 1 168 assert c.getlist("a") == [1, 2, 3] 169 170 # copy method 2 171 c = copy(md) 172 assert c["a"] == 1 173 assert c.getlist("a") == [1, 2, 3] 174 175 # deepcopy method 176 c = md.deepcopy() 177 assert c["a"] == 1 178 assert c.getlist("a") == [1, 2, 3] 179 180 # deepcopy method 2 181 c = deepcopy(md) 182 assert c["a"] == 1 183 assert c.getlist("a") == [1, 2, 3] 184 185 # update with a multidict 186 od = self.storage_class([("a", 4), ("a", 5), ("y", 0)]) 187 md.update(od) 188 assert md.getlist("a") == [1, 2, 3, 4, 5] 189 assert md.getlist("y") == [0] 190 191 # update with a regular dict 192 md = c 193 od = {"a": 4, "y": 0} 194 md.update(od) 195 assert md.getlist("a") == [1, 2, 3, 4] 196 assert md.getlist("y") == [0] 197 198 # pop, poplist, popitem, popitemlist 199 assert md.pop("y") == 0 200 assert "y" not in md 201 assert md.poplist("a") == [1, 2, 3, 4] 202 assert "a" not in md 203 assert md.poplist("missing") == [] 204 205 # remaining: b=2, c=3 206 popped = md.popitem() 207 assert popped in [("b", 2), ("c", 3)] 208 popped = md.popitemlist() 209 assert popped in [("b", [2]), ("c", [3])] 210 211 # type conversion 212 md = self.storage_class({"a": "4", "b": ["2", "3"]}) 213 assert md.get("a", type=int) == 4 214 assert md.getlist("b", type=int) == [2, 3] 215 216 # repr 217 md = self.storage_class([("a", 1), ("a", 2), ("b", 3)]) 218 assert "('a', 1)" in repr(md) 219 assert "('a', 2)" in repr(md) 220 assert "('b', 3)" in repr(md) 221 222 # add and getlist 223 md.add("c", "42") 224 md.add("c", "23") 225 assert md.getlist("c") == ["42", "23"] 226 md.add("c", "blah") 227 assert md.getlist("c", type=int) == [42, 23] 228 229 # setdefault 230 md = self.storage_class() 231 md.setdefault("x", []).append(42) 232 md.setdefault("x", []).append(23) 233 assert md["x"] == [42, 23] 234 235 # to dict 236 md = self.storage_class() 237 md["foo"] = 42 238 md.add("bar", 1) 239 md.add("bar", 2) 240 assert md.to_dict() == {"foo": 42, "bar": 1} 241 assert md.to_dict(flat=False) == {"foo": [42], "bar": [1, 2]} 242 243 # popitem from empty dict 244 with pytest.raises(KeyError): 245 self.storage_class().popitem() 246 247 with pytest.raises(KeyError): 248 self.storage_class().popitemlist() 249 250 # key errors are of a special type 251 with pytest.raises(BadRequestKeyError): 252 self.storage_class()[42] 253 254 # setlist works 255 md = self.storage_class() 256 md["foo"] = 42 257 md.setlist("foo", [1, 2]) 258 assert md.getlist("foo") == [1, 2] 259 260 261class _ImmutableDictTests: 262 storage_class: t.Type[dict] 263 264 def test_follows_dict_interface(self): 265 cls = self.storage_class 266 267 data = {"foo": 1, "bar": 2, "baz": 3} 268 d = cls(data) 269 270 assert d["foo"] == 1 271 assert d["bar"] == 2 272 assert d["baz"] == 3 273 assert sorted(d.keys()) == ["bar", "baz", "foo"] 274 assert "foo" in d 275 assert "foox" not in d 276 assert len(d) == 3 277 278 def test_copies_are_mutable(self): 279 cls = self.storage_class 280 immutable = cls({"a": 1}) 281 with pytest.raises(TypeError): 282 immutable.pop("a") 283 284 mutable = immutable.copy() 285 mutable.pop("a") 286 assert "a" in immutable 287 assert mutable is not immutable 288 assert copy(immutable) is immutable 289 290 def test_dict_is_hashable(self): 291 cls = self.storage_class 292 immutable = cls({"a": 1, "b": 2}) 293 immutable2 = cls({"a": 2, "b": 2}) 294 x = {immutable} 295 assert immutable in x 296 assert immutable2 not in x 297 x.discard(immutable) 298 assert immutable not in x 299 assert immutable2 not in x 300 x.add(immutable2) 301 assert immutable not in x 302 assert immutable2 in x 303 x.add(immutable) 304 assert immutable in x 305 assert immutable2 in x 306 307 308class TestImmutableTypeConversionDict(_ImmutableDictTests): 309 storage_class = ds.ImmutableTypeConversionDict 310 311 312class TestImmutableMultiDict(_ImmutableDictTests): 313 storage_class = ds.ImmutableMultiDict 314 315 def test_multidict_is_hashable(self): 316 cls = self.storage_class 317 immutable = cls({"a": [1, 2], "b": 2}) 318 immutable2 = cls({"a": [1], "b": 2}) 319 x = {immutable} 320 assert immutable in x 321 assert immutable2 not in x 322 x.discard(immutable) 323 assert immutable not in x 324 assert immutable2 not in x 325 x.add(immutable2) 326 assert immutable not in x 327 assert immutable2 in x 328 x.add(immutable) 329 assert immutable in x 330 assert immutable2 in x 331 332 333class TestImmutableDict(_ImmutableDictTests): 334 storage_class = ds.ImmutableDict 335 336 337class TestImmutableOrderedMultiDict(_ImmutableDictTests): 338 storage_class = ds.ImmutableOrderedMultiDict 339 340 def test_ordered_multidict_is_hashable(self): 341 a = self.storage_class([("a", 1), ("b", 1), ("a", 2)]) 342 b = self.storage_class([("a", 1), ("a", 2), ("b", 1)]) 343 assert hash(a) != hash(b) 344 345 346class TestMultiDict(_MutableMultiDictTests): 347 storage_class = ds.MultiDict 348 349 def test_multidict_pop(self): 350 def make_d(): 351 return self.storage_class({"foo": [1, 2, 3, 4]}) 352 353 d = make_d() 354 assert d.pop("foo") == 1 355 assert not d 356 d = make_d() 357 assert d.pop("foo", 32) == 1 358 assert not d 359 d = make_d() 360 assert d.pop("foos", 32) == 32 361 assert d 362 363 with pytest.raises(KeyError): 364 d.pop("foos") 365 366 def test_multidict_pop_raise_badrequestkeyerror_for_empty_list_value(self): 367 mapping = [("a", "b"), ("a", "c")] 368 md = self.storage_class(mapping) 369 370 md.setlistdefault("empty", []) 371 372 with pytest.raises(KeyError): 373 md.pop("empty") 374 375 def test_multidict_popitem_raise_badrequestkeyerror_for_empty_list_value(self): 376 mapping = [] 377 md = self.storage_class(mapping) 378 379 md.setlistdefault("empty", []) 380 381 with pytest.raises(BadRequestKeyError): 382 md.popitem() 383 384 def test_setlistdefault(self): 385 md = self.storage_class() 386 assert md.setlistdefault("u", [-1, -2]) == [-1, -2] 387 assert md.getlist("u") == [-1, -2] 388 assert md["u"] == -1 389 390 def test_iter_interfaces(self): 391 mapping = [ 392 ("a", 1), 393 ("b", 2), 394 ("a", 2), 395 ("d", 3), 396 ("a", 1), 397 ("a", 3), 398 ("d", 4), 399 ("c", 3), 400 ] 401 md = self.storage_class(mapping) 402 assert list(zip(md.keys(), md.listvalues())) == list(md.lists()) 403 assert list(zip(md, md.listvalues())) == list(md.lists()) 404 assert list(zip(md.keys(), md.listvalues())) == list(md.lists()) 405 406 def test_getitem_raise_badrequestkeyerror_for_empty_list_value(self): 407 mapping = [("a", "b"), ("a", "c")] 408 md = self.storage_class(mapping) 409 410 md.setlistdefault("empty", []) 411 412 with pytest.raises(KeyError): 413 md["empty"] 414 415 416class TestOrderedMultiDict(_MutableMultiDictTests): 417 storage_class = ds.OrderedMultiDict 418 419 def test_ordered_interface(self): 420 cls = self.storage_class 421 422 d = cls() 423 assert not d 424 d.add("foo", "bar") 425 assert len(d) == 1 426 d.add("foo", "baz") 427 assert len(d) == 1 428 assert list(d.items()) == [("foo", "bar")] 429 assert list(d) == ["foo"] 430 assert list(d.items(multi=True)) == [("foo", "bar"), ("foo", "baz")] 431 del d["foo"] 432 assert not d 433 assert len(d) == 0 434 assert list(d) == [] 435 436 d.update([("foo", 1), ("foo", 2), ("bar", 42)]) 437 d.add("foo", 3) 438 assert d.getlist("foo") == [1, 2, 3] 439 assert d.getlist("bar") == [42] 440 assert list(d.items()) == [("foo", 1), ("bar", 42)] 441 442 expected = ["foo", "bar"] 443 444 assert list(d.keys()) == expected 445 assert list(d) == expected 446 assert list(d.keys()) == expected 447 448 assert list(d.items(multi=True)) == [ 449 ("foo", 1), 450 ("foo", 2), 451 ("bar", 42), 452 ("foo", 3), 453 ] 454 assert len(d) == 2 455 456 assert d.pop("foo") == 1 457 assert d.pop("blafasel", None) is None 458 assert d.pop("blafasel", 42) == 42 459 assert len(d) == 1 460 assert d.poplist("bar") == [42] 461 assert not d 462 463 assert d.get("missingkey") is None 464 465 d.add("foo", 42) 466 d.add("foo", 23) 467 d.add("bar", 2) 468 d.add("foo", 42) 469 assert d == ds.MultiDict(d) 470 id = self.storage_class(d) 471 assert d == id 472 d.add("foo", 2) 473 assert d != id 474 475 d.update({"blah": [1, 2, 3]}) 476 assert d["blah"] == 1 477 assert d.getlist("blah") == [1, 2, 3] 478 479 # setlist works 480 d = self.storage_class() 481 d["foo"] = 42 482 d.setlist("foo", [1, 2]) 483 assert d.getlist("foo") == [1, 2] 484 with pytest.raises(BadRequestKeyError): 485 d.pop("missing") 486 487 with pytest.raises(BadRequestKeyError): 488 d["missing"] 489 490 # popping 491 d = self.storage_class() 492 d.add("foo", 23) 493 d.add("foo", 42) 494 d.add("foo", 1) 495 assert d.popitem() == ("foo", 23) 496 with pytest.raises(BadRequestKeyError): 497 d.popitem() 498 assert not d 499 500 d.add("foo", 23) 501 d.add("foo", 42) 502 d.add("foo", 1) 503 assert d.popitemlist() == ("foo", [23, 42, 1]) 504 505 with pytest.raises(BadRequestKeyError): 506 d.popitemlist() 507 508 # Unhashable 509 d = self.storage_class() 510 d.add("foo", 23) 511 pytest.raises(TypeError, hash, d) 512 513 def test_iterables(self): 514 a = ds.MultiDict((("key_a", "value_a"),)) 515 b = ds.MultiDict((("key_b", "value_b"),)) 516 ab = ds.CombinedMultiDict((a, b)) 517 518 assert sorted(ab.lists()) == [("key_a", ["value_a"]), ("key_b", ["value_b"])] 519 assert sorted(ab.listvalues()) == [["value_a"], ["value_b"]] 520 assert sorted(ab.keys()) == ["key_a", "key_b"] 521 522 assert sorted(ab.lists()) == [("key_a", ["value_a"]), ("key_b", ["value_b"])] 523 assert sorted(ab.listvalues()) == [["value_a"], ["value_b"]] 524 assert sorted(ab.keys()) == ["key_a", "key_b"] 525 526 def test_get_description(self): 527 data = ds.OrderedMultiDict() 528 529 with pytest.raises(BadRequestKeyError) as exc_info: 530 data["baz"] 531 532 assert "baz" not in exc_info.value.get_description() 533 exc_info.value.show_exception = True 534 assert "baz" in exc_info.value.get_description() 535 536 with pytest.raises(BadRequestKeyError) as exc_info: 537 data.pop("baz") 538 539 exc_info.value.show_exception = True 540 assert "baz" in exc_info.value.get_description() 541 exc_info.value.args = () 542 assert "baz" not in exc_info.value.get_description() 543 544 545class TestTypeConversionDict: 546 storage_class = ds.TypeConversionDict 547 548 def test_value_conversion(self): 549 d = self.storage_class(foo="1") 550 assert d.get("foo", type=int) == 1 551 552 def test_return_default_when_conversion_is_not_possible(self): 553 d = self.storage_class(foo="bar") 554 assert d.get("foo", default=-1, type=int) == -1 555 556 def test_propagate_exceptions_in_conversion(self): 557 d = self.storage_class(foo="bar") 558 switch = {"a": 1} 559 with pytest.raises(KeyError): 560 d.get("foo", type=lambda x: switch[x]) 561 562 563class TestCombinedMultiDict: 564 storage_class = ds.CombinedMultiDict 565 566 def test_basic_interface(self): 567 d1 = ds.MultiDict([("foo", "1")]) 568 d2 = ds.MultiDict([("bar", "2"), ("bar", "3")]) 569 d = self.storage_class([d1, d2]) 570 571 # lookup 572 assert d["foo"] == "1" 573 assert d["bar"] == "2" 574 assert d.getlist("bar") == ["2", "3"] 575 576 assert sorted(d.items()) == [("bar", "2"), ("foo", "1")] 577 assert sorted(d.items(multi=True)) == [("bar", "2"), ("bar", "3"), ("foo", "1")] 578 assert "missingkey" not in d 579 assert "foo" in d 580 581 # type lookup 582 assert d.get("foo", type=int) == 1 583 assert d.getlist("bar", type=int) == [2, 3] 584 585 # get key errors for missing stuff 586 with pytest.raises(KeyError): 587 d["missing"] 588 589 # make sure that they are immutable 590 with pytest.raises(TypeError): 591 d["foo"] = "blub" 592 593 # copies are mutable 594 d = d.copy() 595 d["foo"] = "blub" 596 597 # make sure lists merges 598 md1 = ds.MultiDict((("foo", "bar"), ("foo", "baz"))) 599 md2 = ds.MultiDict((("foo", "blafasel"),)) 600 x = self.storage_class((md1, md2)) 601 assert list(x.lists()) == [("foo", ["bar", "baz", "blafasel"])] 602 603 # make sure dicts are created properly 604 assert x.to_dict() == {"foo": "bar"} 605 assert x.to_dict(flat=False) == {"foo": ["bar", "baz", "blafasel"]} 606 607 def test_length(self): 608 d1 = ds.MultiDict([("foo", "1")]) 609 d2 = ds.MultiDict([("bar", "2")]) 610 assert len(d1) == len(d2) == 1 611 d = self.storage_class([d1, d2]) 612 assert len(d) == 2 613 d1.clear() 614 assert len(d1) == 0 615 assert len(d) == 1 616 617 618class TestHeaders: 619 storage_class = ds.Headers 620 621 def test_basic_interface(self): 622 headers = self.storage_class() 623 headers.add("Content-Type", "text/plain") 624 headers.add("X-Foo", "bar") 625 assert "x-Foo" in headers 626 assert "Content-type" in headers 627 628 headers["Content-Type"] = "foo/bar" 629 assert headers["Content-Type"] == "foo/bar" 630 assert len(headers.getlist("Content-Type")) == 1 631 632 # list conversion 633 assert headers.to_wsgi_list() == [("Content-Type", "foo/bar"), ("X-Foo", "bar")] 634 assert str(headers) == "Content-Type: foo/bar\r\nX-Foo: bar\r\n\r\n" 635 assert str(self.storage_class()) == "\r\n" 636 637 # extended add 638 headers.add("Content-Disposition", "attachment", filename="foo") 639 assert headers["Content-Disposition"] == "attachment; filename=foo" 640 641 headers.add("x", "y", z='"') 642 assert headers["x"] == r'y; z="\""' 643 644 def test_defaults_and_conversion(self): 645 # defaults 646 headers = self.storage_class( 647 [ 648 ("Content-Type", "text/plain"), 649 ("X-Foo", "bar"), 650 ("X-Bar", "1"), 651 ("X-Bar", "2"), 652 ] 653 ) 654 assert headers.getlist("x-bar") == ["1", "2"] 655 assert headers.get("x-Bar") == "1" 656 assert headers.get("Content-Type") == "text/plain" 657 658 assert headers.setdefault("X-Foo", "nope") == "bar" 659 assert headers.setdefault("X-Bar", "nope") == "1" 660 assert headers.setdefault("X-Baz", "quux") == "quux" 661 assert headers.setdefault("X-Baz", "nope") == "quux" 662 headers.pop("X-Baz") 663 664 # type conversion 665 assert headers.get("x-bar", type=int) == 1 666 assert headers.getlist("x-bar", type=int) == [1, 2] 667 668 # list like operations 669 assert headers[0] == ("Content-Type", "text/plain") 670 assert headers[:1] == self.storage_class([("Content-Type", "text/plain")]) 671 del headers[:2] 672 del headers[-1] 673 assert headers == self.storage_class([("X-Bar", "1")]) 674 675 def test_copying(self): 676 a = self.storage_class([("foo", "bar")]) 677 b = a.copy() 678 a.add("foo", "baz") 679 assert a.getlist("foo") == ["bar", "baz"] 680 assert b.getlist("foo") == ["bar"] 681 682 def test_popping(self): 683 headers = self.storage_class([("a", 1)]) 684 assert headers.pop("a") == 1 685 assert headers.pop("b", 2) == 2 686 687 with pytest.raises(KeyError): 688 headers.pop("c") 689 690 def test_set_arguments(self): 691 a = self.storage_class() 692 a.set("Content-Disposition", "useless") 693 a.set("Content-Disposition", "attachment", filename="foo") 694 assert a["Content-Disposition"] == "attachment; filename=foo" 695 696 def test_reject_newlines(self): 697 h = self.storage_class() 698 699 for variation in "foo\nbar", "foo\r\nbar", "foo\rbar": 700 with pytest.raises(ValueError): 701 h["foo"] = variation 702 with pytest.raises(ValueError): 703 h.add("foo", variation) 704 with pytest.raises(ValueError): 705 h.add("foo", "test", option=variation) 706 with pytest.raises(ValueError): 707 h.set("foo", variation) 708 with pytest.raises(ValueError): 709 h.set("foo", "test", option=variation) 710 711 def test_slicing(self): 712 # there's nothing wrong with these being native strings 713 # Headers doesn't care about the data types 714 h = self.storage_class() 715 h.set("X-Foo-Poo", "bleh") 716 h.set("Content-Type", "application/whocares") 717 h.set("X-Forwarded-For", "192.168.0.123") 718 h[:] = [(k, v) for k, v in h if k.startswith("X-")] 719 assert list(h) == [("X-Foo-Poo", "bleh"), ("X-Forwarded-For", "192.168.0.123")] 720 721 def test_bytes_operations(self): 722 h = self.storage_class() 723 h.set("X-Foo-Poo", "bleh") 724 h.set("X-Whoops", b"\xff") 725 h.set(b"X-Bytes", b"something") 726 727 assert h.get("x-foo-poo", as_bytes=True) == b"bleh" 728 assert h.get("x-whoops", as_bytes=True) == b"\xff" 729 assert h.get("x-bytes") == "something" 730 731 def test_extend(self): 732 h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")]) 733 h.extend(ds.Headers([("a", "3"), ("a", "4")])) 734 assert h.getlist("a") == ["0", "3", "4"] 735 h.extend(b=["5", "6"]) 736 assert h.getlist("b") == ["1", "5", "6"] 737 h.extend({"c": "7", "d": ["8", "9"]}, c="10") 738 assert h.getlist("c") == ["2", "7", "10"] 739 assert h.getlist("d") == ["8", "9"] 740 741 with pytest.raises(TypeError): 742 h.extend({"x": "x"}, {"x": "x"}) 743 744 def test_update(self): 745 h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")]) 746 h.update(ds.Headers([("a", "3"), ("a", "4")])) 747 assert h.getlist("a") == ["3", "4"] 748 h.update(b=["5", "6"]) 749 assert h.getlist("b") == ["5", "6"] 750 h.update({"c": "7", "d": ["8", "9"]}) 751 assert h.getlist("c") == ["7"] 752 assert h.getlist("d") == ["8", "9"] 753 h.update({"c": "10"}, c="11") 754 assert h.getlist("c") == ["11"] 755 756 with pytest.raises(TypeError): 757 h.extend({"x": "x"}, {"x": "x"}) 758 759 def test_setlist(self): 760 h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")]) 761 h.setlist("b", ["3", "4"]) 762 assert h[1] == ("b", "3") 763 assert h[-1] == ("b", "4") 764 h.setlist("b", []) 765 assert "b" not in h 766 h.setlist("d", ["5"]) 767 assert h["d"] == "5" 768 769 def test_setlistdefault(self): 770 h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")]) 771 assert h.setlistdefault("a", ["3"]) == ["0"] 772 assert h.setlistdefault("d", ["4", "5"]) == ["4", "5"] 773 774 def test_to_wsgi_list(self): 775 h = self.storage_class() 776 h.set("Key", "Value") 777 for key, value in h.to_wsgi_list(): 778 assert key == "Key" 779 assert value == "Value" 780 781 def test_to_wsgi_list_bytes(self): 782 h = self.storage_class() 783 h.set(b"Key", b"Value") 784 for key, value in h.to_wsgi_list(): 785 assert key == "Key" 786 assert value == "Value" 787 788 def test_equality(self): 789 # test equality, given keys are case insensitive 790 h1 = self.storage_class() 791 h1.add("X-Foo", "foo") 792 h1.add("X-Bar", "bah") 793 h1.add("X-Bar", "humbug") 794 795 h2 = self.storage_class() 796 h2.add("x-foo", "foo") 797 h2.add("x-bar", "bah") 798 h2.add("x-bar", "humbug") 799 800 assert h1 == h2 801 802 803class TestEnvironHeaders: 804 storage_class = ds.EnvironHeaders 805 806 def test_basic_interface(self): 807 # this happens in multiple WSGI servers because they 808 # use a vary naive way to convert the headers; 809 broken_env = { 810 "HTTP_CONTENT_TYPE": "text/html", 811 "CONTENT_TYPE": "text/html", 812 "HTTP_CONTENT_LENGTH": "0", 813 "CONTENT_LENGTH": "0", 814 "HTTP_ACCEPT": "*", 815 "wsgi.version": (1, 0), 816 } 817 headers = self.storage_class(broken_env) 818 assert headers 819 assert len(headers) == 3 820 assert sorted(headers) == [ 821 ("Accept", "*"), 822 ("Content-Length", "0"), 823 ("Content-Type", "text/html"), 824 ] 825 assert not self.storage_class({"wsgi.version": (1, 0)}) 826 assert len(self.storage_class({"wsgi.version": (1, 0)})) == 0 827 assert 42 not in headers 828 829 def test_skip_empty_special_vars(self): 830 env = {"HTTP_X_FOO": "42", "CONTENT_TYPE": "", "CONTENT_LENGTH": ""} 831 headers = self.storage_class(env) 832 assert dict(headers) == {"X-Foo": "42"} 833 834 env = {"HTTP_X_FOO": "42", "CONTENT_TYPE": "", "CONTENT_LENGTH": "0"} 835 headers = self.storage_class(env) 836 assert dict(headers) == {"X-Foo": "42", "Content-Length": "0"} 837 838 def test_return_type_is_str(self): 839 headers = self.storage_class({"HTTP_FOO": "\xe2\x9c\x93"}) 840 assert headers["Foo"] == "\xe2\x9c\x93" 841 assert next(iter(headers)) == ("Foo", "\xe2\x9c\x93") 842 843 def test_bytes_operations(self): 844 foo_val = "\xff" 845 h = self.storage_class({"HTTP_X_FOO": foo_val}) 846 847 assert h.get("x-foo", as_bytes=True) == b"\xff" 848 assert h.get("x-foo") == "\xff" 849 850 851class TestHeaderSet: 852 storage_class = ds.HeaderSet 853 854 def test_basic_interface(self): 855 hs = self.storage_class() 856 hs.add("foo") 857 hs.add("bar") 858 assert "Bar" in hs 859 assert hs.find("foo") == 0 860 assert hs.find("BAR") == 1 861 assert hs.find("baz") < 0 862 hs.discard("missing") 863 hs.discard("foo") 864 assert hs.find("foo") < 0 865 assert hs.find("bar") == 0 866 867 with pytest.raises(IndexError): 868 hs.index("missing") 869 870 assert hs.index("bar") == 0 871 assert hs 872 hs.clear() 873 assert not hs 874 875 876class TestImmutableList: 877 storage_class = ds.ImmutableList 878 879 def test_list_hashable(self): 880 data = (1, 2, 3, 4) 881 store = self.storage_class(data) 882 assert hash(data) == hash(store) 883 assert data != store 884 885 886def make_call_asserter(func=None): 887 """Utility to assert a certain number of function calls. 888 889 :param func: Additional callback for each function call. 890 891 .. code-block:: python 892 assert_calls, func = make_call_asserter() 893 with assert_calls(2): 894 func() 895 func() 896 """ 897 calls = [0] 898 899 @contextmanager 900 def asserter(count, msg=None): 901 calls[0] = 0 902 yield 903 assert calls[0] == count 904 905 def wrapped(*args, **kwargs): 906 calls[0] += 1 907 if func is not None: 908 return func(*args, **kwargs) 909 910 return asserter, wrapped 911 912 913class TestCallbackDict: 914 storage_class = ds.CallbackDict 915 916 def test_callback_dict_reads(self): 917 assert_calls, func = make_call_asserter() 918 initial = {"a": "foo", "b": "bar"} 919 dct = self.storage_class(initial=initial, on_update=func) 920 with assert_calls(0, "callback triggered by read-only method"): 921 # read-only methods 922 dct["a"] 923 dct.get("a") 924 pytest.raises(KeyError, lambda: dct["x"]) 925 assert "a" in dct 926 list(iter(dct)) 927 dct.copy() 928 with assert_calls(0, "callback triggered without modification"): 929 # methods that may write but don't 930 dct.pop("z", None) 931 dct.setdefault("a") 932 933 def test_callback_dict_writes(self): 934 assert_calls, func = make_call_asserter() 935 initial = {"a": "foo", "b": "bar"} 936 dct = self.storage_class(initial=initial, on_update=func) 937 with assert_calls(8, "callback not triggered by write method"): 938 # always-write methods 939 dct["z"] = 123 940 dct["z"] = 123 # must trigger again 941 del dct["z"] 942 dct.pop("b", None) 943 dct.setdefault("x") 944 dct.popitem() 945 dct.update([]) 946 dct.clear() 947 with assert_calls(0, "callback triggered by failed del"): 948 pytest.raises(KeyError, lambda: dct.__delitem__("x")) 949 with assert_calls(0, "callback triggered by failed pop"): 950 pytest.raises(KeyError, lambda: dct.pop("x")) 951 952 953class TestCacheControl: 954 def test_repr(self): 955 cc = ds.RequestCacheControl([("max-age", "0"), ("private", "True")]) 956 assert repr(cc) == "<RequestCacheControl max-age='0' private='True'>" 957 958 def test_set_none(self): 959 cc = ds.ResponseCacheControl([("max-age", "0")]) 960 assert cc.no_cache is None 961 cc.no_cache = None 962 assert cc.no_cache is None 963 964 965class TestContentSecurityPolicy: 966 def test_construct(self): 967 csp = ds.ContentSecurityPolicy([("font-src", "'self'"), ("media-src", "*")]) 968 assert csp.font_src == "'self'" 969 assert csp.media_src == "*" 970 policies = [policy.strip() for policy in csp.to_header().split(";")] 971 assert "font-src 'self'" in policies 972 assert "media-src *" in policies 973 974 def test_properties(self): 975 csp = ds.ContentSecurityPolicy() 976 csp.default_src = "* 'self' quart.com" 977 csp.img_src = "'none'" 978 policies = [policy.strip() for policy in csp.to_header().split(";")] 979 assert "default-src * 'self' quart.com" in policies 980 assert "img-src 'none'" in policies 981 982 983class TestAccept: 984 storage_class = ds.Accept 985 986 def test_accept_basic(self): 987 accept = self.storage_class( 988 [("tinker", 0), ("tailor", 0.333), ("soldier", 0.667), ("sailor", 1)] 989 ) 990 # check __getitem__ on indices 991 assert accept[3] == ("tinker", 0) 992 assert accept[2] == ("tailor", 0.333) 993 assert accept[1] == ("soldier", 0.667) 994 assert accept[0], ("sailor", 1) 995 # check __getitem__ on string 996 assert accept["tinker"] == 0 997 assert accept["tailor"] == 0.333 998 assert accept["soldier"] == 0.667 999 assert accept["sailor"] == 1 1000 assert accept["spy"] == 0 1001 # check quality method 1002 assert accept.quality("tinker") == 0 1003 assert accept.quality("tailor") == 0.333 1004 assert accept.quality("soldier") == 0.667 1005 assert accept.quality("sailor") == 1 1006 assert accept.quality("spy") == 0 1007 # check __contains__ 1008 assert "sailor" in accept 1009 assert "spy" not in accept 1010 # check index method 1011 assert accept.index("tinker") == 3 1012 assert accept.index("tailor") == 2 1013 assert accept.index("soldier") == 1 1014 assert accept.index("sailor") == 0 1015 with pytest.raises(ValueError): 1016 accept.index("spy") 1017 # check find method 1018 assert accept.find("tinker") == 3 1019 assert accept.find("tailor") == 2 1020 assert accept.find("soldier") == 1 1021 assert accept.find("sailor") == 0 1022 assert accept.find("spy") == -1 1023 # check to_header method 1024 assert accept.to_header() == "sailor,soldier;q=0.667,tailor;q=0.333,tinker;q=0" 1025 # check best_match method 1026 assert ( 1027 accept.best_match(["tinker", "tailor", "soldier", "sailor"], default=None) 1028 == "sailor" 1029 ) 1030 assert ( 1031 accept.best_match(["tinker", "tailor", "soldier"], default=None) 1032 == "soldier" 1033 ) 1034 assert accept.best_match(["tinker", "tailor"], default=None) == "tailor" 1035 assert accept.best_match(["tinker"], default=None) is None 1036 assert accept.best_match(["tinker"], default="x") == "x" 1037 1038 def test_accept_wildcard(self): 1039 accept = self.storage_class([("*", 0), ("asterisk", 1)]) 1040 assert "*" in accept 1041 assert accept.best_match(["asterisk", "star"], default=None) == "asterisk" 1042 assert accept.best_match(["star"], default=None) is None 1043 1044 def test_accept_keep_order(self): 1045 accept = self.storage_class([("*", 1)]) 1046 assert accept.best_match(["alice", "bob"]) == "alice" 1047 assert accept.best_match(["bob", "alice"]) == "bob" 1048 accept = self.storage_class([("alice", 1), ("bob", 1)]) 1049 assert accept.best_match(["alice", "bob"]) == "alice" 1050 assert accept.best_match(["bob", "alice"]) == "bob" 1051 1052 def test_accept_wildcard_specificity(self): 1053 accept = self.storage_class([("asterisk", 0), ("star", 0.5), ("*", 1)]) 1054 assert accept.best_match(["star", "asterisk"], default=None) == "star" 1055 assert accept.best_match(["asterisk", "star"], default=None) == "star" 1056 assert accept.best_match(["asterisk", "times"], default=None) == "times" 1057 assert accept.best_match(["asterisk"], default=None) is None 1058 1059 def test_accept_equal_quality(self): 1060 accept = self.storage_class([("a", 1), ("b", 1)]) 1061 assert accept.best == "a" 1062 1063 1064class TestMIMEAccept: 1065 @pytest.mark.parametrize( 1066 ("values", "matches", "default", "expect"), 1067 [ 1068 ([("text/*", 1)], ["text/html"], None, "text/html"), 1069 ([("text/*", 1)], ["image/png"], "text/plain", "text/plain"), 1070 ([("text/*", 1)], ["image/png"], None, None), 1071 ( 1072 [("*/*", 1), ("text/html", 1)], 1073 ["image/png", "text/html"], 1074 None, 1075 "text/html", 1076 ), 1077 ( 1078 [("*/*", 1), ("text/html", 1)], 1079 ["image/png", "text/plain"], 1080 None, 1081 "image/png", 1082 ), 1083 ( 1084 [("*/*", 1), ("text/html", 1), ("image/*", 1)], 1085 ["image/png", "text/html"], 1086 None, 1087 "text/html", 1088 ), 1089 ( 1090 [("*/*", 1), ("text/html", 1), ("image/*", 1)], 1091 ["text/plain", "image/png"], 1092 None, 1093 "image/png", 1094 ), 1095 ( 1096 [("text/html", 1), ("text/html; level=1", 1)], 1097 ["text/html;level=1"], 1098 None, 1099 "text/html;level=1", 1100 ), 1101 ], 1102 ) 1103 def test_mime_accept(self, values, matches, default, expect): 1104 accept = ds.MIMEAccept(values) 1105 match = accept.best_match(matches, default=default) 1106 assert match == expect 1107 1108 1109class TestLanguageAccept: 1110 @pytest.mark.parametrize( 1111 ("values", "matches", "default", "expect"), 1112 ( 1113 ([("en-us", 1)], ["en"], None, "en"), 1114 ([("en", 1)], ["en_US"], None, "en_US"), 1115 ([("en-GB", 1)], ["en-US"], None, None), 1116 ([("de_AT", 1), ("de", 0.9)], ["en"], None, None), 1117 ([("de_AT", 1), ("de", 0.9), ("en-US", 0.8)], ["de", "en"], None, "de"), 1118 ([("de_AT", 0.9), ("en-US", 1)], ["en"], None, "en"), 1119 ([("en-us", 1)], ["en-us"], None, "en-us"), 1120 ([("en-us", 1)], ["en-us", "en"], None, "en-us"), 1121 ([("en-GB", 1)], ["en-US", "en"], "en-US", "en"), 1122 ([("de_AT", 1)], ["en-US", "en"], "en-US", "en-US"), 1123 ([("aus-EN", 1)], ["aus"], None, "aus"), 1124 ([("aus", 1)], ["aus-EN"], None, "aus-EN"), 1125 ), 1126 ) 1127 def test_best_match_fallback(self, values, matches, default, expect): 1128 accept = ds.LanguageAccept(values) 1129 best = accept.best_match(matches, default=default) 1130 assert best == expect 1131 1132 1133class TestFileStorage: 1134 storage_class = ds.FileStorage 1135 1136 def test_mimetype_always_lowercase(self): 1137 file_storage = self.storage_class(content_type="APPLICATION/JSON") 1138 assert file_storage.mimetype == "application/json" 1139 1140 @pytest.mark.parametrize("data", [io.StringIO("one\ntwo"), io.BytesIO(b"one\ntwo")]) 1141 def test_bytes_proper_sentinel(self, data): 1142 # iterate over new lines and don't enter an infinite loop 1143 storage = self.storage_class(data) 1144 idx = -1 1145 1146 for idx, _line in enumerate(storage): 1147 assert idx < 2 1148 1149 assert idx == 1 1150 1151 @pytest.mark.parametrize("stream", (tempfile.SpooledTemporaryFile, io.BytesIO)) 1152 def test_proxy_can_access_stream_attrs(self, stream): 1153 """``SpooledTemporaryFile`` doesn't implement some of 1154 ``IOBase``. Ensure that ``FileStorage`` can still access the 1155 attributes from the backing file object. 1156 1157 https://github.com/pallets/werkzeug/issues/1344 1158 https://github.com/python/cpython/pull/3249 1159 """ 1160 file_storage = self.storage_class(stream=stream()) 1161 1162 for name in ("fileno", "writable", "readable", "seekable"): 1163 assert hasattr(file_storage, name) 1164 1165 @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") 1166 def test_save_to_pathlib_dst(self, tmp_path): 1167 src = tmp_path / "src.txt" 1168 src.write_text("test") 1169 storage = self.storage_class(src.open("rb")) 1170 dst = tmp_path / "dst.txt" 1171 storage.save(dst) 1172 assert dst.read_text() == "test" 1173 1174 def test_save_to_bytes_io(self): 1175 storage = self.storage_class(io.BytesIO(b"one\ntwo")) 1176 dst = io.BytesIO() 1177 storage.save(dst) 1178 assert dst.getvalue() == b"one\ntwo" 1179 1180 def test_save_to_file(self, tmp_path): 1181 path = tmp_path / "file.data" 1182 storage = self.storage_class(io.BytesIO(b"one\ntwo")) 1183 with path.open("wb") as dst: 1184 storage.save(dst) 1185 with path.open("rb") as src: 1186 assert src.read() == b"one\ntwo" 1187 1188 1189@pytest.mark.parametrize("ranges", ([(0, 1), (-5, None)], [(5, None)])) 1190def test_range_to_header(ranges): 1191 header = ds.Range("byes", ranges).to_header() 1192 r = http.parse_range_header(header) 1193 assert r.ranges == ranges 1194 1195 1196@pytest.mark.parametrize( 1197 "ranges", ([(0, 0)], [(None, 1)], [(1, 0)], [(0, 1), (-5, 10)]) 1198) 1199def test_range_validates_ranges(ranges): 1200 with pytest.raises(ValueError): 1201 ds.Range("bytes", ranges) 1202