1
2#
3# spyne - Copyright (C) Spyne contributors.
4#
5# This library is free software; you can redistribute it and/or
6# modify it under the terms of the GNU Lesser General Public
7# License as published by the Free Software Foundation; either
8# version 2.1 of the License, or (at your option) any later version.
9#
10# This library is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13# Lesser General Public License for more details.
14#
15# You should have received a copy of the GNU Lesser General Public
16# License along with this library; if not, write to the Free Software
17# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301
18#
19
20import logging
21logger = logging.getLogger(__name__)
22
23import os
24import json
25import shutil
26
27import sqlalchemy.dialects
28
29from uuid import uuid1
30from mmap import mmap, ACCESS_READ
31from contextlib import closing
32from os.path import join, abspath, dirname, basename, isfile
33
34try:
35    from lxml import etree
36    from lxml import html
37    from spyne.util.xml import get_object_as_xml, get_xml_as_object
38
39except ImportError as _import_error:
40    etree = None
41    html = None
42
43    _local_import_error = _import_error
44    def get_object_as_xml(*_, **__):
45        raise _local_import_error
46    def get_xml_as_object(*_, **__):
47        raise _local_import_error
48
49from sqlalchemy.sql.type_api import UserDefinedType
50
51from spyne import ValidationError
52from spyne.model.relational import FileData
53
54from spyne.util import six
55from spyne.util.six import binary_type, text_type, BytesIO, StringIO
56from spyne.util.fileproxy import SeekableFileProxy
57
58
59class PGXml(UserDefinedType):
60    def __init__(self, pretty_print=False, xml_declaration=False,
61                                                              encoding='UTF-8'):
62        super(PGXml, self).__init__()
63        self.xml_declaration = xml_declaration
64        self.pretty_print = pretty_print
65        self.encoding = encoding
66
67    def get_col_spec(self, **_):
68        return "xml"
69
70    def bind_processor(self, dialect):
71        def process(value):
72            if value is None or \
73                            isinstance(value, (six.text_type, six.binary_type)):
74                return value
75
76            if six.PY2:
77                return etree.tostring(value, pretty_print=self.pretty_print,
78                                  encoding=self.encoding, xml_declaration=False)
79
80            return etree.tostring(value, pretty_print=self.pretty_print,
81                                      encoding="unicode", xml_declaration=False)
82
83        return process
84
85    def result_processor(self, dialect, col_type):
86        def process(value):
87            if value is not None:
88                return etree.fromstring(value)
89            else:
90                return value
91        return process
92
93sqlalchemy.dialects.postgresql.base.ischema_names['xml'] = PGXml
94
95
96class PGHtml(UserDefinedType):
97    def __init__(self, pretty_print=False, encoding='UTF-8'):
98        super(PGHtml, self).__init__()
99
100        self.pretty_print = pretty_print
101        self.encoding = encoding
102
103    def get_col_spec(self, **_):
104        return "text"
105
106    def bind_processor(self, dialect):
107        def process(value):
108            if isinstance(value, (six.text_type, six.binary_type)) \
109                                                               or value is None:
110                return value
111            else:
112                return html.tostring(value, pretty_print=self.pretty_print,
113                                                         encoding=self.encoding)
114        return process
115
116    def result_processor(self, dialect, col_type):
117        def process(value):
118            if value is not None and len(value) > 0:
119                return html.fromstring(value)
120            else:
121                return None
122        return process
123
124
125class PGJson(UserDefinedType):
126    def __init__(self, encoding='UTF-8'):
127        self.encoding = encoding
128
129    def get_col_spec(self, **_):
130        return "json"
131
132    def bind_processor(self, dialect):
133        def process(value):
134            if isinstance(value, (text_type, binary_type)) or value is None:
135                return value
136            else:
137                if six.PY2:
138                    return json.dumps(value, encoding=self.encoding)
139                else:
140                    return json.dumps(value)
141        return process
142
143    def result_processor(self, dialect, col_type):
144        def process(value):
145            if isinstance(value, (text_type, binary_type)):
146                return json.loads(value)
147            else:
148                return value
149        return process
150
151sqlalchemy.dialects.postgresql.base.ischema_names['json'] = PGJson
152
153
154class PGJsonB(PGJson):
155    def get_col_spec(self, **_):
156        return "jsonb"
157
158
159sqlalchemy.dialects.postgresql.base.ischema_names['jsonb'] = PGJsonB
160
161
162class PGObjectXml(UserDefinedType):
163    def __init__(self, cls, root_tag_name=None, no_namespace=False,
164                                                            pretty_print=False):
165        self.cls = cls
166        self.root_tag_name = root_tag_name
167        self.no_namespace = no_namespace
168        self.pretty_print = pretty_print
169
170    def get_col_spec(self, **_):
171        return "xml"
172
173    def bind_processor(self, dialect):
174        def process(value):
175            if value is not None:
176                return etree.tostring(get_object_as_xml(value, self.cls,
177                    self.root_tag_name, self.no_namespace), encoding='utf8',
178                          pretty_print=self.pretty_print, xml_declaration=False)
179        return process
180
181    def result_processor(self, dialect, col_type):
182        def process(value):
183            if value is not None:
184                return get_xml_as_object(etree.fromstring(value), self.cls)
185        return process
186
187
188class PGObjectJson(UserDefinedType):
189    def __init__(self, cls, ignore_wrappers=True, complex_as=dict, dbt='json',
190                                                               encoding='utf8'):
191        self.cls = cls
192        self.ignore_wrappers = ignore_wrappers
193        self.complex_as = complex_as
194        self.dbt = dbt
195        self.encoding = encoding
196
197        from spyne.util.dictdoc import get_dict_as_object
198        from spyne.util.dictdoc import get_object_as_json
199        self.get_object_as_json = get_object_as_json
200        self.get_dict_as_object = get_dict_as_object
201
202    def get_col_spec(self, **_):
203        return self.dbt
204
205    def bind_processor(self, dialect):
206        def process(value):
207            if value is not None:
208                try:
209                    return self.get_object_as_json(value, self.cls,
210                        ignore_wrappers=self.ignore_wrappers,
211                        complex_as=self.complex_as,
212                    ).decode(self.encoding)
213
214                except Exception as e:
215                    logger.debug("Failed to serialize %r to json: %r", value, e)
216                    raise
217
218        return process
219
220    def result_processor(self, dialect, col_type):
221        from spyne.util.dictdoc import JsonDocument
222
223        def process(value):
224            if value is None:
225                return None
226
227            if isinstance(value, six.binary_type):
228                value = value.decode(self.encoding)
229
230            if isinstance(value, six.text_type):
231                return self.get_dict_as_object(json.loads(value), self.cls,
232                        ignore_wrappers=self.ignore_wrappers,
233                        complex_as=self.complex_as,
234                        protocol=JsonDocument,
235                    )
236
237            return self.get_dict_as_object(value, self.cls,
238                    ignore_wrappers=self.ignore_wrappers,
239                    complex_as=self.complex_as,
240                    protocol=JsonDocument,
241                )
242
243        return process
244
245
246class PGFileJson(PGObjectJson):
247    def __init__(self, store, type=None, dbt='json'):
248        if type is None:
249            type = FileData
250
251        super(PGFileJson, self).__init__(type, ignore_wrappers=True,
252                                                       complex_as=list, dbt=dbt)
253        self.store = store
254
255    def bind_processor(self, dialect):
256        def process(value):
257            if value is not None:
258                if value.data is not None:
259                    value.path = uuid1().hex
260                    fp = join(self.store, value.path)
261                    if not abspath(fp).startswith(self.store):
262                        raise ValidationError(value.path, "Path %r contains "
263                                          "relative path operators (e.g. '..')")
264
265                    with open(fp, 'wb') as file:
266                        for d in value.data:
267                            file.write(d)
268
269                elif value.handle is not None:
270                    value.path = uuid1().hex
271                    fp = join(self.store, value.path)
272                    if not abspath(fp).startswith(self.store):
273                        raise ValidationError(value.path, "Path %r contains "
274                                          "relative path operators (e.g. '..')")
275
276                    if isinstance(value.handle, (StringIO, BytesIO)):
277                        with open(fp, 'wb') as out_file:
278                            out_file.write(value.handle.getvalue())
279                    else:
280                        with closing(mmap(value.handle.fileno(), 0,
281                                                   access=ACCESS_READ)) as data:
282                            with open(fp, 'wb') as out_file:
283                                out_file.write(data)
284
285                elif value.path is not None:
286                    in_file_path = value.path
287
288                    if not isfile(in_file_path):
289                        logger.error("File path in %r not found" % value)
290
291                    if dirname(abspath(in_file_path)) != self.store:
292                        dest = join(self.store, uuid1().get_hex())
293
294                        if value.move:
295                            shutil.move(in_file_path, dest)
296                            logger.debug("move '%s' => '%s'",
297                                                             in_file_path, dest)
298
299                        else:
300                            shutil.copy(in_file_path, dest)
301                            logger.debug("copy '%s' => '%s'",
302                                                             in_file_path, dest)
303
304                        value.path = basename(dest)
305                        value.abspath = dest
306
307                else:
308                    raise ValueError("Invalid file object passed in. All of "
309                                           ".data, .handle and .path are None.")
310
311                value.store = self.store
312                value.abspath = join(self.store, value.path)
313
314                return self.get_object_as_json(value, self.cls,
315                    ignore_wrappers=self.ignore_wrappers,
316                    complex_as=self.complex_as,
317                )
318
319        return process
320
321    def result_processor(self, dialect, col_type):
322        def process(value):
323            if value is None:
324                return None
325
326            if isinstance(value, six.text_type):
327                value = json.loads(value)
328
329            elif isinstance(value, six.binary_type):
330                value = json.loads(value.decode('utf8'))
331
332            retval = self.get_dict_as_object(value, self.cls,
333                    ignore_wrappers=self.ignore_wrappers,
334                    complex_as=self.complex_as)
335
336            retval.store = self.store
337            retval.abspath = path = join(self.store, retval.path)
338            retval.handle = None
339            retval.data = [b'']
340
341            if not os.access(path, os.R_OK):
342                logger.error("File %r is not readable", path)
343                return retval
344
345            h = retval.handle = SeekableFileProxy(open(path, 'rb'))
346            if os.fstat(retval.handle.fileno()).st_size > 0:
347                h.mmap = mmap(h.fileno(), 0, access=ACCESS_READ)
348                retval.data = (h.mmap,)
349                # FIXME: Where do we close this mmap?
350
351            return retval
352
353        return process
354