1import imghdr
2import io
3import os
4import pathlib
5import unittest
6import warnings
7from test.support import findfile
8from test.support.os_helper import TESTFN, unlink
9
10
11TEST_FILES = (
12    ('python.png', 'png'),
13    ('python.gif', 'gif'),
14    ('python.bmp', 'bmp'),
15    ('python.ppm', 'ppm'),
16    ('python.pgm', 'pgm'),
17    ('python.pbm', 'pbm'),
18    ('python.jpg', 'jpeg'),
19    ('python-raw.jpg', 'jpeg'),  # raw JPEG without JFIF/EXIF markers
20    ('python.ras', 'rast'),
21    ('python.sgi', 'rgb'),
22    ('python.tiff', 'tiff'),
23    ('python.xbm', 'xbm'),
24    ('python.webp', 'webp'),
25    ('python.exr', 'exr'),
26)
27
28class UnseekableIO(io.FileIO):
29    def tell(self):
30        raise io.UnsupportedOperation
31
32    def seek(self, *args, **kwargs):
33        raise io.UnsupportedOperation
34
35class TestImghdr(unittest.TestCase):
36    @classmethod
37    def setUpClass(cls):
38        cls.testfile = findfile('python.png', subdir='imghdrdata')
39        with open(cls.testfile, 'rb') as stream:
40            cls.testdata = stream.read()
41
42    def tearDown(self):
43        unlink(TESTFN)
44
45    def test_data(self):
46        for filename, expected in TEST_FILES:
47            filename = findfile(filename, subdir='imghdrdata')
48            self.assertEqual(imghdr.what(filename), expected)
49            with open(filename, 'rb') as stream:
50                self.assertEqual(imghdr.what(stream), expected)
51            with open(filename, 'rb') as stream:
52                data = stream.read()
53            self.assertEqual(imghdr.what(None, data), expected)
54            self.assertEqual(imghdr.what(None, bytearray(data)), expected)
55
56    def test_pathlike_filename(self):
57        for filename, expected in TEST_FILES:
58            with self.subTest(filename=filename):
59                filename = findfile(filename, subdir='imghdrdata')
60                self.assertEqual(imghdr.what(pathlib.Path(filename)), expected)
61
62    def test_register_test(self):
63        def test_jumbo(h, file):
64            if h.startswith(b'eggs'):
65                return 'ham'
66        imghdr.tests.append(test_jumbo)
67        self.addCleanup(imghdr.tests.pop)
68        self.assertEqual(imghdr.what(None, b'eggs'), 'ham')
69
70    def test_file_pos(self):
71        with open(TESTFN, 'wb') as stream:
72            stream.write(b'ababagalamaga')
73            pos = stream.tell()
74            stream.write(self.testdata)
75        with open(TESTFN, 'rb') as stream:
76            stream.seek(pos)
77            self.assertEqual(imghdr.what(stream), 'png')
78            self.assertEqual(stream.tell(), pos)
79
80    def test_bad_args(self):
81        with self.assertRaises(TypeError):
82            imghdr.what()
83        with self.assertRaises(AttributeError):
84            imghdr.what(None)
85        with self.assertRaises(TypeError):
86            imghdr.what(self.testfile, 1)
87        with self.assertRaises(AttributeError):
88            imghdr.what(os.fsencode(self.testfile))
89        with open(self.testfile, 'rb') as f:
90            with self.assertRaises(AttributeError):
91                imghdr.what(f.fileno())
92
93    def test_invalid_headers(self):
94        for header in (b'\211PN\r\n',
95                       b'\001\331',
96                       b'\x59\xA6',
97                       b'cutecat',
98                       b'000000JFI',
99                       b'GIF80'):
100            self.assertIsNone(imghdr.what(None, header))
101
102    def test_string_data(self):
103        with warnings.catch_warnings():
104            warnings.simplefilter("ignore", BytesWarning)
105            for filename, _ in TEST_FILES:
106                filename = findfile(filename, subdir='imghdrdata')
107                with open(filename, 'rb') as stream:
108                    data = stream.read().decode('latin1')
109                with self.assertRaises(TypeError):
110                    imghdr.what(io.StringIO(data))
111                with self.assertRaises(TypeError):
112                    imghdr.what(None, data)
113
114    def test_missing_file(self):
115        with self.assertRaises(FileNotFoundError):
116            imghdr.what('missing')
117
118    def test_closed_file(self):
119        stream = open(self.testfile, 'rb')
120        stream.close()
121        with self.assertRaises(ValueError) as cm:
122            imghdr.what(stream)
123        stream = io.BytesIO(self.testdata)
124        stream.close()
125        with self.assertRaises(ValueError) as cm:
126            imghdr.what(stream)
127
128    def test_unseekable(self):
129        with open(TESTFN, 'wb') as stream:
130            stream.write(self.testdata)
131        with UnseekableIO(TESTFN, 'rb') as stream:
132            with self.assertRaises(io.UnsupportedOperation):
133                imghdr.what(stream)
134
135    def test_output_stream(self):
136        with open(TESTFN, 'wb') as stream:
137            stream.write(self.testdata)
138            stream.seek(0)
139            with self.assertRaises(OSError) as cm:
140                imghdr.what(stream)
141
142if __name__ == '__main__':
143    unittest.main()
144