1from __future__ import absolute_import, print_function, division
2from pony.py23compat import PY2, basestring, unicode, buffer, int_types, iteritems
3
4import os, re, json
5from decimal import Decimal, InvalidOperation
6from datetime import datetime, date, time, timedelta
7from uuid import uuid4, UUID
8
9import pony
10from pony.utils import is_utf8, decorator, throw, localbase, deprecated
11from pony.converting import str2date, str2time, str2datetime, str2timedelta
12from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, TrackedArray, Json, QueryType, Array
13
14class DBException(Exception):
15    def __init__(exc, original_exc, *args):
16        args = args or getattr(original_exc, 'args', ())
17        Exception.__init__(exc, *args)
18        exc.original_exc = original_exc
19
20# Exception inheritance layout of DBAPI 2.0-compatible provider:
21#
22# Exception
23#   Warning
24#   Error
25#     InterfaceError
26#     DatabaseError
27#       DataError
28#       OperationalError
29#       IntegrityError
30#       InternalError
31#       ProgrammingError
32#       NotSupportedError
33
34class Warning(DBException): pass
35class Error(DBException): pass
36class   InterfaceError(Error): pass
37class   DatabaseError(Error): pass
38class     DataError(DatabaseError): pass
39class     OperationalError(DatabaseError): pass
40class     IntegrityError(DatabaseError): pass
41class     InternalError(DatabaseError): pass
42class     ProgrammingError(DatabaseError): pass
43class     NotSupportedError(DatabaseError): pass
44
45@decorator
46def wrap_dbapi_exceptions(func, provider, *args, **kwargs):
47    dbapi_module = provider.dbapi_module
48    should_retry = False
49    try:
50        try:
51            if provider.dialect != 'SQLite':
52                return func(provider, *args, **kwargs)
53            else:
54                provider.local_exceptions.keep_traceback = True
55                try: return func(provider, *args, **kwargs)
56                finally: provider.local_exceptions.keep_traceback = False
57        except dbapi_module.NotSupportedError as e: raise NotSupportedError(e)
58        except dbapi_module.ProgrammingError as e:
59            if provider.dialect == 'PostgreSQL':
60                msg = str(e)
61                if msg.startswith('operator does not exist:') and ' json ' in msg:
62                    msg += ' (Note: use column type `jsonb` instead of `json`)'
63                    raise ProgrammingError(e, msg, *e.args[1:])
64            raise ProgrammingError(e)
65        except dbapi_module.InternalError as e: raise InternalError(e)
66        except dbapi_module.IntegrityError as e: raise IntegrityError(e)
67        except dbapi_module.OperationalError as e:
68            if provider.dialect == 'PostgreSQL' and e.pgcode == '40001':
69                should_retry = True
70            if provider.dialect == 'SQLite':
71                provider.restore_exception()
72            raise OperationalError(e)
73        except dbapi_module.DataError as e: raise DataError(e)
74        except dbapi_module.DatabaseError as e: raise DatabaseError(e)
75        except dbapi_module.InterfaceError as e:
76            if e.args == (0, '') and getattr(dbapi_module, '__name__', None) == 'MySQLdb':
77                throw(InterfaceError, e, 'MySQL server misconfiguration')
78            raise InterfaceError(e)
79        except dbapi_module.Error as e: raise Error(e)
80        except dbapi_module.Warning as e: raise Warning(e)
81    except Exception as e:
82        if should_retry:
83            e.should_retry = True
84        raise
85
86def unexpected_args(attr, args):
87    throw(TypeError, 'Unexpected positional argument{} for attribute {}: {}'.format(
88        len(args) > 1 and 's' or '', attr, ', '.join(repr(arg) for arg in args))
89    )
90
91version_re = re.compile('[0-9\.]+')
92
93def get_version_tuple(s):
94    m = version_re.match(s)
95    if m is not None:
96        components = m.group(0).split('.')
97        return tuple(int(component) for component in components)
98    return None
99
100class DBAPIProvider(object):
101    paramstyle = 'qmark'
102    quote_char = '"'
103    max_params_count = 999
104    max_name_len = 128
105    table_if_not_exists_syntax = True
106    index_if_not_exists_syntax = True
107    max_time_precision = default_time_precision = 6
108    uint64_support = False
109
110    # SQLite and PostgreSQL does not limit varchar max length.
111    varchar_default_max_len = None
112
113    dialect = None
114    dbapi_module = None
115    dbschema_cls = None
116    translator_cls = None
117    sqlbuilder_cls = None
118    array_converter_cls = None
119
120    name_before_table = 'schema_name'
121    default_schema_name = None
122
123    fk_types = { 'SERIAL' : 'INTEGER', 'BIGSERIAL' : 'BIGINT' }
124
125    def __init__(provider, *args, **kwargs):
126        pool_mockup = kwargs.pop('pony_pool_mockup', None)
127        call_on_connect = kwargs.pop('pony_call_on_connect', None)
128        if pool_mockup: provider.pool = pool_mockup
129        else: provider.pool = provider.get_pool(*args, **kwargs)
130        connection, is_new_connection = provider.connect()
131        if call_on_connect:
132            call_on_connect(connection)
133        provider.inspect_connection(connection)
134        provider.release(connection)
135
136    @wrap_dbapi_exceptions
137    def inspect_connection(provider, connection):
138        pass
139
140    def normalize_name(provider, name):
141        return name[:provider.max_name_len]
142
143    def get_default_entity_table_name(provider, entity):
144        return provider.normalize_name(entity.__name__)
145
146    def get_default_m2m_table_name(provider, attr, reverse):
147        if attr.symmetric:
148            assert reverse is attr
149            name = attr.entity.__name__ + '_' + attr.name
150        else:
151            name = attr.entity.__name__ + '_' + reverse.entity.__name__
152        return provider.normalize_name(name)
153
154    def get_default_column_names(provider, attr, reverse_pk_columns=None):
155        normalize_name = provider.normalize_name
156        if reverse_pk_columns is None:
157            return [ normalize_name(attr.name) ]
158        elif len(reverse_pk_columns) == 1:
159            return [ normalize_name(attr.name) ]
160        else:
161            prefix = attr.name + '_'
162            return [ normalize_name(prefix + column) for column in reverse_pk_columns ]
163
164    def get_default_m2m_column_names(provider, entity):
165        normalize_name = provider.normalize_name
166        columns = entity._get_pk_columns_()
167        if len(columns) == 1:
168            return [ normalize_name(entity.__name__.lower()) ]
169        else:
170            prefix = entity.__name__.lower() + '_'
171            return [ normalize_name(prefix + column) for column in columns ]
172
173    def get_default_index_name(provider, table_name, column_names, is_pk=False, is_unique=False, m2m=False):
174        if is_pk: index_name = 'pk_%s' % provider.base_name(table_name)
175        else:
176            if is_unique: template = 'unq_%(tname)s__%(cnames)s'
177            elif m2m: template = 'idx_%(tname)s'
178            else: template = 'idx_%(tname)s__%(cnames)s'
179            index_name = template % dict(tname=provider.base_name(table_name),
180                                         cnames='_'.join(name for name in column_names))
181        return provider.normalize_name(index_name.lower())
182
183    def get_default_fk_name(provider, child_table_name, parent_table_name, child_column_names):
184        fk_name = 'fk_%s__%s' % (provider.base_name(child_table_name), '__'.join(child_column_names))
185        return provider.normalize_name(fk_name.lower())
186
187    def split_table_name(provider, table_name):
188        if isinstance(table_name, basestring): return provider.default_schema_name, table_name
189        if not table_name: throw(TypeError, 'Invalid table name: %r' % table_name)
190        if len(table_name) != 2:
191            size = len(table_name)
192            throw(TypeError, '%s qualified table name must have two components: '
193                             '%s and table_name. Got %d component%s: %s'
194                             % (provider.dialect, provider.name_before_table,
195                                size, 's' if size != 1 else '', table_name))
196        return table_name[0], table_name[1]
197
198    def base_name(provider, name):
199        if not isinstance(name, basestring):
200            assert type(name) is tuple
201            name = name[-1]
202            assert isinstance(name, basestring)
203        return name
204
205    def quote_name(provider, name):
206        quote_char = provider.quote_char
207        if isinstance(name, basestring):
208            name = name.replace(quote_char, quote_char+quote_char)
209            return quote_char + name + quote_char
210        return '.'.join(provider.quote_name(item) for item in name)
211
212    def format_table_name(provider, name):
213        return provider.quote_name(name)
214
215    def normalize_vars(provider, vars, vartypes):
216        for key, value in iteritems(vars):
217            vartype = vartypes[key]
218            if isinstance(vartype, QueryType):
219                vartypes[key], vars[key] = value._normalize_var(vartype)
220
221    def ast2sql(provider, ast):
222        builder = provider.sqlbuilder_cls(provider, ast)
223        return builder.sql, builder.adapter
224
225    def should_reconnect(provider, exc):
226        return False
227
228    @wrap_dbapi_exceptions
229    def connect(provider):
230        return provider.pool.connect()
231
232    @wrap_dbapi_exceptions
233    def set_transaction_mode(provider, connection, cache):
234        pass
235
236    @wrap_dbapi_exceptions
237    def commit(provider, connection, cache=None):
238        core = pony.orm.core
239        if core.local.debug: core.log_orm('COMMIT')
240        connection.commit()
241        if cache is not None: cache.in_transaction = False
242
243    @wrap_dbapi_exceptions
244    def rollback(provider, connection, cache=None):
245        core = pony.orm.core
246        if core.local.debug: core.log_orm('ROLLBACK')
247        connection.rollback()
248        if cache is not None: cache.in_transaction = False
249
250    @wrap_dbapi_exceptions
251    def release(provider, connection, cache=None):
252        core = pony.orm.core
253        if cache is not None and cache.db_session is not None and cache.db_session.ddl:
254            provider.drop(connection, cache)
255        else:
256            if core.local.debug: core.log_orm('RELEASE CONNECTION')
257            provider.pool.release(connection)
258
259    @wrap_dbapi_exceptions
260    def drop(provider, connection, cache=None):
261        core = pony.orm.core
262        if core.local.debug: core.log_orm('CLOSE CONNECTION')
263        provider.pool.drop(connection)
264        if cache is not None: cache.in_transaction = False
265
266    @wrap_dbapi_exceptions
267    def disconnect(provider):
268        core = pony.orm.core
269        if core.local.debug: core.log_orm('DISCONNECT')
270        provider.pool.disconnect()
271
272    @wrap_dbapi_exceptions
273    def execute(provider, cursor, sql, arguments=None, returning_id=False):
274        if type(arguments) is list:
275            assert arguments and not returning_id
276            cursor.executemany(sql, arguments)
277        else:
278            if arguments is None: cursor.execute(sql)
279            else: cursor.execute(sql, arguments)
280            if returning_id: return cursor.lastrowid
281
282    converter_classes = []
283
284    def _get_converter_type_by_py_type(provider, py_type):
285        if isinstance(py_type, type):
286            for t, converter_cls in provider.converter_classes:
287                if issubclass(py_type, t): return converter_cls
288            if issubclass(py_type, Array):
289                converter_cls = provider.array_converter_cls
290                if converter_cls is None:
291                    throw(NotImplementedError, 'Array type is not supported for %r' % provider.dialect)
292                return converter_cls
293        if isinstance(py_type, RawSQLType):
294            return Converter  # for cases like select(raw_sql(...) for x in X)
295        throw(TypeError, 'No database converter found for type %s' % py_type)
296
297    def get_converter_by_py_type(provider, py_type):
298        converter_cls = provider._get_converter_type_by_py_type(py_type)
299        return converter_cls(provider, py_type)
300
301    def get_converter_by_attr(provider, attr):
302        py_type = attr.py_type
303        converter_cls = provider._get_converter_type_by_py_type(py_type)
304        return converter_cls(provider, py_type, attr)
305
306    def get_pool(provider, *args, **kwargs):
307        return Pool(provider.dbapi_module, *args, **kwargs)
308
309    def table_exists(provider, connection, table_name, case_sensitive=True):
310        throw(NotImplementedError)
311
312    def index_exists(provider, connection, table_name, index_name, case_sensitive=True):
313        throw(NotImplementedError)
314
315    def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True):
316        throw(NotImplementedError)
317
318    def table_has_data(provider, connection, table_name):
319        cursor = connection.cursor()
320        cursor.execute('SELECT 1 FROM %s LIMIT 1' % provider.quote_name(table_name))
321        return cursor.fetchone() is not None
322
323    def disable_fk_checks(provider, connection):
324        pass
325
326    def enable_fk_checks(provider, connection, prev_state):
327        pass
328
329    def drop_table(provider, connection, table_name):
330        cursor = connection.cursor()
331        sql = 'DROP TABLE %s' % provider.quote_name(table_name)
332        cursor.execute(sql)
333
334class Pool(localbase):
335    forked_connections = []
336    def __init__(pool, dbapi_module, *args, **kwargs): # called separately in each thread
337        pool.dbapi_module = dbapi_module
338        pool.args = args
339        pool.kwargs = kwargs
340        pool.con = pool.pid = None
341    def connect(pool):
342        pid = os.getpid()
343        if pool.con is not None and pool.pid != pid:
344            pool.forked_connections.append((pool.con, pool.pid))
345            pool.con = pool.pid = None
346        core = pony.orm.core
347        is_new_connection = False
348        if pool.con is None:
349            if core.local.debug: core.log_orm('GET NEW CONNECTION')
350            is_new_connection = True
351            pool._connect()
352            pool.pid = pid
353        elif core.local.debug:
354            core.log_orm('GET CONNECTION FROM THE LOCAL POOL')
355        return pool.con, is_new_connection
356    def _connect(pool):
357        pool.con = pool.dbapi_module.connect(*pool.args, **pool.kwargs)
358    def release(pool, con):
359        assert con is pool.con
360        try: con.rollback()
361        except:
362            pool.drop(con)
363            raise
364    def drop(pool, con):
365        assert con is pool.con, (con, pool.con)
366        pool.con = None
367        con.close()
368    def disconnect(pool):
369        con = pool.con
370        pool.con = None
371        if con is not None: con.close()
372
373class Converter(object):
374    EQ = 'EQ'
375    NE = 'NE'
376    optimistic = True
377    def __deepcopy__(converter, memo):
378        return converter  # Converter instances are "immutable"
379    def __init__(converter, provider, py_type, attr=None):
380        converter.provider = provider
381        converter.py_type = py_type
382        converter.attr = attr
383        if attr is None: return
384        kwargs = attr.kwargs.copy()
385        converter.init(kwargs)
386        for option in kwargs: throw(TypeError, 'Attribute %s has unknown option %r' % (attr, option))
387    def init(converter, kwargs):
388        attr = converter.attr
389        if attr and attr.args: unexpected_args(attr, attr.args)
390    def validate(converter, val, obj=None):
391        return val
392    def py2sql(converter, val):
393        return val
394    def sql2py(converter, val):
395        return val
396    def val2dbval(self, val, obj=None):
397        return val
398    def dbval2val(self, dbval, obj=None):
399        return dbval
400    def dbvals_equal(self, x, y):
401        return x == y
402    def get_sql_type(converter, attr=None):
403        if attr is not None and attr.sql_type is not None:
404            return attr.sql_type
405        attr = converter.attr
406        if attr.sql_type is not None:
407            assert len(attr.columns) == 1
408            return converter.get_fk_type(attr.sql_type)
409        if attr is not None and attr.reverse and not attr.is_collection:
410            i = attr.converters.index(converter)
411            rentity = attr.reverse.entity
412            rpk_converters = rentity._pk_converters_
413            assert rpk_converters is not None and len(attr.converters) == len(rpk_converters)
414            rconverter = rpk_converters[i]
415            return rconverter.sql_type()
416        return converter.sql_type()
417    def get_fk_type(converter, sql_type):
418        fk_types = converter.provider.fk_types
419        if sql_type.isupper(): return fk_types.get(sql_type, sql_type)
420        sql_type = sql_type.upper()
421        return fk_types.get(sql_type, sql_type).lower()
422
423class NoneConverter(Converter):  # used for raw_sql() parameters only
424    def __init__(converter, provider, py_type, attr=None):
425        if attr is not None: throw(TypeError, 'Attribute %s has invalid type NoneType' % attr)
426        Converter.__init__(converter, provider, py_type)
427    def get_sql_type(converter, attr=None):
428        assert False
429    def get_fk_type(converter, sql_type):
430        assert False
431
432class BoolConverter(Converter):
433    def validate(converter, val, obj=None):
434        return bool(val)
435    def sql2py(converter, val):
436        return bool(val)
437    def sql_type(converter):
438        return "BOOLEAN"
439
440class StrConverter(Converter):
441    def __init__(converter, provider, py_type, attr=None):
442        converter.max_len = None
443        converter.db_encoding = None
444        Converter.__init__(converter, provider, py_type, attr)
445    def init(converter, kwargs):
446        attr = converter.attr
447        max_len = kwargs.pop('max_len', None)
448        if len(attr.args) > 1: unexpected_args(attr, attr.args[1:])
449        elif attr.args:
450            if max_len is not None: throw(TypeError,
451                'Max length option specified twice: as a positional argument and as a `max_len` named argument')
452            max_len = attr.args[0]
453        if issubclass(attr.py_type, (LongStr, LongUnicode)):
454            if max_len is not None: throw(TypeError, 'Max length is not supported for CLOBs')
455        elif max_len is None: max_len = converter.provider.varchar_default_max_len
456        elif not isinstance(max_len, int_types):
457            throw(TypeError, 'Max length argument must be int. Got: %r' % max_len)
458        converter.max_len = max_len
459        converter.db_encoding = kwargs.pop('db_encoding', None)
460        converter.autostrip = kwargs.pop('autostrip', True)
461    def validate(converter, val, obj=None):
462        if PY2 and isinstance(val, str): val = val.decode('ascii')
463        elif not isinstance(val, unicode): throw(TypeError,
464            'Value type for attribute %s must be %s. Got: %r' % (converter.attr, unicode.__name__, type(val)))
465        if converter.autostrip: val = val.strip()
466        max_len = converter.max_len
467        val_len = len(val)
468        if max_len and val_len > max_len:
469            throw(ValueError, 'Value for attribute %s is too long. Max length is %d, value length is %d'
470                             % (converter.attr, max_len, val_len))
471        return val
472    def sql_type(converter):
473        if converter.max_len:
474            return 'VARCHAR(%d)' % converter.max_len
475        return 'TEXT'
476
477class IntConverter(Converter):
478    signed_types = {None: 'INTEGER', 8: 'TINYINT', 16: 'SMALLINT', 24: 'MEDIUMINT', 32: 'INTEGER', 64: 'BIGINT'}
479    unsigned_types = None
480    def init(converter, kwargs):
481        Converter.init(converter, kwargs)
482        attr = converter.attr
483
484        min_val = kwargs.pop('min', None)
485        if min_val is not None and not isinstance(min_val, int_types):
486            throw(TypeError, "'min' argument for attribute %s must be int. Got: %r" % (attr, min_val))
487
488        max_val = kwargs.pop('max', None)
489        if max_val is not None and not isinstance(max_val, int_types):
490            throw(TypeError, "'max' argument for attribute %s must be int. Got: %r" % (attr, max_val))
491
492        size = kwargs.pop('size', None)
493        if size is None:
494            if attr.py_type.__name__ == 'long':
495                deprecated(9, "Attribute %s: 'long' attribute type is deprecated. "
496                              "Please use 'int' type with size=64 option instead" % attr)
497                attr.py_type = int
498                size = 64
499        elif attr.py_type.__name__ == 'long': throw(TypeError,
500            "Attribute %s: 'size' option cannot be used with long type. Please use int type instead" % attr)
501        elif not isinstance(size, int_types):
502            throw(TypeError, "'size' option for attribute %s must be of int type. Got: %r" % (attr, size))
503        elif size not in (8, 16, 24, 32, 64):
504            throw(TypeError, "incorrect value of 'size' option for attribute %s. "
505                             "Should be 8, 16, 24, 32 or 64. Got: %d" % (attr, size))
506
507        unsigned = kwargs.pop('unsigned', False)
508        if unsigned is not None and not isinstance(unsigned, bool):
509            throw(TypeError, "'unsigned' option for attribute %s must be of bool type. Got: %r" % (attr, unsigned))
510
511        if size == 64 and unsigned and not converter.provider.uint64_support: throw(TypeError,
512            'Attribute %s: %s provider does not support unsigned bigint type' % (attr, converter.provider.dialect))
513
514        if unsigned is not None and size is None: size = 32
515        lowest = highest = None
516        if size:
517            highest = highest = 2 ** size - 1 if unsigned else 2 ** (size - 1) - 1
518            lowest = 0 if unsigned else -(2 ** (size - 1))
519
520        if highest is not None and max_val is not None and max_val > highest:
521            throw(ValueError, "'max' argument should be less or equal to %d because of size=%d and unsigned=%s. "
522                              "Got: %d" % (highest, size, max_val, unsigned))
523
524        if lowest is not None and min_val is not None and min_val < lowest:
525            throw(ValueError, "'min' argument should be greater or equal to %d because of size=%d and unsigned=%s. "
526                              "Got: %d" % (lowest, size, min_val, unsigned))
527
528        converter.min_val = min_val or lowest
529        converter.max_val = max_val or highest
530        converter.size = size
531        converter.unsigned = unsigned
532    def validate(converter, val, obj=None):
533        if isinstance(val, int_types): pass
534        elif hasattr(val, '__index__'):
535            val = val.__index__()
536        elif isinstance(val, basestring):
537            try: val = int(val)
538            except ValueError: throw(ValueError,
539                'Value type for attribute %s must be int. Got string %r' % (converter.attr, val))
540        else: throw(TypeError, 'Value type for attribute %s must be int. Got: %r' % (converter.attr, type(val)))
541
542        if converter.min_val and val < converter.min_val:
543            throw(ValueError, 'Value %r of attr %s is less than the minimum allowed value %r'
544                             % (val, converter.attr, converter.min_val))
545        if converter.max_val and val > converter.max_val:
546            throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r'
547                             % (val, converter.attr, converter.max_val))
548        return val
549    def sql2py(converter, val):
550        return int(val)
551    def sql_type(converter):
552        if not converter.unsigned:
553            return converter.signed_types.get(converter.size)
554        if converter.unsigned_types is None:
555            return converter.signed_types.get(converter.size) + ' UNSIGNED'
556        return converter.unsigned_types.get(converter.size)
557
558class RealConverter(Converter):
559    EQ = 'FLOAT_EQ'
560    NE = 'FLOAT_NE'
561    # The tolerance is necessary for Oracle, because it has different representation of float numbers.
562    # For other databases the default tolerance is set because the precision can be lost during
563    # Python -> JavaScript -> Python conversion
564    default_tolerance = 1e-14
565    optimistic = False
566    def init(converter, kwargs):
567        Converter.init(converter, kwargs)
568        min_val = kwargs.pop('min', None)
569        if min_val is not None:
570            try: min_val = float(min_val)
571            except ValueError:
572                throw(TypeError, "Invalid value for 'min' argument for attribute %s: %r" % (converter.attr, min_val))
573        max_val = kwargs.pop('max', None)
574        if max_val is not None:
575            try: max_val = float(max_val)
576            except ValueError:
577                throw(TypeError, "Invalid value for 'max' argument for attribute %s: %r" % (converter.attr, max_val))
578        converter.min_val = min_val
579        converter.max_val = max_val
580        converter.tolerance = kwargs.pop('tolerance', converter.default_tolerance)
581    def validate(converter, val, obj=None):
582        try: val = float(val)
583        except ValueError:
584            throw(TypeError, 'Invalid value for attribute %s: %r' % (converter.attr, val))
585        if converter.min_val and val < converter.min_val:
586            throw(ValueError, 'Value %r of attr %s is less than the minimum allowed value %r'
587                             % (val, converter.attr, converter.min_val))
588        if converter.max_val and val > converter.max_val:
589            throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r'
590                             % (val, converter.attr, converter.max_val))
591        return val
592    def dbvals_equal(converter, x, y):
593        tolerance = converter.tolerance
594        if tolerance is None or x is None or y is None: return x == y
595        denominator = max(abs(x), abs(y))
596        if not denominator: return True
597        diff = abs(x-y) / denominator
598        return diff <= tolerance
599    def sql2py(converter, val):
600        return float(val)
601    def sql_type(converter):
602        return 'REAL'
603
604class DecimalConverter(Converter):
605    def __init__(converter, provider, py_type, attr=None):
606        converter.exp = None  # for the case when attr is None
607        Converter.__init__(converter, provider, py_type, attr)
608    def init(converter, kwargs):
609        attr = converter.attr
610        args = attr.args
611        if len(args) > 2: throw(TypeError, 'Too many positional parameters for Decimal '
612                                           '(expected: precision and scale), got: %s' % args)
613        if args: precision = args[0]
614        else: precision = kwargs.pop('precision', 12)
615        if not isinstance(precision, int_types):
616            throw(TypeError, "'precision' positional argument for attribute %s must be int. Got: %r" % (attr, precision))
617        if precision <= 0: throw(TypeError,
618            "'precision' positional argument for attribute %s must be positive. Got: %r" % (attr, precision))
619
620        if len(args) == 2: scale = args[1]
621        else: scale = kwargs.pop('scale', 2)
622        if not isinstance(scale, int_types):
623            throw(TypeError, "'scale' positional argument for attribute %s must be int. Got: %r" % (attr, scale))
624        if scale <= 0: throw(TypeError,
625            "'scale' positional argument for attribute %s must be positive. Got: %r" % (attr, scale))
626
627        if scale > precision: throw(ValueError, "'scale' must be less or equal 'precision'")
628        converter.precision = precision
629        converter.scale = scale
630        converter.exp = Decimal(10) ** -scale
631
632        min_val = kwargs.pop('min', None)
633        if min_val is not None:
634            try: min_val = Decimal(min_val)
635            except TypeError: throw(TypeError,
636                "Invalid value for 'min' argument for attribute %s: %r" % (attr, min_val))
637
638        max_val = kwargs.pop('max', None)
639        if max_val is not None:
640            try: max_val = Decimal(max_val)
641            except TypeError: throw(TypeError,
642                "Invalid value for 'max' argument for attribute %s: %r" % (attr, max_val))
643
644        converter.min_val = min_val
645        converter.max_val = max_val
646    def validate(converter, val, obj=None):
647        if isinstance(val, float):
648            s = str(val)
649            if float(s) != val: s = repr(val)
650            val = Decimal(s)
651        try: val = Decimal(val)
652        except InvalidOperation as exc:
653            throw(TypeError, 'Invalid value for attribute %s: %r' % (converter.attr, val))
654        if converter.min_val is not None and val < converter.min_val:
655            throw(ValueError, 'Value %r of attr %s is less than the minimum allowed value %r'
656                             % (val, converter.attr, converter.min_val))
657        if converter.max_val is not None and val > converter.max_val:
658            throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r'
659                             % (val, converter.attr, converter.max_val))
660        return val
661    def sql2py(converter, val):
662        return Decimal(val)
663    def sql_type(converter):
664        return 'DECIMAL(%d, %d)' % (converter.precision, converter.scale)
665
666class BlobConverter(Converter):
667    def validate(converter, val, obj=None):
668        if isinstance(val, buffer): return val
669        if isinstance(val, str): return buffer(val)
670        throw(TypeError, "Attribute %r: expected type is 'buffer'. Got: %r" % (converter.attr, type(val)))
671    def sql2py(converter, val):
672        if not isinstance(val, buffer):
673            try: val = buffer(val)
674            except: pass
675        elif PY2 and converter.attr is not None and converter.attr.is_part_of_unique_index:
676            try: hash(val)
677            except TypeError:
678                val = buffer(val)
679        return val
680    def sql_type(converter):
681        return 'BLOB'
682
683class DateConverter(Converter):
684    def validate(converter, val, obj=None):
685        if isinstance(val, datetime): return val.date()
686        if isinstance(val, date): return val
687        if isinstance(val, basestring): return str2date(val)
688        throw(TypeError, "Attribute %r: expected type is 'date'. Got: %r" % (converter.attr, val))
689    def sql2py(converter, val):
690        if not isinstance(val, date): throw(ValueError,
691            'Value of unexpected type received from database: instead of date got %s' % type(val))
692        return val
693    def sql_type(converter):
694        return 'DATE'
695
696class ConverterWithMicroseconds(Converter):
697    def __init__(converter, provider, py_type, attr=None):
698        converter.precision = None  # for the case when attr is None
699        Converter.__init__(converter, provider, py_type, attr)
700    def init(converter, kwargs):
701        attr = converter.attr
702        args = attr.args
703        if len(args) > 1: throw(TypeError, 'Too many positional parameters for attribute %s. '
704                                           'Expected: precision, got: %r' % (attr, args))
705        provider = attr.entity._database_.provider
706        if args:
707            precision = args[0]
708            if 'precision' in kwargs: throw(TypeError,
709                'Precision for attribute %s has both positional and keyword value' % attr)
710        else: precision = kwargs.pop('precision', provider.default_time_precision)
711        if not isinstance(precision, int) or not 0 <= precision <= 6: throw(ValueError,
712            'Precision value of attribute %s must be between 0 and 6. Got: %r' % (attr, precision))
713        if precision > provider.max_time_precision: throw(ValueError,
714            'Precision value (%d) of attribute %s exceeds max datetime precision (%d) of %s %s'
715            % (precision, attr, provider.max_time_precision, provider.dialect, provider.server_version))
716        converter.precision = precision
717    def round_microseconds_to_precision(converter, microseconds, precision):
718        # returns None if no change is required
719        if not precision: result = 0
720        elif precision < 6:
721            rounding = 10 ** (6-precision)
722            result = (microseconds // rounding) * rounding
723        else: return None
724        return result if result != microseconds else None
725    def sql_type(converter):
726        attr = converter.attr
727        precision = converter.precision
728        if not attr or precision == attr.entity._database_.provider.default_time_precision:
729            return converter.sql_type_name
730        return converter.sql_type_name + '(%d)' % precision
731
732class TimeConverter(ConverterWithMicroseconds):
733    sql_type_name = 'TIME'
734    def validate(converter, val, obj=None):
735        if isinstance(val, time): pass
736        elif isinstance(val, basestring): val = str2time(val)
737        else: throw(TypeError, "Attribute %r: expected type is 'time'. Got: %r" % (converter.attr, val))
738        mcs = converter.round_microseconds_to_precision(val.microsecond, converter.precision)
739        if mcs is not None: val = val.replace(microsecond=mcs)
740        return val
741    def sql2py(converter, val):
742        if not isinstance(val, time): throw(ValueError,
743            'Value of unexpected type received from database: instead of time got %s' % type(val))
744        return val
745
746class TimedeltaConverter(ConverterWithMicroseconds):
747    sql_type_name = 'INTERVAL'
748    def validate(converter, val, obj=None):
749        if isinstance(val, timedelta): pass
750        elif isinstance(val, basestring): val = str2timedelta(val)
751        else: throw(TypeError, "Attribute %r: expected type is 'timedelta'. Got: %r" % (converter.attr, val))
752        mcs = converter.round_microseconds_to_precision(val.microseconds, converter.precision)
753        if mcs is not None: val = timedelta(val.days, val.seconds, mcs)
754        return val
755    def sql2py(converter, val):
756        if not isinstance(val, timedelta): throw(ValueError,
757            'Value of unexpected type received from database: instead of time got %s' % type(val))
758        return val
759
760class DatetimeConverter(ConverterWithMicroseconds):
761    sql_type_name = 'DATETIME'
762    def validate(converter, val, obj=None):
763        if isinstance(val, datetime): pass
764        elif isinstance(val, basestring): val = str2datetime(val)
765        else: throw(TypeError, "Attribute %r: expected type is 'datetime'. Got: %r" % (converter.attr, val))
766        mcs = converter.round_microseconds_to_precision(val.microsecond, converter.precision)
767        if mcs is not None: val = val.replace(microsecond=mcs)
768        return val
769    def sql2py(converter, val):
770        if not isinstance(val, datetime): throw(ValueError,
771            'Value of unexpected type received from database: instead of datetime got %s' % type(val))
772        return val
773
774class UuidConverter(Converter):
775    def __init__(converter, provider, py_type, attr=None):
776        if attr is not None and attr.auto:
777            attr.auto = False
778            if not attr.default: attr.default = uuid4
779        Converter.__init__(converter, provider, py_type, attr)
780    def validate(converter, val, obj=None):
781        if isinstance(val, UUID): return val
782        if isinstance(val, buffer): return UUID(bytes=val)
783        if isinstance(val, basestring):
784            if len(val) == 16: return UUID(bytes=val)
785            return UUID(hex=val)
786        if isinstance(val, int): return UUID(int=val)
787        if converter.attr is not None:
788            throw(ValueError, 'Value type of attribute %s must be UUID. Got: %r'
789                               % (converter.attr, type(val)))
790        else: throw(ValueError, 'Expected UUID value, got: %r' % type(val))
791    def py2sql(converter, val):
792        return buffer(val.bytes)
793    sql2py = validate
794    def sql_type(converter):
795        return "UUID"
796
797class JsonConverter(Converter):
798    json_kwargs = {}
799    class JsonEncoder(json.JSONEncoder):
800        def default(converter, obj):
801            if isinstance(obj, Json):
802                return obj.wrapped
803            return json.JSONEncoder.default(converter, obj)
804    def validate(converter, val, obj=None):
805        if obj is None or converter.attr is None:
806            return val
807        if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr:
808            return val
809        return TrackedValue.make(obj, converter.attr, val)
810    def val2dbval(converter, val, obj=None):
811        return json.dumps(val, cls=converter.JsonEncoder, **converter.json_kwargs)
812    def dbval2val(converter, dbval, obj=None):
813        if isinstance(dbval, (int, bool, float, type(None))):
814            return dbval
815        val = json.loads(dbval)
816        if obj is None:
817            return val
818        return TrackedValue.make(obj, converter.attr, val)
819    def dbvals_equal(converter, x, y):
820        if x == y: return True  # optimization
821        if isinstance(x, basestring): x = json.loads(x)
822        if isinstance(y, basestring): y = json.loads(y)
823        return x == y
824    def sql_type(converter):
825        return "JSON"
826
827class ArrayConverter(Converter):
828    array_types = {
829        int: ('int', IntConverter),
830        unicode: ('text', StrConverter),
831        float: ('real', RealConverter)
832    }
833
834    def __init__(converter, provider, py_type, attr=None):
835        Converter.__init__(converter, provider, py_type, attr)
836        converter.item_converter = converter.array_types[converter.py_type.item_type][1]
837
838    def validate(converter, val, obj=None):
839        if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr:
840            return val
841
842        if isinstance(val, basestring) or not hasattr(val, '__len__'):
843            items = [val]
844        else:
845            items = list(val)
846        item_type = converter.py_type.item_type
847        if item_type == float:
848            item_type = (float, int)
849        for i, v in enumerate(items):
850            if PY2 and isinstance(v, str):
851                v = v.decode('ascii')
852            if not isinstance(v, item_type):
853                if hasattr(v, '__index__'):
854                    items[i] = v.__index__()
855                else:
856                    throw(TypeError, 'Cannot store %s item in array of %s' %
857                          (type(v).__name__, converter.py_type.item_type.__name__))
858
859        if obj is None or converter.attr is None:
860            return items
861        return TrackedArray(obj, converter.attr, items)
862
863    def dbval2val(converter, dbval, obj=None):
864        if obj is None or dbval is None:
865            return dbval
866        return TrackedArray(obj, converter.attr, dbval)
867
868    def val2dbval(converter, val, obj=None):
869        return list(val)
870
871    def sql_type(converter):
872        return '%s[]' % converter.array_types[converter.py_type.item_type][0]
873