1#!/usr/bin/env python
2
3from __future__ import print_function
4
5import json
6import decimal
7import unittest
8
9import pytz
10import sqlalchemy
11
12from pprint import pprint
13from decimal import Decimal as D
14from datetime import datetime
15
16from lxml import etree
17
18from spyne.const import MAX_STRING_FIELD_LENGTH
19
20from spyne.decorator import srpc
21from spyne.application import Application
22
23from spyne.model.complex import XmlAttribute, TypeInfo
24from spyne.model.complex import ComplexModel
25from spyne.model.complex import Iterable
26from spyne.model.complex import Array
27from spyne.model.primitive import Decimal
28from spyne.model.primitive import DateTime
29from spyne.model.primitive import Integer
30from spyne.model.primitive import Unicode
31
32from spyne.service import Service
33
34from spyne.util import AttrDict, AttrDictColl, get_version
35from spyne.util import memoize, memoize_ignore_none, memoize_ignore, memoize_id
36
37from spyne.util.protocol import deserialize_request_string
38
39from spyne.util.dictdoc import get_dict_as_object, get_object_as_yaml, \
40    get_object_as_json
41from spyne.util.dictdoc import get_object_as_dict
42from spyne.util.tdict import tdict
43from spyne.util.tlist import tlist
44
45from spyne.util.xml import get_object_as_xml
46from spyne.util.xml import get_xml_as_object
47from spyne.util.xml import get_schema_documents
48from spyne.util.xml import get_validation_schema
49
50
51class TestUtil(unittest.TestCase):
52    def test_version(self):
53        assert get_version('sqlalchemy') == get_version(sqlalchemy)
54        assert '.'.join([str(i) for i in get_version('sqlalchemy')]) == \
55                                                          sqlalchemy.__version__
56
57
58class TestTypeInfo(unittest.TestCase):
59    def test_insert(self):
60        d = TypeInfo()
61
62        d['a'] = 1
63        assert d[0] == d['a'] == 1
64
65        d.insert(0, ('b', 2))
66
67        assert d[1] == d['a'] == 1
68        assert d[0] == d['b'] == 2
69
70    def test_insert_existing(self):
71        d = TypeInfo()
72
73        d["a"] = 1
74        d["b"] = 2
75        assert d[1] == d['b'] == 2
76
77        d.insert(0, ('b', 3))
78        assert d[1] == d['a'] == 1
79        assert d[0] == d['b'] == 3
80
81    def test_update(self):
82        d = TypeInfo()
83        d["a"] = 1
84        d.update([('b', 2)])
85        assert d[0] == d['a'] == 1
86        assert d[1] == d['b'] == 2
87
88
89class TestXml(unittest.TestCase):
90    def test_serialize(self):
91
92        class C(ComplexModel):
93            __namespace__ = "tns"
94            i = Integer
95            s = Unicode
96
97        c = C(i=5, s="x")
98
99        ret = get_object_as_xml(c, C)
100        print(etree.tostring(ret))
101        assert ret.tag == "{tns}C"
102
103        ret = get_object_as_xml(c, C, "X")
104        print(etree.tostring(ret))
105        assert ret.tag == "{tns}X"
106
107        ret = get_object_as_xml(c, C, "X", no_namespace=True)
108        print(etree.tostring(ret))
109        assert ret.tag == "X"
110
111        ret = get_object_as_xml(c, C, no_namespace=True)
112        print(etree.tostring(ret))
113        assert ret.tag == "C"
114
115    def test_deserialize(self):
116        class Punk(ComplexModel):
117            __namespace__ = 'some_namespace'
118
119            a = Unicode
120            b = Integer
121            c = Decimal
122            d = DateTime
123
124        class Foo(ComplexModel):
125            __namespace__ = 'some_other_namespace'
126
127            a = Unicode
128            b = Integer
129            c = Decimal
130            d = DateTime
131            e = XmlAttribute(Integer)
132
133            def __eq__(self, other):
134                # remember that this is a test object
135                assert (
136                    self.a == other.a and
137                    self.b == other.b and
138                    self.c == other.c and
139                    self.d == other.d and
140                    self.e == other.e
141                )
142
143                return True
144
145        docs = get_schema_documents([Punk, Foo])
146        pprint(docs)
147        assert docs['s0'].tag == '{http://www.w3.org/2001/XMLSchema}schema'
148        assert docs['tns'].tag == '{http://www.w3.org/2001/XMLSchema}schema'
149        print()
150
151        print("the other namespace %r:" % docs['tns'].attrib['targetNamespace'])
152        assert docs['tns'].attrib['targetNamespace'] == 'some_namespace'
153        print(etree.tostring(docs['tns'], pretty_print=True))
154        print()
155
156        print("the other namespace %r:" % docs['s0'].attrib['targetNamespace'])
157        assert docs['s0'].attrib['targetNamespace'] == 'some_other_namespace'
158        print(etree.tostring(docs['s0'], pretty_print=True))
159        print()
160
161        foo = Foo(a=u'a', b=1, c=decimal.Decimal('3.4'),
162                                    d=datetime(2011,2,20,tzinfo=pytz.utc), e=5)
163        doc = get_object_as_xml(foo, Foo)
164        print(etree.tostring(doc, pretty_print=True))
165        foo_back = get_xml_as_object(doc, Foo)
166
167        assert foo_back == foo
168
169        # as long as it doesn't fail, it's ok.
170        get_validation_schema([Punk, Foo])
171
172
173class TestCDict(unittest.TestCase):
174    def test_cdict(self):
175        from spyne.util.cdict import cdict
176
177        class A(object):
178            pass
179
180        class B(A):
181            pass
182
183        class E(B):
184            pass
185
186        class F(E):
187            pass
188
189        class C(object):
190            pass
191
192        d = cdict({A: "fun", F: 'zan'})
193
194        assert d[A] == 'fun'
195        assert d[B] == 'fun'
196        assert d[F] == 'zan'
197        try:
198            d[C]
199        except KeyError:
200            pass
201        else:
202            raise Exception("Must fail.")
203
204
205class TestTDict(unittest.TestCase):
206    def test_tdict_notype(self):
207        d = tdict()
208        d[0] = 1
209        assert d[0] == 1
210
211        d = tdict()
212        d.update({0:1})
213        assert d[0] == 1
214
215        d = tdict.fromkeys([0], 1)
216        assert d[0] == 1
217
218    def test_tdict_k(self):
219        d = tdict(str)
220        try:
221            d[0] = 1
222        except TypeError:
223            pass
224        else:
225            raise Exception("must fail")
226
227        d = tdict(str)
228        d['s'] = 1
229        assert d['s'] == 1
230
231    def test_tdict_v(self):
232        d = tdict(vt=str)
233        try:
234            d[0] = 1
235        except TypeError:
236            pass
237        else:
238            raise Exception("must fail")
239
240        d = tdict(vt=str)
241        d[0] = 's'
242        assert d[0] == 's'
243
244
245class TestLogRepr(unittest.TestCase):
246    def test_log_repr_simple(self):
247        from spyne.model.complex import ComplexModel
248        from spyne.model.primitive import String
249        from spyne.util.web import log_repr
250
251        class Z(ComplexModel):
252            z=String
253
254        l = MAX_STRING_FIELD_LENGTH + 100
255        print(log_repr(Z(z="a" * l)))
256        print("Z(z='%s'(...))" % ('a' * MAX_STRING_FIELD_LENGTH))
257
258        assert log_repr(Z(z="a" * l)) == "Z(z='%s'(...))" % \
259                                                ('a' * MAX_STRING_FIELD_LENGTH)
260        assert log_repr(['a','b','c'], Array(String)) ==  "['a', 'b', (...)]"
261
262    def test_log_repr_complex(self):
263        from spyne.model import ByteArray
264        from spyne.model import File
265        from spyne.model.complex import ComplexModel
266        from spyne.model.primitive import String
267        from spyne.util.web import log_repr
268
269        class Z(ComplexModel):
270            _type_info = [
271                ('f', File(logged=False)),
272                ('t', ByteArray(logged=False)),
273                ('z', Array(String)),
274            ]
275        l = MAX_STRING_FIELD_LENGTH + 100
276        val = Z(z=["abc"] * l, t=['t'], f=File.Value(name='aaa', data=['t']))
277        print(repr(val))
278
279        assert log_repr(val) == "Z(z=['abc', 'abc', (...)])"
280
281    def test_log_repr_dict_vanilla(self):
282        from spyne.model import AnyDict
283        from spyne.util.web import log_repr
284
285        t = AnyDict
286
287        assert log_repr({1: 1}, t) == "{1: 1}"
288        assert log_repr({1: 1, 2: 2}, t) == "{1: 1, 2: 2}"
289        assert log_repr({1: 1, 2: 2, 3: 3}, t) == "{1: 1, 2: 2, (...)}"
290
291        assert log_repr([1], t) == "[1]"
292        assert log_repr([1, 2], t) == "[1, 2]"
293        assert log_repr([1, 2, 3], t) == "[1, 2, (...)]"
294
295    def test_log_repr_dict_keys(self):
296        from spyne.model import AnyDict
297        from spyne.util.web import log_repr
298
299        t = AnyDict(logged='keys')
300
301        assert log_repr({1: 1}, t) == "{1: (...)}"
302
303        assert log_repr([1], t) == "[1]"
304
305    def test_log_repr_dict_values(self):
306        from spyne.model import AnyDict
307        from spyne.util.web import log_repr
308
309        t = AnyDict(logged='values')
310
311        assert log_repr({1: 1}, t) == "{(...): 1}"
312
313        assert log_repr([1], t) == "[1]"
314
315    def test_log_repr_dict_full(self):
316        from spyne.model import AnyDict
317        from spyne.util.web import log_repr
318
319        t = AnyDict(logged='full')
320
321        assert log_repr({1: 1, 2: 2, 3: 3}, t) == "{1: 1, 2: 2, 3: 3}"
322        assert log_repr([1, 2, 3], t) == "[1, 2, 3]"
323
324    def test_log_repr_dict_keys_full(self):
325        from spyne.model import AnyDict
326        from spyne.util.web import log_repr
327
328        t = AnyDict(logged='keys-full')
329
330        assert log_repr({1: 1, 2: 2, 3: 3}, t) == "{1: (...), 2: (...), 3: (...)}"
331        assert log_repr([1, 2, 3], t) == "[1, 2, 3]"
332
333    def test_log_repr_dict_values_full(self):
334        from spyne.model import AnyDict
335        from spyne.util.web import log_repr
336
337        t = AnyDict(logged='values-full')
338
339        assert log_repr({1: 1, 2: 2, 3: 3}, t) == "{(...): 1, (...): 2, (...): 3}"
340        assert log_repr([1, 2, 3], t) == "[1, 2, 3]"
341
342
343class TestDeserialize(unittest.TestCase):
344    def test_deserialize(self):
345        from spyne.protocol.soap import Soap11
346
347        class SomeService(Service):
348            @srpc(Integer, _returns=Iterable(Integer))
349            def some_call(yo):
350                return range(yo)
351
352        app = Application([SomeService], 'tns', in_protocol=Soap11(),
353                                                out_protocol=Soap11())
354
355        meat = 30
356
357        string = """
358            <x:Envelope xmlns:x="http://schemas.xmlsoap.org/soap/envelope/">
359                <x:Body>
360                    <tns:some_call xmlns:tns="tns">
361                        <tns:yo>%s</tns:yo>
362                    </tns:some_call>
363                </x:Body>
364            </x:Envelope>
365        """ % meat
366
367        obj = deserialize_request_string(string, app)
368
369        assert obj.yo == meat
370
371
372class TestEtreeDict(unittest.TestCase):
373
374    longMessage = True
375
376    def test_simple(self):
377        from lxml.etree import tostring
378        from spyne.util.etreeconv import root_dict_to_etree
379        assert tostring(root_dict_to_etree({'a':{'b':'c'}})) == b'<a><b>c</b></a>'
380
381    def test_not_sized(self):
382        from lxml.etree import tostring
383        from spyne.util.etreeconv import root_dict_to_etree
384
385        complex_value = root_dict_to_etree({'a':{'b':1}})
386        self.assertEqual(tostring(complex_value), b'<a><b>1</b></a>',
387            "The integer should be properly rendered in the etree")
388
389        complex_none = root_dict_to_etree({'a':{'b':None}})
390        self.assertEqual(tostring(complex_none), b'<a><b/></a>',
391            "None should not be rendered in the etree")
392
393        simple_value = root_dict_to_etree({'a': 1})
394        self.assertEqual(tostring(simple_value), b'<a>1</a>',
395            "The integer should be properly rendered in the etree")
396
397        none_value = root_dict_to_etree({'a': None})
398        self.assertEqual(tostring(none_value), b'<a/>',
399            "None should not be rendered in the etree")
400
401        string_value = root_dict_to_etree({'a': 'lol'})
402        self.assertEqual(tostring(string_value), b'<a>lol</a>',
403            "A string should be rendered as a string")
404
405        complex_string_value = root_dict_to_etree({'a': {'b': 'lol'}})
406        self.assertEqual(tostring(complex_string_value), b'<a><b>lol</b></a>',
407            "A string should be rendered as a string")
408
409
410class TestDictDoc(unittest.TestCase):
411    def test_the(self):
412        class C(ComplexModel):
413            __namespace__ = "tns"
414            i = Integer
415            s = Unicode
416            a = Array(DateTime)
417
418            def __eq__(self, other):
419                print("Yaaay!")
420                return  self.i == other.i and \
421                        self.s == other.s and \
422                        self.a == other.a
423
424        c = C(i=5, s="x", a=[datetime(2011,12,22, tzinfo=pytz.utc)])
425
426        for iw, ca in ((False,dict), (True,dict), (False,list), (True, list)):
427            print()
428            print('complex_as:', ca)
429            d = get_object_as_dict(c, C, complex_as=ca)
430            print(d)
431            o = get_dict_as_object(d, C, complex_as=ca)
432            print(o)
433            print(c)
434            assert o == c
435
436
437class TestAttrDict(unittest.TestCase):
438    def test_attr_dict(self):
439        assert AttrDict(a=1)['a'] == 1
440
441    def test_attr_dict_coll(self):
442        assert AttrDictColl('SomeDict').SomeDict.NAME == 'SomeDict'
443        assert AttrDictColl('SomeDict').SomeDict(a=1)['a'] == 1
444        assert AttrDictColl('SomeDict').SomeDict(a=1).NAME == 'SomeDict'
445
446
447class TestYaml(unittest.TestCase):
448    def test_deser(self):
449        class C(ComplexModel):
450            a = Unicode
451            b = Decimal
452
453        ret = get_object_as_yaml(C(a='burak', b=D(30)), C)
454        assert ret == b"""C:
455    a: burak
456    b: '30'
457"""
458
459
460class TestJson(unittest.TestCase):
461    def test_deser(self):
462        class C(ComplexModel):
463            _type_info = [
464                ('a', Unicode),
465                ('b', Decimal),
466            ]
467
468        ret = get_object_as_json(C(a='burak', b=D(30)), C)
469        assert ret == b'["burak", "30"]'
470        ret = get_object_as_json(C(a='burak', b=D(30)), C, complex_as=dict)
471        assert json.loads(ret.decode('utf8')) == \
472                                        json.loads(u'{"a": "burak", "b": "30"}')
473
474
475class TestFifo(unittest.TestCase):
476    def test_msgpack_fifo(self):
477        import msgpack
478
479        v1 = [1, 2, 3, 4]
480        v2 = [5, 6, 7, 8]
481        v3 = {b"a": 9, b"b": 10, b"c": 11}
482
483        s1 = msgpack.packb(v1)
484        s2 = msgpack.packb(v2)
485        s3 = msgpack.packb(v3)
486
487        unpacker = msgpack.Unpacker()
488        unpacker.feed(s1)
489        unpacker.feed(s2)
490        unpacker.feed(s3[:4])
491
492        assert next(iter(unpacker)) == v1
493        assert next(iter(unpacker)) == v2
494        try:
495            next(iter(unpacker))
496        except StopIteration:
497            pass
498        else:
499            raise Exception("must fail")
500
501        unpacker.feed(s3[4:])
502        assert next(iter(unpacker)) == v3
503
504
505class TestTlist(unittest.TestCase):
506    def test_tlist(self):
507        tlist([], int)
508
509        a = tlist([1, 2], int)
510        a.append(3)
511        a += [4]
512        a = [5] + [a]
513        a = a + [6]
514        a[0] = 1
515        a[5:] = [5]
516
517        try:
518            tlist([1, 2, 'a'], int)
519            a.append('a')
520            a += ['a']
521            _ = ['a'] + a
522            _ = a + ['a']
523            a[0] = 'a'
524            a[0:] = 'a'
525
526        except TypeError:
527            pass
528        else:
529            raise Exception("Must fail")
530
531
532class TestMemoization(unittest.TestCase):
533    def test_memoize(self):
534        counter = [0]
535        @memoize
536        def f(arg):
537            counter[0] += 1
538            print(arg, counter)
539
540        f(1)
541        f(1)
542        assert counter[0] == 1
543
544        f(2)
545        assert counter[0] == 2
546
547    def test_memoize_ignore_none(self):
548        counter = [0]
549        @memoize_ignore_none
550        def f(arg):
551            counter[0] += 1
552            print(arg, counter)
553            return arg
554
555        f(None)
556        f(None)
557        assert counter[0] == 2
558
559        f(1)
560        assert counter[0] == 3
561        f(1)
562        assert counter[0] == 3
563
564    def test_memoize_ignore_values(self):
565        counter = [0]
566        @memoize_ignore((1,))
567        def f(arg):
568            counter[0] += 1
569            print(arg, counter)
570            return arg
571
572        f(1)
573        f(1)
574        assert counter[0] == 2
575
576        f(2)
577        assert counter[0] == 3
578        f(2)
579        assert counter[0] == 3
580
581    def test_memoize_id(self):
582        counter = [0]
583        @memoize_id
584        def f(arg):
585            counter[0] += 1
586            print(arg, counter)
587            return arg
588
589        d = {}
590        f(d)
591        f(d)
592        assert counter[0] == 1
593
594        f({})
595        assert counter[0] == 2
596        f({})
597        assert counter[0] == 3
598
599
600if __name__ == '__main__':
601    unittest.main()
602