1# -*- coding: utf-8 -*-
2"""
3    pyrseas.dbobject.function
4    ~~~~~~~~~~~~~~~~~~~~~~~~~
5
6    This module defines four classes: Proc derived from
7    DbSchemaObject, Function and Aggregate derived from Proc, and
8    FunctionDict derived from DbObjectDict.
9"""
10from pyrseas.lib.pycompat import PY2
11from pyrseas.yamlutil import MultiLineStr
12from . import DbObjectDict, DbSchemaObject
13from . import commentable, ownable, grantable, split_schema_obj
14
15VOLATILITY_TYPES = {'i': 'immutable', 's': 'stable', 'v': 'volatile'}
16PARALLEL_SAFETY = {'r': 'restricted', 's': 'safe', 'u': 'unsafe'}
17
18
19def split_schema_func(schema, func):
20    """Split a function related to an object from its schema
21
22    :param schema: schema to which the main object belongs
23    :param func: possibly qualified function name
24    :returns: a schema, function tuple, or just the unqualified function name
25    """
26    (sch, fnc) = split_schema_obj(func, schema)
27    if sch != schema:
28        return (sch, fnc)
29    else:
30        return fnc
31
32
33def join_schema_func(func):
34    """Join the schema and function, if needed, to form a qualified name
35
36    :param func: a schema, function tuple, or just an unqualified function name
37    :returns: a possibly-qualified schema.function string
38    """
39    if isinstance(func, tuple):
40        return "%s.%s" % func
41    else:
42        return func
43
44
45class Proc(DbSchemaObject):
46    """A procedure such as a FUNCTION or an AGGREGATE"""
47
48    keylist = ['schema', 'name', 'arguments']
49    catalog = 'pg_proc'
50
51    @property
52    def allprivs(self):
53        return 'X'
54
55    def __init__(self, name, schema, description, owner, privileges,
56                 arguments):
57        """Initialize the procedure
58
59        :param name: function name (from proname)
60        :param schema: schema name (from pronamespace)
61        :param description: comment text (from obj_description())
62        :param owner: owner name (from rolname via proowner)
63        :param privileges: access privileges (from proacl)
64        :param arguments: argument list (without default values, from
65               pg_function_identity_arguments)
66        """
67        super(Proc, self).__init__(name, schema, description)
68        self._init_own_privs(owner, privileges)
69        self.arguments = arguments
70
71    def extern_key(self):
72        """Return the key to be used in external maps for this function
73
74        :return: string
75        """
76        return '%s %s(%s)' % (self.objtype.lower(), self.name, self.arguments)
77
78    def identifier(self):
79        """Return a full identifier for a function object
80
81        :return: string
82        """
83        return "%s(%s)" % (self.qualname(), self.arguments)
84
85    def get_implied_deps(self, db):
86        # List the previous dependencies
87        deps = super(Proc, self).get_implied_deps(db)
88
89        # Add back the language
90        if isinstance(self, Function) and getattr(self, 'language', None):
91            lang = db.languages.get(self.language)
92            if lang:
93                deps.add(lang)
94
95        # Add back the types
96        if self.arguments:
97            for arg in self.arguments.split(', '):
98                arg = db.find_type(arg.split()[-1])
99                if arg is not None:
100                    deps.add(arg)
101
102        return deps
103
104
105class Function(Proc):
106    """A procedural language function"""
107
108    def __init__(self, name, schema, description, owner, privileges,
109                 arguments, language, returns, source, obj_file=None,
110                 configuration=None, volatility=None, leakproof=False,
111                 strict=False, security_definer=False, cost=0, rows=0,
112                 allargs=None, oid=None):
113        """Initialize the function
114
115        :param name-arguments: see Proc.__init__ params
116        :param language: implementation language (from prolang)
117        :param returns: return type (from pg_get_function_result/prorettype)
118        :param source: source code, link symbol, etc. (from prosrc)
119        :param obj_file: language-specific info (from probin)
120        :param configuration: configuration variables (from proconfig)
121        :param volatility: volatility type (from provolatile)
122        :param leakproof: has side effects (from proleakproof)
123        :param strict: null handling (from proisstrict)
124        :param security_definer: security definer (from prosecdef)
125        :param cost: execution cost estimate (from procost)
126        :param rows: result row estimate (from prorows)
127        :param allargs: argument list with defaults (from
128               pg_get_function_arguments)
129        """
130        super(Function, self).__init__(
131            name, schema, description, owner, privileges, arguments)
132        self.language = language
133        self.returns = returns
134        if source and '\n' in source:
135            newsrc = []
136            for line in source.split('\n'):
137                if line and line[-1] in (' ', '\t'):
138                    line = line.rstrip()
139                newsrc.append(line)
140            source = '\n'.join(newsrc)
141        if PY2:
142            if source is not None:
143                self.source = source.encode('utf_8').decode('utf_8')
144            else:
145                self.source = None
146        else:
147            self.source = MultiLineStr(source)
148        self.obj_file = obj_file
149        self.configuration = configuration
150        self.allargs = allargs
151        if volatility is not None:
152            self.volatility = volatility[:1].lower()
153        else:
154            self.volatility = 'v'
155        assert self.volatility in VOLATILITY_TYPES.keys()
156        self.leakproof = leakproof
157        self.strict = strict
158        self.security_definer = security_definer
159        self.cost = cost
160        self.rows = rows
161        self.oid = oid
162
163    @staticmethod
164    def query(dbversion=None):
165        query = """
166            SELECT nspname AS schema, proname AS name,
167                   pg_get_function_identity_arguments(p.oid) AS arguments,
168                   pg_get_function_arguments(p.oid) AS allargs,
169                   pg_get_function_result(p.oid) AS returns, rolname AS owner,
170                   array_to_string(proacl, ',') AS privileges,
171                   l.lanname AS language, provolatile AS volatility,
172                   proisstrict AS strict, prosrc AS source,
173                   probin::text AS obj_file, proconfig AS configuration,
174                   prosecdef AS security_definer, procost AS cost,
175                   proleakproof AS leakproof, prorows::integer AS rows,
176                   obj_description(p.oid, 'pg_proc') AS description, p.oid
177            FROM pg_proc p JOIN pg_roles r ON (r.oid = proowner)
178                 JOIN pg_namespace n ON (pronamespace = n.oid)
179                 JOIN pg_language l ON (prolang = l.oid)
180            WHERE (nspname != 'pg_catalog' AND nspname != 'information_schema')
181              AND %s
182              AND p.oid NOT IN (
183                  SELECT objid FROM pg_depend WHERE deptype = 'e'
184                               AND classid = 'pg_proc'::regclass)
185            ORDER BY nspname, proname"""
186        if dbversion < 110000:
187            query = query % "NOT proisagg"
188        else:
189            query = query % "prokind = 'f'"
190        return query
191
192    @staticmethod
193    def from_map(name, schema, arguments, inobj):
194        """Initialize a function instance from a YAML map
195
196        :param name: function name
197        :param name: schema name
198        :param arguments: arguments
199        :param inobj: YAML map of the function
200        :return: function instance
201        """
202        src = inobj.get('source', None)
203        objfile = inobj.get('obj_file', None)
204        if (src and objfile) or not (src or objfile):
205            raise ValueError("Function '%s': either source or obj_file must "
206                             "be specified" % name)
207        obj = Function(
208            name, schema.name, inobj.pop('description', None),
209            inobj.pop('owner', None), inobj.pop('privileges', []),
210            arguments, inobj.pop('language', None),
211            inobj.pop('returns', None), inobj.pop('source', None),
212            inobj.pop('obj_file', None),
213            inobj.pop('configuration', None),
214            inobj.pop('volatility', None),
215            inobj.pop('leakproof', False), inobj.pop('strict', False),
216            inobj.pop('security_definer', False),
217            inobj.pop('cost', 0), inobj.pop('rows', 0),
218            inobj.pop('allargs', None))
219        obj.fix_privileges()
220        return obj
221
222    def to_map(self, db, no_owner, no_privs):
223        """Convert a function to a YAML-suitable format
224
225        :param no_owner: exclude function owner information
226        :param no_privs: exclude privilege information
227        :return: dictionary
228        """
229        dct = super(Function, self).to_map(db, no_owner, no_privs)
230        for attr in ('leakproof', 'strict', 'security_definer'):
231            if dct[attr] is False:
232                dct.pop(attr)
233        if self.allargs is None or len(self.allargs) == 0 or \
234           self.allargs == self.arguments:
235            dct.pop('allargs')
236        if self.configuration is None:
237            dct.pop('configuration')
238        if self.volatility == 'v':
239            dct.pop('volatility')
240        else:
241            dct['volatility'] = VOLATILITY_TYPES[self.volatility]
242        if self.obj_file is not None:
243            dct['link_symbol'] = self.source
244            del dct['source']
245        else:
246            del dct['obj_file']
247        if self.cost != 0:
248            if self.language in ['c', 'internal']:
249                if self.cost == 1:
250                    del dct['cost']
251            else:
252                if self.cost == 100:
253                    del dct['cost']
254        else:
255            del dct['cost']
256        if self.rows != 0:
257            if self.rows == 1000:
258                del dct['rows']
259        else:
260            del dct['rows']
261
262        return dct
263
264    @commentable
265    @grantable
266    @ownable
267    def create(self, dbversion=None, newsrc=None, basetype=False, returns=None):
268        """Return SQL statements to CREATE or REPLACE the function
269
270        :param newsrc: new source for a changed function
271        :return: SQL statements
272        """
273        stmts = []
274        if self.obj_file is not None:
275            src = "'%s', '%s'" % (self.obj_file,
276                                  hasattr(self, 'link_symbol') and
277                                  self.link_symbol or self.name)
278        elif self.language == 'internal':
279            src = "$$%s$$" % (newsrc or self.source)
280        else:
281            src = "$_$%s$_$" % (newsrc or self.source)
282        volat = leakproof = strict = secdef = cost = rows = config = ''
283        if self.volatility != 'v':
284            volat = ' ' + VOLATILITY_TYPES[self.volatility].upper()
285        if self.leakproof is True:
286            leakproof = ' LEAKPROOF'
287        if self.strict:
288            strict = ' STRICT'
289        if self.security_definer:
290            secdef = ' SECURITY DEFINER'
291        if self.configuration is not None:
292            config = ' SET %s' % self.configuration[0]
293        if self.cost != 0:
294            if self.language in ['c', 'internal']:
295                if self.cost != 1:
296                    cost = " COST %s" % self.cost
297            else:
298                if self.cost != 100:
299                    cost = " COST %s" % self.cost
300        if self.rows != 0:
301            if self.rows != 1000:
302                rows = " ROWS %s" % self.rows
303
304        # We may have to create a shell type if we are its input or output
305        # functions
306        t = getattr(self, '_defining', None)
307        if t is not None:
308            if not hasattr(t, '_shell_created'):
309                t._shell_created = True
310                stmts.append("CREATE TYPE %s" % t.qualname())
311
312        if self.allargs is not None:
313            args = self.allargs
314        elif self.arguments is not None:
315            args = self.arguments
316        else:
317            args = ''
318        stmts.append("CREATE%s FUNCTION %s(%s) RETURNS %s\n    LANGUAGE %s"
319                     "%s%s%s%s%s%s%s\n    AS %s" % (
320                         newsrc and " OR REPLACE" or '', self.qualname(),
321                         args, returns or self.returns, self.language, volat, leakproof,
322                         strict, secdef, cost, rows, config, src))
323        return stmts
324
325    def alter(self, infunction, dbversion=None, no_owner=False):
326        """Generate SQL to transform an existing function
327
328        :param infunction: a YAML map defining the new function
329        :return: list of SQL statements
330
331        Compares the function to an input function and generates SQL
332        statements to transform it into the one represented by the
333        input.
334        """
335        stmts = []
336        if self.source != infunction.source and infunction.source is not None:
337            stmts.append(self.create(
338                dbversion=dbversion,
339                returns=infunction.returns,
340                newsrc=infunction.source,
341            ))
342        if self.leakproof is True:
343            if infunction.leakproof is True:
344                stmts.append("ALTER FUNCTION %s LEAKPROOF" % self.identifier())
345            else:
346                stmts.append("ALTER FUNCTION %s NOT LEAKPROOF"
347                             % self.identifier())
348        elif infunction.leakproof is True:
349            stmts.append("ALTER FUNCTION %s LEAKPROOF" % self.qualname())
350        stmts.append(super(Function, self).alter(infunction,
351                                                 no_owner=no_owner))
352        return stmts
353
354    def get_implied_deps(self, db):
355        # List the previous dependencies
356        deps = super(Function, self).get_implied_deps(db)
357
358        # Add back the return type
359        rettype = self.returns
360        if rettype.upper().startswith("SETOF "):
361            rettype = rettype.split(None, 1)[-1]
362        rettype = db.find_type(rettype)
363        if rettype is not None:
364            deps.add(rettype)
365
366        return deps
367
368    def get_deps(self, db):
369        deps = super(Function, self).get_deps(db)
370
371        # avoid circular import dependencies
372        from .dbtype import DbType
373
374        # drop the dependency on the type if this function is an in/out
375        # because there is a loop here.
376        for dep in list(deps):
377            if isinstance(dep, DbType):
378                for attr in ('input', 'output', 'send', 'receive'):
379                    fname = getattr(dep, attr, None)
380                    if isinstance(fname, tuple):
381                        fname = "%s.%s" % fname
382                    else:
383                        fname = "%s.%s" % (self.schema, fname)
384                    if fname and fname == self.qualname():
385                        deps.remove(dep)
386                        self._defining = dep    # we may need a shell for this
387                        break
388
389        return deps
390
391    def drop(self):
392        """Generate SQL to drop the current function
393
394        :return: list of SQL statements
395        """
396        # If the function defines a type it will be dropped by the CASCADE
397        # on the type.
398        if getattr(self, '_defining', None):
399            return []
400        else:
401            return super(Function, self).drop()
402
403
404AGGREGATE_KINDS = {'n': 'normal', 'o': 'ordered', 'h': 'hypothetical'}
405
406
407class Aggregate(Proc):
408    """An aggregate function"""
409
410    def __init__(self, name, schema, description, owner, privileges,
411                 arguments, sfunc, stype, sspace=0, finalfunc=None,
412                 finalfunc_extra=False, initcond=None, sortop=None,
413                 msfunc=None, minvfunc=None, mstype=None, msspace=0,
414                 mfinalfunc=None, mfinalfunc_extra=False, minitcond=None,
415                 kind='normal', combinefunc=None, serialfunc=None,
416                 deserialfunc=None, parallel='unsafe',
417                 oid=None):
418        """Initialize the aggregate
419
420        :param name-arguments: see Proc.__init__ params
421        :param sfunc: state transition function (from aggtransfn)
422        :param stype: state datatype (from aggtranstype)
423        :param sspace: transition state data size (from aggtransspace)
424        :param finalfunc: final function (from aggfinalfn)
425        :param finalfunc_extra: extra args? (from aggfinalextra)
426        :param initcond: initial value (from agginitval)
427        :param sortop: sort operator (from aggsortop)
428        :param msfunc: state transition function (from aggmtransfn)
429        :param minvfunc: inverse transition function (from aggminvtransfn)
430        :param mstype: state datatype (from aggmtranstype)
431        :param msspace: transition state data size (from aggmtransspace)
432        :param mfinalfunc: final function (from aggfinalfn)
433        :param mfinalfunc_extra: extra args? (from aggmfinalextra)
434        :param minitcond: initial value (from aggminitval)
435        :param kind: aggregate kind (from aggkind)
436        :param combinefunc: combine function (from aggcombinefn)
437        :param serialfunc: serialization function (from aggserialfn)
438        :param deserialfunc: deserialization function (from aggdeserialfn)
439        :param parallel: parallel safety indicator (from proparallel)
440        """
441        super(Aggregate, self).__init__(
442            name, schema, description, owner, privileges, arguments)
443        self.sfunc = split_schema_obj(sfunc, self.schema)
444        self.stype = self.unqualify(stype)
445        self.sspace = sspace
446        if finalfunc is not None and finalfunc != '-':
447            self.finalfunc = split_schema_obj(finalfunc, self.schema)
448        else:
449            self.finalfunc = None
450        self.finalfunc_extra = finalfunc_extra
451        self.initcond = initcond
452        self.sortop = sortop if sortop != '0' else None
453        if msfunc is not None and msfunc != '-':
454            self.msfunc = split_schema_obj(msfunc, self.schema)
455        else:
456            self.msfunc = None
457        if minvfunc is not None and minvfunc != '-':
458            self.minvfunc = split_schema_obj(minvfunc, self.schema)
459        else:
460            self.minvfunc = None
461        if mstype is not None and mstype != '-':
462            self.mstype = self.unqualify(mstype)
463        else:
464            self.mstype = None
465        self.msspace = msspace
466        if mfinalfunc is not None and mfinalfunc != '-':
467            self.mfinalfunc = split_schema_obj(mfinalfunc, self.schema)
468        else:
469            self.mfinalfunc = None
470        self.mfinalfunc_extra = mfinalfunc_extra
471        self.minitcond = minitcond
472        if kind is None:
473            self.kind = 'normal'
474        elif len(kind) == 1:
475            self.kind = AGGREGATE_KINDS[kind]
476        else:
477            self.kind = kind
478        assert self.kind in AGGREGATE_KINDS.values()
479        self.combinefunc = combinefunc if combinefunc != '-' else None
480        self.serialfunc = serialfunc if serialfunc != '-' else None
481        self.deserialfunc = deserialfunc if deserialfunc != '-' else None
482        if parallel is None:
483            self.parallel = 'unsafe'
484        elif len(parallel) == 1:
485            self.parallel = PARALLEL_SAFETY[parallel]
486        else:
487            self.parallel = parallel
488        assert self.parallel in PARALLEL_SAFETY.values()
489        self.oid = oid
490
491    @staticmethod
492    def query(dbversion):
493        query = """
494            SELECT nspname AS schema, proname AS name,
495                   pg_get_function_identity_arguments(p.oid) AS arguments,
496                   rolname AS owner,
497                   array_to_string(proacl, ',') AS privileges,
498                   aggtransfn::regproc AS sfunc,
499                   aggtranstype::regtype AS stype, %s AS sspace,
500                   aggfinalfn::regproc AS finalfunc, %s AS finalfunc_extra,
501                   agginitval AS initcond, aggsortop::regoper AS sortop, %s,
502                   obj_description(p.oid, 'pg_proc') AS description, p.oid
503            FROM pg_proc p JOIN pg_roles r ON (r.oid = proowner)
504                 JOIN pg_namespace n ON (pronamespace = n.oid)
505                 LEFT JOIN pg_aggregate a ON (p.oid = aggfnoid)
506            WHERE (nspname != 'pg_catalog' AND nspname != 'information_schema')
507              AND %s
508              AND p.oid NOT IN (
509                  SELECT objid FROM pg_depend WHERE deptype = 'e'
510                               AND classid = 'pg_proc'::regclass)
511            ORDER BY nspname, proname"""
512        V94_COLS = """aggmtransfn::regproc AS msfunc,
513                   aggminvtransfn::regproc AS minvfunc,
514                   aggmtranstype::regtype AS mstype,
515                   aggmtransspace AS msspace,
516                   aggmfinalfn::regproc AS mfinalfunc,
517                   aggmfinalextra AS mfinalfunc_extra,
518                   aggminitval AS minitcond, aggkind AS kind"""
519        V96_COLS = V94_COLS + """,aggcombinefn AS combinefunc,
520                   aggserialfn AS serialfunc, aggdeserialfn AS deserialfunc,
521                   proparallel AS parallel"""
522        cols = ('aggtransspace', 'aggfinalextra')
523        if dbversion < 90400:
524            cols = ('0', 'false',
525                    """'-' AS msfunc, '-' AS minvfunc, NULL AS mstype,
526                    0 AS msspace, '-' AS mfinalfunc, false AS mfinalfunc_extra,
527                    NULL AS minitcond""", "proisagg")
528        elif dbversion < 90600:
529            cols += (V94_COLS, "proisagg")
530        elif dbversion < 110000:
531            cols += (V96_COLS, "proisagg")
532        else:
533            cols += (V96_COLS, "prokind = 'a'")
534        return query % cols
535
536    @staticmethod
537    def from_map(name, schema, arguments, inobj):
538        """Initialize an aggregate instance from a YAML map
539
540        :param name: aggregate name
541        :param name: schema name
542        :param arguments: arguments
543        :param inobj: YAML map of the aggregate
544        :return: aggregate instance
545        """
546        obj = Aggregate(
547            name, schema.name, inobj.pop('description', None),
548            inobj.pop('owner', None), inobj.pop('privileges', []),
549            arguments, inobj.get('sfunc'), inobj.get('stype'),
550            inobj.pop('sspace', 0), inobj.pop('finalfunc', None),
551            inobj.pop('finalfunc_extra', False), inobj.pop('initcond', None),
552            inobj.pop('sortop', None), inobj.pop('msfunc', None),
553            inobj.pop('minvfunc', None), inobj.pop('mstype', None),
554            inobj.pop('msspace', 0), inobj.pop('mfinalfunc', None),
555            inobj.pop('mfinalfunc_extra', False),
556            inobj.pop('minitcond', None), inobj.pop('kind', 'normal'),
557            inobj.pop('combinefunc', None), inobj.pop('serialfunc', None),
558            inobj.pop('deseriafunc', None), inobj.pop('parallel', 'unsafe'))
559        obj.fix_privileges()
560        return obj
561
562    def to_map(self, db, no_owner, no_privs):
563        """Convert an agggregate to a YAML-suitable format
564
565        :param no_owner: exclude aggregate owner information
566        :param no_privs: exclude privilege information
567        :return: dictionary
568        """
569        dct = super(Aggregate, self).to_map(db, no_owner, no_privs)
570        dct['sfunc'] = self.unqualify(join_schema_func(self.sfunc))
571        for attr in ('finalfunc', 'msfunc', 'minvfunc', 'mfinalfunc'):
572            if getattr(self, attr) is None:
573                dct.pop(attr)
574            else:
575                dct[attr] = self.unqualify(
576                    join_schema_func(getattr(self, attr)))
577        for attr in ('initcond', 'sortop', 'minitcond', 'mstype',
578                     'combinefunc', 'serialfunc', 'deserialfunc'):
579            if getattr(self, attr) is None:
580                dct.pop(attr)
581        for attr in ('sspace', 'msspace'):
582            if getattr(self, attr) == 0:
583                dct.pop(attr)
584        for attr in ('finalfunc_extra', 'mfinalfunc_extra'):
585            if getattr(self, attr) is False:
586                dct.pop(attr)
587        if self.kind == 'normal':
588            dct.pop('kind')
589        if self.parallel == 'unsafe':
590            dct.pop('parallel')
591        return dct
592
593    @commentable
594    @grantable
595    @ownable
596    def create(self, dbversion=None):
597        """Return SQL statements to CREATE the aggregate
598
599        :param dbversion: Posgres version
600        :return: SQL statements
601        """
602        opt_clauses = []
603        if self.finalfunc is not None:
604            opt_clauses.append("FINALFUNC = %s" %
605                               join_schema_func(self.finalfunc))
606        if self.initcond is not None:
607            opt_clauses.append("INITCOND = '%s'" % self.initcond)
608        if dbversion >= 90600:
609            if self.combinefunc is not None:
610                opt_clauses.append("COMBINEFUNC = %s" % self.combinefunc)
611            if self.serialfunc is not None:
612                opt_clauses.append("SERIALFUNC = %s" % self.serialfunc)
613            if self.deserialfunc is not None:
614                opt_clauses.append("DESERIALFUNC = %s" % self.deserialfunc)
615        if dbversion >= 90400:
616            if self.sspace > 0:
617                opt_clauses.append("SSPACE = %d" % self.sspace)
618            if self.finalfunc_extra:
619                opt_clauses.append("FINALFUNC_EXTRA")
620            if self.msfunc is not None:
621                opt_clauses.append("MSFUNC = %s" %
622                                   join_schema_func(self.msfunc))
623            if self.minvfunc is not None:
624                opt_clauses.append("MINVFUNC = %s" %
625                                   join_schema_func(self.minvfunc))
626            if self.mstype is not None:
627                opt_clauses.append("MSTYPE = %s" % self.mstype)
628            if self.msspace > 0:
629                opt_clauses.append("MSSPACE = %d" % self.msspace)
630            if self.mfinalfunc is not None:
631                opt_clauses.append("MFINALFUNC = %s" %
632                                   join_schema_func(self.mfinalfunc))
633            if self.mfinalfunc_extra:
634                opt_clauses.append("MFINALFUNC_EXTRA")
635            if self.minitcond is not None:
636                opt_clauses.append("MINITCOND = '%s'" % self.minitcond)
637        if self.kind == 'hypothetical':
638            opt_clauses.append("HYPOTHETICAL")
639        if self.sortop is not None:
640            clause = self.sortop
641            if not clause.startswith('OPERATOR'):
642                clause = "OPERATOR(%s)" % clause
643            opt_clauses.append("SORTOP = %s" % clause)
644        if dbversion >= 90600:
645            if self.parallel != 'unsafe':
646                opt_clauses.append("PARALLEL = %s" % self.parallel.upper())
647        return ["CREATE AGGREGATE %s(%s) (\n    SFUNC = %s,"
648                "\n    STYPE = %s%s%s)" % (
649                    self.qualname(), self.arguments,
650                    join_schema_func(self.sfunc), self.stype,
651                    opt_clauses and ',\n    ' or '',
652                    ',\n    '.join(opt_clauses))]
653
654    def get_implied_deps(self, db):
655        # List the previous dependencies
656        deps = super(Aggregate, self).get_implied_deps(db)
657
658        if isinstance(self.sfunc, tuple):
659            sch, fnc = self.sfunc
660        else:
661            sch, fnc = self.schema, self.sfunc
662        if 'ORDER BY' in self.arguments:
663            args = self.arguments.replace(' ORDER BY', ',')
664        else:
665            args = self.stype + ', ' + self.arguments
666        deps.add(db.functions[sch, fnc, args])
667        for fn in ('finalfunc', 'mfinalfunc'):
668            if getattr(self, fn) is not None:
669                func = getattr(self, fn)
670                if isinstance(func, tuple):
671                    sch, fnc = func
672                else:
673                    sch, fnc = self.schema, func
674                deps.add(db.functions[sch, fnc, self.mstype
675                                      if fn[0] == 'm' else self.stype])
676        for fn in ('msfunc', 'minvfunc'):
677            if getattr(self, fn) is not None:
678                func = getattr(self, fn)
679                if isinstance(func, tuple):
680                    sch, fnc = func
681                else:
682                    sch, fnc = self.schema, func
683                args = self.mstype + ", " + self.arguments
684                deps.add(db.functions[sch, fnc, args])
685
686        return deps
687
688
689class ProcDict(DbObjectDict):
690    "The collection of regular and aggregate functions in a database"
691
692    cls = Proc
693
694    def _from_catalog(self):
695        """Initialize the dictionary of procedures by querying the catalogs"""
696        for cls in (Function, Aggregate):
697            self.cls = cls
698            for obj in self.fetch():
699                self[obj.key()] = obj
700                self.by_oid[obj.oid] = obj
701
702    def from_map(self, schema, infuncs):
703        """Initalize the dictionary of functions by converting the input map
704
705        :param schema: schema owning the functions
706        :param infuncs: YAML map defining the functions
707        """
708        for key in infuncs:
709            (objtype, spc, fnc) = key.partition(' ')
710            if spc != ' ' or objtype not in ['function', 'aggregate']:
711                raise KeyError("Unrecognized object type: %s" % key)
712            paren = fnc.find('(')
713            if paren == -1 or fnc[-1:] != ')':
714                raise KeyError("Invalid function signature: %s" % fnc)
715            arguments = fnc[paren + 1:-1]
716            inobj = infuncs[key]
717            fnc = fnc[:paren]
718            if objtype == 'function':
719                func = Function.from_map(fnc, schema, arguments, inobj)
720            else:
721                func = Aggregate.from_map(fnc, schema, arguments, inobj)
722            self[(schema.name, fnc, arguments)] = func
723
724    def find(self, func, args):
725        """Return a function given its name and arguments
726
727        :param func: name of the function, eventually with schema
728        :param args: list of type names
729
730        Return the function found, else None.
731        """
732        schema, name = split_schema_obj(func)
733        args = ', '.join(args)
734        return self.get((schema, name, args))
735
736    def link_refs(self, dbtypes):
737        """Connect the functions to other objects
738
739        - Connect defining functions to the type they define
740
741        :param dbtypes: dictionary of types
742        """
743        # TODO: this link is needed from map, not from sql.
744        # is this a pattern? I was assuming link_refs would have disappeared
745        # but I'm actually still maintaining them. Verify if they are always
746        # only used for from_map, not for from_catalog
747        for key in dbtypes:
748            t = dbtypes[key]
749            for f in t.find_defining_funcs(self):
750                f._defining = t
751