1#!/usr/bin/env python
2
3# test_copy.py - unit test for COPY support
4#
5# Copyright (C) 2010-2019 Daniele Varrazzo  <daniele.varrazzo@gmail.com>
6# Copyright (C) 2020-2021 The Psycopg Team
7#
8# psycopg2 is free software: you can redistribute it and/or modify it
9# under the terms of the GNU Lesser General Public License as published
10# by the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# In addition, as a special exception, the copyright holders give
14# permission to link this program with the OpenSSL library (or with
15# modified versions of OpenSSL that use the same license as OpenSSL),
16# and distribute linked combinations including the two.
17#
18# You must obey the GNU Lesser General Public License in all respects for
19# all of the code used other than OpenSSL.
20#
21# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
22# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
23# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
24# License for more details.
25
26import io
27import sys
28import string
29import unittest
30from .testutils import ConnectingTestCase, skip_before_postgres, slow, StringIO
31from .testutils import skip_if_crdb
32from itertools import cycle
33from subprocess import Popen, PIPE
34
35import psycopg2
36import psycopg2.extensions
37from .testutils import skip_copy_if_green, TextIOBase
38from .testconfig import dsn
39
40
41class MinimalRead(TextIOBase):
42    """A file wrapper exposing the minimal interface to copy from."""
43    def __init__(self, f):
44        self.f = f
45
46    def read(self, size):
47        return self.f.read(size)
48
49    def readline(self):
50        return self.f.readline()
51
52
53class MinimalWrite(TextIOBase):
54    """A file wrapper exposing the minimal interface to copy to."""
55    def __init__(self, f):
56        self.f = f
57
58    def write(self, data):
59        return self.f.write(data)
60
61
62@skip_copy_if_green
63class CopyTests(ConnectingTestCase):
64
65    def setUp(self):
66        ConnectingTestCase.setUp(self)
67        self._create_temp_table()
68
69    def _create_temp_table(self):
70        skip_if_crdb("copy", self.conn)
71        curs = self.conn.cursor()
72        curs.execute('''
73            CREATE TEMPORARY TABLE tcopy (
74              id serial PRIMARY KEY,
75              data text
76            )''')
77
78    @slow
79    def test_copy_from(self):
80        curs = self.conn.cursor()
81        try:
82            self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={})
83        finally:
84            curs.close()
85
86    @slow
87    def test_copy_from_insane_size(self):
88        # Trying to trigger a "would block" error
89        curs = self.conn.cursor()
90        try:
91            self._copy_from(curs, nrecs=10 * 1024, srec=10 * 1024,
92                copykw={'size': 20 * 1024 * 1024})
93        finally:
94            curs.close()
95
96    def test_copy_from_cols(self):
97        curs = self.conn.cursor()
98        f = StringIO()
99        for i in range(10):
100            f.write(f"{i}\n")
101
102        f.seek(0)
103        curs.copy_from(MinimalRead(f), "tcopy", columns=['id'])
104
105        curs.execute("select * from tcopy order by id")
106        self.assertEqual([(i, None) for i in range(10)], curs.fetchall())
107
108    def test_copy_from_cols_err(self):
109        curs = self.conn.cursor()
110        f = StringIO()
111        for i in range(10):
112            f.write(f"{i}\n")
113
114        f.seek(0)
115
116        def cols():
117            raise ZeroDivisionError()
118            yield 'id'
119
120        self.assertRaises(ZeroDivisionError,
121            curs.copy_from, MinimalRead(f), "tcopy", columns=cols())
122
123    @slow
124    def test_copy_to(self):
125        curs = self.conn.cursor()
126        try:
127            self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={})
128            self._copy_to(curs, srec=10 * 1024)
129        finally:
130            curs.close()
131
132    def test_copy_text(self):
133        self.conn.set_client_encoding('latin1')
134        self._create_temp_table()  # the above call closed the xn
135
136        abin = bytes(list(range(32, 127))
137            + list(range(160, 256))).decode('latin1')
138        about = abin.replace('\\', '\\\\')
139
140        curs = self.conn.cursor()
141        curs.execute('insert into tcopy values (%s, %s)',
142            (42, abin))
143
144        f = io.StringIO()
145        curs.copy_to(f, 'tcopy', columns=('data',))
146        f.seek(0)
147        self.assertEqual(f.readline().rstrip(), about)
148
149    def test_copy_bytes(self):
150        self.conn.set_client_encoding('latin1')
151        self._create_temp_table()  # the above call closed the xn
152
153        abin = bytes(list(range(32, 127))
154            + list(range(160, 255))).decode('latin1')
155        about = abin.replace('\\', '\\\\').encode('latin1')
156
157        curs = self.conn.cursor()
158        curs.execute('insert into tcopy values (%s, %s)',
159            (42, abin))
160
161        f = io.BytesIO()
162        curs.copy_to(f, 'tcopy', columns=('data',))
163        f.seek(0)
164        self.assertEqual(f.readline().rstrip(), about)
165
166    def test_copy_expert_textiobase(self):
167        self.conn.set_client_encoding('latin1')
168        self._create_temp_table()  # the above call closed the xn
169
170        abin = bytes(list(range(32, 127))
171            + list(range(160, 256))).decode('latin1')
172        about = abin.replace('\\', '\\\\')
173
174        f = io.StringIO()
175        f.write(about)
176        f.seek(0)
177
178        curs = self.conn.cursor()
179        psycopg2.extensions.register_type(
180            psycopg2.extensions.UNICODE, curs)
181
182        curs.copy_expert('COPY tcopy (data) FROM STDIN', f)
183        curs.execute("select data from tcopy;")
184        self.assertEqual(curs.fetchone()[0], abin)
185
186        f = io.StringIO()
187        curs.copy_expert('COPY tcopy (data) TO STDOUT', f)
188        f.seek(0)
189        self.assertEqual(f.readline().rstrip(), about)
190
191        # same tests with setting size
192        f = io.StringIO()
193        f.write(about)
194        f.seek(0)
195        exp_size = 123
196        # hack here to leave file as is, only check size when reading
197        real_read = f.read
198
199        def read(_size, f=f, exp_size=exp_size):
200            self.assertEqual(_size, exp_size)
201            return real_read(_size)
202
203        f.read = read
204        curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size)
205        curs.execute("select data from tcopy;")
206        self.assertEqual(curs.fetchone()[0], abin)
207
208    def _copy_from(self, curs, nrecs, srec, copykw):
209        f = StringIO()
210        for i, c in zip(range(nrecs), cycle(string.ascii_letters)):
211            l = c * srec
212            f.write(f"{i}\t{l}\n")
213
214        f.seek(0)
215        curs.copy_from(MinimalRead(f), "tcopy", **copykw)
216
217        curs.execute("select count(*) from tcopy")
218        self.assertEqual(nrecs, curs.fetchone()[0])
219
220        curs.execute("select data from tcopy where id < %s order by id",
221                (len(string.ascii_letters),))
222        for i, (l,) in enumerate(curs):
223            self.assertEqual(l, string.ascii_letters[i] * srec)
224
225    def _copy_to(self, curs, srec):
226        f = StringIO()
227        curs.copy_to(MinimalWrite(f), "tcopy")
228
229        f.seek(0)
230        ntests = 0
231        for line in f:
232            n, s = line.split()
233            if int(n) < len(string.ascii_letters):
234                self.assertEqual(s, string.ascii_letters[int(n)] * srec)
235                ntests += 1
236
237        self.assertEqual(ntests, len(string.ascii_letters))
238
239    def test_copy_expert_file_refcount(self):
240        class Whatever:
241            pass
242
243        f = Whatever()
244        curs = self.conn.cursor()
245        self.assertRaises(TypeError,
246            curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f)
247
248    def test_copy_no_column_limit(self):
249        cols = [f"c{i:050}" for i in range(200)]
250
251        curs = self.conn.cursor()
252        curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join(
253            ["%s int" % c for c in cols]))
254        curs.execute("INSERT INTO manycols DEFAULT VALUES")
255
256        f = StringIO()
257        curs.copy_to(f, "manycols", columns=cols)
258        f.seek(0)
259        self.assertEqual(f.read().split(), ['\\N'] * len(cols))
260
261        f.seek(0)
262        curs.copy_from(f, "manycols", columns=cols)
263        curs.execute("select count(*) from manycols;")
264        self.assertEqual(curs.fetchone()[0], 2)
265
266    def test_copy_funny_names(self):
267        cols = ["select", "insert", "group"]
268
269        curs = self.conn.cursor()
270        curs.execute('CREATE TEMPORARY TABLE "select" (%s)' % ',\n'.join(
271            ['"%s" int' % c for c in cols]))
272        curs.execute('INSERT INTO "select" DEFAULT VALUES')
273
274        f = StringIO()
275        curs.copy_to(f, "select", columns=cols)
276        f.seek(0)
277        self.assertEqual(f.read().split(), ['\\N'] * len(cols))
278
279        f.seek(0)
280        curs.copy_from(f, "select", columns=cols)
281        curs.execute('select count(*) from "select";')
282        self.assertEqual(curs.fetchone()[0], 2)
283
284    @skip_before_postgres(8, 2)     # they don't send the count
285    def test_copy_rowcount(self):
286        curs = self.conn.cursor()
287
288        curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data'])
289        self.assertEqual(curs.rowcount, 3)
290
291        curs.copy_expert(
292            "copy tcopy (data) from stdin",
293            StringIO('ddd\neee\n'))
294        self.assertEqual(curs.rowcount, 2)
295
296        curs.copy_to(StringIO(), "tcopy")
297        self.assertEqual(curs.rowcount, 5)
298
299        curs.execute("insert into tcopy (data) values ('fff')")
300        curs.copy_expert("copy tcopy to stdout", StringIO())
301        self.assertEqual(curs.rowcount, 6)
302
303    def test_copy_rowcount_error(self):
304        curs = self.conn.cursor()
305
306        curs.execute("insert into tcopy (data) values ('fff')")
307        self.assertEqual(curs.rowcount, 1)
308
309        self.assertRaises(psycopg2.DataError,
310            curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy')
311        self.assertEqual(curs.rowcount, -1)
312
313    def test_copy_query(self):
314        curs = self.conn.cursor()
315
316        curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data'])
317        self.assert_(b"copy " in curs.query.lower())
318        self.assert_(b" from stdin" in curs.query.lower())
319
320        curs.copy_expert(
321            "copy tcopy (data) from stdin",
322            StringIO('ddd\neee\n'))
323        self.assert_(b"copy " in curs.query.lower())
324        self.assert_(b" from stdin" in curs.query.lower())
325
326        curs.copy_to(StringIO(), "tcopy")
327        self.assert_(b"copy " in curs.query.lower())
328        self.assert_(b" to stdout" in curs.query.lower())
329
330        curs.execute("insert into tcopy (data) values ('fff')")
331        curs.copy_expert("copy tcopy to stdout", StringIO())
332        self.assert_(b"copy " in curs.query.lower())
333        self.assert_(b" to stdout" in curs.query.lower())
334
335    @slow
336    def test_copy_from_segfault(self):
337        # issue #219
338        script = f"""import psycopg2
339conn = psycopg2.connect({dsn!r})
340curs = conn.cursor()
341curs.execute("create table copy_segf (id int)")
342try:
343    curs.execute("copy copy_segf from stdin")
344except psycopg2.ProgrammingError:
345    pass
346conn.close()
347"""
348
349        proc = Popen([sys.executable, '-c', script])
350        proc.communicate()
351        self.assertEqual(0, proc.returncode)
352
353    @slow
354    def test_copy_to_segfault(self):
355        # issue #219
356        script = f"""import psycopg2
357conn = psycopg2.connect({dsn!r})
358curs = conn.cursor()
359curs.execute("create table copy_segf (id int)")
360try:
361    curs.execute("copy copy_segf to stdout")
362except psycopg2.ProgrammingError:
363    pass
364conn.close()
365"""
366
367        proc = Popen([sys.executable, '-c', script], stdout=PIPE)
368        proc.communicate()
369        self.assertEqual(0, proc.returncode)
370
371    def test_copy_from_propagate_error(self):
372        class BrokenRead(TextIOBase):
373            def read(self, size):
374                return 1 / 0
375
376            def readline(self):
377                return 1 / 0
378
379        curs = self.conn.cursor()
380        # It seems we cannot do this, but now at least we propagate the error
381        # self.assertRaises(ZeroDivisionError,
382        #     curs.copy_from, BrokenRead(), "tcopy")
383        try:
384            curs.copy_from(BrokenRead(), "tcopy")
385        except Exception as e:
386            self.assert_('ZeroDivisionError' in str(e))
387
388    def test_copy_to_propagate_error(self):
389        class BrokenWrite(TextIOBase):
390            def write(self, data):
391                return 1 / 0
392
393        curs = self.conn.cursor()
394        curs.execute("insert into tcopy values (10, 'hi')")
395        self.assertRaises(ZeroDivisionError,
396            curs.copy_to, BrokenWrite(), "tcopy")
397
398
399def test_suite():
400    return unittest.TestLoader().loadTestsFromName(__name__)
401
402
403if __name__ == "__main__":
404    unittest.main()
405