1"""Test the parser and generator are inverses.
2
3Note that this is only strictly true if we are parsing RFC valid messages and
4producing RFC valid messages.
5"""
6
7import io
8import unittest
9from email import policy, message_from_bytes
10from email.message import EmailMessage
11from email.generator import BytesGenerator
12from test.test_email import TestEmailBase, parameterize
13
14# This is like textwrap.dedent for bytes, except that it uses \r\n for the line
15# separators on the rebuilt string.
16def dedent(bstr):
17    lines = bstr.splitlines()
18    if not lines[0].strip():
19        raise ValueError("First line must contain text")
20    stripamt = len(lines[0]) - len(lines[0].lstrip())
21    return b'\r\n'.join(
22        [x[stripamt:] if len(x)>=stripamt else b''
23            for x in lines])
24
25
26@parameterize
27class TestInversion(TestEmailBase):
28
29    policy = policy.default
30    message = EmailMessage
31
32    def msg_as_input(self, msg):
33        m = message_from_bytes(msg, policy=policy.SMTP)
34        b = io.BytesIO()
35        g = BytesGenerator(b)
36        g.flatten(m)
37        self.assertEqual(b.getvalue(), msg)
38
39    # XXX: spaces are not preserved correctly here yet in the general case.
40    msg_params = {
41        'header_with_one_space_body': (dedent(b"""\
42            From: abc@xyz.com
43            X-Status:\x20
44            Subject: test
45
46            foo
47            """),),
48
49            }
50
51    payload_params = {
52        'plain_text': dict(payload='This is a test\n'*20),
53        'base64_text': dict(payload=(('xy a'*40+'\n')*5), cte='base64'),
54        'qp_text': dict(payload=(('xy a'*40+'\n')*5), cte='quoted-printable'),
55        }
56
57    def payload_as_body(self, payload, **kw):
58        msg = self._make_message()
59        msg['From'] = 'foo'
60        msg['To'] = 'bar'
61        msg['Subject'] = 'payload round trip test'
62        msg.set_content(payload, **kw)
63        b = bytes(msg)
64        msg2 = message_from_bytes(b, policy=self.policy)
65        self.assertEqual(bytes(msg2), b)
66        self.assertEqual(msg2.get_content(), payload)
67
68
69if __name__ == '__main__':
70    unittest.main()
71