1# -*- coding: utf-8 -*- 2""" 3 pyrseas.dbobject.function 4 ~~~~~~~~~~~~~~~~~~~~~~~~~ 5 6 This module defines four classes: Proc derived from 7 DbSchemaObject, Function and Aggregate derived from Proc, and 8 FunctionDict derived from DbObjectDict. 9""" 10from pyrseas.lib.pycompat import PY2 11from pyrseas.yamlutil import MultiLineStr 12from . import DbObjectDict, DbSchemaObject 13from . import commentable, ownable, grantable, split_schema_obj 14 15VOLATILITY_TYPES = {'i': 'immutable', 's': 'stable', 'v': 'volatile'} 16PARALLEL_SAFETY = {'r': 'restricted', 's': 'safe', 'u': 'unsafe'} 17 18 19def split_schema_func(schema, func): 20 """Split a function related to an object from its schema 21 22 :param schema: schema to which the main object belongs 23 :param func: possibly qualified function name 24 :returns: a schema, function tuple, or just the unqualified function name 25 """ 26 (sch, fnc) = split_schema_obj(func, schema) 27 if sch != schema: 28 return (sch, fnc) 29 else: 30 return fnc 31 32 33def join_schema_func(func): 34 """Join the schema and function, if needed, to form a qualified name 35 36 :param func: a schema, function tuple, or just an unqualified function name 37 :returns: a possibly-qualified schema.function string 38 """ 39 if isinstance(func, tuple): 40 return "%s.%s" % func 41 else: 42 return func 43 44 45class Proc(DbSchemaObject): 46 """A procedure such as a FUNCTION or an AGGREGATE""" 47 48 keylist = ['schema', 'name', 'arguments'] 49 catalog = 'pg_proc' 50 51 @property 52 def allprivs(self): 53 return 'X' 54 55 def __init__(self, name, schema, description, owner, privileges, 56 arguments): 57 """Initialize the procedure 58 59 :param name: function name (from proname) 60 :param schema: schema name (from pronamespace) 61 :param description: comment text (from obj_description()) 62 :param owner: owner name (from rolname via proowner) 63 :param privileges: access privileges (from proacl) 64 :param arguments: argument list (without default values, from 65 pg_function_identity_arguments) 66 """ 67 super(Proc, self).__init__(name, schema, description) 68 self._init_own_privs(owner, privileges) 69 self.arguments = arguments 70 71 def extern_key(self): 72 """Return the key to be used in external maps for this function 73 74 :return: string 75 """ 76 return '%s %s(%s)' % (self.objtype.lower(), self.name, self.arguments) 77 78 def identifier(self): 79 """Return a full identifier for a function object 80 81 :return: string 82 """ 83 return "%s(%s)" % (self.qualname(), self.arguments) 84 85 def get_implied_deps(self, db): 86 # List the previous dependencies 87 deps = super(Proc, self).get_implied_deps(db) 88 89 # Add back the language 90 if isinstance(self, Function) and getattr(self, 'language', None): 91 lang = db.languages.get(self.language) 92 if lang: 93 deps.add(lang) 94 95 # Add back the types 96 if self.arguments: 97 for arg in self.arguments.split(', '): 98 arg = db.find_type(arg.split()[-1]) 99 if arg is not None: 100 deps.add(arg) 101 102 return deps 103 104 105class Function(Proc): 106 """A procedural language function""" 107 108 def __init__(self, name, schema, description, owner, privileges, 109 arguments, language, returns, source, obj_file=None, 110 configuration=None, volatility=None, leakproof=False, 111 strict=False, security_definer=False, cost=0, rows=0, 112 allargs=None, oid=None): 113 """Initialize the function 114 115 :param name-arguments: see Proc.__init__ params 116 :param language: implementation language (from prolang) 117 :param returns: return type (from pg_get_function_result/prorettype) 118 :param source: source code, link symbol, etc. (from prosrc) 119 :param obj_file: language-specific info (from probin) 120 :param configuration: configuration variables (from proconfig) 121 :param volatility: volatility type (from provolatile) 122 :param leakproof: has side effects (from proleakproof) 123 :param strict: null handling (from proisstrict) 124 :param security_definer: security definer (from prosecdef) 125 :param cost: execution cost estimate (from procost) 126 :param rows: result row estimate (from prorows) 127 :param allargs: argument list with defaults (from 128 pg_get_function_arguments) 129 """ 130 super(Function, self).__init__( 131 name, schema, description, owner, privileges, arguments) 132 self.language = language 133 self.returns = returns 134 if source and '\n' in source: 135 newsrc = [] 136 for line in source.split('\n'): 137 if line and line[-1] in (' ', '\t'): 138 line = line.rstrip() 139 newsrc.append(line) 140 source = '\n'.join(newsrc) 141 if PY2: 142 if source is not None: 143 self.source = source.encode('utf_8').decode('utf_8') 144 else: 145 self.source = None 146 else: 147 self.source = MultiLineStr(source) 148 self.obj_file = obj_file 149 self.configuration = configuration 150 self.allargs = allargs 151 if volatility is not None: 152 self.volatility = volatility[:1].lower() 153 else: 154 self.volatility = 'v' 155 assert self.volatility in VOLATILITY_TYPES.keys() 156 self.leakproof = leakproof 157 self.strict = strict 158 self.security_definer = security_definer 159 self.cost = cost 160 self.rows = rows 161 self.oid = oid 162 163 @staticmethod 164 def query(dbversion=None): 165 query = """ 166 SELECT nspname AS schema, proname AS name, 167 pg_get_function_identity_arguments(p.oid) AS arguments, 168 pg_get_function_arguments(p.oid) AS allargs, 169 pg_get_function_result(p.oid) AS returns, rolname AS owner, 170 array_to_string(proacl, ',') AS privileges, 171 l.lanname AS language, provolatile AS volatility, 172 proisstrict AS strict, prosrc AS source, 173 probin::text AS obj_file, proconfig AS configuration, 174 prosecdef AS security_definer, procost AS cost, 175 proleakproof AS leakproof, prorows::integer AS rows, 176 obj_description(p.oid, 'pg_proc') AS description, p.oid 177 FROM pg_proc p JOIN pg_roles r ON (r.oid = proowner) 178 JOIN pg_namespace n ON (pronamespace = n.oid) 179 JOIN pg_language l ON (prolang = l.oid) 180 WHERE (nspname != 'pg_catalog' AND nspname != 'information_schema') 181 AND %s 182 AND p.oid NOT IN ( 183 SELECT objid FROM pg_depend WHERE deptype = 'e' 184 AND classid = 'pg_proc'::regclass) 185 ORDER BY nspname, proname""" 186 if dbversion < 110000: 187 query = query % "NOT proisagg" 188 else: 189 query = query % "prokind = 'f'" 190 return query 191 192 @staticmethod 193 def from_map(name, schema, arguments, inobj): 194 """Initialize a function instance from a YAML map 195 196 :param name: function name 197 :param name: schema name 198 :param arguments: arguments 199 :param inobj: YAML map of the function 200 :return: function instance 201 """ 202 src = inobj.get('source', None) 203 objfile = inobj.get('obj_file', None) 204 if (src and objfile) or not (src or objfile): 205 raise ValueError("Function '%s': either source or obj_file must " 206 "be specified" % name) 207 obj = Function( 208 name, schema.name, inobj.pop('description', None), 209 inobj.pop('owner', None), inobj.pop('privileges', []), 210 arguments, inobj.pop('language', None), 211 inobj.pop('returns', None), inobj.pop('source', None), 212 inobj.pop('obj_file', None), 213 inobj.pop('configuration', None), 214 inobj.pop('volatility', None), 215 inobj.pop('leakproof', False), inobj.pop('strict', False), 216 inobj.pop('security_definer', False), 217 inobj.pop('cost', 0), inobj.pop('rows', 0), 218 inobj.pop('allargs', None)) 219 obj.fix_privileges() 220 return obj 221 222 def to_map(self, db, no_owner, no_privs): 223 """Convert a function to a YAML-suitable format 224 225 :param no_owner: exclude function owner information 226 :param no_privs: exclude privilege information 227 :return: dictionary 228 """ 229 dct = super(Function, self).to_map(db, no_owner, no_privs) 230 for attr in ('leakproof', 'strict', 'security_definer'): 231 if dct[attr] is False: 232 dct.pop(attr) 233 if self.allargs is None or len(self.allargs) == 0 or \ 234 self.allargs == self.arguments: 235 dct.pop('allargs') 236 if self.configuration is None: 237 dct.pop('configuration') 238 if self.volatility == 'v': 239 dct.pop('volatility') 240 else: 241 dct['volatility'] = VOLATILITY_TYPES[self.volatility] 242 if self.obj_file is not None: 243 dct['link_symbol'] = self.source 244 del dct['source'] 245 else: 246 del dct['obj_file'] 247 if self.cost != 0: 248 if self.language in ['c', 'internal']: 249 if self.cost == 1: 250 del dct['cost'] 251 else: 252 if self.cost == 100: 253 del dct['cost'] 254 else: 255 del dct['cost'] 256 if self.rows != 0: 257 if self.rows == 1000: 258 del dct['rows'] 259 else: 260 del dct['rows'] 261 262 return dct 263 264 @commentable 265 @grantable 266 @ownable 267 def create(self, dbversion=None, newsrc=None, basetype=False, returns=None): 268 """Return SQL statements to CREATE or REPLACE the function 269 270 :param newsrc: new source for a changed function 271 :return: SQL statements 272 """ 273 stmts = [] 274 if self.obj_file is not None: 275 src = "'%s', '%s'" % (self.obj_file, 276 hasattr(self, 'link_symbol') and 277 self.link_symbol or self.name) 278 elif self.language == 'internal': 279 src = "$$%s$$" % (newsrc or self.source) 280 else: 281 src = "$_$%s$_$" % (newsrc or self.source) 282 volat = leakproof = strict = secdef = cost = rows = config = '' 283 if self.volatility != 'v': 284 volat = ' ' + VOLATILITY_TYPES[self.volatility].upper() 285 if self.leakproof is True: 286 leakproof = ' LEAKPROOF' 287 if self.strict: 288 strict = ' STRICT' 289 if self.security_definer: 290 secdef = ' SECURITY DEFINER' 291 if self.configuration is not None: 292 config = ' SET %s' % self.configuration[0] 293 if self.cost != 0: 294 if self.language in ['c', 'internal']: 295 if self.cost != 1: 296 cost = " COST %s" % self.cost 297 else: 298 if self.cost != 100: 299 cost = " COST %s" % self.cost 300 if self.rows != 0: 301 if self.rows != 1000: 302 rows = " ROWS %s" % self.rows 303 304 # We may have to create a shell type if we are its input or output 305 # functions 306 t = getattr(self, '_defining', None) 307 if t is not None: 308 if not hasattr(t, '_shell_created'): 309 t._shell_created = True 310 stmts.append("CREATE TYPE %s" % t.qualname()) 311 312 if self.allargs is not None: 313 args = self.allargs 314 elif self.arguments is not None: 315 args = self.arguments 316 else: 317 args = '' 318 stmts.append("CREATE%s FUNCTION %s(%s) RETURNS %s\n LANGUAGE %s" 319 "%s%s%s%s%s%s%s\n AS %s" % ( 320 newsrc and " OR REPLACE" or '', self.qualname(), 321 args, returns or self.returns, self.language, volat, leakproof, 322 strict, secdef, cost, rows, config, src)) 323 return stmts 324 325 def alter(self, infunction, dbversion=None, no_owner=False): 326 """Generate SQL to transform an existing function 327 328 :param infunction: a YAML map defining the new function 329 :return: list of SQL statements 330 331 Compares the function to an input function and generates SQL 332 statements to transform it into the one represented by the 333 input. 334 """ 335 stmts = [] 336 if self.source != infunction.source and infunction.source is not None: 337 stmts.append(self.create( 338 dbversion=dbversion, 339 returns=infunction.returns, 340 newsrc=infunction.source, 341 )) 342 if self.leakproof is True: 343 if infunction.leakproof is True: 344 stmts.append("ALTER FUNCTION %s LEAKPROOF" % self.identifier()) 345 else: 346 stmts.append("ALTER FUNCTION %s NOT LEAKPROOF" 347 % self.identifier()) 348 elif infunction.leakproof is True: 349 stmts.append("ALTER FUNCTION %s LEAKPROOF" % self.qualname()) 350 stmts.append(super(Function, self).alter(infunction, 351 no_owner=no_owner)) 352 return stmts 353 354 def get_implied_deps(self, db): 355 # List the previous dependencies 356 deps = super(Function, self).get_implied_deps(db) 357 358 # Add back the return type 359 rettype = self.returns 360 if rettype.upper().startswith("SETOF "): 361 rettype = rettype.split(None, 1)[-1] 362 rettype = db.find_type(rettype) 363 if rettype is not None: 364 deps.add(rettype) 365 366 return deps 367 368 def get_deps(self, db): 369 deps = super(Function, self).get_deps(db) 370 371 # avoid circular import dependencies 372 from .dbtype import DbType 373 374 # drop the dependency on the type if this function is an in/out 375 # because there is a loop here. 376 for dep in list(deps): 377 if isinstance(dep, DbType): 378 for attr in ('input', 'output', 'send', 'receive'): 379 fname = getattr(dep, attr, None) 380 if isinstance(fname, tuple): 381 fname = "%s.%s" % fname 382 else: 383 fname = "%s.%s" % (self.schema, fname) 384 if fname and fname == self.qualname(): 385 deps.remove(dep) 386 self._defining = dep # we may need a shell for this 387 break 388 389 return deps 390 391 def drop(self): 392 """Generate SQL to drop the current function 393 394 :return: list of SQL statements 395 """ 396 # If the function defines a type it will be dropped by the CASCADE 397 # on the type. 398 if getattr(self, '_defining', None): 399 return [] 400 else: 401 return super(Function, self).drop() 402 403 404AGGREGATE_KINDS = {'n': 'normal', 'o': 'ordered', 'h': 'hypothetical'} 405 406 407class Aggregate(Proc): 408 """An aggregate function""" 409 410 def __init__(self, name, schema, description, owner, privileges, 411 arguments, sfunc, stype, sspace=0, finalfunc=None, 412 finalfunc_extra=False, initcond=None, sortop=None, 413 msfunc=None, minvfunc=None, mstype=None, msspace=0, 414 mfinalfunc=None, mfinalfunc_extra=False, minitcond=None, 415 kind='normal', combinefunc=None, serialfunc=None, 416 deserialfunc=None, parallel='unsafe', 417 oid=None): 418 """Initialize the aggregate 419 420 :param name-arguments: see Proc.__init__ params 421 :param sfunc: state transition function (from aggtransfn) 422 :param stype: state datatype (from aggtranstype) 423 :param sspace: transition state data size (from aggtransspace) 424 :param finalfunc: final function (from aggfinalfn) 425 :param finalfunc_extra: extra args? (from aggfinalextra) 426 :param initcond: initial value (from agginitval) 427 :param sortop: sort operator (from aggsortop) 428 :param msfunc: state transition function (from aggmtransfn) 429 :param minvfunc: inverse transition function (from aggminvtransfn) 430 :param mstype: state datatype (from aggmtranstype) 431 :param msspace: transition state data size (from aggmtransspace) 432 :param mfinalfunc: final function (from aggfinalfn) 433 :param mfinalfunc_extra: extra args? (from aggmfinalextra) 434 :param minitcond: initial value (from aggminitval) 435 :param kind: aggregate kind (from aggkind) 436 :param combinefunc: combine function (from aggcombinefn) 437 :param serialfunc: serialization function (from aggserialfn) 438 :param deserialfunc: deserialization function (from aggdeserialfn) 439 :param parallel: parallel safety indicator (from proparallel) 440 """ 441 super(Aggregate, self).__init__( 442 name, schema, description, owner, privileges, arguments) 443 self.sfunc = split_schema_obj(sfunc, self.schema) 444 self.stype = self.unqualify(stype) 445 self.sspace = sspace 446 if finalfunc is not None and finalfunc != '-': 447 self.finalfunc = split_schema_obj(finalfunc, self.schema) 448 else: 449 self.finalfunc = None 450 self.finalfunc_extra = finalfunc_extra 451 self.initcond = initcond 452 self.sortop = sortop if sortop != '0' else None 453 if msfunc is not None and msfunc != '-': 454 self.msfunc = split_schema_obj(msfunc, self.schema) 455 else: 456 self.msfunc = None 457 if minvfunc is not None and minvfunc != '-': 458 self.minvfunc = split_schema_obj(minvfunc, self.schema) 459 else: 460 self.minvfunc = None 461 if mstype is not None and mstype != '-': 462 self.mstype = self.unqualify(mstype) 463 else: 464 self.mstype = None 465 self.msspace = msspace 466 if mfinalfunc is not None and mfinalfunc != '-': 467 self.mfinalfunc = split_schema_obj(mfinalfunc, self.schema) 468 else: 469 self.mfinalfunc = None 470 self.mfinalfunc_extra = mfinalfunc_extra 471 self.minitcond = minitcond 472 if kind is None: 473 self.kind = 'normal' 474 elif len(kind) == 1: 475 self.kind = AGGREGATE_KINDS[kind] 476 else: 477 self.kind = kind 478 assert self.kind in AGGREGATE_KINDS.values() 479 self.combinefunc = combinefunc if combinefunc != '-' else None 480 self.serialfunc = serialfunc if serialfunc != '-' else None 481 self.deserialfunc = deserialfunc if deserialfunc != '-' else None 482 if parallel is None: 483 self.parallel = 'unsafe' 484 elif len(parallel) == 1: 485 self.parallel = PARALLEL_SAFETY[parallel] 486 else: 487 self.parallel = parallel 488 assert self.parallel in PARALLEL_SAFETY.values() 489 self.oid = oid 490 491 @staticmethod 492 def query(dbversion): 493 query = """ 494 SELECT nspname AS schema, proname AS name, 495 pg_get_function_identity_arguments(p.oid) AS arguments, 496 rolname AS owner, 497 array_to_string(proacl, ',') AS privileges, 498 aggtransfn::regproc AS sfunc, 499 aggtranstype::regtype AS stype, %s AS sspace, 500 aggfinalfn::regproc AS finalfunc, %s AS finalfunc_extra, 501 agginitval AS initcond, aggsortop::regoper AS sortop, %s, 502 obj_description(p.oid, 'pg_proc') AS description, p.oid 503 FROM pg_proc p JOIN pg_roles r ON (r.oid = proowner) 504 JOIN pg_namespace n ON (pronamespace = n.oid) 505 LEFT JOIN pg_aggregate a ON (p.oid = aggfnoid) 506 WHERE (nspname != 'pg_catalog' AND nspname != 'information_schema') 507 AND %s 508 AND p.oid NOT IN ( 509 SELECT objid FROM pg_depend WHERE deptype = 'e' 510 AND classid = 'pg_proc'::regclass) 511 ORDER BY nspname, proname""" 512 V94_COLS = """aggmtransfn::regproc AS msfunc, 513 aggminvtransfn::regproc AS minvfunc, 514 aggmtranstype::regtype AS mstype, 515 aggmtransspace AS msspace, 516 aggmfinalfn::regproc AS mfinalfunc, 517 aggmfinalextra AS mfinalfunc_extra, 518 aggminitval AS minitcond, aggkind AS kind""" 519 V96_COLS = V94_COLS + """,aggcombinefn AS combinefunc, 520 aggserialfn AS serialfunc, aggdeserialfn AS deserialfunc, 521 proparallel AS parallel""" 522 cols = ('aggtransspace', 'aggfinalextra') 523 if dbversion < 90400: 524 cols = ('0', 'false', 525 """'-' AS msfunc, '-' AS minvfunc, NULL AS mstype, 526 0 AS msspace, '-' AS mfinalfunc, false AS mfinalfunc_extra, 527 NULL AS minitcond""", "proisagg") 528 elif dbversion < 90600: 529 cols += (V94_COLS, "proisagg") 530 elif dbversion < 110000: 531 cols += (V96_COLS, "proisagg") 532 else: 533 cols += (V96_COLS, "prokind = 'a'") 534 return query % cols 535 536 @staticmethod 537 def from_map(name, schema, arguments, inobj): 538 """Initialize an aggregate instance from a YAML map 539 540 :param name: aggregate name 541 :param name: schema name 542 :param arguments: arguments 543 :param inobj: YAML map of the aggregate 544 :return: aggregate instance 545 """ 546 obj = Aggregate( 547 name, schema.name, inobj.pop('description', None), 548 inobj.pop('owner', None), inobj.pop('privileges', []), 549 arguments, inobj.get('sfunc'), inobj.get('stype'), 550 inobj.pop('sspace', 0), inobj.pop('finalfunc', None), 551 inobj.pop('finalfunc_extra', False), inobj.pop('initcond', None), 552 inobj.pop('sortop', None), inobj.pop('msfunc', None), 553 inobj.pop('minvfunc', None), inobj.pop('mstype', None), 554 inobj.pop('msspace', 0), inobj.pop('mfinalfunc', None), 555 inobj.pop('mfinalfunc_extra', False), 556 inobj.pop('minitcond', None), inobj.pop('kind', 'normal'), 557 inobj.pop('combinefunc', None), inobj.pop('serialfunc', None), 558 inobj.pop('deseriafunc', None), inobj.pop('parallel', 'unsafe')) 559 obj.fix_privileges() 560 return obj 561 562 def to_map(self, db, no_owner, no_privs): 563 """Convert an agggregate to a YAML-suitable format 564 565 :param no_owner: exclude aggregate owner information 566 :param no_privs: exclude privilege information 567 :return: dictionary 568 """ 569 dct = super(Aggregate, self).to_map(db, no_owner, no_privs) 570 dct['sfunc'] = self.unqualify(join_schema_func(self.sfunc)) 571 for attr in ('finalfunc', 'msfunc', 'minvfunc', 'mfinalfunc'): 572 if getattr(self, attr) is None: 573 dct.pop(attr) 574 else: 575 dct[attr] = self.unqualify( 576 join_schema_func(getattr(self, attr))) 577 for attr in ('initcond', 'sortop', 'minitcond', 'mstype', 578 'combinefunc', 'serialfunc', 'deserialfunc'): 579 if getattr(self, attr) is None: 580 dct.pop(attr) 581 for attr in ('sspace', 'msspace'): 582 if getattr(self, attr) == 0: 583 dct.pop(attr) 584 for attr in ('finalfunc_extra', 'mfinalfunc_extra'): 585 if getattr(self, attr) is False: 586 dct.pop(attr) 587 if self.kind == 'normal': 588 dct.pop('kind') 589 if self.parallel == 'unsafe': 590 dct.pop('parallel') 591 return dct 592 593 @commentable 594 @grantable 595 @ownable 596 def create(self, dbversion=None): 597 """Return SQL statements to CREATE the aggregate 598 599 :param dbversion: Posgres version 600 :return: SQL statements 601 """ 602 opt_clauses = [] 603 if self.finalfunc is not None: 604 opt_clauses.append("FINALFUNC = %s" % 605 join_schema_func(self.finalfunc)) 606 if self.initcond is not None: 607 opt_clauses.append("INITCOND = '%s'" % self.initcond) 608 if dbversion >= 90600: 609 if self.combinefunc is not None: 610 opt_clauses.append("COMBINEFUNC = %s" % self.combinefunc) 611 if self.serialfunc is not None: 612 opt_clauses.append("SERIALFUNC = %s" % self.serialfunc) 613 if self.deserialfunc is not None: 614 opt_clauses.append("DESERIALFUNC = %s" % self.deserialfunc) 615 if dbversion >= 90400: 616 if self.sspace > 0: 617 opt_clauses.append("SSPACE = %d" % self.sspace) 618 if self.finalfunc_extra: 619 opt_clauses.append("FINALFUNC_EXTRA") 620 if self.msfunc is not None: 621 opt_clauses.append("MSFUNC = %s" % 622 join_schema_func(self.msfunc)) 623 if self.minvfunc is not None: 624 opt_clauses.append("MINVFUNC = %s" % 625 join_schema_func(self.minvfunc)) 626 if self.mstype is not None: 627 opt_clauses.append("MSTYPE = %s" % self.mstype) 628 if self.msspace > 0: 629 opt_clauses.append("MSSPACE = %d" % self.msspace) 630 if self.mfinalfunc is not None: 631 opt_clauses.append("MFINALFUNC = %s" % 632 join_schema_func(self.mfinalfunc)) 633 if self.mfinalfunc_extra: 634 opt_clauses.append("MFINALFUNC_EXTRA") 635 if self.minitcond is not None: 636 opt_clauses.append("MINITCOND = '%s'" % self.minitcond) 637 if self.kind == 'hypothetical': 638 opt_clauses.append("HYPOTHETICAL") 639 if self.sortop is not None: 640 clause = self.sortop 641 if not clause.startswith('OPERATOR'): 642 clause = "OPERATOR(%s)" % clause 643 opt_clauses.append("SORTOP = %s" % clause) 644 if dbversion >= 90600: 645 if self.parallel != 'unsafe': 646 opt_clauses.append("PARALLEL = %s" % self.parallel.upper()) 647 return ["CREATE AGGREGATE %s(%s) (\n SFUNC = %s," 648 "\n STYPE = %s%s%s)" % ( 649 self.qualname(), self.arguments, 650 join_schema_func(self.sfunc), self.stype, 651 opt_clauses and ',\n ' or '', 652 ',\n '.join(opt_clauses))] 653 654 def get_implied_deps(self, db): 655 # List the previous dependencies 656 deps = super(Aggregate, self).get_implied_deps(db) 657 658 if isinstance(self.sfunc, tuple): 659 sch, fnc = self.sfunc 660 else: 661 sch, fnc = self.schema, self.sfunc 662 if 'ORDER BY' in self.arguments: 663 args = self.arguments.replace(' ORDER BY', ',') 664 else: 665 args = self.stype + ', ' + self.arguments 666 deps.add(db.functions[sch, fnc, args]) 667 for fn in ('finalfunc', 'mfinalfunc'): 668 if getattr(self, fn) is not None: 669 func = getattr(self, fn) 670 if isinstance(func, tuple): 671 sch, fnc = func 672 else: 673 sch, fnc = self.schema, func 674 deps.add(db.functions[sch, fnc, self.mstype 675 if fn[0] == 'm' else self.stype]) 676 for fn in ('msfunc', 'minvfunc'): 677 if getattr(self, fn) is not None: 678 func = getattr(self, fn) 679 if isinstance(func, tuple): 680 sch, fnc = func 681 else: 682 sch, fnc = self.schema, func 683 args = self.mstype + ", " + self.arguments 684 deps.add(db.functions[sch, fnc, args]) 685 686 return deps 687 688 689class ProcDict(DbObjectDict): 690 "The collection of regular and aggregate functions in a database" 691 692 cls = Proc 693 694 def _from_catalog(self): 695 """Initialize the dictionary of procedures by querying the catalogs""" 696 for cls in (Function, Aggregate): 697 self.cls = cls 698 for obj in self.fetch(): 699 self[obj.key()] = obj 700 self.by_oid[obj.oid] = obj 701 702 def from_map(self, schema, infuncs): 703 """Initalize the dictionary of functions by converting the input map 704 705 :param schema: schema owning the functions 706 :param infuncs: YAML map defining the functions 707 """ 708 for key in infuncs: 709 (objtype, spc, fnc) = key.partition(' ') 710 if spc != ' ' or objtype not in ['function', 'aggregate']: 711 raise KeyError("Unrecognized object type: %s" % key) 712 paren = fnc.find('(') 713 if paren == -1 or fnc[-1:] != ')': 714 raise KeyError("Invalid function signature: %s" % fnc) 715 arguments = fnc[paren + 1:-1] 716 inobj = infuncs[key] 717 fnc = fnc[:paren] 718 if objtype == 'function': 719 func = Function.from_map(fnc, schema, arguments, inobj) 720 else: 721 func = Aggregate.from_map(fnc, schema, arguments, inobj) 722 self[(schema.name, fnc, arguments)] = func 723 724 def find(self, func, args): 725 """Return a function given its name and arguments 726 727 :param func: name of the function, eventually with schema 728 :param args: list of type names 729 730 Return the function found, else None. 731 """ 732 schema, name = split_schema_obj(func) 733 args = ', '.join(args) 734 return self.get((schema, name, args)) 735 736 def link_refs(self, dbtypes): 737 """Connect the functions to other objects 738 739 - Connect defining functions to the type they define 740 741 :param dbtypes: dictionary of types 742 """ 743 # TODO: this link is needed from map, not from sql. 744 # is this a pattern? I was assuming link_refs would have disappeared 745 # but I'm actually still maintaining them. Verify if they are always 746 # only used for from_map, not for from_catalog 747 for key in dbtypes: 748 t = dbtypes[key] 749 for f in t.find_defining_funcs(self): 750 f._defining = t 751