1import io
2import types
3import textwrap
4import unittest
5import email.errors
6import email.policy
7import email.parser
8import email.generator
9import email.message
10from email import headerregistry
11
12def make_defaults(base_defaults, differences):
13    defaults = base_defaults.copy()
14    defaults.update(differences)
15    return defaults
16
17class PolicyAPITests(unittest.TestCase):
18
19    longMessage = True
20
21    # Base default values.
22    compat32_defaults = {
23        'max_line_length':          78,
24        'linesep':                  '\n',
25        'cte_type':                 '8bit',
26        'raise_on_defect':          False,
27        'mangle_from_':             True,
28        'message_factory':          None,
29        }
30    # These default values are the ones set on email.policy.default.
31    # If any of these defaults change, the docs must be updated.
32    policy_defaults = compat32_defaults.copy()
33    policy_defaults.update({
34        'utf8':                     False,
35        'raise_on_defect':          False,
36        'header_factory':           email.policy.EmailPolicy.header_factory,
37        'refold_source':            'long',
38        'content_manager':          email.policy.EmailPolicy.content_manager,
39        'mangle_from_':             False,
40        'message_factory':          email.message.EmailMessage,
41        })
42
43    # For each policy under test, we give here what we expect the defaults to
44    # be for that policy.  The second argument to make defaults is the
45    # difference between the base defaults and that for the particular policy.
46    new_policy = email.policy.EmailPolicy()
47    policies = {
48        email.policy.compat32: make_defaults(compat32_defaults, {}),
49        email.policy.default: make_defaults(policy_defaults, {}),
50        email.policy.SMTP: make_defaults(policy_defaults,
51                                         {'linesep': '\r\n'}),
52        email.policy.SMTPUTF8: make_defaults(policy_defaults,
53                                             {'linesep': '\r\n',
54                                              'utf8': True}),
55        email.policy.HTTP: make_defaults(policy_defaults,
56                                         {'linesep': '\r\n',
57                                          'max_line_length': None}),
58        email.policy.strict: make_defaults(policy_defaults,
59                                           {'raise_on_defect': True}),
60        new_policy: make_defaults(policy_defaults, {}),
61        }
62    # Creating a new policy creates a new header factory.  There is a test
63    # later that proves this.
64    policies[new_policy]['header_factory'] = new_policy.header_factory
65
66    def test_defaults(self):
67        for policy, expected in self.policies.items():
68            for attr, value in expected.items():
69                with self.subTest(policy=policy, attr=attr):
70                    self.assertEqual(getattr(policy, attr), value,
71                                    ("change {} docs/docstrings if defaults have "
72                                    "changed").format(policy))
73
74    def test_all_attributes_covered(self):
75        for policy, expected in self.policies.items():
76            for attr in dir(policy):
77                with self.subTest(policy=policy, attr=attr):
78                    if (attr.startswith('_') or
79                            isinstance(getattr(email.policy.EmailPolicy, attr),
80                                  types.FunctionType)):
81                        continue
82                    else:
83                        self.assertIn(attr, expected,
84                                      "{} is not fully tested".format(attr))
85
86    def test_abc(self):
87        with self.assertRaises(TypeError) as cm:
88            email.policy.Policy()
89        msg = str(cm.exception)
90        abstract_methods = ('fold',
91                            'fold_binary',
92                            'header_fetch_parse',
93                            'header_source_parse',
94                            'header_store_parse')
95        for method in abstract_methods:
96            self.assertIn(method, msg)
97
98    def test_policy_is_immutable(self):
99        for policy, defaults in self.policies.items():
100            for attr in defaults:
101                with self.assertRaisesRegex(AttributeError, attr+".*read-only"):
102                    setattr(policy, attr, None)
103            with self.assertRaisesRegex(AttributeError, 'no attribute.*foo'):
104                policy.foo = None
105
106    def test_set_policy_attrs_when_cloned(self):
107        # None of the attributes has a default value of None, so we set them
108        # all to None in the clone call and check that it worked.
109        for policyclass, defaults in self.policies.items():
110            testattrdict = {attr: None for attr in defaults}
111            policy = policyclass.clone(**testattrdict)
112            for attr in defaults:
113                self.assertIsNone(getattr(policy, attr))
114
115    def test_reject_non_policy_keyword_when_called(self):
116        for policyclass in self.policies:
117            with self.assertRaises(TypeError):
118                policyclass(this_keyword_should_not_be_valid=None)
119            with self.assertRaises(TypeError):
120                policyclass(newtline=None)
121
122    def test_policy_addition(self):
123        expected = self.policy_defaults.copy()
124        p1 = email.policy.default.clone(max_line_length=100)
125        p2 = email.policy.default.clone(max_line_length=50)
126        added = p1 + p2
127        expected.update(max_line_length=50)
128        for attr, value in expected.items():
129            self.assertEqual(getattr(added, attr), value)
130        added = p2 + p1
131        expected.update(max_line_length=100)
132        for attr, value in expected.items():
133            self.assertEqual(getattr(added, attr), value)
134        added = added + email.policy.default
135        for attr, value in expected.items():
136            self.assertEqual(getattr(added, attr), value)
137
138    def test_fold_zero_max_line_length(self):
139        expected = 'Subject: =?utf-8?q?=C3=A1?=\n'
140
141        msg = email.message.EmailMessage()
142        msg['Subject'] = 'á'
143
144        p1 = email.policy.default.clone(max_line_length=0)
145        p2 = email.policy.default.clone(max_line_length=None)
146
147        self.assertEqual(p1.fold('Subject', msg['Subject']), expected)
148        self.assertEqual(p2.fold('Subject', msg['Subject']), expected)
149
150    def test_register_defect(self):
151        class Dummy:
152            def __init__(self):
153                self.defects = []
154        obj = Dummy()
155        defect = object()
156        policy = email.policy.EmailPolicy()
157        policy.register_defect(obj, defect)
158        self.assertEqual(obj.defects, [defect])
159        defect2 = object()
160        policy.register_defect(obj, defect2)
161        self.assertEqual(obj.defects, [defect, defect2])
162
163    class MyObj:
164        def __init__(self):
165            self.defects = []
166
167    class MyDefect(Exception):
168        pass
169
170    def test_handle_defect_raises_on_strict(self):
171        foo = self.MyObj()
172        defect = self.MyDefect("the telly is broken")
173        with self.assertRaisesRegex(self.MyDefect, "the telly is broken"):
174            email.policy.strict.handle_defect(foo, defect)
175
176    def test_handle_defect_registers_defect(self):
177        foo = self.MyObj()
178        defect1 = self.MyDefect("one")
179        email.policy.default.handle_defect(foo, defect1)
180        self.assertEqual(foo.defects, [defect1])
181        defect2 = self.MyDefect("two")
182        email.policy.default.handle_defect(foo, defect2)
183        self.assertEqual(foo.defects, [defect1, defect2])
184
185    class MyPolicy(email.policy.EmailPolicy):
186        defects = None
187        def __init__(self, *args, **kw):
188            super().__init__(*args, defects=[], **kw)
189        def register_defect(self, obj, defect):
190            self.defects.append(defect)
191
192    def test_overridden_register_defect_still_raises(self):
193        foo = self.MyObj()
194        defect = self.MyDefect("the telly is broken")
195        with self.assertRaisesRegex(self.MyDefect, "the telly is broken"):
196            self.MyPolicy(raise_on_defect=True).handle_defect(foo, defect)
197
198    def test_overridden_register_defect_works(self):
199        foo = self.MyObj()
200        defect1 = self.MyDefect("one")
201        my_policy = self.MyPolicy()
202        my_policy.handle_defect(foo, defect1)
203        self.assertEqual(my_policy.defects, [defect1])
204        self.assertEqual(foo.defects, [])
205        defect2 = self.MyDefect("two")
206        my_policy.handle_defect(foo, defect2)
207        self.assertEqual(my_policy.defects, [defect1, defect2])
208        self.assertEqual(foo.defects, [])
209
210    def test_default_header_factory(self):
211        h = email.policy.default.header_factory('Test', 'test')
212        self.assertEqual(h.name, 'Test')
213        self.assertIsInstance(h, headerregistry.UnstructuredHeader)
214        self.assertIsInstance(h, headerregistry.BaseHeader)
215
216    class Foo:
217        parse = headerregistry.UnstructuredHeader.parse
218
219    def test_each_Policy_gets_unique_factory(self):
220        policy1 = email.policy.EmailPolicy()
221        policy2 = email.policy.EmailPolicy()
222        policy1.header_factory.map_to_type('foo', self.Foo)
223        h = policy1.header_factory('foo', 'test')
224        self.assertIsInstance(h, self.Foo)
225        self.assertNotIsInstance(h, headerregistry.UnstructuredHeader)
226        h = policy2.header_factory('foo', 'test')
227        self.assertNotIsInstance(h, self.Foo)
228        self.assertIsInstance(h, headerregistry.UnstructuredHeader)
229
230    def test_clone_copies_factory(self):
231        policy1 = email.policy.EmailPolicy()
232        policy2 = policy1.clone()
233        policy1.header_factory.map_to_type('foo', self.Foo)
234        h = policy1.header_factory('foo', 'test')
235        self.assertIsInstance(h, self.Foo)
236        h = policy2.header_factory('foo', 'test')
237        self.assertIsInstance(h, self.Foo)
238
239    def test_new_factory_overrides_default(self):
240        mypolicy = email.policy.EmailPolicy()
241        myfactory = mypolicy.header_factory
242        newpolicy = mypolicy + email.policy.strict
243        self.assertEqual(newpolicy.header_factory, myfactory)
244        newpolicy = email.policy.strict + mypolicy
245        self.assertEqual(newpolicy.header_factory, myfactory)
246
247    def test_adding_default_policies_preserves_default_factory(self):
248        newpolicy = email.policy.default + email.policy.strict
249        self.assertEqual(newpolicy.header_factory,
250                         email.policy.EmailPolicy.header_factory)
251        self.assertEqual(newpolicy.__dict__, {'raise_on_defect': True})
252
253    def test_non_ascii_chars_do_not_cause_inf_loop(self):
254        policy = email.policy.default.clone(max_line_length=20)
255        actual = policy.fold('Subject', 'ą' * 12)
256        self.assertEqual(
257            actual,
258            'Subject: \n' +
259            12 * ' =?utf-8?q?=C4=85?=\n')
260
261    def test_short_maxlen_error(self):
262        # RFC 2047 chrome takes up 7 characters, plus the length of the charset
263        # name, so folding should fail if maxlen is lower than the minimum
264        # required length for a line.
265
266        # Note: This is only triggered when there is a single word longer than
267        # max_line_length, hence the 1234567890 at the end of this whimsical
268        # subject. This is because when we encounter a word longer than
269        # max_line_length, it is broken down into encoded words to fit
270        # max_line_length. If the max_line_length isn't large enough to even
271        # contain the RFC 2047 chrome (`?=<charset>?q??=`), we fail.
272        subject = "Melt away the pounds with this one simple trick! 1234567890"
273
274        for maxlen in [3, 7, 9]:
275            with self.subTest(maxlen=maxlen):
276                policy = email.policy.default.clone(max_line_length=maxlen)
277                with self.assertRaises(email.errors.HeaderParseError):
278                    policy.fold("Subject", subject)
279
280    # XXX: Need subclassing tests.
281    # For adding subclassed objects, make sure the usual rules apply (subclass
282    # wins), but that the order still works (right overrides left).
283
284
285class TestException(Exception):
286    pass
287
288class TestPolicyPropagation(unittest.TestCase):
289
290    # The abstract methods are used by the parser but not by the wrapper
291    # functions that call it, so if the exception gets raised we know that the
292    # policy was actually propagated all the way to feedparser.
293    class MyPolicy(email.policy.Policy):
294        def badmethod(self, *args, **kw):
295            raise TestException("test")
296        fold = fold_binary = header_fetch_parser = badmethod
297        header_source_parse = header_store_parse = badmethod
298
299    def test_message_from_string(self):
300        with self.assertRaisesRegex(TestException, "^test$"):
301            email.message_from_string("Subject: test\n\n",
302                                      policy=self.MyPolicy)
303
304    def test_message_from_bytes(self):
305        with self.assertRaisesRegex(TestException, "^test$"):
306            email.message_from_bytes(b"Subject: test\n\n",
307                                     policy=self.MyPolicy)
308
309    def test_message_from_file(self):
310        f = io.StringIO('Subject: test\n\n')
311        with self.assertRaisesRegex(TestException, "^test$"):
312            email.message_from_file(f, policy=self.MyPolicy)
313
314    def test_message_from_binary_file(self):
315        f = io.BytesIO(b'Subject: test\n\n')
316        with self.assertRaisesRegex(TestException, "^test$"):
317            email.message_from_binary_file(f, policy=self.MyPolicy)
318
319    # These are redundant, but we need them for black-box completeness.
320
321    def test_parser(self):
322        p = email.parser.Parser(policy=self.MyPolicy)
323        with self.assertRaisesRegex(TestException, "^test$"):
324            p.parsestr('Subject: test\n\n')
325
326    def test_bytes_parser(self):
327        p = email.parser.BytesParser(policy=self.MyPolicy)
328        with self.assertRaisesRegex(TestException, "^test$"):
329            p.parsebytes(b'Subject: test\n\n')
330
331    # Now that we've established that all the parse methods get the
332    # policy in to feedparser, we can use message_from_string for
333    # the rest of the propagation tests.
334
335    def _make_msg(self, source='Subject: test\n\n', policy=None):
336        self.policy = email.policy.default.clone() if policy is None else policy
337        return email.message_from_string(source, policy=self.policy)
338
339    def test_parser_propagates_policy_to_message(self):
340        msg = self._make_msg()
341        self.assertIs(msg.policy, self.policy)
342
343    def test_parser_propagates_policy_to_sub_messages(self):
344        msg = self._make_msg(textwrap.dedent("""\
345            Subject: mime test
346            MIME-Version: 1.0
347            Content-Type: multipart/mixed, boundary="XXX"
348
349            --XXX
350            Content-Type: text/plain
351
352            test
353            --XXX
354            Content-Type: text/plain
355
356            test2
357            --XXX--
358            """))
359        for part in msg.walk():
360            self.assertIs(part.policy, self.policy)
361
362    def test_message_policy_propagates_to_generator(self):
363        msg = self._make_msg("Subject: test\nTo: foo\n\n",
364                             policy=email.policy.default.clone(linesep='X'))
365        s = io.StringIO()
366        g = email.generator.Generator(s)
367        g.flatten(msg)
368        self.assertEqual(s.getvalue(), "Subject: testXTo: fooXX")
369
370    def test_message_policy_used_by_as_string(self):
371        msg = self._make_msg("Subject: test\nTo: foo\n\n",
372                             policy=email.policy.default.clone(linesep='X'))
373        self.assertEqual(msg.as_string(), "Subject: testXTo: fooXX")
374
375
376class TestConcretePolicies(unittest.TestCase):
377
378    def test_header_store_parse_rejects_newlines(self):
379        instance = email.policy.EmailPolicy()
380        self.assertRaises(ValueError,
381                          instance.header_store_parse,
382                          'From', 'spam\negg@foo.py')
383
384
385if __name__ == '__main__':
386    unittest.main()
387