1#!/usr/bin/env python
2# encoding: utf-8
3#
4# spyne - Copyright (C) Spyne contributors.
5#
6# This library is free software; you can redistribute it and/or
7# modify it under the terms of the GNU Lesser General Public
8# License as published by the Free Software Foundation; either
9# version 2.1 of the License, or (at your option) any later version.
10#
11# This library is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14# Lesser General Public License for more details.
15#
16# You should have received a copy of the GNU Lesser General Public
17# License along with this library; if not, write to the Free Software
18# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301
19#
20
21from __future__ import print_function
22
23import logging
24logging.basicConfig(level=logging.DEBUG)
25
26import sys
27import unittest
28import decimal
29import datetime
30
31from pprint import pprint
32from base64 import b64encode
33
34from lxml import etree
35from lxml.builder import E
36
37from spyne import MethodContext, rpc, ByteArray, File, AnyXml
38from spyne.context import FakeContext
39from spyne.const import RESULT_SUFFIX
40from spyne.service import Service
41from spyne.server import ServerBase
42from spyne.application import Application
43from spyne.decorator import srpc
44from spyne.util.six import BytesIO
45from spyne.model import Fault, Integer, Decimal, Unicode, Date, DateTime, \
46    XmlData, Array, ComplexModel, XmlAttribute, Mandatory as M
47from spyne.protocol.xml import XmlDocument, SchemaValidationError
48
49from spyne.util import six
50from spyne.util.xml import get_xml_as_object, get_object_as_xml, \
51    get_object_as_xml_polymorphic, get_xml_as_object_polymorphic
52from spyne.server.wsgi import WsgiApplication
53from spyne.const.xml import NS_XSI
54
55
56class TestXml(unittest.TestCase):
57    def test_empty_string(self):
58        class a(ComplexModel):
59            b = Unicode
60
61        elt = etree.fromstring('<a><b/></a>')
62        o = get_xml_as_object(elt, a)
63
64        assert o.b == ''
65
66    def test_xml_data(self):
67        class C(ComplexModel):
68            a = XmlData(Unicode)
69            b = XmlAttribute(Unicode)
70
71        class SomeService(Service):
72            @srpc(C, _returns=C)
73            def some_call(c):
74                assert c.a == 'a'
75                assert c.b == 'b'
76                return c
77
78        app = Application([SomeService], "tns", name="test_xml_data",
79                        in_protocol=XmlDocument(), out_protocol=XmlDocument())
80        server = ServerBase(app)
81        initial_ctx = MethodContext(server, MethodContext.SERVER)
82        initial_ctx.in_string = [
83            b'<some_call xmlns="tns">'
84                b'<c b="b">a</c>'
85            b'</some_call>'
86        ]
87
88        ctx, = server.generate_contexts(initial_ctx)
89        server.get_in_object(ctx)
90        server.get_out_object(ctx)
91        server.get_out_string(ctx)
92
93        print(ctx.out_string)
94        pprint(app.interface.nsmap)
95
96        ret = etree.fromstring(b''.join(ctx.out_string)).xpath(
97            '//tns:some_call' + RESULT_SUFFIX, namespaces=app.interface.nsmap)[0]
98
99        print(etree.tostring(ret, pretty_print=True))
100
101        assert ret.text == "a"
102        assert ret.attrib['b'] == "b"
103
104    def test_wrapped_array(self):
105        parent = etree.Element('parent')
106        val = ['a', 'b']
107        cls = Array(Unicode, namespace='tns')
108        XmlDocument().to_parent(None, cls, val, parent, 'tns')
109        print(etree.tostring(parent, pretty_print=True))
110        xpath = parent.xpath('//x:stringArray/x:string/text()',
111                                                        namespaces={'x': 'tns'})
112        assert xpath == val
113
114    def test_simple_array(self):
115        class cls(ComplexModel):
116            __namespace__ = 'tns'
117            s = Unicode(max_occurs='unbounded')
118        val = cls(s=['a', 'b'])
119
120        parent = etree.Element('parent')
121        XmlDocument().to_parent(None, cls, val, parent, 'tns')
122        print(etree.tostring(parent, pretty_print=True))
123        xpath = parent.xpath('//x:cls/x:s/text()', namespaces={'x': 'tns'})
124        assert xpath == val.s
125
126    def test_decimal(self):
127        d = decimal.Decimal('1e100')
128
129        class SomeService(Service):
130            @srpc(Decimal(120,4), _returns=Decimal)
131            def some_call(p):
132                print(p)
133                print(type(p))
134                assert type(p) == decimal.Decimal
135                assert d == p
136                return p
137
138        app = Application([SomeService], "tns", in_protocol=XmlDocument(),
139                                                out_protocol=XmlDocument())
140        server = ServerBase(app)
141        initial_ctx = MethodContext(server, MethodContext.SERVER)
142        initial_ctx.in_string = [
143            b'<some_call xmlns="tns"><p>',
144            str(d).encode('ascii'),
145            b'</p></some_call>'
146        ]
147
148        ctx, = server.generate_contexts(initial_ctx)
149        server.get_in_object(ctx)
150        server.get_out_object(ctx)
151        server.get_out_string(ctx)
152
153        elt = etree.fromstring(b''.join(ctx.out_string))
154
155        print(etree.tostring(elt, pretty_print=True))
156        target = elt.xpath('//tns:some_callResult/text()',
157                                              namespaces=app.interface.nsmap)[0]
158        assert target == str(d)
159
160    def test_subs(self):
161        from lxml import etree
162        from spyne.util.xml import get_xml_as_object
163        from spyne.util.xml import get_object_as_xml
164
165        m = {
166            "s0": "aa",
167            "s2": "cc",
168            "s3": "dd",
169        }
170
171        class C(ComplexModel):
172            __namespace__ = "aa"
173            a = Integer
174            b = Integer(sub_name="bb")
175            c = Integer(sub_ns="cc")
176            d = Integer(sub_ns="dd", sub_name="dd")
177
178        elt = get_object_as_xml(C(a=1, b=2, c=3, d=4), C)
179        print(etree.tostring(elt, pretty_print=True))
180
181        assert elt.xpath("s0:a/text()",  namespaces=m) == ["1"]
182        assert elt.xpath("s0:bb/text()", namespaces=m) == ["2"]
183        assert elt.xpath("s2:c/text()",  namespaces=m) == ["3"]
184        assert elt.xpath("s3:dd/text()", namespaces=m) == ["4"]
185
186        c = get_xml_as_object(elt, C)
187        print(c)
188        assert c.a == 1
189        assert c.b == 2
190        assert c.c == 3
191        assert c.d == 4
192
193    def test_sub_attributes(self):
194        from lxml import etree
195        from spyne.util.xml import get_xml_as_object
196        from spyne.util.xml import get_object_as_xml
197
198        m = {
199            "s0": "aa",
200            "s2": "cc",
201            "s3": "dd",
202        }
203
204        class C(ComplexModel):
205            __namespace__ = "aa"
206            a = XmlAttribute(Integer)
207            b = XmlAttribute(Integer(sub_name="bb"))
208            c = XmlAttribute(Integer(sub_ns="cc"))
209            d = XmlAttribute(Integer(sub_ns="dd", sub_name="dd"))
210
211        elt = get_object_as_xml(C(a=1, b=2, c=3, d=4), C)
212        print(etree.tostring(elt, pretty_print=True))
213
214        assert elt.xpath("//*/@a")  == ["1"]
215        assert elt.xpath("//*/@bb") == ["2"]
216        assert elt.xpath("//*/@s2:c", namespaces=m)  == ["3"]
217        assert elt.xpath("//*/@s3:dd", namespaces=m) == ["4"]
218
219        c = get_xml_as_object(elt, C)
220        print(c)
221        assert c.a == 1
222        assert c.b == 2
223        assert c.c == 3
224        assert c.d == 4
225
226    def test_dates(self):
227        d = Date
228        xml_dates = [
229            etree.fromstring(b'<d>2013-04-05</d>'),
230            etree.fromstring(b'<d>2013-04-05+02:00</d>'),
231            etree.fromstring(b'<d>2013-04-05-02:00</d>'),
232            etree.fromstring(b'<d>2013-04-05Z</d>'),
233        ]
234
235        for xml_date in xml_dates:
236            c = get_xml_as_object(xml_date, d)
237            assert isinstance(c, datetime.date) == True
238            assert c.year == 2013
239            assert c.month == 4
240            assert c.day == 5
241
242    def test_datetime_usec(self):
243        fs = etree.fromstring
244        d = get_xml_as_object(fs('<d>2013-04-05T06:07:08.123456</d>'), DateTime)
245        assert d.microsecond == 123456
246
247        # rounds up
248        d = get_xml_as_object(fs('<d>2013-04-05T06:07:08.1234567</d>'), DateTime)
249        assert d.microsecond == 123457
250
251        # rounds down
252        d = get_xml_as_object(fs('<d>2013-04-05T06:07:08.1234564</d>'), DateTime)
253        assert d.microsecond == 123456
254
255        # rounds up as well
256        d = get_xml_as_object(fs('<d>2013-04-05T06:07:08.1234565</d>'), DateTime)
257        # FIXME: this is very interesting. why?
258        if not six.PY2:
259            assert d.microsecond == 123456
260        else:
261            assert d.microsecond == 123457
262
263    def _get_ctx(self, server, in_string):
264        initial_ctx = MethodContext(server, MethodContext.SERVER)
265        initial_ctx.in_string = in_string
266        ctx, = server.generate_contexts(initial_ctx)
267        server.get_in_object(ctx)
268        return ctx
269
270    def test_mandatory_elements(self):
271        class SomeService(Service):
272            @srpc(M(Unicode), _returns=Unicode)
273            def some_call(s):
274                assert s == 'hello'
275                return s
276
277        app = Application([SomeService], "tns", name="test_mandatory_elements",
278                          in_protocol=XmlDocument(validator='lxml'),
279                          out_protocol=XmlDocument())
280        server = ServerBase(app)
281
282        # Valid call with all mandatory elements in
283        ctx = self._get_ctx(server, [
284            b'<some_call xmlns="tns">'
285                b'<s>hello</s>'
286            b'</some_call>'
287        ])
288        server.get_out_object(ctx)
289        server.get_out_string(ctx)
290        ret = etree.fromstring(b''.join(ctx.out_string)).xpath(
291            '//tns:some_call%s/text()' % RESULT_SUFFIX,
292            namespaces=app.interface.nsmap)[0]
293        assert ret == 'hello'
294
295        # Invalid call
296        ctx = self._get_ctx(server, [
297            b'<some_call xmlns="tns">'
298                # no mandatory elements here...
299            b'</some_call>'
300        ])
301        self.assertRaises(SchemaValidationError, server.get_out_object, ctx)
302
303    def test_unicode_chars_in_exception(self):
304        class SomeService(Service):
305            @srpc(Unicode(pattern=u'x'), _returns=Unicode)
306            def some_call(s):
307                test(should, never, reach, here)
308
309        app = Application([SomeService], "tns", name="test_mandatory_elements",
310                          in_protocol=XmlDocument(validator='lxml'),
311                          out_protocol=XmlDocument())
312        server = WsgiApplication(app)
313
314        req = (
315            u'<some_call xmlns="tns">'
316                u'<s>Ğ</s>'
317            u'</some_call>'
318        ).encode('utf8')
319
320        print("AAA")
321        resp = server({
322            'QUERY_STRING': '',
323            'PATH_INFO': '/',
324            'REQUEST_METHOD': 'POST',
325            'SERVER_NAME': 'localhost',
326            'SERVER_PORT': '80',
327            'wsgi.input': BytesIO(req),
328            "wsgi.url_scheme": 'http',
329        }, lambda x, y: print(x,y))
330        print("AAA")
331
332        assert u'Ğ'.encode('utf8') in b''.join(resp)
333
334    def test_mandatory_subelements(self):
335        class C(ComplexModel):
336            foo = M(Unicode)
337
338        class SomeService(Service):
339            @srpc(C.customize(min_occurs=1), _returns=Unicode)
340            def some_call(c):
341                assert c is not None
342                assert c.foo == 'hello'
343                return c.foo
344
345        app = Application(
346            [SomeService], "tns", name="test_mandatory_subelements",
347            in_protocol=XmlDocument(validator='lxml'),
348            out_protocol=XmlDocument())
349        server = ServerBase(app)
350
351        ctx = self._get_ctx(server, [
352            b'<some_call xmlns="tns">'
353                # no mandatory elements at all...
354            b'</some_call>'
355        ])
356        self.assertRaises(SchemaValidationError, server.get_out_object, ctx)
357
358        ctx = self._get_ctx(server, [
359            b'<some_call xmlns="tns">'
360                b'<c>'
361                    # no mandatory elements here...
362                b'</c>'
363            b'</some_call>'
364        ])
365        self.assertRaises(SchemaValidationError, server.get_out_object, ctx)
366
367    def test_mandatory_element_attributes(self):
368        class C(ComplexModel):
369            bar = XmlAttribute(M(Unicode))
370
371        class SomeService(Service):
372            @srpc(C.customize(min_occurs=1), _returns=Unicode)
373            def some_call(c):
374                assert c is not None
375                assert hasattr(c, 'foo')
376                assert c.foo == 'hello'
377                return c.foo
378
379        app = Application(
380            [SomeService], "tns", name="test_mandatory_element_attributes",
381            in_protocol=XmlDocument(validator='lxml'),
382            out_protocol=XmlDocument())
383        server = ServerBase(app)
384
385        ctx = self._get_ctx(server, [
386            b'<some_call xmlns="tns">'
387                # no mandatory elements at all...
388            b'</some_call>'
389        ])
390        self.assertRaises(SchemaValidationError, server.get_out_object, ctx)
391
392        ctx = self._get_ctx(server, [
393            b'<some_call xmlns="tns">'
394                b'<c>'
395                    # no mandatory elements here...
396                b'</c>'
397            b'</some_call>'
398        ])
399        self.assertRaises(SchemaValidationError, server.get_out_object, ctx)
400
401    def test_bare_sub_name_ns(self):
402        class Action(ComplexModel):
403            class Attributes(ComplexModel.Attributes):
404                sub_ns = "SOME_NS"
405                sub_name = "Action"
406            data = XmlData(Unicode)
407            must_understand = XmlAttribute(Unicode)
408
409        elt = get_object_as_xml(Action("x", must_understand="y"), Action)
410        eltstr = etree.tostring(elt)
411        print(eltstr)
412        assert eltstr == b'<ns0:Action xmlns:ns0="SOME_NS" must_understand="y">x</ns0:Action>'
413
414    def test_null_mandatory_attribute(self):
415        class Action (ComplexModel):
416            data = XmlAttribute(M(Unicode))
417
418        elt = get_object_as_xml(Action(), Action)
419        eltstr = etree.tostring(elt)
420        print(eltstr)
421        assert eltstr == b'<Action/>'
422
423    def test_bytearray(self):
424        v = b'aaaa'
425        elt = get_object_as_xml([v], ByteArray, 'B')
426        eltstr = etree.tostring(elt)
427        print(eltstr)
428        assert elt.text == b64encode(v).decode('ascii')
429
430    def test_any_xml_text(self):
431        v = u"<roots><bloody/></roots>"
432        elt = get_object_as_xml(v, AnyXml, 'B', no_namespace=True)
433        eltstr = etree.tostring(elt)
434        print(eltstr)
435        assert etree.tostring(elt[0], encoding="unicode") == v
436
437    def test_any_xml_bytes(self):
438        v = b"<roots><bloody/></roots>"
439
440        elt = get_object_as_xml(v, AnyXml, 'B', no_namespace=True)
441        eltstr = etree.tostring(elt)
442        print(eltstr)
443        assert etree.tostring(elt[0]) == v
444
445    def test_any_xml_elt(self):
446        v = E.roots(E.bloody(E.roots()))
447        elt = get_object_as_xml(v, AnyXml, 'B')
448        eltstr = etree.tostring(elt)
449        print(eltstr)
450        assert etree.tostring(elt[0]) == etree.tostring(v)
451
452    def test_file(self):
453        v = b'aaaa'
454        f = BytesIO(v)
455        elt = get_object_as_xml(File.Value(handle=f), File, 'B')
456        eltstr = etree.tostring(elt)
457        print(eltstr)
458        assert elt.text == b64encode(v).decode('ascii')
459
460    def test_fault_detail_as_dict(self):
461        elt = get_object_as_xml(Fault(detail={"this": "that"}), Fault)
462        eltstr = etree.tostring(elt)
463        print(eltstr)
464        assert b'<detail><this>that</this></detail>' in eltstr
465
466    def test_xml_encoding(self):
467        ctx = FakeContext(out_document=E.rain(u"yağmur"))
468        XmlDocument(encoding='iso-8859-9').create_out_string(ctx)
469        s = b''.join(ctx.out_string)
470        assert u"ğ".encode('iso-8859-9') in s
471
472    def test_default(self):
473        class SomeComplexModel(ComplexModel):
474            _type_info = [
475                ('a', Unicode),
476                ('b', Unicode(default='default')),
477            ]
478
479        obj = XmlDocument().from_element(
480            None, SomeComplexModel,
481            etree.fromstring("""
482                <hey>
483                    <a>string</a>
484                </hey>
485            """)
486        )
487
488        # xml schema says it should be None
489        assert obj.b == 'default'
490
491        obj = XmlDocument().from_element(
492            None, SomeComplexModel,
493            etree.fromstring("""
494                <hey>
495                    <a>string</a>
496                    <b xsi:nil="true" xmlns:xsi="%s"/>
497                </hey>
498            """ % NS_XSI)
499        )
500
501        # xml schema says it should be 'default'
502        assert obj.b == 'default'
503
504        obj = XmlDocument(replace_null_with_default=False).from_element(
505            None, SomeComplexModel,
506            etree.fromstring("""
507                <hey>
508                    <a>string</a>
509                    <b xsi:nil="true" xmlns:xsi="%s"/>
510                </hey>
511            """ % NS_XSI)
512        )
513
514        # xml schema says it should be 'default'
515        assert obj.b is None
516
517    def test_polymorphic_roundtrip(self):
518
519        class B(ComplexModel):
520            __namespace__ = 'some_ns'
521            _type_info = {
522                '_b': Unicode,
523            }
524
525            def __init__(self):
526                super(B, self).__init__()
527                self._b = "b"
528
529        class C(B):
530            __namespace__ = 'some_ns'
531            _type_info = {
532                '_c': Unicode,
533            }
534
535            def __init__(self):
536                super(C, self).__init__()
537                self._c = "c"
538
539        class A(ComplexModel):
540            __namespace__ = 'some_ns'
541            _type_info = {
542                '_a': Unicode,
543                '_b': B,
544            }
545
546            def __init__(self, b=None):
547                super(A, self).__init__()
548                self._a = 'a'
549                self._b = b
550
551        a = A(b=C())
552        elt = get_object_as_xml_polymorphic(a, A)
553        xml_string = etree.tostring(elt, pretty_print=True)
554        if six.PY2:
555            print(xml_string, end="")
556        else:
557            sys.stdout.buffer.write(xml_string)
558
559        element_tree = etree.fromstring(xml_string)
560        new_a = get_xml_as_object_polymorphic(elt, A)
561
562        assert new_a._a == a._a, (a._a, new_a._a)
563        assert new_a._b._b == a._b._b, (a._b._b, new_a._b._b)
564        assert new_a._b._c == a._b._c, (a._b._c, new_a._b._c)
565
566
567class TestIncremental(unittest.TestCase):
568    def test_one(self):
569        class SomeComplexModel(ComplexModel):
570            s = Unicode
571            i = Integer
572
573        v = SomeComplexModel(s='a', i=1),
574
575        class SomeService(Service):
576            @rpc(_returns=SomeComplexModel)
577            def get(ctx):
578                return v
579
580        desc = SomeService.public_methods['get']
581        ctx = FakeContext(out_object=v, descriptor=desc)
582        ostr = ctx.out_stream = BytesIO()
583        XmlDocument(Application([SomeService], __name__)) \
584                             .serialize(ctx, XmlDocument.RESPONSE)
585
586        elt = etree.fromstring(ostr.getvalue())
587        print(etree.tostring(elt, pretty_print=True))
588
589        assert elt.xpath('x:getResult/x:i/text()',
590                                            namespaces={'x':__name__}) == ['1']
591        assert elt.xpath('x:getResult/x:s/text()',
592                                            namespaces={'x':__name__}) == ['a']
593
594    def test_many(self):
595        class SomeComplexModel(ComplexModel):
596            s = Unicode
597            i = Integer
598
599        v = [
600            SomeComplexModel(s='a', i=1),
601            SomeComplexModel(s='b', i=2),
602            SomeComplexModel(s='c', i=3),
603            SomeComplexModel(s='d', i=4),
604            SomeComplexModel(s='e', i=5),
605        ]
606
607        class SomeService(Service):
608            @rpc(_returns=Array(SomeComplexModel))
609            def get(ctx):
610                return v
611
612        desc = SomeService.public_methods['get']
613        ctx = FakeContext(out_object=[v], descriptor=desc)
614        ostr = ctx.out_stream = BytesIO()
615        XmlDocument(Application([SomeService], __name__)) \
616                            .serialize(ctx, XmlDocument.RESPONSE)
617
618        elt = etree.fromstring(ostr.getvalue())
619        print(etree.tostring(elt, pretty_print=True))
620
621        assert elt.xpath('x:getResult/x:SomeComplexModel/x:i/text()',
622                        namespaces={'x': __name__}) == ['1', '2', '3', '4', '5']
623        assert elt.xpath('x:getResult/x:SomeComplexModel/x:s/text()',
624                        namespaces={'x': __name__}) == ['a', 'b', 'c', 'd', 'e']
625
626
627if __name__ == '__main__':
628    unittest.main()
629