1# XXX TypeErrors on calling handlers, or on bad return values from a
2# handler, are obscure and unhelpful.
3
4from io import BytesIO
5import os
6import sys
7import sysconfig
8import unittest
9import traceback
10
11from xml.parsers import expat
12from xml.parsers.expat import errors
13
14from test.support import sortdict
15
16
17class SetAttributeTest(unittest.TestCase):
18    def setUp(self):
19        self.parser = expat.ParserCreate(namespace_separator='!')
20
21    def test_buffer_text(self):
22        self.assertIs(self.parser.buffer_text, False)
23        for x in 0, 1, 2, 0:
24            self.parser.buffer_text = x
25            self.assertIs(self.parser.buffer_text, bool(x))
26
27    def test_namespace_prefixes(self):
28        self.assertIs(self.parser.namespace_prefixes, False)
29        for x in 0, 1, 2, 0:
30            self.parser.namespace_prefixes = x
31            self.assertIs(self.parser.namespace_prefixes, bool(x))
32
33    def test_ordered_attributes(self):
34        self.assertIs(self.parser.ordered_attributes, False)
35        for x in 0, 1, 2, 0:
36            self.parser.ordered_attributes = x
37            self.assertIs(self.parser.ordered_attributes, bool(x))
38
39    def test_specified_attributes(self):
40        self.assertIs(self.parser.specified_attributes, False)
41        for x in 0, 1, 2, 0:
42            self.parser.specified_attributes = x
43            self.assertIs(self.parser.specified_attributes, bool(x))
44
45    def test_invalid_attributes(self):
46        with self.assertRaises(AttributeError):
47            self.parser.returns_unicode = 1
48        with self.assertRaises(AttributeError):
49            self.parser.returns_unicode
50
51        # Issue #25019
52        self.assertRaises(TypeError, setattr, self.parser, range(0xF), 0)
53        self.assertRaises(TypeError, self.parser.__setattr__, range(0xF), 0)
54        self.assertRaises(TypeError, getattr, self.parser, range(0xF))
55
56
57data = b'''\
58<?xml version="1.0" encoding="iso-8859-1" standalone="no"?>
59<?xml-stylesheet href="stylesheet.css"?>
60<!-- comment data -->
61<!DOCTYPE quotations SYSTEM "quotations.dtd" [
62<!ELEMENT root ANY>
63<!ATTLIST root attr1 CDATA #REQUIRED attr2 CDATA #IMPLIED>
64<!NOTATION notation SYSTEM "notation.jpeg">
65<!ENTITY acirc "&#226;">
66<!ENTITY external_entity SYSTEM "entity.file">
67<!ENTITY unparsed_entity SYSTEM "entity.file" NDATA notation>
68%unparsed_entity;
69]>
70
71<root attr1="value1" attr2="value2&#8000;">
72<myns:subelement xmlns:myns="http://www.python.org/namespace">
73     Contents of subelements
74</myns:subelement>
75<sub2><![CDATA[contents of CDATA section]]></sub2>
76&external_entity;
77&skipped_entity;
78\xb5
79</root>
80'''
81
82
83# Produce UTF-8 output
84class ParseTest(unittest.TestCase):
85    class Outputter:
86        def __init__(self):
87            self.out = []
88
89        def StartElementHandler(self, name, attrs):
90            self.out.append('Start element: ' + repr(name) + ' ' +
91                            sortdict(attrs))
92
93        def EndElementHandler(self, name):
94            self.out.append('End element: ' + repr(name))
95
96        def CharacterDataHandler(self, data):
97            data = data.strip()
98            if data:
99                self.out.append('Character data: ' + repr(data))
100
101        def ProcessingInstructionHandler(self, target, data):
102            self.out.append('PI: ' + repr(target) + ' ' + repr(data))
103
104        def StartNamespaceDeclHandler(self, prefix, uri):
105            self.out.append('NS decl: ' + repr(prefix) + ' ' + repr(uri))
106
107        def EndNamespaceDeclHandler(self, prefix):
108            self.out.append('End of NS decl: ' + repr(prefix))
109
110        def StartCdataSectionHandler(self):
111            self.out.append('Start of CDATA section')
112
113        def EndCdataSectionHandler(self):
114            self.out.append('End of CDATA section')
115
116        def CommentHandler(self, text):
117            self.out.append('Comment: ' + repr(text))
118
119        def NotationDeclHandler(self, *args):
120            name, base, sysid, pubid = args
121            self.out.append('Notation declared: %s' %(args,))
122
123        def UnparsedEntityDeclHandler(self, *args):
124            entityName, base, systemId, publicId, notationName = args
125            self.out.append('Unparsed entity decl: %s' %(args,))
126
127        def NotStandaloneHandler(self):
128            self.out.append('Not standalone')
129            return 1
130
131        def ExternalEntityRefHandler(self, *args):
132            context, base, sysId, pubId = args
133            self.out.append('External entity ref: %s' %(args[1:],))
134            return 1
135
136        def StartDoctypeDeclHandler(self, *args):
137            self.out.append(('Start doctype', args))
138            return 1
139
140        def EndDoctypeDeclHandler(self):
141            self.out.append("End doctype")
142            return 1
143
144        def EntityDeclHandler(self, *args):
145            self.out.append(('Entity declaration', args))
146            return 1
147
148        def XmlDeclHandler(self, *args):
149            self.out.append(('XML declaration', args))
150            return 1
151
152        def ElementDeclHandler(self, *args):
153            self.out.append(('Element declaration', args))
154            return 1
155
156        def AttlistDeclHandler(self, *args):
157            self.out.append(('Attribute list declaration', args))
158            return 1
159
160        def SkippedEntityHandler(self, *args):
161            self.out.append(("Skipped entity", args))
162            return 1
163
164        def DefaultHandler(self, userData):
165            pass
166
167        def DefaultHandlerExpand(self, userData):
168            pass
169
170    handler_names = [
171        'StartElementHandler', 'EndElementHandler', 'CharacterDataHandler',
172        'ProcessingInstructionHandler', 'UnparsedEntityDeclHandler',
173        'NotationDeclHandler', 'StartNamespaceDeclHandler',
174        'EndNamespaceDeclHandler', 'CommentHandler',
175        'StartCdataSectionHandler', 'EndCdataSectionHandler', 'DefaultHandler',
176        'DefaultHandlerExpand', 'NotStandaloneHandler',
177        'ExternalEntityRefHandler', 'StartDoctypeDeclHandler',
178        'EndDoctypeDeclHandler', 'EntityDeclHandler', 'XmlDeclHandler',
179        'ElementDeclHandler', 'AttlistDeclHandler', 'SkippedEntityHandler',
180        ]
181
182    def _hookup_callbacks(self, parser, handler):
183        """
184        Set each of the callbacks defined on handler and named in
185        self.handler_names on the given parser.
186        """
187        for name in self.handler_names:
188            setattr(parser, name, getattr(handler, name))
189
190    def _verify_parse_output(self, operations):
191        expected_operations = [
192            ('XML declaration', ('1.0', 'iso-8859-1', 0)),
193            'PI: \'xml-stylesheet\' \'href="stylesheet.css"\'',
194            "Comment: ' comment data '",
195            "Not standalone",
196            ("Start doctype", ('quotations', 'quotations.dtd', None, 1)),
197            ('Element declaration', ('root', (2, 0, None, ()))),
198            ('Attribute list declaration', ('root', 'attr1', 'CDATA', None,
199                1)),
200            ('Attribute list declaration', ('root', 'attr2', 'CDATA', None,
201                0)),
202            "Notation declared: ('notation', None, 'notation.jpeg', None)",
203            ('Entity declaration', ('acirc', 0, '\xe2', None, None, None, None)),
204            ('Entity declaration', ('external_entity', 0, None, None,
205                'entity.file', None, None)),
206            "Unparsed entity decl: ('unparsed_entity', None, 'entity.file', None, 'notation')",
207            "Not standalone",
208            "End doctype",
209            "Start element: 'root' {'attr1': 'value1', 'attr2': 'value2\u1f40'}",
210            "NS decl: 'myns' 'http://www.python.org/namespace'",
211            "Start element: 'http://www.python.org/namespace!subelement' {}",
212            "Character data: 'Contents of subelements'",
213            "End element: 'http://www.python.org/namespace!subelement'",
214            "End of NS decl: 'myns'",
215            "Start element: 'sub2' {}",
216            'Start of CDATA section',
217            "Character data: 'contents of CDATA section'",
218            'End of CDATA section',
219            "End element: 'sub2'",
220            "External entity ref: (None, 'entity.file', None)",
221            ('Skipped entity', ('skipped_entity', 0)),
222            "Character data: '\xb5'",
223            "End element: 'root'",
224        ]
225        for operation, expected_operation in zip(operations, expected_operations):
226            self.assertEqual(operation, expected_operation)
227
228    def test_parse_bytes(self):
229        out = self.Outputter()
230        parser = expat.ParserCreate(namespace_separator='!')
231        self._hookup_callbacks(parser, out)
232
233        parser.Parse(data, 1)
234
235        operations = out.out
236        self._verify_parse_output(operations)
237        # Issue #6697.
238        self.assertRaises(AttributeError, getattr, parser, '\uD800')
239
240    def test_parse_str(self):
241        out = self.Outputter()
242        parser = expat.ParserCreate(namespace_separator='!')
243        self._hookup_callbacks(parser, out)
244
245        parser.Parse(data.decode('iso-8859-1'), 1)
246
247        operations = out.out
248        self._verify_parse_output(operations)
249
250    def test_parse_file(self):
251        # Try parsing a file
252        out = self.Outputter()
253        parser = expat.ParserCreate(namespace_separator='!')
254        self._hookup_callbacks(parser, out)
255        file = BytesIO(data)
256
257        parser.ParseFile(file)
258
259        operations = out.out
260        self._verify_parse_output(operations)
261
262    def test_parse_again(self):
263        parser = expat.ParserCreate()
264        file = BytesIO(data)
265        parser.ParseFile(file)
266        # Issue 6676: ensure a meaningful exception is raised when attempting
267        # to parse more than one XML document per xmlparser instance,
268        # a limitation of the Expat library.
269        with self.assertRaises(expat.error) as cm:
270            parser.ParseFile(file)
271        self.assertEqual(expat.ErrorString(cm.exception.code),
272                          expat.errors.XML_ERROR_FINISHED)
273
274class NamespaceSeparatorTest(unittest.TestCase):
275    def test_legal(self):
276        # Tests that make sure we get errors when the namespace_separator value
277        # is illegal, and that we don't for good values:
278        expat.ParserCreate()
279        expat.ParserCreate(namespace_separator=None)
280        expat.ParserCreate(namespace_separator=' ')
281
282    def test_illegal(self):
283        try:
284            expat.ParserCreate(namespace_separator=42)
285            self.fail()
286        except TypeError as e:
287            self.assertEqual(str(e),
288                'ParserCreate() argument 2 must be str or None, not int')
289
290        try:
291            expat.ParserCreate(namespace_separator='too long')
292            self.fail()
293        except ValueError as e:
294            self.assertEqual(str(e),
295                'namespace_separator must be at most one character, omitted, or None')
296
297    def test_zero_length(self):
298        # ParserCreate() needs to accept a namespace_separator of zero length
299        # to satisfy the requirements of RDF applications that are required
300        # to simply glue together the namespace URI and the localname.  Though
301        # considered a wart of the RDF specifications, it needs to be supported.
302        #
303        # See XML-SIG mailing list thread starting with
304        # http://mail.python.org/pipermail/xml-sig/2001-April/005202.html
305        #
306        expat.ParserCreate(namespace_separator='') # too short
307
308
309class InterningTest(unittest.TestCase):
310    def test(self):
311        # Test the interning machinery.
312        p = expat.ParserCreate()
313        L = []
314        def collector(name, *args):
315            L.append(name)
316        p.StartElementHandler = collector
317        p.EndElementHandler = collector
318        p.Parse(b"<e> <e/> <e></e> </e>", 1)
319        tag = L[0]
320        self.assertEqual(len(L), 6)
321        for entry in L:
322            # L should have the same string repeated over and over.
323            self.assertTrue(tag is entry)
324
325    def test_issue9402(self):
326        # create an ExternalEntityParserCreate with buffer text
327        class ExternalOutputter:
328            def __init__(self, parser):
329                self.parser = parser
330                self.parser_result = None
331
332            def ExternalEntityRefHandler(self, context, base, sysId, pubId):
333                external_parser = self.parser.ExternalEntityParserCreate("")
334                self.parser_result = external_parser.Parse(b"", 1)
335                return 1
336
337        parser = expat.ParserCreate(namespace_separator='!')
338        parser.buffer_text = 1
339        out = ExternalOutputter(parser)
340        parser.ExternalEntityRefHandler = out.ExternalEntityRefHandler
341        parser.Parse(data, 1)
342        self.assertEqual(out.parser_result, 1)
343
344
345class BufferTextTest(unittest.TestCase):
346    def setUp(self):
347        self.stuff = []
348        self.parser = expat.ParserCreate()
349        self.parser.buffer_text = 1
350        self.parser.CharacterDataHandler = self.CharacterDataHandler
351
352    def check(self, expected, label):
353        self.assertEqual(self.stuff, expected,
354                "%s\nstuff    = %r\nexpected = %r"
355                % (label, self.stuff, map(str, expected)))
356
357    def CharacterDataHandler(self, text):
358        self.stuff.append(text)
359
360    def StartElementHandler(self, name, attrs):
361        self.stuff.append("<%s>" % name)
362        bt = attrs.get("buffer-text")
363        if bt == "yes":
364            self.parser.buffer_text = 1
365        elif bt == "no":
366            self.parser.buffer_text = 0
367
368    def EndElementHandler(self, name):
369        self.stuff.append("</%s>" % name)
370
371    def CommentHandler(self, data):
372        self.stuff.append("<!--%s-->" % data)
373
374    def setHandlers(self, handlers=[]):
375        for name in handlers:
376            setattr(self.parser, name, getattr(self, name))
377
378    def test_default_to_disabled(self):
379        parser = expat.ParserCreate()
380        self.assertFalse(parser.buffer_text)
381
382    def test_buffering_enabled(self):
383        # Make sure buffering is turned on
384        self.assertTrue(self.parser.buffer_text)
385        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", 1)
386        self.assertEqual(self.stuff, ['123'],
387                         "buffered text not properly collapsed")
388
389    def test1(self):
390        # XXX This test exposes more detail of Expat's text chunking than we
391        # XXX like, but it tests what we need to concisely.
392        self.setHandlers(["StartElementHandler"])
393        self.parser.Parse(b"<a>1<b buffer-text='no'/>2\n3<c buffer-text='yes'/>4\n5</a>", 1)
394        self.assertEqual(self.stuff,
395                         ["<a>", "1", "<b>", "2", "\n", "3", "<c>", "4\n5"],
396                         "buffering control not reacting as expected")
397
398    def test2(self):
399        self.parser.Parse(b"<a>1<b/>&lt;2&gt;<c/>&#32;\n&#x20;3</a>", 1)
400        self.assertEqual(self.stuff, ["1<2> \n 3"],
401                         "buffered text not properly collapsed")
402
403    def test3(self):
404        self.setHandlers(["StartElementHandler"])
405        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", 1)
406        self.assertEqual(self.stuff, ["<a>", "1", "<b>", "2", "<c>", "3"],
407                         "buffered text not properly split")
408
409    def test4(self):
410        self.setHandlers(["StartElementHandler", "EndElementHandler"])
411        self.parser.CharacterDataHandler = None
412        self.parser.Parse(b"<a>1<b/>2<c/>3</a>", 1)
413        self.assertEqual(self.stuff,
414                         ["<a>", "<b>", "</b>", "<c>", "</c>", "</a>"])
415
416    def test5(self):
417        self.setHandlers(["StartElementHandler", "EndElementHandler"])
418        self.parser.Parse(b"<a>1<b></b>2<c/>3</a>", 1)
419        self.assertEqual(self.stuff,
420            ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3", "</a>"])
421
422    def test6(self):
423        self.setHandlers(["CommentHandler", "EndElementHandler",
424                    "StartElementHandler"])
425        self.parser.Parse(b"<a>1<b/>2<c></c>345</a> ", 1)
426        self.assertEqual(self.stuff,
427            ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "345", "</a>"],
428            "buffered text not properly split")
429
430    def test7(self):
431        self.setHandlers(["CommentHandler", "EndElementHandler",
432                    "StartElementHandler"])
433        self.parser.Parse(b"<a>1<b/>2<c></c>3<!--abc-->4<!--def-->5</a> ", 1)
434        self.assertEqual(self.stuff,
435                         ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3",
436                          "<!--abc-->", "4", "<!--def-->", "5", "</a>"],
437                         "buffered text not properly split")
438
439
440# Test handling of exception from callback:
441class HandlerExceptionTest(unittest.TestCase):
442    def StartElementHandler(self, name, attrs):
443        raise RuntimeError(name)
444
445    def check_traceback_entry(self, entry, filename, funcname):
446        self.assertEqual(os.path.basename(entry[0]), filename)
447        self.assertEqual(entry[2], funcname)
448
449    def test_exception(self):
450        parser = expat.ParserCreate()
451        parser.StartElementHandler = self.StartElementHandler
452        try:
453            parser.Parse(b"<a><b><c/></b></a>", 1)
454            self.fail()
455        except RuntimeError as e:
456            self.assertEqual(e.args[0], 'a',
457                             "Expected RuntimeError for element 'a', but" + \
458                             " found %r" % e.args[0])
459            # Check that the traceback contains the relevant line in pyexpat.c
460            entries = traceback.extract_tb(e.__traceback__)
461            self.assertEqual(len(entries), 3)
462            self.check_traceback_entry(entries[0],
463                                       "test_pyexpat.py", "test_exception")
464            self.check_traceback_entry(entries[1],
465                                       "pyexpat.c", "StartElement")
466            self.check_traceback_entry(entries[2],
467                                       "test_pyexpat.py", "StartElementHandler")
468            if sysconfig.is_python_build():
469                self.assertIn('call_with_frame("StartElement"', entries[1][3])
470
471
472# Test Current* members:
473class PositionTest(unittest.TestCase):
474    def StartElementHandler(self, name, attrs):
475        self.check_pos('s')
476
477    def EndElementHandler(self, name):
478        self.check_pos('e')
479
480    def check_pos(self, event):
481        pos = (event,
482               self.parser.CurrentByteIndex,
483               self.parser.CurrentLineNumber,
484               self.parser.CurrentColumnNumber)
485        self.assertTrue(self.upto < len(self.expected_list),
486                        'too many parser events')
487        expected = self.expected_list[self.upto]
488        self.assertEqual(pos, expected,
489                'Expected position %s, got position %s' %(pos, expected))
490        self.upto += 1
491
492    def test(self):
493        self.parser = expat.ParserCreate()
494        self.parser.StartElementHandler = self.StartElementHandler
495        self.parser.EndElementHandler = self.EndElementHandler
496        self.upto = 0
497        self.expected_list = [('s', 0, 1, 0), ('s', 5, 2, 1), ('s', 11, 3, 2),
498                              ('e', 15, 3, 6), ('e', 17, 4, 1), ('e', 22, 5, 0)]
499
500        xml = b'<a>\n <b>\n  <c/>\n </b>\n</a>'
501        self.parser.Parse(xml, 1)
502
503
504class sf1296433Test(unittest.TestCase):
505    def test_parse_only_xml_data(self):
506        # http://python.org/sf/1296433
507        #
508        xml = "<?xml version='1.0' encoding='iso8859'?><s>%s</s>" % ('a' * 1025)
509        # this one doesn't crash
510        #xml = "<?xml version='1.0'?><s>%s</s>" % ('a' * 10000)
511
512        class SpecificException(Exception):
513            pass
514
515        def handler(text):
516            raise SpecificException
517
518        parser = expat.ParserCreate()
519        parser.CharacterDataHandler = handler
520
521        self.assertRaises(Exception, parser.Parse, xml.encode('iso8859'))
522
523class ChardataBufferTest(unittest.TestCase):
524    """
525    test setting of chardata buffer size
526    """
527
528    def test_1025_bytes(self):
529        self.assertEqual(self.small_buffer_test(1025), 2)
530
531    def test_1000_bytes(self):
532        self.assertEqual(self.small_buffer_test(1000), 1)
533
534    def test_wrong_size(self):
535        parser = expat.ParserCreate()
536        parser.buffer_text = 1
537        with self.assertRaises(ValueError):
538            parser.buffer_size = -1
539        with self.assertRaises(ValueError):
540            parser.buffer_size = 0
541        with self.assertRaises((ValueError, OverflowError)):
542            parser.buffer_size = sys.maxsize + 1
543        with self.assertRaises(TypeError):
544            parser.buffer_size = 512.0
545
546    def test_unchanged_size(self):
547        xml1 = b"<?xml version='1.0' encoding='iso8859'?><s>" + b'a' * 512
548        xml2 = b'a'*512 + b'</s>'
549        parser = expat.ParserCreate()
550        parser.CharacterDataHandler = self.counting_handler
551        parser.buffer_size = 512
552        parser.buffer_text = 1
553
554        # Feed 512 bytes of character data: the handler should be called
555        # once.
556        self.n = 0
557        parser.Parse(xml1)
558        self.assertEqual(self.n, 1)
559
560        # Reassign to buffer_size, but assign the same size.
561        parser.buffer_size = parser.buffer_size
562        self.assertEqual(self.n, 1)
563
564        # Try parsing rest of the document
565        parser.Parse(xml2)
566        self.assertEqual(self.n, 2)
567
568
569    def test_disabling_buffer(self):
570        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a>" + b'a' * 512
571        xml2 = b'b' * 1024
572        xml3 = b'c' * 1024 + b'</a>';
573        parser = expat.ParserCreate()
574        parser.CharacterDataHandler = self.counting_handler
575        parser.buffer_text = 1
576        parser.buffer_size = 1024
577        self.assertEqual(parser.buffer_size, 1024)
578
579        # Parse one chunk of XML
580        self.n = 0
581        parser.Parse(xml1, 0)
582        self.assertEqual(parser.buffer_size, 1024)
583        self.assertEqual(self.n, 1)
584
585        # Turn off buffering and parse the next chunk.
586        parser.buffer_text = 0
587        self.assertFalse(parser.buffer_text)
588        self.assertEqual(parser.buffer_size, 1024)
589        for i in range(10):
590            parser.Parse(xml2, 0)
591        self.assertEqual(self.n, 11)
592
593        parser.buffer_text = 1
594        self.assertTrue(parser.buffer_text)
595        self.assertEqual(parser.buffer_size, 1024)
596        parser.Parse(xml3, 1)
597        self.assertEqual(self.n, 12)
598
599    def counting_handler(self, text):
600        self.n += 1
601
602    def small_buffer_test(self, buffer_len):
603        xml = b"<?xml version='1.0' encoding='iso8859'?><s>" + b'a' * buffer_len + b'</s>'
604        parser = expat.ParserCreate()
605        parser.CharacterDataHandler = self.counting_handler
606        parser.buffer_size = 1024
607        parser.buffer_text = 1
608
609        self.n = 0
610        parser.Parse(xml)
611        return self.n
612
613    def test_change_size_1(self):
614        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a><s>" + b'a' * 1024
615        xml2 = b'aaa</s><s>' + b'a' * 1025 + b'</s></a>'
616        parser = expat.ParserCreate()
617        parser.CharacterDataHandler = self.counting_handler
618        parser.buffer_text = 1
619        parser.buffer_size = 1024
620        self.assertEqual(parser.buffer_size, 1024)
621
622        self.n = 0
623        parser.Parse(xml1, 0)
624        parser.buffer_size *= 2
625        self.assertEqual(parser.buffer_size, 2048)
626        parser.Parse(xml2, 1)
627        self.assertEqual(self.n, 2)
628
629    def test_change_size_2(self):
630        xml1 = b"<?xml version='1.0' encoding='iso8859'?><a>a<s>" + b'a' * 1023
631        xml2 = b'aaa</s><s>' + b'a' * 1025 + b'</s></a>'
632        parser = expat.ParserCreate()
633        parser.CharacterDataHandler = self.counting_handler
634        parser.buffer_text = 1
635        parser.buffer_size = 2048
636        self.assertEqual(parser.buffer_size, 2048)
637
638        self.n=0
639        parser.Parse(xml1, 0)
640        parser.buffer_size = parser.buffer_size // 2
641        self.assertEqual(parser.buffer_size, 1024)
642        parser.Parse(xml2, 1)
643        self.assertEqual(self.n, 4)
644
645class MalformedInputTest(unittest.TestCase):
646    def test1(self):
647        xml = b"\0\r\n"
648        parser = expat.ParserCreate()
649        try:
650            parser.Parse(xml, True)
651            self.fail()
652        except expat.ExpatError as e:
653            self.assertEqual(str(e), 'unclosed token: line 2, column 0')
654
655    def test2(self):
656        # \xc2\x85 is UTF-8 encoded U+0085 (NEXT LINE)
657        xml = b"<?xml version\xc2\x85='1.0'?>\r\n"
658        parser = expat.ParserCreate()
659        err_pattern = r'XML declaration not well-formed: line 1, column \d+'
660        with self.assertRaisesRegex(expat.ExpatError, err_pattern):
661            parser.Parse(xml, True)
662
663class ErrorMessageTest(unittest.TestCase):
664    def test_codes(self):
665        # verify mapping of errors.codes and errors.messages
666        self.assertEqual(errors.XML_ERROR_SYNTAX,
667                         errors.messages[errors.codes[errors.XML_ERROR_SYNTAX]])
668
669    def test_expaterror(self):
670        xml = b'<'
671        parser = expat.ParserCreate()
672        try:
673            parser.Parse(xml, True)
674            self.fail()
675        except expat.ExpatError as e:
676            self.assertEqual(e.code,
677                             errors.codes[errors.XML_ERROR_UNCLOSED_TOKEN])
678
679
680class ForeignDTDTests(unittest.TestCase):
681    """
682    Tests for the UseForeignDTD method of expat parser objects.
683    """
684    def test_use_foreign_dtd(self):
685        """
686        If UseForeignDTD is passed True and a document without an external
687        entity reference is parsed, ExternalEntityRefHandler is first called
688        with None for the public and system ids.
689        """
690        handler_call_args = []
691        def resolve_entity(context, base, system_id, public_id):
692            handler_call_args.append((public_id, system_id))
693            return 1
694
695        parser = expat.ParserCreate()
696        parser.UseForeignDTD(True)
697        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
698        parser.ExternalEntityRefHandler = resolve_entity
699        parser.Parse(b"<?xml version='1.0'?><element/>")
700        self.assertEqual(handler_call_args, [(None, None)])
701
702        # test UseForeignDTD() is equal to UseForeignDTD(True)
703        handler_call_args[:] = []
704
705        parser = expat.ParserCreate()
706        parser.UseForeignDTD()
707        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
708        parser.ExternalEntityRefHandler = resolve_entity
709        parser.Parse(b"<?xml version='1.0'?><element/>")
710        self.assertEqual(handler_call_args, [(None, None)])
711
712    def test_ignore_use_foreign_dtd(self):
713        """
714        If UseForeignDTD is passed True and a document with an external
715        entity reference is parsed, ExternalEntityRefHandler is called with
716        the public and system ids from the document.
717        """
718        handler_call_args = []
719        def resolve_entity(context, base, system_id, public_id):
720            handler_call_args.append((public_id, system_id))
721            return 1
722
723        parser = expat.ParserCreate()
724        parser.UseForeignDTD(True)
725        parser.SetParamEntityParsing(expat.XML_PARAM_ENTITY_PARSING_ALWAYS)
726        parser.ExternalEntityRefHandler = resolve_entity
727        parser.Parse(
728            b"<?xml version='1.0'?><!DOCTYPE foo PUBLIC 'bar' 'baz'><element/>")
729        self.assertEqual(handler_call_args, [("bar", "baz")])
730
731
732if __name__ == "__main__":
733    unittest.main()
734