1#!/usr/bin/env python 2 3# test_copy.py - unit test for COPY support 4# 5# Copyright (C) 2010-2019 Daniele Varrazzo <daniele.varrazzo@gmail.com> 6# Copyright (C) 2020-2021 The Psycopg Team 7# 8# psycopg2 is free software: you can redistribute it and/or modify it 9# under the terms of the GNU Lesser General Public License as published 10# by the Free Software Foundation, either version 3 of the License, or 11# (at your option) any later version. 12# 13# In addition, as a special exception, the copyright holders give 14# permission to link this program with the OpenSSL library (or with 15# modified versions of OpenSSL that use the same license as OpenSSL), 16# and distribute linked combinations including the two. 17# 18# You must obey the GNU Lesser General Public License in all respects for 19# all of the code used other than OpenSSL. 20# 21# psycopg2 is distributed in the hope that it will be useful, but WITHOUT 22# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 23# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public 24# License for more details. 25 26import io 27import sys 28import string 29import unittest 30from .testutils import ConnectingTestCase, skip_before_postgres, slow, StringIO 31from .testutils import skip_if_crdb 32from itertools import cycle 33from subprocess import Popen, PIPE 34 35import psycopg2 36import psycopg2.extensions 37from .testutils import skip_copy_if_green, TextIOBase 38from .testconfig import dsn 39 40 41class MinimalRead(TextIOBase): 42 """A file wrapper exposing the minimal interface to copy from.""" 43 def __init__(self, f): 44 self.f = f 45 46 def read(self, size): 47 return self.f.read(size) 48 49 def readline(self): 50 return self.f.readline() 51 52 53class MinimalWrite(TextIOBase): 54 """A file wrapper exposing the minimal interface to copy to.""" 55 def __init__(self, f): 56 self.f = f 57 58 def write(self, data): 59 return self.f.write(data) 60 61 62@skip_copy_if_green 63class CopyTests(ConnectingTestCase): 64 65 def setUp(self): 66 ConnectingTestCase.setUp(self) 67 self._create_temp_table() 68 69 def _create_temp_table(self): 70 skip_if_crdb("copy", self.conn) 71 curs = self.conn.cursor() 72 curs.execute(''' 73 CREATE TEMPORARY TABLE tcopy ( 74 id serial PRIMARY KEY, 75 data text 76 )''') 77 78 @slow 79 def test_copy_from(self): 80 curs = self.conn.cursor() 81 try: 82 self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={}) 83 finally: 84 curs.close() 85 86 @slow 87 def test_copy_from_insane_size(self): 88 # Trying to trigger a "would block" error 89 curs = self.conn.cursor() 90 try: 91 self._copy_from(curs, nrecs=10 * 1024, srec=10 * 1024, 92 copykw={'size': 20 * 1024 * 1024}) 93 finally: 94 curs.close() 95 96 def test_copy_from_cols(self): 97 curs = self.conn.cursor() 98 f = StringIO() 99 for i in range(10): 100 f.write(f"{i}\n") 101 102 f.seek(0) 103 curs.copy_from(MinimalRead(f), "tcopy", columns=['id']) 104 105 curs.execute("select * from tcopy order by id") 106 self.assertEqual([(i, None) for i in range(10)], curs.fetchall()) 107 108 def test_copy_from_cols_err(self): 109 curs = self.conn.cursor() 110 f = StringIO() 111 for i in range(10): 112 f.write(f"{i}\n") 113 114 f.seek(0) 115 116 def cols(): 117 raise ZeroDivisionError() 118 yield 'id' 119 120 self.assertRaises(ZeroDivisionError, 121 curs.copy_from, MinimalRead(f), "tcopy", columns=cols()) 122 123 @slow 124 def test_copy_to(self): 125 curs = self.conn.cursor() 126 try: 127 self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={}) 128 self._copy_to(curs, srec=10 * 1024) 129 finally: 130 curs.close() 131 132 def test_copy_text(self): 133 self.conn.set_client_encoding('latin1') 134 self._create_temp_table() # the above call closed the xn 135 136 abin = bytes(list(range(32, 127)) 137 + list(range(160, 256))).decode('latin1') 138 about = abin.replace('\\', '\\\\') 139 140 curs = self.conn.cursor() 141 curs.execute('insert into tcopy values (%s, %s)', 142 (42, abin)) 143 144 f = io.StringIO() 145 curs.copy_to(f, 'tcopy', columns=('data',)) 146 f.seek(0) 147 self.assertEqual(f.readline().rstrip(), about) 148 149 def test_copy_bytes(self): 150 self.conn.set_client_encoding('latin1') 151 self._create_temp_table() # the above call closed the xn 152 153 abin = bytes(list(range(32, 127)) 154 + list(range(160, 255))).decode('latin1') 155 about = abin.replace('\\', '\\\\').encode('latin1') 156 157 curs = self.conn.cursor() 158 curs.execute('insert into tcopy values (%s, %s)', 159 (42, abin)) 160 161 f = io.BytesIO() 162 curs.copy_to(f, 'tcopy', columns=('data',)) 163 f.seek(0) 164 self.assertEqual(f.readline().rstrip(), about) 165 166 def test_copy_expert_textiobase(self): 167 self.conn.set_client_encoding('latin1') 168 self._create_temp_table() # the above call closed the xn 169 170 abin = bytes(list(range(32, 127)) 171 + list(range(160, 256))).decode('latin1') 172 about = abin.replace('\\', '\\\\') 173 174 f = io.StringIO() 175 f.write(about) 176 f.seek(0) 177 178 curs = self.conn.cursor() 179 psycopg2.extensions.register_type( 180 psycopg2.extensions.UNICODE, curs) 181 182 curs.copy_expert('COPY tcopy (data) FROM STDIN', f) 183 curs.execute("select data from tcopy;") 184 self.assertEqual(curs.fetchone()[0], abin) 185 186 f = io.StringIO() 187 curs.copy_expert('COPY tcopy (data) TO STDOUT', f) 188 f.seek(0) 189 self.assertEqual(f.readline().rstrip(), about) 190 191 # same tests with setting size 192 f = io.StringIO() 193 f.write(about) 194 f.seek(0) 195 exp_size = 123 196 # hack here to leave file as is, only check size when reading 197 real_read = f.read 198 199 def read(_size, f=f, exp_size=exp_size): 200 self.assertEqual(_size, exp_size) 201 return real_read(_size) 202 203 f.read = read 204 curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size) 205 curs.execute("select data from tcopy;") 206 self.assertEqual(curs.fetchone()[0], abin) 207 208 def _copy_from(self, curs, nrecs, srec, copykw): 209 f = StringIO() 210 for i, c in zip(range(nrecs), cycle(string.ascii_letters)): 211 l = c * srec 212 f.write(f"{i}\t{l}\n") 213 214 f.seek(0) 215 curs.copy_from(MinimalRead(f), "tcopy", **copykw) 216 217 curs.execute("select count(*) from tcopy") 218 self.assertEqual(nrecs, curs.fetchone()[0]) 219 220 curs.execute("select data from tcopy where id < %s order by id", 221 (len(string.ascii_letters),)) 222 for i, (l,) in enumerate(curs): 223 self.assertEqual(l, string.ascii_letters[i] * srec) 224 225 def _copy_to(self, curs, srec): 226 f = StringIO() 227 curs.copy_to(MinimalWrite(f), "tcopy") 228 229 f.seek(0) 230 ntests = 0 231 for line in f: 232 n, s = line.split() 233 if int(n) < len(string.ascii_letters): 234 self.assertEqual(s, string.ascii_letters[int(n)] * srec) 235 ntests += 1 236 237 self.assertEqual(ntests, len(string.ascii_letters)) 238 239 def test_copy_expert_file_refcount(self): 240 class Whatever: 241 pass 242 243 f = Whatever() 244 curs = self.conn.cursor() 245 self.assertRaises(TypeError, 246 curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f) 247 248 def test_copy_no_column_limit(self): 249 cols = [f"c{i:050}" for i in range(200)] 250 251 curs = self.conn.cursor() 252 curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join( 253 ["%s int" % c for c in cols])) 254 curs.execute("INSERT INTO manycols DEFAULT VALUES") 255 256 f = StringIO() 257 curs.copy_to(f, "manycols", columns=cols) 258 f.seek(0) 259 self.assertEqual(f.read().split(), ['\\N'] * len(cols)) 260 261 f.seek(0) 262 curs.copy_from(f, "manycols", columns=cols) 263 curs.execute("select count(*) from manycols;") 264 self.assertEqual(curs.fetchone()[0], 2) 265 266 def test_copy_funny_names(self): 267 cols = ["select", "insert", "group"] 268 269 curs = self.conn.cursor() 270 curs.execute('CREATE TEMPORARY TABLE "select" (%s)' % ',\n'.join( 271 ['"%s" int' % c for c in cols])) 272 curs.execute('INSERT INTO "select" DEFAULT VALUES') 273 274 f = StringIO() 275 curs.copy_to(f, "select", columns=cols) 276 f.seek(0) 277 self.assertEqual(f.read().split(), ['\\N'] * len(cols)) 278 279 f.seek(0) 280 curs.copy_from(f, "select", columns=cols) 281 curs.execute('select count(*) from "select";') 282 self.assertEqual(curs.fetchone()[0], 2) 283 284 @skip_before_postgres(8, 2) # they don't send the count 285 def test_copy_rowcount(self): 286 curs = self.conn.cursor() 287 288 curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data']) 289 self.assertEqual(curs.rowcount, 3) 290 291 curs.copy_expert( 292 "copy tcopy (data) from stdin", 293 StringIO('ddd\neee\n')) 294 self.assertEqual(curs.rowcount, 2) 295 296 curs.copy_to(StringIO(), "tcopy") 297 self.assertEqual(curs.rowcount, 5) 298 299 curs.execute("insert into tcopy (data) values ('fff')") 300 curs.copy_expert("copy tcopy to stdout", StringIO()) 301 self.assertEqual(curs.rowcount, 6) 302 303 def test_copy_rowcount_error(self): 304 curs = self.conn.cursor() 305 306 curs.execute("insert into tcopy (data) values ('fff')") 307 self.assertEqual(curs.rowcount, 1) 308 309 self.assertRaises(psycopg2.DataError, 310 curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy') 311 self.assertEqual(curs.rowcount, -1) 312 313 def test_copy_query(self): 314 curs = self.conn.cursor() 315 316 curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data']) 317 self.assert_(b"copy " in curs.query.lower()) 318 self.assert_(b" from stdin" in curs.query.lower()) 319 320 curs.copy_expert( 321 "copy tcopy (data) from stdin", 322 StringIO('ddd\neee\n')) 323 self.assert_(b"copy " in curs.query.lower()) 324 self.assert_(b" from stdin" in curs.query.lower()) 325 326 curs.copy_to(StringIO(), "tcopy") 327 self.assert_(b"copy " in curs.query.lower()) 328 self.assert_(b" to stdout" in curs.query.lower()) 329 330 curs.execute("insert into tcopy (data) values ('fff')") 331 curs.copy_expert("copy tcopy to stdout", StringIO()) 332 self.assert_(b"copy " in curs.query.lower()) 333 self.assert_(b" to stdout" in curs.query.lower()) 334 335 @slow 336 def test_copy_from_segfault(self): 337 # issue #219 338 script = f"""import psycopg2 339conn = psycopg2.connect({dsn!r}) 340curs = conn.cursor() 341curs.execute("create table copy_segf (id int)") 342try: 343 curs.execute("copy copy_segf from stdin") 344except psycopg2.ProgrammingError: 345 pass 346conn.close() 347""" 348 349 proc = Popen([sys.executable, '-c', script]) 350 proc.communicate() 351 self.assertEqual(0, proc.returncode) 352 353 @slow 354 def test_copy_to_segfault(self): 355 # issue #219 356 script = f"""import psycopg2 357conn = psycopg2.connect({dsn!r}) 358curs = conn.cursor() 359curs.execute("create table copy_segf (id int)") 360try: 361 curs.execute("copy copy_segf to stdout") 362except psycopg2.ProgrammingError: 363 pass 364conn.close() 365""" 366 367 proc = Popen([sys.executable, '-c', script], stdout=PIPE) 368 proc.communicate() 369 self.assertEqual(0, proc.returncode) 370 371 def test_copy_from_propagate_error(self): 372 class BrokenRead(TextIOBase): 373 def read(self, size): 374 return 1 / 0 375 376 def readline(self): 377 return 1 / 0 378 379 curs = self.conn.cursor() 380 # It seems we cannot do this, but now at least we propagate the error 381 # self.assertRaises(ZeroDivisionError, 382 # curs.copy_from, BrokenRead(), "tcopy") 383 try: 384 curs.copy_from(BrokenRead(), "tcopy") 385 except Exception as e: 386 self.assert_('ZeroDivisionError' in str(e)) 387 388 def test_copy_to_propagate_error(self): 389 class BrokenWrite(TextIOBase): 390 def write(self, data): 391 return 1 / 0 392 393 curs = self.conn.cursor() 394 curs.execute("insert into tcopy values (10, 'hi')") 395 self.assertRaises(ZeroDivisionError, 396 curs.copy_to, BrokenWrite(), "tcopy") 397 398 399def test_suite(): 400 return unittest.TestLoader().loadTestsFromName(__name__) 401 402 403if __name__ == "__main__": 404 unittest.main() 405