1"""
2   Module for visitor class mapping.
3"""
4import sqlalchemy as sa
5
6from migrate.changeset import ansisql
7from migrate.changeset.databases import (sqlite,
8                                         postgres,
9                                         mysql,
10                                         oracle,
11                                         firebird)
12
13
14# Map SA dialects to the corresponding Migrate extensions
15DIALECTS = {
16    "default": ansisql.ANSIDialect,
17    "sqlite": sqlite.SQLiteDialect,
18    "postgres": postgres.PGDialect,
19    "postgresql": postgres.PGDialect,
20    "mysql": mysql.MySQLDialect,
21    "oracle": oracle.OracleDialect,
22    "firebird": firebird.FBDialect,
23}
24
25
26# NOTE(mriedem): We have to conditionally check for DB2 in case ibm_db_sa
27# isn't available since ibm_db_sa is not packaged in sqlalchemy like the
28# other dialects.
29try:
30    from migrate.changeset.databases import ibmdb2
31    DIALECTS["ibm_db_sa"] = ibmdb2.IBMDBDialect
32except ImportError:
33    pass
34
35
36def get_engine_visitor(engine, name):
37    """
38    Get the visitor implementation for the given database engine.
39
40    :param engine: SQLAlchemy Engine
41    :param name: Name of the visitor
42    :type name: string
43    :type engine: Engine
44    :returns: visitor
45    """
46    # TODO: link to supported visitors
47    return get_dialect_visitor(engine.dialect, name)
48
49
50def get_dialect_visitor(sa_dialect, name):
51    """
52    Get the visitor implementation for the given dialect.
53
54    Finds the visitor implementation based on the dialect class and
55    returns and instance initialized with the given name.
56
57    Binds dialect specific preparer to visitor.
58    """
59
60    # map sa dialect to migrate dialect and return visitor
61    sa_dialect_name = getattr(sa_dialect, 'name', 'default')
62    migrate_dialect_cls = DIALECTS[sa_dialect_name]
63    visitor = getattr(migrate_dialect_cls, name)
64
65    # bind preparer
66    visitor.preparer = sa_dialect.preparer(sa_dialect)
67
68    return visitor
69
70def run_single_visitor(engine, visitorcallable, element,
71    connection=None, **kwargs):
72    """Taken from :meth:`sqlalchemy.engine.base.Engine._run_single_visitor`
73    with support for migrate visitors.
74    """
75    if connection is None:
76        conn = engine.connect()
77    else:
78        conn = connection
79    visitor = visitorcallable(engine.dialect, conn)
80    try:
81        if hasattr(element, '__migrate_visit_name__'):
82            fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__)
83        else:
84            fn = getattr(visitor, 'visit_' + element.__visit_name__)
85        fn(element, **kwargs)
86    finally:
87        if connection is None:
88            conn.close()
89