1import sqlalchemy as sa
2from sqlalchemy import Computed
3from sqlalchemy import event
4from sqlalchemy import Integer
5from sqlalchemy import String
6from sqlalchemy import testing
7from sqlalchemy.orm import create_session
8from sqlalchemy.orm import mapper
9from sqlalchemy.orm import Session
10from sqlalchemy.testing import eq_
11from sqlalchemy.testing import fixtures
12from sqlalchemy.testing.assertsql import assert_engine
13from sqlalchemy.testing.assertsql import CompiledSQL
14from sqlalchemy.testing.schema import Column
15from sqlalchemy.testing.schema import Table
16
17
18class TriggerDefaultsTest(fixtures.MappedTest):
19    __requires__ = ("row_triggers",)
20
21    @classmethod
22    def define_tables(cls, metadata):
23        dt = Table(
24            "dt",
25            metadata,
26            Column(
27                "id", Integer, primary_key=True, test_needs_autoincrement=True
28            ),
29            Column("col1", String(20)),
30            Column(
31                "col2", String(20), server_default=sa.schema.FetchedValue()
32            ),
33            Column(
34                "col3", String(20), sa.schema.FetchedValue(for_update=True)
35            ),
36            Column(
37                "col4",
38                String(20),
39                sa.schema.FetchedValue(),
40                sa.schema.FetchedValue(for_update=True),
41            ),
42        )
43
44        dialect_name = testing.db.dialect.name
45
46        for ins in (
47            sa.DDL(
48                "CREATE TRIGGER dt_ins AFTER INSERT ON dt "
49                "FOR EACH ROW BEGIN "
50                "UPDATE dt SET col2='ins', col4='ins' "
51                "WHERE dt.id = NEW.id; END"
52            ).execute_if(dialect="sqlite"),
53            sa.DDL(
54                "CREATE TRIGGER dt_ins ON dt AFTER INSERT AS "
55                "UPDATE dt SET col2='ins', col4='ins' "
56                "WHERE dt.id IN (SELECT id FROM inserted);"
57            ).execute_if(dialect="mssql"),
58            sa.DDL(
59                "CREATE TRIGGER dt_ins BEFORE INSERT "
60                "ON dt "
61                "FOR EACH ROW "
62                "BEGIN "
63                ":NEW.col2 := 'ins'; :NEW.col4 := 'ins'; END;"
64            ).execute_if(dialect="oracle"),
65            sa.DDL(
66                "CREATE TRIGGER dt_ins BEFORE INSERT "
67                "ON dt "
68                "FOR EACH ROW "
69                "EXECUTE PROCEDURE my_func_ins();"
70            ).execute_if(dialect="postgresql"),
71            sa.DDL(
72                "CREATE TRIGGER dt_ins BEFORE INSERT ON dt "
73                "FOR EACH ROW BEGIN "
74                "SET NEW.col2='ins'; SET NEW.col4='ins'; END"
75            ).execute_if(
76                callable_=lambda ddl, target, bind, **kw: bind.engine.name
77                not in ("oracle", "mssql", "sqlite", "postgresql")
78            ),
79        ):
80            my_func_ins = sa.DDL(
81                "CREATE OR REPLACE FUNCTION my_func_ins() "
82                "RETURNS TRIGGER AS $$ "
83                "BEGIN "
84                "NEW.col2 := 'ins'; NEW.col4 := 'ins'; "
85                "RETURN NEW; "
86                "END; $$ LANGUAGE PLPGSQL"
87            ).execute_if(dialect="postgresql")
88            event.listen(dt, "after_create", my_func_ins)
89
90            event.listen(dt, "after_create", ins)
91        if dialect_name == "postgresql":
92            event.listen(
93                dt, "before_drop", sa.DDL("DROP TRIGGER dt_ins ON dt")
94            )
95        else:
96            event.listen(dt, "before_drop", sa.DDL("DROP TRIGGER dt_ins"))
97
98        for up in (
99            sa.DDL(
100                "CREATE TRIGGER dt_up AFTER UPDATE ON dt "
101                "FOR EACH ROW BEGIN "
102                "UPDATE dt SET col3='up', col4='up' "
103                "WHERE dt.id = OLD.id; END"
104            ).execute_if(dialect="sqlite"),
105            sa.DDL(
106                "CREATE TRIGGER dt_up ON dt AFTER UPDATE AS "
107                "UPDATE dt SET col3='up', col4='up' "
108                "WHERE dt.id IN (SELECT id FROM deleted);"
109            ).execute_if(dialect="mssql"),
110            sa.DDL(
111                "CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
112                "FOR EACH ROW BEGIN "
113                ":NEW.col3 := 'up'; :NEW.col4 := 'up'; END;"
114            ).execute_if(dialect="oracle"),
115            sa.DDL(
116                "CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
117                "FOR EACH ROW "
118                "EXECUTE PROCEDURE my_func_up();"
119            ).execute_if(dialect="postgresql"),
120            sa.DDL(
121                "CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
122                "FOR EACH ROW BEGIN "
123                "SET NEW.col3='up'; SET NEW.col4='up'; END"
124            ).execute_if(
125                callable_=lambda ddl, target, bind, **kw: bind.engine.name
126                not in ("oracle", "mssql", "sqlite", "postgresql")
127            ),
128        ):
129            my_func_up = sa.DDL(
130                "CREATE OR REPLACE FUNCTION my_func_up() "
131                "RETURNS TRIGGER AS $$ "
132                "BEGIN "
133                "NEW.col3 := 'up'; NEW.col4 := 'up'; "
134                "RETURN NEW; "
135                "END; $$ LANGUAGE PLPGSQL"
136            ).execute_if(dialect="postgresql")
137            event.listen(dt, "after_create", my_func_up)
138
139            event.listen(dt, "after_create", up)
140
141        if dialect_name == "postgresql":
142            event.listen(dt, "before_drop", sa.DDL("DROP TRIGGER dt_up ON dt"))
143        else:
144            event.listen(dt, "before_drop", sa.DDL("DROP TRIGGER dt_up"))
145
146    @classmethod
147    def setup_classes(cls):
148        class Default(cls.Comparable):
149            pass
150
151    @classmethod
152    def setup_mappers(cls):
153        Default, dt = cls.classes.Default, cls.tables.dt
154
155        mapper(Default, dt)
156
157    def test_insert(self):
158        Default = self.classes.Default
159
160        d1 = Default(id=1)
161
162        eq_(d1.col1, None)
163        eq_(d1.col2, None)
164        eq_(d1.col3, None)
165        eq_(d1.col4, None)
166
167        session = create_session()
168        session.add(d1)
169        session.flush()
170
171        eq_(d1.col1, None)
172        eq_(d1.col2, "ins")
173        eq_(d1.col3, None)
174        # don't care which trigger fired
175        assert d1.col4 in ("ins", "up")
176
177    def test_update(self):
178        Default = self.classes.Default
179
180        d1 = Default(id=1)
181
182        session = create_session()
183        session.add(d1)
184        session.flush()
185        d1.col1 = "set"
186        session.flush()
187
188        eq_(d1.col1, "set")
189        eq_(d1.col2, "ins")
190        eq_(d1.col3, "up")
191        eq_(d1.col4, "up")
192
193
194class ExcludedDefaultsTest(fixtures.MappedTest):
195    @classmethod
196    def define_tables(cls, metadata):
197        Table(
198            "dt",
199            metadata,
200            Column(
201                "id", Integer, primary_key=True, test_needs_autoincrement=True
202            ),
203            Column("col1", String(20), default="hello"),
204        )
205
206    def test_exclude(self):
207        dt = self.tables.dt
208
209        class Foo(fixtures.BasicEntity):
210            pass
211
212        mapper(Foo, dt, exclude_properties=("col1",))
213
214        f1 = Foo()
215        sess = create_session()
216        sess.add(f1)
217        sess.flush()
218        eq_(dt.select().execute().fetchall(), [(1, "hello")])
219
220
221class ComputedDefaultsOnUpdateTest(fixtures.MappedTest):
222    """test that computed columns are recognized as server
223    oninsert/onupdate defaults."""
224
225    __backend__ = True
226    __requires__ = ("computed_columns",)
227
228    @classmethod
229    def define_tables(cls, metadata):
230        Table(
231            "test",
232            metadata,
233            Column("id", Integer, primary_key=True),
234            Column("foo", Integer),
235            Column("bar", Integer, Computed("foo + 42")),
236        )
237
238    @classmethod
239    def setup_classes(cls):
240        class Thing(cls.Basic):
241            pass
242
243        class ThingNoEager(cls.Basic):
244            pass
245
246    @classmethod
247    def setup_mappers(cls):
248        Thing = cls.classes.Thing
249
250        mapper(Thing, cls.tables.test, eager_defaults=True)
251
252        ThingNoEager = cls.classes.ThingNoEager
253        mapper(ThingNoEager, cls.tables.test, eager_defaults=False)
254
255    @testing.combinations(("eager", True), ("noneager", False), id_="ia")
256    def test_insert_computed(self, eager):
257        if eager:
258            Thing = self.classes.Thing
259        else:
260            Thing = self.classes.ThingNoEager
261
262        s = Session()
263
264        t1, t2 = (Thing(id=1, foo=5), Thing(id=2, foo=10))
265
266        s.add_all([t1, t2])
267
268        with assert_engine(testing.db) as asserter:
269            s.flush()
270            eq_(t1.bar, 5 + 42)
271            eq_(t2.bar, 10 + 42)
272
273        if eager and testing.db.dialect.implicit_returning:
274            asserter.assert_(
275                CompiledSQL(
276                    "INSERT INTO test (id, foo) VALUES (%(id)s, %(foo)s) "
277                    "RETURNING test.bar",
278                    [{"foo": 5, "id": 1}],
279                    dialect="postgresql",
280                ),
281                CompiledSQL(
282                    "INSERT INTO test (id, foo) VALUES (%(id)s, %(foo)s) "
283                    "RETURNING test.bar",
284                    [{"foo": 10, "id": 2}],
285                    dialect="postgresql",
286                ),
287            )
288        else:
289            asserter.assert_(
290                CompiledSQL(
291                    "INSERT INTO test (id, foo) VALUES (:id, :foo)",
292                    [{"foo": 5, "id": 1}, {"foo": 10, "id": 2}],
293                ),
294                CompiledSQL(
295                    "SELECT test.bar AS test_bar FROM test "
296                    "WHERE test.id = :param_1",
297                    [{"param_1": 1}],
298                ),
299                CompiledSQL(
300                    "SELECT test.bar AS test_bar FROM test "
301                    "WHERE test.id = :param_1",
302                    [{"param_1": 2}],
303                ),
304            )
305
306    @testing.combinations(
307        (
308            "eagerload",
309            True,
310            testing.requires.computed_columns_on_update_returning,
311        ),
312        (
313            "noneagerload",
314            False,
315        ),
316        id_="ia",
317    )
318    def test_update_computed(self, eager):
319        if eager:
320            Thing = self.classes.Thing
321        else:
322            Thing = self.classes.ThingNoEager
323
324        s = Session()
325
326        t1, t2 = (Thing(id=1, foo=1), Thing(id=2, foo=2))
327
328        s.add_all([t1, t2])
329        s.flush()
330
331        t1.foo = 5
332        t2.foo = 6
333
334        with assert_engine(testing.db) as asserter:
335            s.flush()
336            eq_(t1.bar, 5 + 42)
337            eq_(t2.bar, 6 + 42)
338
339        if eager and testing.db.dialect.implicit_returning:
340            asserter.assert_(
341                CompiledSQL(
342                    "UPDATE test SET foo=%(foo)s "
343                    "WHERE test.id = %(test_id)s "
344                    "RETURNING test.bar",
345                    [{"foo": 5, "test_id": 1}],
346                    dialect="postgresql",
347                ),
348                CompiledSQL(
349                    "UPDATE test SET foo=%(foo)s "
350                    "WHERE test.id = %(test_id)s "
351                    "RETURNING test.bar",
352                    [{"foo": 6, "test_id": 2}],
353                    dialect="postgresql",
354                ),
355            )
356        elif eager:
357            asserter.assert_(
358                CompiledSQL(
359                    "UPDATE test SET foo=:foo WHERE test.id = :test_id",
360                    [{"foo": 5, "test_id": 1}],
361                ),
362                CompiledSQL(
363                    "UPDATE test SET foo=:foo WHERE test.id = :test_id",
364                    [{"foo": 6, "test_id": 2}],
365                ),
366                CompiledSQL(
367                    "SELECT test.bar AS test_bar FROM test "
368                    "WHERE test.id = :param_1",
369                    [{"param_1": 1}],
370                ),
371                CompiledSQL(
372                    "SELECT test.bar AS test_bar FROM test "
373                    "WHERE test.id = :param_1",
374                    [{"param_1": 2}],
375                ),
376            )
377        else:
378            asserter.assert_(
379                CompiledSQL(
380                    "UPDATE test SET foo=:foo WHERE test.id = :test_id",
381                    [{"foo": 5, "test_id": 1}, {"foo": 6, "test_id": 2}],
382                ),
383                CompiledSQL(
384                    "SELECT test.bar AS test_bar FROM test "
385                    "WHERE test.id = :param_1",
386                    [{"param_1": 1}],
387                ),
388                CompiledSQL(
389                    "SELECT test.bar AS test_bar FROM test "
390                    "WHERE test.id = :param_1",
391                    [{"param_1": 2}],
392                ),
393            )
394