1# -*- coding: utf-8 -*-
2"""
3    pyrseas.database
4    ~~~~~~~~~~~~~~~~
5
6    A `Database` is initialized with a DbConnection object.  It
7    consists of one or two `Dicts` objects, each holding various
8    dictionary objects.  The `db` Dicts object defines the database
9    schemas, including their tables and other objects, by querying the
10    system catalogs.  The `ndb` Dicts object defines the schemas based
11    on the `input_map` supplied to the `from_map` method.
12"""
13import os
14import sys
15from operator import itemgetter
16from collections import defaultdict, deque
17import yaml
18
19from pgdbconn.dbconn import DbConnection
20
21from pyrseas.yamlutil import yamldump
22from pyrseas.dbobject import fetch_reserved_words, DbObjectDict, DbSchemaObject
23from pyrseas.dbobject.language import LanguageDict
24from pyrseas.dbobject.cast import CastDict
25from pyrseas.dbobject.schema import SchemaDict
26from pyrseas.dbobject.dbtype import TypeDict
27from pyrseas.dbobject.table import ClassDict
28from pyrseas.dbobject.column import ColumnDict
29from pyrseas.dbobject.constraint import ConstraintDict
30from pyrseas.dbobject.index import IndexDict
31from pyrseas.dbobject.function import ProcDict
32from pyrseas.dbobject.operator import OperatorDict
33from pyrseas.dbobject.operclass import OperatorClassDict
34from pyrseas.dbobject.operfamily import OperatorFamilyDict
35from pyrseas.dbobject.rule import RuleDict
36from pyrseas.dbobject.trigger import TriggerDict
37from pyrseas.dbobject.conversion import ConversionDict
38from pyrseas.dbobject.textsearch import TSConfigurationDict, TSDictionaryDict
39from pyrseas.dbobject.textsearch import TSParserDict, TSTemplateDict
40from pyrseas.dbobject.foreign import ForeignDataWrapperDict
41from pyrseas.dbobject.foreign import ForeignServerDict, UserMappingDict
42from pyrseas.dbobject.foreign import ForeignTableDict
43from pyrseas.dbobject.extension import ExtensionDict
44from pyrseas.dbobject.collation import CollationDict
45from pyrseas.dbobject.eventtrig import EventTriggerDict
46
47
48def flatten(lst):
49    "Flatten a list possibly containing lists to a single list"
50    for elem in lst:
51        if isinstance(elem, list) and not isinstance(elem, str):
52            for subelem in flatten(elem):
53                yield subelem
54        else:
55            yield elem
56
57
58class CatDbConnection(DbConnection):
59    """A database connection, specialized for querying catalogs"""
60
61    def connect(self):
62        """Connect to the database"""
63        super(CatDbConnection, self).connect()
64        schs = self.fetchall("SELECT current_schemas(false)")
65        addschs = [sch for sch in schs[0][0] if sch != 'public']
66        srch_path = "pg_catalog"
67        if addschs:
68            srch_path += ", " + ", ".join(addschs)
69        self.execute("set search_path to %s" % srch_path)
70        self.commit()
71        self._version = self.conn.server_version
72
73    @property
74    def version(self):
75        "The server's version number"
76        if self.conn is None:
77            self.connect()
78        return self._version
79
80
81class Database(object):
82    """A database definition, from its catalogs and/or a YAML spec."""
83
84    class Dicts(object):
85        """A holder for dictionaries (maps) describing a database"""
86
87        def __init__(self, dbconn=None, single_db=False):
88            """Initialize the various DbObjectDict-derived dictionaries
89
90            :param dbconn: a DbConnection object
91            """
92            self.schemas = SchemaDict(dbconn)
93            self.extensions = ExtensionDict(dbconn)
94            self.languages = LanguageDict(dbconn)
95            self.casts = CastDict(dbconn)
96            self.types = TypeDict(dbconn)
97            self.tables = ClassDict(dbconn)
98            self.columns = ColumnDict(dbconn)
99            self.constraints = ConstraintDict(dbconn)
100            self.indexes = IndexDict(dbconn)
101            self.functions = ProcDict(dbconn)
102            self.operators = OperatorDict(dbconn)
103            self.operclasses = OperatorClassDict(dbconn)
104            self.operfams = OperatorFamilyDict(dbconn)
105            self.rules = RuleDict(dbconn)
106            self.triggers = TriggerDict(dbconn)
107            self.conversions = ConversionDict(dbconn)
108            self.tstempls = TSTemplateDict(dbconn)
109            self.tsdicts = TSDictionaryDict(dbconn)
110            self.tsparsers = TSParserDict(dbconn)
111            self.tsconfigs = TSConfigurationDict(dbconn)
112            self.fdwrappers = ForeignDataWrapperDict(dbconn)
113            self.servers = ForeignServerDict(dbconn)
114            self.usermaps = UserMappingDict(dbconn)
115            self.ftables = ForeignTableDict(dbconn)
116            self.collations = CollationDict(dbconn)
117            self.eventtrigs = EventTriggerDict(dbconn)
118
119            # Populate a map from system catalog to the respective dict
120            self._catalog_map = {}
121            for _, d in self.all_dicts(single_db):
122                if d.cls.catalog is not None:
123                    self._catalog_map[d.cls.catalog] = d
124
125            # Map from objects extkey to their (dict name, key)
126            self._extkey_map = {}
127
128        def _get_by_extkey(self, extkey):
129            """Return any database item from its extkey
130
131            Note: probably doesn't work for all the objects, e.g. constraints
132            may clash because two in different tables have different extkeys.
133            However this shouldn't matter as such objects are generated as part
134            of the containing one and they should be returned by the
135            `get_implied_deps()` implementation of specific classes (which
136            would look for the object in by key in the right dict instead,
137            (e.g.  check `Domain.get_implied_deps()` implementation.
138
139            """
140            try:
141                return self._extkey_map[extkey]
142            except KeyError:
143                # TODO: Likely it's the first time we call this function so
144                # let's warm up the cache. But we should really define the life
145                # cycle of this object as trying and catching KeyError on it is
146                # *very* expensive!
147                for _, d in self.all_dicts():
148                    for obj in list(d.values()):
149                        self._extkey_map[obj.extern_key()] = obj
150
151                return self._extkey_map[extkey]
152
153        def all_dicts(self, non_empty=False):
154            """Iterate over the DbObjectDict-derived dictionaries returning
155            an ordered list of tuples (dict name, DbObjectDict object).
156
157            :param non_empty: do not include empty dicts
158
159            :return: list of tuples
160            """
161            rv = []
162            for attr in self.__dict__:
163                d = getattr(self, attr)
164                if non_empty and len(d) == 0:
165                    continue
166                if isinstance(d, DbObjectDict):
167                    # skip ColumnDict as not needed for dependency tracking
168                    # and internally has lists, not objects
169                    if not isinstance(d, ColumnDict):
170                        rv.append((attr, d))
171
172            # first return the dicts for non-schema objects, then the
173            # others, each group sorted alphabetically.
174            rv.sort(key=lambda pair: (issubclass(pair[1].cls, DbSchemaObject),
175                                      pair[1].cls.__name__))
176
177            return rv
178
179        def dbobjdict_from_catalog(self, catalog):
180            """Given a catalog name, return corresponding DbObjectDict
181
182            :param catalog: full name of a pg_ catalog
183            :return: DbObjectDict object
184            """
185            return self._catalog_map.get(catalog)
186
187        def find_type(self, name):
188            """Return a db type given a qualname
189
190            Note that tables and views are types too.
191            """
192            rv = self.types.find(name)
193            if rv is not None:
194                return rv
195
196            rv = self.tables.find(name)
197            return rv
198
199    def __init__(self, config):
200        """Initialize the database
201
202        :param config: configuration dictionary
203        """
204        db = config['database']
205        self.dbconn = CatDbConnection(db['dbname'], db['username'],
206                                      db['password'], db['host'], db['port'])
207        self.db = None
208        self.config = config
209
210    def _link_refs(self, db):
211        """Link related objects"""
212        langs = []
213        if self.dbconn.version >= 90100:
214            langs = [lang[0] for lang in self.dbconn.fetchall(
215                """SELECT lanname FROM pg_language l
216                     JOIN pg_depend p ON (l.oid = p.objid)
217                    WHERE deptype = 'e' """)]
218        db.languages.link_refs(db.functions, langs)
219        copycfg = {}
220        if 'datacopy' in self.config:
221            copycfg = self.config['datacopy']
222        db.schemas.link_refs(db, copycfg)
223        db.tables.link_refs(db.columns, db.constraints, db.indexes, db.rules,
224                            db.triggers)
225        db.functions.link_refs(db.types)
226        db.fdwrappers.link_refs(db.servers)
227        db.servers.link_refs(db.usermaps)
228        db.ftables.link_refs(db.columns)
229        db.types.link_refs(db.columns, db.constraints, db.functions)
230        db.constraints.link_refs(db)
231
232    def _build_dependency_graph(self, db, dbconn):
233        """Build the dependency graph of the database objects
234
235        :param db: dictionary of dictionary of all objects
236        :param dbconn: a DbConnection object
237        """
238        alldeps = defaultdict(list)
239
240        # This query wanted to be simple. it got complicated because
241        # we don't handle indexes together with the other pg_class
242        # but in their own pg_index place (so fetch i1, i2)
243        # "Normal" dependencies, but excluding system objects
244        # (objid < 16384 and refobjid < 16384)
245        query = """SELECT DISTINCT
246                          CASE WHEN i1.indexrelid IS NOT NULL
247                          THEN 'pg_index'::regclass
248                          ELSE classid::regclass END AS class_name, objid,
249                          CASE WHEN i2.indexrelid IS NOT NULL
250                          THEN 'pg_index'::regclass
251                          ELSE refclassid::regclass END AS refclass, refobjid
252                   FROM pg_depend
253                        LEFT JOIN pg_index i1 ON classid = 'pg_class'::regclass
254                             AND objid = i1.indexrelid
255                        LEFT JOIN pg_index i2
256                             ON refclassid = 'pg_class'::regclass
257                             AND refobjid = i2.indexrelid
258                   WHERE deptype = 'n'
259                   AND NOT (objid < 16384 AND refobjid < 16384)"""
260        for r in dbconn.fetchall(query):
261            alldeps[r['class_name'], r['objid']].append(
262                (r['refclass'], r['refobjid']))
263
264        # The dependencies across views is not in pg_depend. We have to
265        # parse the rewrite rule.  "ev_class >= 16384" is to exclude
266        # system views.
267        query = """SELECT DISTINCT 'pg_class' AS class_name, ev_class,
268                          CASE WHEN depid[1] = 'relid' THEN 'pg_class'
269                               WHEN depid[1] = 'funcid' THEN 'pg_proc'
270                               END AS refclass, depid[2]::oid AS refobjid
271                   FROM (SELECT ev_class, regexp_matches(ev_action,
272                                ':(relid|funcid)\s+(\d+)', 'g') AS depid
273                         FROM pg_rewrite
274                         WHERE rulename = '_RETURN'
275                         AND ev_class >= 16384) x
276                         LEFT JOIN pg_class c
277                              ON (depid[1], depid[2]::oid) = ('relid', c.oid)
278                         LEFT JOIN pg_namespace cs ON cs.oid = relnamespace
279                         LEFT JOIN pg_proc p
280                              ON (depid[1], depid[2]::oid) = ('funcid', p.oid)
281                         LEFT JOIN pg_namespace ps ON ps.oid = pronamespace
282                   WHERE ev_class <> depid[2]::oid
283                   AND coalesce(cs.nspname, ps.nspname)
284                         NOT IN ('information_schema', 'pg_catalog')"""
285        for r in dbconn.fetchall(query):
286            alldeps[r['class_name'], r['ev_class']].append(
287                (r['refclass'], r['refobjid']))
288
289        # Add the dependencies between a table and other objects through the
290        # columns defaults
291        query = """SELECT 'pg_class' AS class_name, adrelid,
292                          d.refclassid::regclass, d.refobjid
293                   FROM pg_attrdef ad JOIN pg_depend d
294                        ON classid = 'pg_attrdef'::regclass AND objid = ad.oid
295                        AND deptype = 'n'"""
296        for r in dbconn.fetchall(query):
297            alldeps[r['class_name'], r['adrelid']].append(
298                (r['refclassid'], r['refobjid']))
299
300        for (stbl, soid), deps in list(alldeps.items()):
301            sdict = db.dbobjdict_from_catalog(stbl)
302            if sdict is None or len(sdict) == 0:
303                continue
304            src = sdict.by_oid.get(soid)
305            if src is None:
306                continue
307            for ttbl, toid in deps:
308                tdict = db.dbobjdict_from_catalog(ttbl)
309                if tdict is None or len(tdict) == 0:
310                    continue
311                tgt = tdict.by_oid.get(toid)
312                if tgt is None:
313                    continue
314                src.depends_on.append(tgt)
315
316    def _trim_objects(self, schemas):
317        """Remove unwanted schema objects
318
319        :param schemas: list of schemas to keep
320        """
321        for objtype in ['types', 'tables', 'constraints', 'indexes',
322                        'functions', 'operators', 'operclasses', 'operfams',
323                        'rules', 'triggers', 'conversions', 'tstempls',
324                        'tsdicts', 'tsparsers', 'tsconfigs', 'extensions',
325                        'collations', 'eventtrigs']:
326            objdict = getattr(self.db, objtype)
327            for obj in list(objdict.keys()):
328                # obj[0] is the schema name in all these dicts
329                if obj[0] not in schemas:
330                    del objdict[obj]
331        for sch in list(self.db.schemas.keys()):
332            if sch not in schemas:
333                del self.db.schemas[sch]
334        # exclude database-wide objects
335        self.db.languages = LanguageDict()
336        self.db.casts = CastDict()
337
338    def from_catalog(self, single_db=False):
339        """Populate the database objects by querying the catalogs
340
341        :param single_db: populating only this database?
342
343        The `db` holder is populated by various DbObjectDict-derived
344        classes by querying the catalogs.  A dependency graph is
345        constructed by querying the pg_depend catalog.  The objects in
346        the dictionary are then linked to related objects, e.g.,
347        columns are linked to the tables they belong.
348        """
349        self.db = self.Dicts(self.dbconn, single_db)
350        self._build_dependency_graph(self.db, self.dbconn)
351        if self.dbconn.conn:
352            self.dbconn.conn.close()
353        self._link_refs(self.db)
354
355    def from_map(self, input_map, langs=None):
356        """Populate the new database objects from the input map
357
358        :param input_map: a YAML map defining the new database
359        :param langs: list of language templates
360
361        The `ndb` holder is populated by various DbObjectDict-derived
362        classes by traversing the YAML input map. The objects in the
363        dictionary are then linked to related objects, e.g., columns
364        are linked to the tables they belong.
365        """
366        self.ndb = self.Dicts()
367        input_schemas = {}
368        input_extens = {}
369        input_langs = {}
370        input_casts = {}
371        input_fdws = {}
372        input_ums = {}
373        input_evttrigs = {}
374        for key in input_map:
375            if key.startswith('schema '):
376                input_schemas.update({key: input_map[key]})
377            elif key.startswith('extension '):
378                input_extens.update({key: input_map[key]})
379            elif key.startswith('language '):
380                input_langs.update({key: input_map[key]})
381            elif key.startswith('cast '):
382                input_casts.update({key: input_map[key]})
383            elif key.startswith('foreign data wrapper '):
384                input_fdws.update({key: input_map[key]})
385            elif key.startswith('user mapping for '):
386                input_ums.update({key: input_map[key]})
387            elif key.startswith('event trigger '):
388                input_evttrigs.update({key: input_map[key]})
389            else:
390                raise KeyError("Expected typed object, found '%s'" % key)
391        self.ndb.extensions.from_map(input_extens, langs, self.ndb)
392        self.ndb.languages.from_map(input_langs)
393        self.ndb.schemas.from_map(input_schemas, self.ndb)
394        self.ndb.casts.from_map(input_casts, self.ndb)
395        self.ndb.fdwrappers.from_map(input_fdws, self.ndb)
396        self.ndb.eventtrigs.from_map(input_evttrigs, self.ndb)
397        self._link_refs(self.ndb)
398
399    def map_from_dir(self):
400        """Read the database maps starting from the metadata directory
401
402        :return: dictionary
403        """
404        metadata_dir = self.config['files']['metadata_path']
405        if not os.path.isdir(metadata_dir):
406            sys.exit("Metadata directory '%s' doesn't exist" % metadata_dir)
407
408        def load(subdir, obj):
409            with open(os.path.join(subdir, obj), 'r') as f:
410                objmap = yaml.safe_load(f)
411            return objmap if isinstance(objmap, dict) else {}
412
413        inmap = {}
414        for entry in os.listdir(metadata_dir):
415            if entry.endswith('.yaml'):
416                if entry.startswith('database.'):
417                    continue
418                if not entry.startswith('schema.'):
419                    inmap.update(load(metadata_dir, entry))
420            else:
421                # skip over unknown files/dirs
422                if not entry.startswith('schema.'):
423                    continue
424                # read schema.xxx.yaml first
425                schmap = load(metadata_dir, entry + '.yaml')
426                assert(len(schmap) == 1)
427                key = list(schmap.keys())[0]
428                inmap.update({key: {}})
429                subdir = os.path.join(metadata_dir, entry)
430                if os.path.isdir(subdir):
431                    for schobj in os.listdir(subdir):
432                        schmap[key].update(load(subdir, schobj))
433                inmap.update(schmap)
434
435        return inmap
436
437    def to_map(self):
438        """Convert the db maps to a single hierarchy suitable for YAML
439
440        :return: a YAML-suitable dictionary (without any Python objects)
441        """
442        if not self.db:
443            self.from_catalog(True)
444
445        opts = self.config['options']
446
447        def mkdir_parents(dir):
448            head, tail = os.path.split(dir)
449            if head and not os.path.isdir(head):
450                mkdir_parents(head)
451            if tail:
452                os.mkdir(dir)
453
454        if opts.multiple_files:
455            opts.metadata_dir = self.config['files']['metadata_path']
456            if not os.path.exists(opts.metadata_dir):
457                mkdir_parents(opts.metadata_dir)
458            dbfilepath = os.path.join(opts.metadata_dir, 'database.%s.yaml' %
459                                      self.dbconn.dbname)
460            if os.path.exists(dbfilepath):
461                with open(dbfilepath, 'r') as f:
462                    objmap = yaml.safe_load(f)
463                for obj, val in list(objmap.items()):
464                    if isinstance(val, dict):
465                        dirpath = ''
466                        for schobj, fpath in list(val.items()):
467                            filepath = os.path.join(opts.metadata_dir, fpath)
468                            if os.path.exists(filepath):
469                                os.remove(filepath)
470                                if schobj == 'schema':
471                                    (dirpath, ext) = os.path.splitext(filepath)
472                        if os.path.exists(dirpath):
473                            os.rmdir(dirpath)
474                    else:
475                        filepath = os.path.join(opts.metadata_dir, val)
476                        if (os.path.exists(filepath)):
477                            os.remove(filepath)
478
479        dbmap = self.db.extensions.to_map(self.db, opts)
480        dbmap.update(self.db.languages.to_map(self.db, opts))
481        dbmap.update(self.db.casts.to_map(self.db, opts))
482        dbmap.update(self.db.fdwrappers.to_map(self.db, opts))
483        dbmap.update(self.db.eventtrigs.to_map(self.db, opts))
484        if 'datacopy' in self.config:
485            opts.data_dir = self.config['files']['data_path']
486            if not os.path.exists(opts.data_dir):
487                mkdir_parents(opts.data_dir)
488        dbmap.update(self.db.schemas.to_map(self.db, opts))
489
490        if opts.multiple_files:
491            with open(dbfilepath, 'w') as f:
492                f.write(yamldump(dbmap))
493
494        return dbmap
495
496    def diff_map(self, input_map, quote_reserved=True):
497        """Generate SQL to transform an existing database
498
499        :param input_map: a YAML map defining the new database
500        :param quote_reserved: fetch reserved words
501        :return: list of SQL statements
502
503        Compares the existing database definition, as fetched from the
504        catalogs, to the input YAML map and generates SQL statements
505        to transform the database into the one represented by the
506        input.
507        """
508        from .dbobject.table import Table
509
510        if not self.db:
511            self.from_catalog()
512        opts = self.config['options']
513        if opts.schemas:
514            schlist = ['schema ' + sch for sch in opts.schemas]
515            for sch in list(input_map.keys()):
516                if sch not in schlist and sch.startswith('schema '):
517                    del input_map[sch]
518            self._trim_objects(opts.schemas)
519
520        # quote_reserved is only set to False by most tests
521        if quote_reserved:
522            fetch_reserved_words(self.dbconn)
523
524        langs = [lang[0] for lang in self.dbconn.fetchall(
525            "SELECT tmplname FROM pg_pltemplate")]
526        self.from_map(input_map, langs)
527        if opts.revert:
528            (self.db, self.ndb) = (self.ndb, self.db)
529            del self.ndb.schemas['pg_catalog']
530            self.db.languages.dbconn = self.dbconn
531
532        # First sort the objects in the new db in dependency order
533        new_objs = []
534        for _, d in self.ndb.all_dicts():
535            pairs = list(d.items())
536            pairs.sort()
537            new_objs.extend(list(map(itemgetter(1), pairs)))
538
539        new_objs = self.dep_sorted(new_objs, self.ndb)
540
541        # Then generate the sql for all the objects, walking in dependency
542        # order over all the db objects
543
544        stmts = []
545        for new in new_objs:
546            d = self.db.dbobjdict_from_catalog(new.catalog)
547            old = d.get(new.key())
548            if old is not None:
549                stmts.append(old.alter(new))
550            else:
551                stmts.append(new.create_sql(self.dbconn.version))
552
553                # Check if the object just created was renamed, in which case
554                # don't try to delete the original one
555                if getattr(new, 'oldname', None):
556                    try:
557                        origname, new.name = new.name, new.oldname
558                        oldkey = new.key()
559                    finally:
560                        new.name = origname
561                    # Intentionally raising KeyError as tested e.g. in
562                    # test_bad_rename_view -- ok Joe?
563                    old = d[oldkey]
564                    old._nodrop = True
565
566        # Order the old database objects in reverse dependency order
567        old_objs = []
568        for _, d in self.db.all_dicts():
569            pairs = list(d.items())
570            pairs.sort
571            old_objs.extend(list(map(itemgetter(1), pairs)))
572        old_objs = self.dep_sorted(old_objs, self.db)
573        old_objs.reverse()
574
575        # Drop the objects that don't appear in the new db
576        for old in old_objs:
577            d = self.ndb.dbobjdict_from_catalog(old.catalog)
578            if isinstance(old, Table):
579                new = d.get(old.key())
580                if new is not None:
581                    stmts.extend(old.alter_drop_columns(new))
582            if not getattr(old, '_nodrop', False) and old.key() not in d:
583                stmts.extend(old.drop())
584
585        if 'datacopy' in self.config:
586            opts.data_dir = self.config['files']['data_path']
587            stmts.append(self.ndb.schemas.data_import(opts))
588
589        stmts = [s for s in flatten(stmts)]
590        funcs = False
591        for s in stmts:
592            if "LANGUAGE sql" in s and (
593                    s.startswith("CREATE FUNCTION ") or
594                    s.startswith("CREATE OR REPLACE FUNCTION ")):
595                funcs = True
596                break
597        if funcs:
598            stmts.insert(0, "SET check_function_bodies = false")
599
600        return stmts
601
602    def dep_sorted(self, objs, db):
603        """Sort `objs` in order of dependency.
604
605        The function implements the classic Kahn 62 algorighm, see
606        <http://en.wikipedia.org/wiki/Topological_sorting>.
607        """
608        # List of objects to return
609        L = []
610
611        # Collect the graph edges.
612        # Note that our "dependencies" are sort of backwards compared to the
613        # terms used in the algorithm (an edge in the algo would be from the
614        # schema to the table, we have the table depending on the schema)
615        ein = defaultdict(set)
616        eout = defaultdict(deque)
617        for obj in objs:
618            for dep in obj.get_deps(db):
619                eout[dep].append(obj)
620                ein[obj].add(dep)
621
622        # The objects with no dependency to start with
623        S = deque()
624        for obj in objs:
625            if obj not in ein:
626                S.append(obj)
627
628        while S:
629            # Objects with no dependencies can be emitted
630            obj = S.popleft()
631            L.append(obj)
632
633            # Delete the edges and check if depending objects have no
634            # dependency now
635            while eout[obj]:
636                ch = eout[obj].popleft()
637                ein[ch].remove(obj)
638                if not ein[ch]:
639                    del ein[ch]
640                    S.append(ch)
641
642            del eout[obj]   # remove the empty set
643
644        assert bool(ein) == bool(eout)
645        if not ein:
646            return L
647        else:
648            # is it possible? How do we deal with that?
649            raise Exception("the objects dependencies graph has loops")
650