1import sqlalchemy as sa
2from sqlalchemy import Integer, String, ForeignKey, event
3from sqlalchemy import testing
4from sqlalchemy.testing.schema import Table, Column
5from sqlalchemy.orm import mapper, relationship, create_session
6from sqlalchemy.testing import fixtures
7from sqlalchemy.testing import eq_
8
9
10class TriggerDefaultsTest(fixtures.MappedTest):
11    __requires__ = ('row_triggers',)
12
13    @classmethod
14    def define_tables(cls, metadata):
15        dt = Table('dt', metadata,
16                   Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
17                   Column('col1', String(20)),
18                   Column('col2', String(20),
19                          server_default=sa.schema.FetchedValue()),
20                   Column('col3', String(20),
21                          sa.schema.FetchedValue(for_update=True)),
22                   Column('col4', String(20),
23                          sa.schema.FetchedValue(),
24                          sa.schema.FetchedValue(for_update=True)))
25        for ins in (
26            sa.DDL("CREATE TRIGGER dt_ins AFTER INSERT ON dt "
27                   "FOR EACH ROW BEGIN "
28                   "UPDATE dt SET col2='ins', col4='ins' "
29                   "WHERE dt.id = NEW.id; END",
30                   on='sqlite'),
31            sa.DDL("CREATE TRIGGER dt_ins ON dt AFTER INSERT AS "
32                   "UPDATE dt SET col2='ins', col4='ins' "
33                   "WHERE dt.id IN (SELECT id FROM inserted);",
34                   on='mssql'),
35            sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT "
36                     "ON dt "
37                     "FOR EACH ROW "
38                     "BEGIN "
39                     ":NEW.col2 := 'ins'; :NEW.col4 := 'ins'; END;",
40                     on='oracle'),
41            sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt "
42                         "FOR EACH ROW BEGIN "
43                         "SET NEW.col2='ins'; SET NEW.col4='ins'; END",
44                         on=lambda ddl, event, target, bind, **kw:
45                                bind.engine.name not in ('oracle', 'mssql', 'sqlite')
46                ),
47            ):
48            event.listen(dt, 'after_create', ins)
49
50        event.listen(dt, 'before_drop', sa.DDL("DROP TRIGGER dt_ins"))
51
52        for up in (
53            sa.DDL("CREATE TRIGGER dt_up AFTER UPDATE ON dt "
54                   "FOR EACH ROW BEGIN "
55                   "UPDATE dt SET col3='up', col4='up' "
56                   "WHERE dt.id = OLD.id; END",
57                   on='sqlite'),
58            sa.DDL("CREATE TRIGGER dt_up ON dt AFTER UPDATE AS "
59                   "UPDATE dt SET col3='up', col4='up' "
60                   "WHERE dt.id IN (SELECT id FROM deleted);",
61                   on='mssql'),
62            sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
63                  "FOR EACH ROW BEGIN "
64                  ":NEW.col3 := 'up'; :NEW.col4 := 'up'; END;",
65                  on='oracle'),
66            sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
67                        "FOR EACH ROW BEGIN "
68                        "SET NEW.col3='up'; SET NEW.col4='up'; END",
69                        on=lambda ddl, event, target, bind, **kw:
70                                bind.engine.name not in ('oracle', 'mssql', 'sqlite')
71                    ),
72            ):
73            event.listen(dt, 'after_create', up)
74
75        event.listen(dt, 'before_drop', sa.DDL("DROP TRIGGER dt_up"))
76
77
78    @classmethod
79    def setup_classes(cls):
80        class Default(cls.Comparable):
81            pass
82
83    @classmethod
84    def setup_mappers(cls):
85        Default, dt = cls.classes.Default, cls.tables.dt
86
87        mapper(Default, dt)
88
89    def test_insert(self):
90        Default = self.classes.Default
91
92
93        d1 = Default(id=1)
94
95        eq_(d1.col1, None)
96        eq_(d1.col2, None)
97        eq_(d1.col3, None)
98        eq_(d1.col4, None)
99
100        session = create_session()
101        session.add(d1)
102        session.flush()
103
104        eq_(d1.col1, None)
105        eq_(d1.col2, 'ins')
106        eq_(d1.col3, None)
107        # don't care which trigger fired
108        assert d1.col4 in ('ins', 'up')
109
110    def test_update(self):
111        Default = self.classes.Default
112
113        d1 = Default(id=1)
114
115        session = create_session()
116        session.add(d1)
117        session.flush()
118        d1.col1 = 'set'
119        session.flush()
120
121        eq_(d1.col1, 'set')
122        eq_(d1.col2, 'ins')
123        eq_(d1.col3, 'up')
124        eq_(d1.col4, 'up')
125
126class ExcludedDefaultsTest(fixtures.MappedTest):
127    @classmethod
128    def define_tables(cls, metadata):
129        dt = Table('dt', metadata,
130                   Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
131                   Column('col1', String(20), default="hello"),
132        )
133
134    def test_exclude(self):
135        dt = self.tables.dt
136
137        class Foo(fixtures.BasicEntity):
138            pass
139        mapper(Foo, dt, exclude_properties=('col1',))
140
141        f1 = Foo()
142        sess = create_session()
143        sess.add(f1)
144        sess.flush()
145        eq_(dt.select().execute().fetchall(), [(1, "hello")])
146
147