1import io
2import email
3import unittest
4from email.message import Message, EmailMessage
5from email.policy import default
6from test.test_email import TestEmailBase
7
8
9class TestCustomMessage(TestEmailBase):
10
11    class MyMessage(Message):
12        def __init__(self, policy):
13            self.check_policy = policy
14            super().__init__()
15
16    MyPolicy = TestEmailBase.policy.clone(linesep='boo')
17
18    def test_custom_message_gets_policy_if_possible_from_string(self):
19        msg = email.message_from_string("Subject: bogus\n\nmsg\n",
20                                        self.MyMessage,
21                                        policy=self.MyPolicy)
22        self.assertIsInstance(msg, self.MyMessage)
23        self.assertIs(msg.check_policy, self.MyPolicy)
24
25    def test_custom_message_gets_policy_if_possible_from_file(self):
26        source_file = io.StringIO("Subject: bogus\n\nmsg\n")
27        msg = email.message_from_file(source_file,
28                                      self.MyMessage,
29                                      policy=self.MyPolicy)
30        self.assertIsInstance(msg, self.MyMessage)
31        self.assertIs(msg.check_policy, self.MyPolicy)
32
33    # XXX add tests for other functions that take Message arg.
34
35
36class TestParserBase:
37
38    def test_only_split_on_cr_lf(self):
39        # The unicode line splitter splits on unicode linebreaks, which are
40        # more numerous than allowed by the email RFCs; make sure we are only
41        # splitting on those two.
42        for parser in self.parsers:
43            with self.subTest(parser=parser.__name__):
44                msg = parser(
45                    "Next-Line: not\x85broken\r\n"
46                    "Null: not\x00broken\r\n"
47                    "Vertical-Tab: not\vbroken\r\n"
48                    "Form-Feed: not\fbroken\r\n"
49                    "File-Separator: not\x1Cbroken\r\n"
50                    "Group-Separator: not\x1Dbroken\r\n"
51                    "Record-Separator: not\x1Ebroken\r\n"
52                    "Line-Separator: not\u2028broken\r\n"
53                    "Paragraph-Separator: not\u2029broken\r\n"
54                    "\r\n",
55                    policy=default,
56                )
57                self.assertEqual(msg.items(), [
58                    ("Next-Line", "not\x85broken"),
59                    ("Null", "not\x00broken"),
60                    ("Vertical-Tab", "not\vbroken"),
61                    ("Form-Feed", "not\fbroken"),
62                    ("File-Separator", "not\x1Cbroken"),
63                    ("Group-Separator", "not\x1Dbroken"),
64                    ("Record-Separator", "not\x1Ebroken"),
65                    ("Line-Separator", "not\u2028broken"),
66                    ("Paragraph-Separator", "not\u2029broken"),
67                ])
68                self.assertEqual(msg.get_payload(), "")
69
70    class MyMessage(EmailMessage):
71        pass
72
73    def test_custom_message_factory_on_policy(self):
74        for parser in self.parsers:
75            with self.subTest(parser=parser.__name__):
76                MyPolicy = default.clone(message_factory=self.MyMessage)
77                msg = parser("To: foo\n\ntest", policy=MyPolicy)
78                self.assertIsInstance(msg, self.MyMessage)
79
80    def test_factory_arg_overrides_policy(self):
81        for parser in self.parsers:
82            with self.subTest(parser=parser.__name__):
83                MyPolicy = default.clone(message_factory=self.MyMessage)
84                msg = parser("To: foo\n\ntest", Message, policy=MyPolicy)
85                self.assertNotIsInstance(msg, self.MyMessage)
86                self.assertIsInstance(msg, Message)
87
88# Play some games to get nice output in subTest.  This code could be clearer
89# if staticmethod supported __name__.
90
91def message_from_file(s, *args, **kw):
92    f = io.StringIO(s)
93    return email.message_from_file(f, *args, **kw)
94
95class TestParser(TestParserBase, TestEmailBase):
96    parsers = (email.message_from_string, message_from_file)
97
98def message_from_bytes(s, *args, **kw):
99    return email.message_from_bytes(s.encode(), *args, **kw)
100
101def message_from_binary_file(s, *args, **kw):
102    f = io.BytesIO(s.encode())
103    return email.message_from_binary_file(f, *args, **kw)
104
105class TestBytesParser(TestParserBase, TestEmailBase):
106    parsers = (message_from_bytes, message_from_binary_file)
107
108
109if __name__ == '__main__':
110    unittest.main()
111