1from sqlalchemy import Column
2from sqlalchemy import Integer
3from sqlalchemy import MetaData
4from sqlalchemy import String
5from sqlalchemy import Table
6from sqlalchemy import text
7from sqlalchemy.sql import column
8from sqlalchemy.sql import table
9from sqlalchemy.types import TypeEngine
10
11from alembic import op
12from alembic.migration import MigrationContext
13from alembic.testing import assert_raises_message
14from alembic.testing import config
15from alembic.testing import eq_
16from alembic.testing.fixtures import op_fixture
17from alembic.testing.fixtures import TestBase
18
19
20class BulkInsertTest(TestBase):
21    def _table_fixture(self, dialect, as_sql):
22        context = op_fixture(dialect, as_sql)
23        t1 = table(
24            "ins_table",
25            column("id", Integer),
26            column("v1", String()),
27            column("v2", String()),
28        )
29        return context, t1
30
31    def _big_t_table_fixture(self, dialect, as_sql):
32        context = op_fixture(dialect, as_sql)
33        t1 = Table(
34            "ins_table",
35            MetaData(),
36            Column("id", Integer, primary_key=True),
37            Column("v1", String()),
38            Column("v2", String()),
39        )
40        return context, t1
41
42    def _test_bulk_insert(self, dialect, as_sql):
43        context, t1 = self._table_fixture(dialect, as_sql)
44
45        op.bulk_insert(
46            t1,
47            [
48                {"id": 1, "v1": "row v1", "v2": "row v5"},
49                {"id": 2, "v1": "row v2", "v2": "row v6"},
50                {"id": 3, "v1": "row v3", "v2": "row v7"},
51                {"id": 4, "v1": "row v4", "v2": "row v8"},
52            ],
53        )
54        return context
55
56    def _test_bulk_insert_single(self, dialect, as_sql):
57        context, t1 = self._table_fixture(dialect, as_sql)
58
59        op.bulk_insert(t1, [{"id": 1, "v1": "row v1", "v2": "row v5"}])
60        return context
61
62    def _test_bulk_insert_single_bigt(self, dialect, as_sql):
63        context, t1 = self._big_t_table_fixture(dialect, as_sql)
64
65        op.bulk_insert(t1, [{"id": 1, "v1": "row v1", "v2": "row v5"}])
66        return context
67
68    def test_bulk_insert(self):
69        context = self._test_bulk_insert("default", False)
70        context.assert_(
71            "INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)"
72        )
73
74    def test_bulk_insert_wrong_cols(self):
75        context = op_fixture("postgresql")
76        t1 = table(
77            "ins_table",
78            column("id", Integer),
79            column("v1", String()),
80            column("v2", String()),
81        )
82        op.bulk_insert(t1, [{"v1": "row v1"}])
83        context.assert_(
84            "INSERT INTO ins_table (id, v1, v2) "
85            "VALUES (%(id)s, %(v1)s, %(v2)s)"
86        )
87
88    def test_bulk_insert_no_rows(self):
89        context, t1 = self._table_fixture("default", False)
90
91        op.bulk_insert(t1, [])
92        context.assert_()
93
94    def test_bulk_insert_pg(self):
95        context = self._test_bulk_insert("postgresql", False)
96        context.assert_(
97            "INSERT INTO ins_table (id, v1, v2) "
98            "VALUES (%(id)s, %(v1)s, %(v2)s)"
99        )
100
101    def test_bulk_insert_pg_single(self):
102        context = self._test_bulk_insert_single("postgresql", False)
103        context.assert_(
104            "INSERT INTO ins_table (id, v1, v2) "
105            "VALUES (%(id)s, %(v1)s, %(v2)s)"
106        )
107
108    def test_bulk_insert_pg_single_as_sql(self):
109        context = self._test_bulk_insert_single("postgresql", True)
110        context.assert_(
111            "INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')"
112        )
113
114    def test_bulk_insert_pg_single_big_t_as_sql(self):
115        context = self._test_bulk_insert_single_bigt("postgresql", True)
116        context.assert_(
117            "INSERT INTO ins_table (id, v1, v2) "
118            "VALUES (1, 'row v1', 'row v5')"
119        )
120
121    def test_bulk_insert_mssql(self):
122        context = self._test_bulk_insert("mssql", False)
123        context.assert_(
124            "INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)"
125        )
126
127    def test_bulk_insert_inline_literal_as_sql(self):
128        context = op_fixture("postgresql", True)
129
130        class MyType(TypeEngine):
131            pass
132
133        t1 = table("t", column("id", Integer), column("data", MyType()))
134
135        op.bulk_insert(
136            t1,
137            [
138                {"id": 1, "data": op.inline_literal("d1")},
139                {"id": 2, "data": op.inline_literal("d2")},
140            ],
141        )
142        context.assert_(
143            "INSERT INTO t (id, data) VALUES (1, 'd1')",
144            "INSERT INTO t (id, data) VALUES (2, 'd2')",
145        )
146
147    def test_bulk_insert_as_sql(self):
148        context = self._test_bulk_insert("default", True)
149        context.assert_(
150            "INSERT INTO ins_table (id, v1, v2) "
151            "VALUES (1, 'row v1', 'row v5')",
152            "INSERT INTO ins_table (id, v1, v2) "
153            "VALUES (2, 'row v2', 'row v6')",
154            "INSERT INTO ins_table (id, v1, v2) "
155            "VALUES (3, 'row v3', 'row v7')",
156            "INSERT INTO ins_table (id, v1, v2) "
157            "VALUES (4, 'row v4', 'row v8')",
158        )
159
160    def test_bulk_insert_as_sql_pg(self):
161        context = self._test_bulk_insert("postgresql", True)
162        context.assert_(
163            "INSERT INTO ins_table (id, v1, v2) "
164            "VALUES (1, 'row v1', 'row v5')",
165            "INSERT INTO ins_table (id, v1, v2) "
166            "VALUES (2, 'row v2', 'row v6')",
167            "INSERT INTO ins_table (id, v1, v2) "
168            "VALUES (3, 'row v3', 'row v7')",
169            "INSERT INTO ins_table (id, v1, v2) "
170            "VALUES (4, 'row v4', 'row v8')",
171        )
172
173    def test_bulk_insert_as_sql_mssql(self):
174        context = self._test_bulk_insert("mssql", True)
175        # SQL server requires IDENTITY_INSERT
176        # TODO: figure out if this is safe to enable for a table that
177        # doesn't have an IDENTITY column
178        context.assert_(
179            "SET IDENTITY_INSERT ins_table ON",
180            "GO",
181            "INSERT INTO ins_table (id, v1, v2) "
182            "VALUES (1, 'row v1', 'row v5')",
183            "GO",
184            "INSERT INTO ins_table (id, v1, v2) "
185            "VALUES (2, 'row v2', 'row v6')",
186            "GO",
187            "INSERT INTO ins_table (id, v1, v2) "
188            "VALUES (3, 'row v3', 'row v7')",
189            "GO",
190            "INSERT INTO ins_table (id, v1, v2) "
191            "VALUES (4, 'row v4', 'row v8')",
192            "GO",
193            "SET IDENTITY_INSERT ins_table OFF",
194            "GO",
195        )
196
197    def test_bulk_insert_from_new_table(self):
198        context = op_fixture("postgresql", True)
199        t1 = op.create_table(
200            "ins_table",
201            Column("id", Integer),
202            Column("v1", String()),
203            Column("v2", String()),
204        )
205        op.bulk_insert(
206            t1,
207            [
208                {"id": 1, "v1": "row v1", "v2": "row v5"},
209                {"id": 2, "v1": "row v2", "v2": "row v6"},
210            ],
211        )
212        context.assert_(
213            "CREATE TABLE ins_table (id INTEGER, v1 VARCHAR, v2 VARCHAR)",
214            "INSERT INTO ins_table (id, v1, v2) VALUES "
215            "(1, 'row v1', 'row v5')",
216            "INSERT INTO ins_table (id, v1, v2) VALUES "
217            "(2, 'row v2', 'row v6')",
218        )
219
220    def test_invalid_format(self):
221        context, t1 = self._table_fixture("sqlite", False)
222        assert_raises_message(
223            TypeError, "List expected", op.bulk_insert, t1, {"id": 5}
224        )
225
226        assert_raises_message(
227            TypeError,
228            "List of dictionaries expected",
229            op.bulk_insert,
230            t1,
231            [(5,)],
232        )
233
234
235class RoundTripTest(TestBase):
236    __only_on__ = "sqlite"
237
238    def setUp(self):
239        self.conn = config.db.connect()
240        with self.conn.begin():
241            self.conn.execute(
242                text(
243                    """
244                create table foo(
245                    id integer primary key,
246                    data varchar(50),
247                    x integer
248                )
249            """
250                )
251            )
252        context = MigrationContext.configure(self.conn)
253        self.op = op.Operations(context)
254        self.t1 = table("foo", column("id"), column("data"), column("x"))
255
256        self.trans = self.conn.begin()
257
258    def tearDown(self):
259        self.trans.rollback()
260        with self.conn.begin():
261            self.conn.execute(text("drop table foo"))
262        self.conn.close()
263
264    def test_single_insert_round_trip(self):
265        self.op.bulk_insert(self.t1, [{"data": "d1", "x": "x1"}])
266
267        eq_(
268            self.conn.execute(text("select id, data, x from foo")).fetchall(),
269            [(1, "d1", "x1")],
270        )
271
272    def test_bulk_insert_round_trip(self):
273        self.op.bulk_insert(
274            self.t1,
275            [
276                {"data": "d1", "x": "x1"},
277                {"data": "d2", "x": "x2"},
278                {"data": "d3", "x": "x3"},
279            ],
280        )
281
282        eq_(
283            self.conn.execute(text("select id, data, x from foo")).fetchall(),
284            [(1, "d1", "x1"), (2, "d2", "x2"), (3, "d3", "x3")],
285        )
286
287    def test_bulk_insert_inline_literal(self):
288        class MyType(TypeEngine):
289            pass
290
291        t1 = table("foo", column("id", Integer), column("data", MyType()))
292
293        self.op.bulk_insert(
294            t1,
295            [
296                {"id": 1, "data": self.op.inline_literal("d1")},
297                {"id": 2, "data": self.op.inline_literal("d2")},
298            ],
299            multiinsert=False,
300        )
301
302        eq_(
303            self.conn.execute(text("select id, data from foo")).fetchall(),
304            [(1, "d1"), (2, "d2")],
305        )
306
307    def test_bulk_insert_from_new_table(self):
308        t1 = self.op.create_table(
309            "ins_table",
310            Column("id", Integer),
311            Column("v1", String()),
312            Column("v2", String()),
313        )
314        self.op.bulk_insert(
315            t1,
316            [
317                {"id": 1, "v1": "row v1", "v2": "row v5"},
318                {"id": 2, "v1": "row v2", "v2": "row v6"},
319            ],
320        )
321        eq_(
322            self.conn.execute(
323                text("select id, v1, v2 from ins_table order by id")
324            ).fetchall(),
325            [(1, u"row v1", u"row v5"), (2, u"row v2", u"row v6")],
326        )
327