1from __future__ import absolute_import, print_function, division 2from pony.py23compat import PY2, basestring, unicode, buffer, int_types, iteritems 3 4import os, re, json 5from decimal import Decimal, InvalidOperation 6from datetime import datetime, date, time, timedelta 7from uuid import uuid4, UUID 8 9import pony 10from pony.utils import is_utf8, decorator, throw, localbase, deprecated 11from pony.converting import str2date, str2time, str2datetime, str2timedelta 12from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, TrackedArray, Json, QueryType, Array 13 14class DBException(Exception): 15 def __init__(exc, original_exc, *args): 16 args = args or getattr(original_exc, 'args', ()) 17 Exception.__init__(exc, *args) 18 exc.original_exc = original_exc 19 20# Exception inheritance layout of DBAPI 2.0-compatible provider: 21# 22# Exception 23# Warning 24# Error 25# InterfaceError 26# DatabaseError 27# DataError 28# OperationalError 29# IntegrityError 30# InternalError 31# ProgrammingError 32# NotSupportedError 33 34class Warning(DBException): pass 35class Error(DBException): pass 36class InterfaceError(Error): pass 37class DatabaseError(Error): pass 38class DataError(DatabaseError): pass 39class OperationalError(DatabaseError): pass 40class IntegrityError(DatabaseError): pass 41class InternalError(DatabaseError): pass 42class ProgrammingError(DatabaseError): pass 43class NotSupportedError(DatabaseError): pass 44 45@decorator 46def wrap_dbapi_exceptions(func, provider, *args, **kwargs): 47 dbapi_module = provider.dbapi_module 48 should_retry = False 49 try: 50 try: 51 if provider.dialect != 'SQLite': 52 return func(provider, *args, **kwargs) 53 else: 54 provider.local_exceptions.keep_traceback = True 55 try: return func(provider, *args, **kwargs) 56 finally: provider.local_exceptions.keep_traceback = False 57 except dbapi_module.NotSupportedError as e: raise NotSupportedError(e) 58 except dbapi_module.ProgrammingError as e: 59 if provider.dialect == 'PostgreSQL': 60 msg = str(e) 61 if msg.startswith('operator does not exist:') and ' json ' in msg: 62 msg += ' (Note: use column type `jsonb` instead of `json`)' 63 raise ProgrammingError(e, msg, *e.args[1:]) 64 raise ProgrammingError(e) 65 except dbapi_module.InternalError as e: raise InternalError(e) 66 except dbapi_module.IntegrityError as e: raise IntegrityError(e) 67 except dbapi_module.OperationalError as e: 68 if provider.dialect == 'PostgreSQL' and e.pgcode == '40001': 69 should_retry = True 70 if provider.dialect == 'SQLite': 71 provider.restore_exception() 72 raise OperationalError(e) 73 except dbapi_module.DataError as e: raise DataError(e) 74 except dbapi_module.DatabaseError as e: raise DatabaseError(e) 75 except dbapi_module.InterfaceError as e: 76 if e.args == (0, '') and getattr(dbapi_module, '__name__', None) == 'MySQLdb': 77 throw(InterfaceError, e, 'MySQL server misconfiguration') 78 raise InterfaceError(e) 79 except dbapi_module.Error as e: raise Error(e) 80 except dbapi_module.Warning as e: raise Warning(e) 81 except Exception as e: 82 if should_retry: 83 e.should_retry = True 84 raise 85 86def unexpected_args(attr, args): 87 throw(TypeError, 'Unexpected positional argument{} for attribute {}: {}'.format( 88 len(args) > 1 and 's' or '', attr, ', '.join(repr(arg) for arg in args)) 89 ) 90 91version_re = re.compile('[0-9\.]+') 92 93def get_version_tuple(s): 94 m = version_re.match(s) 95 if m is not None: 96 components = m.group(0).split('.') 97 return tuple(int(component) for component in components) 98 return None 99 100class DBAPIProvider(object): 101 paramstyle = 'qmark' 102 quote_char = '"' 103 max_params_count = 999 104 max_name_len = 128 105 table_if_not_exists_syntax = True 106 index_if_not_exists_syntax = True 107 max_time_precision = default_time_precision = 6 108 uint64_support = False 109 110 # SQLite and PostgreSQL does not limit varchar max length. 111 varchar_default_max_len = None 112 113 dialect = None 114 dbapi_module = None 115 dbschema_cls = None 116 translator_cls = None 117 sqlbuilder_cls = None 118 array_converter_cls = None 119 120 name_before_table = 'schema_name' 121 default_schema_name = None 122 123 fk_types = { 'SERIAL' : 'INTEGER', 'BIGSERIAL' : 'BIGINT' } 124 125 def __init__(provider, *args, **kwargs): 126 pool_mockup = kwargs.pop('pony_pool_mockup', None) 127 call_on_connect = kwargs.pop('pony_call_on_connect', None) 128 if pool_mockup: provider.pool = pool_mockup 129 else: provider.pool = provider.get_pool(*args, **kwargs) 130 connection, is_new_connection = provider.connect() 131 if call_on_connect: 132 call_on_connect(connection) 133 provider.inspect_connection(connection) 134 provider.release(connection) 135 136 @wrap_dbapi_exceptions 137 def inspect_connection(provider, connection): 138 pass 139 140 def normalize_name(provider, name): 141 return name[:provider.max_name_len] 142 143 def get_default_entity_table_name(provider, entity): 144 return provider.normalize_name(entity.__name__) 145 146 def get_default_m2m_table_name(provider, attr, reverse): 147 if attr.symmetric: 148 assert reverse is attr 149 name = attr.entity.__name__ + '_' + attr.name 150 else: 151 name = attr.entity.__name__ + '_' + reverse.entity.__name__ 152 return provider.normalize_name(name) 153 154 def get_default_column_names(provider, attr, reverse_pk_columns=None): 155 normalize_name = provider.normalize_name 156 if reverse_pk_columns is None: 157 return [ normalize_name(attr.name) ] 158 elif len(reverse_pk_columns) == 1: 159 return [ normalize_name(attr.name) ] 160 else: 161 prefix = attr.name + '_' 162 return [ normalize_name(prefix + column) for column in reverse_pk_columns ] 163 164 def get_default_m2m_column_names(provider, entity): 165 normalize_name = provider.normalize_name 166 columns = entity._get_pk_columns_() 167 if len(columns) == 1: 168 return [ normalize_name(entity.__name__.lower()) ] 169 else: 170 prefix = entity.__name__.lower() + '_' 171 return [ normalize_name(prefix + column) for column in columns ] 172 173 def get_default_index_name(provider, table_name, column_names, is_pk=False, is_unique=False, m2m=False): 174 if is_pk: index_name = 'pk_%s' % provider.base_name(table_name) 175 else: 176 if is_unique: template = 'unq_%(tname)s__%(cnames)s' 177 elif m2m: template = 'idx_%(tname)s' 178 else: template = 'idx_%(tname)s__%(cnames)s' 179 index_name = template % dict(tname=provider.base_name(table_name), 180 cnames='_'.join(name for name in column_names)) 181 return provider.normalize_name(index_name.lower()) 182 183 def get_default_fk_name(provider, child_table_name, parent_table_name, child_column_names): 184 fk_name = 'fk_%s__%s' % (provider.base_name(child_table_name), '__'.join(child_column_names)) 185 return provider.normalize_name(fk_name.lower()) 186 187 def split_table_name(provider, table_name): 188 if isinstance(table_name, basestring): return provider.default_schema_name, table_name 189 if not table_name: throw(TypeError, 'Invalid table name: %r' % table_name) 190 if len(table_name) != 2: 191 size = len(table_name) 192 throw(TypeError, '%s qualified table name must have two components: ' 193 '%s and table_name. Got %d component%s: %s' 194 % (provider.dialect, provider.name_before_table, 195 size, 's' if size != 1 else '', table_name)) 196 return table_name[0], table_name[1] 197 198 def base_name(provider, name): 199 if not isinstance(name, basestring): 200 assert type(name) is tuple 201 name = name[-1] 202 assert isinstance(name, basestring) 203 return name 204 205 def quote_name(provider, name): 206 quote_char = provider.quote_char 207 if isinstance(name, basestring): 208 name = name.replace(quote_char, quote_char+quote_char) 209 return quote_char + name + quote_char 210 return '.'.join(provider.quote_name(item) for item in name) 211 212 def format_table_name(provider, name): 213 return provider.quote_name(name) 214 215 def normalize_vars(provider, vars, vartypes): 216 for key, value in iteritems(vars): 217 vartype = vartypes[key] 218 if isinstance(vartype, QueryType): 219 vartypes[key], vars[key] = value._normalize_var(vartype) 220 221 def ast2sql(provider, ast): 222 builder = provider.sqlbuilder_cls(provider, ast) 223 return builder.sql, builder.adapter 224 225 def should_reconnect(provider, exc): 226 return False 227 228 @wrap_dbapi_exceptions 229 def connect(provider): 230 return provider.pool.connect() 231 232 @wrap_dbapi_exceptions 233 def set_transaction_mode(provider, connection, cache): 234 pass 235 236 @wrap_dbapi_exceptions 237 def commit(provider, connection, cache=None): 238 core = pony.orm.core 239 if core.local.debug: core.log_orm('COMMIT') 240 connection.commit() 241 if cache is not None: cache.in_transaction = False 242 243 @wrap_dbapi_exceptions 244 def rollback(provider, connection, cache=None): 245 core = pony.orm.core 246 if core.local.debug: core.log_orm('ROLLBACK') 247 connection.rollback() 248 if cache is not None: cache.in_transaction = False 249 250 @wrap_dbapi_exceptions 251 def release(provider, connection, cache=None): 252 core = pony.orm.core 253 if cache is not None and cache.db_session is not None and cache.db_session.ddl: 254 provider.drop(connection, cache) 255 else: 256 if core.local.debug: core.log_orm('RELEASE CONNECTION') 257 provider.pool.release(connection) 258 259 @wrap_dbapi_exceptions 260 def drop(provider, connection, cache=None): 261 core = pony.orm.core 262 if core.local.debug: core.log_orm('CLOSE CONNECTION') 263 provider.pool.drop(connection) 264 if cache is not None: cache.in_transaction = False 265 266 @wrap_dbapi_exceptions 267 def disconnect(provider): 268 core = pony.orm.core 269 if core.local.debug: core.log_orm('DISCONNECT') 270 provider.pool.disconnect() 271 272 @wrap_dbapi_exceptions 273 def execute(provider, cursor, sql, arguments=None, returning_id=False): 274 if type(arguments) is list: 275 assert arguments and not returning_id 276 cursor.executemany(sql, arguments) 277 else: 278 if arguments is None: cursor.execute(sql) 279 else: cursor.execute(sql, arguments) 280 if returning_id: return cursor.lastrowid 281 282 converter_classes = [] 283 284 def _get_converter_type_by_py_type(provider, py_type): 285 if isinstance(py_type, type): 286 for t, converter_cls in provider.converter_classes: 287 if issubclass(py_type, t): return converter_cls 288 if issubclass(py_type, Array): 289 converter_cls = provider.array_converter_cls 290 if converter_cls is None: 291 throw(NotImplementedError, 'Array type is not supported for %r' % provider.dialect) 292 return converter_cls 293 if isinstance(py_type, RawSQLType): 294 return Converter # for cases like select(raw_sql(...) for x in X) 295 throw(TypeError, 'No database converter found for type %s' % py_type) 296 297 def get_converter_by_py_type(provider, py_type): 298 converter_cls = provider._get_converter_type_by_py_type(py_type) 299 return converter_cls(provider, py_type) 300 301 def get_converter_by_attr(provider, attr): 302 py_type = attr.py_type 303 converter_cls = provider._get_converter_type_by_py_type(py_type) 304 return converter_cls(provider, py_type, attr) 305 306 def get_pool(provider, *args, **kwargs): 307 return Pool(provider.dbapi_module, *args, **kwargs) 308 309 def table_exists(provider, connection, table_name, case_sensitive=True): 310 throw(NotImplementedError) 311 312 def index_exists(provider, connection, table_name, index_name, case_sensitive=True): 313 throw(NotImplementedError) 314 315 def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): 316 throw(NotImplementedError) 317 318 def table_has_data(provider, connection, table_name): 319 cursor = connection.cursor() 320 cursor.execute('SELECT 1 FROM %s LIMIT 1' % provider.quote_name(table_name)) 321 return cursor.fetchone() is not None 322 323 def disable_fk_checks(provider, connection): 324 pass 325 326 def enable_fk_checks(provider, connection, prev_state): 327 pass 328 329 def drop_table(provider, connection, table_name): 330 cursor = connection.cursor() 331 sql = 'DROP TABLE %s' % provider.quote_name(table_name) 332 cursor.execute(sql) 333 334class Pool(localbase): 335 forked_connections = [] 336 def __init__(pool, dbapi_module, *args, **kwargs): # called separately in each thread 337 pool.dbapi_module = dbapi_module 338 pool.args = args 339 pool.kwargs = kwargs 340 pool.con = pool.pid = None 341 def connect(pool): 342 pid = os.getpid() 343 if pool.con is not None and pool.pid != pid: 344 pool.forked_connections.append((pool.con, pool.pid)) 345 pool.con = pool.pid = None 346 core = pony.orm.core 347 is_new_connection = False 348 if pool.con is None: 349 if core.local.debug: core.log_orm('GET NEW CONNECTION') 350 is_new_connection = True 351 pool._connect() 352 pool.pid = pid 353 elif core.local.debug: 354 core.log_orm('GET CONNECTION FROM THE LOCAL POOL') 355 return pool.con, is_new_connection 356 def _connect(pool): 357 pool.con = pool.dbapi_module.connect(*pool.args, **pool.kwargs) 358 def release(pool, con): 359 assert con is pool.con 360 try: con.rollback() 361 except: 362 pool.drop(con) 363 raise 364 def drop(pool, con): 365 assert con is pool.con, (con, pool.con) 366 pool.con = None 367 con.close() 368 def disconnect(pool): 369 con = pool.con 370 pool.con = None 371 if con is not None: con.close() 372 373class Converter(object): 374 EQ = 'EQ' 375 NE = 'NE' 376 optimistic = True 377 def __deepcopy__(converter, memo): 378 return converter # Converter instances are "immutable" 379 def __init__(converter, provider, py_type, attr=None): 380 converter.provider = provider 381 converter.py_type = py_type 382 converter.attr = attr 383 if attr is None: return 384 kwargs = attr.kwargs.copy() 385 converter.init(kwargs) 386 for option in kwargs: throw(TypeError, 'Attribute %s has unknown option %r' % (attr, option)) 387 def init(converter, kwargs): 388 attr = converter.attr 389 if attr and attr.args: unexpected_args(attr, attr.args) 390 def validate(converter, val, obj=None): 391 return val 392 def py2sql(converter, val): 393 return val 394 def sql2py(converter, val): 395 return val 396 def val2dbval(self, val, obj=None): 397 return val 398 def dbval2val(self, dbval, obj=None): 399 return dbval 400 def dbvals_equal(self, x, y): 401 return x == y 402 def get_sql_type(converter, attr=None): 403 if attr is not None and attr.sql_type is not None: 404 return attr.sql_type 405 attr = converter.attr 406 if attr.sql_type is not None: 407 assert len(attr.columns) == 1 408 return converter.get_fk_type(attr.sql_type) 409 if attr is not None and attr.reverse and not attr.is_collection: 410 i = attr.converters.index(converter) 411 rentity = attr.reverse.entity 412 rpk_converters = rentity._pk_converters_ 413 assert rpk_converters is not None and len(attr.converters) == len(rpk_converters) 414 rconverter = rpk_converters[i] 415 return rconverter.sql_type() 416 return converter.sql_type() 417 def get_fk_type(converter, sql_type): 418 fk_types = converter.provider.fk_types 419 if sql_type.isupper(): return fk_types.get(sql_type, sql_type) 420 sql_type = sql_type.upper() 421 return fk_types.get(sql_type, sql_type).lower() 422 423class NoneConverter(Converter): # used for raw_sql() parameters only 424 def __init__(converter, provider, py_type, attr=None): 425 if attr is not None: throw(TypeError, 'Attribute %s has invalid type NoneType' % attr) 426 Converter.__init__(converter, provider, py_type) 427 def get_sql_type(converter, attr=None): 428 assert False 429 def get_fk_type(converter, sql_type): 430 assert False 431 432class BoolConverter(Converter): 433 def validate(converter, val, obj=None): 434 return bool(val) 435 def sql2py(converter, val): 436 return bool(val) 437 def sql_type(converter): 438 return "BOOLEAN" 439 440class StrConverter(Converter): 441 def __init__(converter, provider, py_type, attr=None): 442 converter.max_len = None 443 converter.db_encoding = None 444 Converter.__init__(converter, provider, py_type, attr) 445 def init(converter, kwargs): 446 attr = converter.attr 447 max_len = kwargs.pop('max_len', None) 448 if len(attr.args) > 1: unexpected_args(attr, attr.args[1:]) 449 elif attr.args: 450 if max_len is not None: throw(TypeError, 451 'Max length option specified twice: as a positional argument and as a `max_len` named argument') 452 max_len = attr.args[0] 453 if issubclass(attr.py_type, (LongStr, LongUnicode)): 454 if max_len is not None: throw(TypeError, 'Max length is not supported for CLOBs') 455 elif max_len is None: max_len = converter.provider.varchar_default_max_len 456 elif not isinstance(max_len, int_types): 457 throw(TypeError, 'Max length argument must be int. Got: %r' % max_len) 458 converter.max_len = max_len 459 converter.db_encoding = kwargs.pop('db_encoding', None) 460 converter.autostrip = kwargs.pop('autostrip', True) 461 def validate(converter, val, obj=None): 462 if PY2 and isinstance(val, str): val = val.decode('ascii') 463 elif not isinstance(val, unicode): throw(TypeError, 464 'Value type for attribute %s must be %s. Got: %r' % (converter.attr, unicode.__name__, type(val))) 465 if converter.autostrip: val = val.strip() 466 max_len = converter.max_len 467 val_len = len(val) 468 if max_len and val_len > max_len: 469 throw(ValueError, 'Value for attribute %s is too long. Max length is %d, value length is %d' 470 % (converter.attr, max_len, val_len)) 471 return val 472 def sql_type(converter): 473 if converter.max_len: 474 return 'VARCHAR(%d)' % converter.max_len 475 return 'TEXT' 476 477class IntConverter(Converter): 478 signed_types = {None: 'INTEGER', 8: 'TINYINT', 16: 'SMALLINT', 24: 'MEDIUMINT', 32: 'INTEGER', 64: 'BIGINT'} 479 unsigned_types = None 480 def init(converter, kwargs): 481 Converter.init(converter, kwargs) 482 attr = converter.attr 483 484 min_val = kwargs.pop('min', None) 485 if min_val is not None and not isinstance(min_val, int_types): 486 throw(TypeError, "'min' argument for attribute %s must be int. Got: %r" % (attr, min_val)) 487 488 max_val = kwargs.pop('max', None) 489 if max_val is not None and not isinstance(max_val, int_types): 490 throw(TypeError, "'max' argument for attribute %s must be int. Got: %r" % (attr, max_val)) 491 492 size = kwargs.pop('size', None) 493 if size is None: 494 if attr.py_type.__name__ == 'long': 495 deprecated(9, "Attribute %s: 'long' attribute type is deprecated. " 496 "Please use 'int' type with size=64 option instead" % attr) 497 attr.py_type = int 498 size = 64 499 elif attr.py_type.__name__ == 'long': throw(TypeError, 500 "Attribute %s: 'size' option cannot be used with long type. Please use int type instead" % attr) 501 elif not isinstance(size, int_types): 502 throw(TypeError, "'size' option for attribute %s must be of int type. Got: %r" % (attr, size)) 503 elif size not in (8, 16, 24, 32, 64): 504 throw(TypeError, "incorrect value of 'size' option for attribute %s. " 505 "Should be 8, 16, 24, 32 or 64. Got: %d" % (attr, size)) 506 507 unsigned = kwargs.pop('unsigned', False) 508 if unsigned is not None and not isinstance(unsigned, bool): 509 throw(TypeError, "'unsigned' option for attribute %s must be of bool type. Got: %r" % (attr, unsigned)) 510 511 if size == 64 and unsigned and not converter.provider.uint64_support: throw(TypeError, 512 'Attribute %s: %s provider does not support unsigned bigint type' % (attr, converter.provider.dialect)) 513 514 if unsigned is not None and size is None: size = 32 515 lowest = highest = None 516 if size: 517 highest = highest = 2 ** size - 1 if unsigned else 2 ** (size - 1) - 1 518 lowest = 0 if unsigned else -(2 ** (size - 1)) 519 520 if highest is not None and max_val is not None and max_val > highest: 521 throw(ValueError, "'max' argument should be less or equal to %d because of size=%d and unsigned=%s. " 522 "Got: %d" % (highest, size, max_val, unsigned)) 523 524 if lowest is not None and min_val is not None and min_val < lowest: 525 throw(ValueError, "'min' argument should be greater or equal to %d because of size=%d and unsigned=%s. " 526 "Got: %d" % (lowest, size, min_val, unsigned)) 527 528 converter.min_val = min_val or lowest 529 converter.max_val = max_val or highest 530 converter.size = size 531 converter.unsigned = unsigned 532 def validate(converter, val, obj=None): 533 if isinstance(val, int_types): pass 534 elif hasattr(val, '__index__'): 535 val = val.__index__() 536 elif isinstance(val, basestring): 537 try: val = int(val) 538 except ValueError: throw(ValueError, 539 'Value type for attribute %s must be int. Got string %r' % (converter.attr, val)) 540 else: throw(TypeError, 'Value type for attribute %s must be int. Got: %r' % (converter.attr, type(val))) 541 542 if converter.min_val and val < converter.min_val: 543 throw(ValueError, 'Value %r of attr %s is less than the minimum allowed value %r' 544 % (val, converter.attr, converter.min_val)) 545 if converter.max_val and val > converter.max_val: 546 throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r' 547 % (val, converter.attr, converter.max_val)) 548 return val 549 def sql2py(converter, val): 550 return int(val) 551 def sql_type(converter): 552 if not converter.unsigned: 553 return converter.signed_types.get(converter.size) 554 if converter.unsigned_types is None: 555 return converter.signed_types.get(converter.size) + ' UNSIGNED' 556 return converter.unsigned_types.get(converter.size) 557 558class RealConverter(Converter): 559 EQ = 'FLOAT_EQ' 560 NE = 'FLOAT_NE' 561 # The tolerance is necessary for Oracle, because it has different representation of float numbers. 562 # For other databases the default tolerance is set because the precision can be lost during 563 # Python -> JavaScript -> Python conversion 564 default_tolerance = 1e-14 565 optimistic = False 566 def init(converter, kwargs): 567 Converter.init(converter, kwargs) 568 min_val = kwargs.pop('min', None) 569 if min_val is not None: 570 try: min_val = float(min_val) 571 except ValueError: 572 throw(TypeError, "Invalid value for 'min' argument for attribute %s: %r" % (converter.attr, min_val)) 573 max_val = kwargs.pop('max', None) 574 if max_val is not None: 575 try: max_val = float(max_val) 576 except ValueError: 577 throw(TypeError, "Invalid value for 'max' argument for attribute %s: %r" % (converter.attr, max_val)) 578 converter.min_val = min_val 579 converter.max_val = max_val 580 converter.tolerance = kwargs.pop('tolerance', converter.default_tolerance) 581 def validate(converter, val, obj=None): 582 try: val = float(val) 583 except ValueError: 584 throw(TypeError, 'Invalid value for attribute %s: %r' % (converter.attr, val)) 585 if converter.min_val and val < converter.min_val: 586 throw(ValueError, 'Value %r of attr %s is less than the minimum allowed value %r' 587 % (val, converter.attr, converter.min_val)) 588 if converter.max_val and val > converter.max_val: 589 throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r' 590 % (val, converter.attr, converter.max_val)) 591 return val 592 def dbvals_equal(converter, x, y): 593 tolerance = converter.tolerance 594 if tolerance is None or x is None or y is None: return x == y 595 denominator = max(abs(x), abs(y)) 596 if not denominator: return True 597 diff = abs(x-y) / denominator 598 return diff <= tolerance 599 def sql2py(converter, val): 600 return float(val) 601 def sql_type(converter): 602 return 'REAL' 603 604class DecimalConverter(Converter): 605 def __init__(converter, provider, py_type, attr=None): 606 converter.exp = None # for the case when attr is None 607 Converter.__init__(converter, provider, py_type, attr) 608 def init(converter, kwargs): 609 attr = converter.attr 610 args = attr.args 611 if len(args) > 2: throw(TypeError, 'Too many positional parameters for Decimal ' 612 '(expected: precision and scale), got: %s' % args) 613 if args: precision = args[0] 614 else: precision = kwargs.pop('precision', 12) 615 if not isinstance(precision, int_types): 616 throw(TypeError, "'precision' positional argument for attribute %s must be int. Got: %r" % (attr, precision)) 617 if precision <= 0: throw(TypeError, 618 "'precision' positional argument for attribute %s must be positive. Got: %r" % (attr, precision)) 619 620 if len(args) == 2: scale = args[1] 621 else: scale = kwargs.pop('scale', 2) 622 if not isinstance(scale, int_types): 623 throw(TypeError, "'scale' positional argument for attribute %s must be int. Got: %r" % (attr, scale)) 624 if scale <= 0: throw(TypeError, 625 "'scale' positional argument for attribute %s must be positive. Got: %r" % (attr, scale)) 626 627 if scale > precision: throw(ValueError, "'scale' must be less or equal 'precision'") 628 converter.precision = precision 629 converter.scale = scale 630 converter.exp = Decimal(10) ** -scale 631 632 min_val = kwargs.pop('min', None) 633 if min_val is not None: 634 try: min_val = Decimal(min_val) 635 except TypeError: throw(TypeError, 636 "Invalid value for 'min' argument for attribute %s: %r" % (attr, min_val)) 637 638 max_val = kwargs.pop('max', None) 639 if max_val is not None: 640 try: max_val = Decimal(max_val) 641 except TypeError: throw(TypeError, 642 "Invalid value for 'max' argument for attribute %s: %r" % (attr, max_val)) 643 644 converter.min_val = min_val 645 converter.max_val = max_val 646 def validate(converter, val, obj=None): 647 if isinstance(val, float): 648 s = str(val) 649 if float(s) != val: s = repr(val) 650 val = Decimal(s) 651 try: val = Decimal(val) 652 except InvalidOperation as exc: 653 throw(TypeError, 'Invalid value for attribute %s: %r' % (converter.attr, val)) 654 if converter.min_val is not None and val < converter.min_val: 655 throw(ValueError, 'Value %r of attr %s is less than the minimum allowed value %r' 656 % (val, converter.attr, converter.min_val)) 657 if converter.max_val is not None and val > converter.max_val: 658 throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r' 659 % (val, converter.attr, converter.max_val)) 660 return val 661 def sql2py(converter, val): 662 return Decimal(val) 663 def sql_type(converter): 664 return 'DECIMAL(%d, %d)' % (converter.precision, converter.scale) 665 666class BlobConverter(Converter): 667 def validate(converter, val, obj=None): 668 if isinstance(val, buffer): return val 669 if isinstance(val, str): return buffer(val) 670 throw(TypeError, "Attribute %r: expected type is 'buffer'. Got: %r" % (converter.attr, type(val))) 671 def sql2py(converter, val): 672 if not isinstance(val, buffer): 673 try: val = buffer(val) 674 except: pass 675 elif PY2 and converter.attr is not None and converter.attr.is_part_of_unique_index: 676 try: hash(val) 677 except TypeError: 678 val = buffer(val) 679 return val 680 def sql_type(converter): 681 return 'BLOB' 682 683class DateConverter(Converter): 684 def validate(converter, val, obj=None): 685 if isinstance(val, datetime): return val.date() 686 if isinstance(val, date): return val 687 if isinstance(val, basestring): return str2date(val) 688 throw(TypeError, "Attribute %r: expected type is 'date'. Got: %r" % (converter.attr, val)) 689 def sql2py(converter, val): 690 if not isinstance(val, date): throw(ValueError, 691 'Value of unexpected type received from database: instead of date got %s' % type(val)) 692 return val 693 def sql_type(converter): 694 return 'DATE' 695 696class ConverterWithMicroseconds(Converter): 697 def __init__(converter, provider, py_type, attr=None): 698 converter.precision = None # for the case when attr is None 699 Converter.__init__(converter, provider, py_type, attr) 700 def init(converter, kwargs): 701 attr = converter.attr 702 args = attr.args 703 if len(args) > 1: throw(TypeError, 'Too many positional parameters for attribute %s. ' 704 'Expected: precision, got: %r' % (attr, args)) 705 provider = attr.entity._database_.provider 706 if args: 707 precision = args[0] 708 if 'precision' in kwargs: throw(TypeError, 709 'Precision for attribute %s has both positional and keyword value' % attr) 710 else: precision = kwargs.pop('precision', provider.default_time_precision) 711 if not isinstance(precision, int) or not 0 <= precision <= 6: throw(ValueError, 712 'Precision value of attribute %s must be between 0 and 6. Got: %r' % (attr, precision)) 713 if precision > provider.max_time_precision: throw(ValueError, 714 'Precision value (%d) of attribute %s exceeds max datetime precision (%d) of %s %s' 715 % (precision, attr, provider.max_time_precision, provider.dialect, provider.server_version)) 716 converter.precision = precision 717 def round_microseconds_to_precision(converter, microseconds, precision): 718 # returns None if no change is required 719 if not precision: result = 0 720 elif precision < 6: 721 rounding = 10 ** (6-precision) 722 result = (microseconds // rounding) * rounding 723 else: return None 724 return result if result != microseconds else None 725 def sql_type(converter): 726 attr = converter.attr 727 precision = converter.precision 728 if not attr or precision == attr.entity._database_.provider.default_time_precision: 729 return converter.sql_type_name 730 return converter.sql_type_name + '(%d)' % precision 731 732class TimeConverter(ConverterWithMicroseconds): 733 sql_type_name = 'TIME' 734 def validate(converter, val, obj=None): 735 if isinstance(val, time): pass 736 elif isinstance(val, basestring): val = str2time(val) 737 else: throw(TypeError, "Attribute %r: expected type is 'time'. Got: %r" % (converter.attr, val)) 738 mcs = converter.round_microseconds_to_precision(val.microsecond, converter.precision) 739 if mcs is not None: val = val.replace(microsecond=mcs) 740 return val 741 def sql2py(converter, val): 742 if not isinstance(val, time): throw(ValueError, 743 'Value of unexpected type received from database: instead of time got %s' % type(val)) 744 return val 745 746class TimedeltaConverter(ConverterWithMicroseconds): 747 sql_type_name = 'INTERVAL' 748 def validate(converter, val, obj=None): 749 if isinstance(val, timedelta): pass 750 elif isinstance(val, basestring): val = str2timedelta(val) 751 else: throw(TypeError, "Attribute %r: expected type is 'timedelta'. Got: %r" % (converter.attr, val)) 752 mcs = converter.round_microseconds_to_precision(val.microseconds, converter.precision) 753 if mcs is not None: val = timedelta(val.days, val.seconds, mcs) 754 return val 755 def sql2py(converter, val): 756 if not isinstance(val, timedelta): throw(ValueError, 757 'Value of unexpected type received from database: instead of time got %s' % type(val)) 758 return val 759 760class DatetimeConverter(ConverterWithMicroseconds): 761 sql_type_name = 'DATETIME' 762 def validate(converter, val, obj=None): 763 if isinstance(val, datetime): pass 764 elif isinstance(val, basestring): val = str2datetime(val) 765 else: throw(TypeError, "Attribute %r: expected type is 'datetime'. Got: %r" % (converter.attr, val)) 766 mcs = converter.round_microseconds_to_precision(val.microsecond, converter.precision) 767 if mcs is not None: val = val.replace(microsecond=mcs) 768 return val 769 def sql2py(converter, val): 770 if not isinstance(val, datetime): throw(ValueError, 771 'Value of unexpected type received from database: instead of datetime got %s' % type(val)) 772 return val 773 774class UuidConverter(Converter): 775 def __init__(converter, provider, py_type, attr=None): 776 if attr is not None and attr.auto: 777 attr.auto = False 778 if not attr.default: attr.default = uuid4 779 Converter.__init__(converter, provider, py_type, attr) 780 def validate(converter, val, obj=None): 781 if isinstance(val, UUID): return val 782 if isinstance(val, buffer): return UUID(bytes=val) 783 if isinstance(val, basestring): 784 if len(val) == 16: return UUID(bytes=val) 785 return UUID(hex=val) 786 if isinstance(val, int): return UUID(int=val) 787 if converter.attr is not None: 788 throw(ValueError, 'Value type of attribute %s must be UUID. Got: %r' 789 % (converter.attr, type(val))) 790 else: throw(ValueError, 'Expected UUID value, got: %r' % type(val)) 791 def py2sql(converter, val): 792 return buffer(val.bytes) 793 sql2py = validate 794 def sql_type(converter): 795 return "UUID" 796 797class JsonConverter(Converter): 798 json_kwargs = {} 799 class JsonEncoder(json.JSONEncoder): 800 def default(converter, obj): 801 if isinstance(obj, Json): 802 return obj.wrapped 803 return json.JSONEncoder.default(converter, obj) 804 def validate(converter, val, obj=None): 805 if obj is None or converter.attr is None: 806 return val 807 if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr: 808 return val 809 return TrackedValue.make(obj, converter.attr, val) 810 def val2dbval(converter, val, obj=None): 811 return json.dumps(val, cls=converter.JsonEncoder, **converter.json_kwargs) 812 def dbval2val(converter, dbval, obj=None): 813 if isinstance(dbval, (int, bool, float, type(None))): 814 return dbval 815 val = json.loads(dbval) 816 if obj is None: 817 return val 818 return TrackedValue.make(obj, converter.attr, val) 819 def dbvals_equal(converter, x, y): 820 if x == y: return True # optimization 821 if isinstance(x, basestring): x = json.loads(x) 822 if isinstance(y, basestring): y = json.loads(y) 823 return x == y 824 def sql_type(converter): 825 return "JSON" 826 827class ArrayConverter(Converter): 828 array_types = { 829 int: ('int', IntConverter), 830 unicode: ('text', StrConverter), 831 float: ('real', RealConverter) 832 } 833 834 def __init__(converter, provider, py_type, attr=None): 835 Converter.__init__(converter, provider, py_type, attr) 836 converter.item_converter = converter.array_types[converter.py_type.item_type][1] 837 838 def validate(converter, val, obj=None): 839 if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr: 840 return val 841 842 if isinstance(val, basestring) or not hasattr(val, '__len__'): 843 items = [val] 844 else: 845 items = list(val) 846 item_type = converter.py_type.item_type 847 if item_type == float: 848 item_type = (float, int) 849 for i, v in enumerate(items): 850 if PY2 and isinstance(v, str): 851 v = v.decode('ascii') 852 if not isinstance(v, item_type): 853 if hasattr(v, '__index__'): 854 items[i] = v.__index__() 855 else: 856 throw(TypeError, 'Cannot store %s item in array of %s' % 857 (type(v).__name__, converter.py_type.item_type.__name__)) 858 859 if obj is None or converter.attr is None: 860 return items 861 return TrackedArray(obj, converter.attr, items) 862 863 def dbval2val(converter, dbval, obj=None): 864 if obj is None or dbval is None: 865 return dbval 866 return TrackedArray(obj, converter.attr, dbval) 867 868 def val2dbval(converter, val, obj=None): 869 return list(val) 870 871 def sql_type(converter): 872 return '%s[]' % converter.array_types[converter.py_type.item_type][0] 873