1# pysqlite2/test/userfunctions.py: tests for user-defined functions and 2# aggregates. 3# 4# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> 5# 6# This file is part of pysqlite. 7# 8# This software is provided 'as-is', without any express or implied 9# warranty. In no event will the authors be held liable for any damages 10# arising from the use of this software. 11# 12# Permission is granted to anyone to use this software for any purpose, 13# including commercial applications, and to alter it and redistribute it 14# freely, subject to the following restrictions: 15# 16# 1. The origin of this software must not be misrepresented; you must not 17# claim that you wrote the original software. If you use this software 18# in a product, an acknowledgment in the product documentation would be 19# appreciated but is not required. 20# 2. Altered source versions must be plainly marked as such, and must not be 21# misrepresented as being the original software. 22# 3. This notice may not be removed or altered from any source distribution. 23 24import unittest 25import unittest.mock 26import sqlite3 as sqlite 27 28def func_returntext(): 29 return "foo" 30def func_returnunicode(): 31 return "bar" 32def func_returnint(): 33 return 42 34def func_returnfloat(): 35 return 3.14 36def func_returnnull(): 37 return None 38def func_returnblob(): 39 return b"blob" 40def func_returnlonglong(): 41 return 1<<31 42def func_raiseexception(): 43 5/0 44 45def func_isstring(v): 46 return type(v) is str 47def func_isint(v): 48 return type(v) is int 49def func_isfloat(v): 50 return type(v) is float 51def func_isnone(v): 52 return type(v) is type(None) 53def func_isblob(v): 54 return isinstance(v, (bytes, memoryview)) 55def func_islonglong(v): 56 return isinstance(v, int) and v >= 1<<31 57 58def func(*args): 59 return len(args) 60 61class AggrNoStep: 62 def __init__(self): 63 pass 64 65 def finalize(self): 66 return 1 67 68class AggrNoFinalize: 69 def __init__(self): 70 pass 71 72 def step(self, x): 73 pass 74 75class AggrExceptionInInit: 76 def __init__(self): 77 5/0 78 79 def step(self, x): 80 pass 81 82 def finalize(self): 83 pass 84 85class AggrExceptionInStep: 86 def __init__(self): 87 pass 88 89 def step(self, x): 90 5/0 91 92 def finalize(self): 93 return 42 94 95class AggrExceptionInFinalize: 96 def __init__(self): 97 pass 98 99 def step(self, x): 100 pass 101 102 def finalize(self): 103 5/0 104 105class AggrCheckType: 106 def __init__(self): 107 self.val = None 108 109 def step(self, whichType, val): 110 theType = {"str": str, "int": int, "float": float, "None": type(None), 111 "blob": bytes} 112 self.val = int(theType[whichType] is type(val)) 113 114 def finalize(self): 115 return self.val 116 117class AggrCheckTypes: 118 def __init__(self): 119 self.val = 0 120 121 def step(self, whichType, *vals): 122 theType = {"str": str, "int": int, "float": float, "None": type(None), 123 "blob": bytes} 124 for val in vals: 125 self.val += int(theType[whichType] is type(val)) 126 127 def finalize(self): 128 return self.val 129 130class AggrSum: 131 def __init__(self): 132 self.val = 0.0 133 134 def step(self, val): 135 self.val += val 136 137 def finalize(self): 138 return self.val 139 140class FunctionTests(unittest.TestCase): 141 def setUp(self): 142 self.con = sqlite.connect(":memory:") 143 144 self.con.create_function("returntext", 0, func_returntext) 145 self.con.create_function("returnunicode", 0, func_returnunicode) 146 self.con.create_function("returnint", 0, func_returnint) 147 self.con.create_function("returnfloat", 0, func_returnfloat) 148 self.con.create_function("returnnull", 0, func_returnnull) 149 self.con.create_function("returnblob", 0, func_returnblob) 150 self.con.create_function("returnlonglong", 0, func_returnlonglong) 151 self.con.create_function("raiseexception", 0, func_raiseexception) 152 153 self.con.create_function("isstring", 1, func_isstring) 154 self.con.create_function("isint", 1, func_isint) 155 self.con.create_function("isfloat", 1, func_isfloat) 156 self.con.create_function("isnone", 1, func_isnone) 157 self.con.create_function("isblob", 1, func_isblob) 158 self.con.create_function("islonglong", 1, func_islonglong) 159 self.con.create_function("spam", -1, func) 160 self.con.execute("create table test(t text)") 161 162 def tearDown(self): 163 self.con.close() 164 165 def CheckFuncErrorOnCreate(self): 166 with self.assertRaises(sqlite.OperationalError): 167 self.con.create_function("bla", -100, lambda x: 2*x) 168 169 def CheckFuncRefCount(self): 170 def getfunc(): 171 def f(): 172 return 1 173 return f 174 f = getfunc() 175 globals()["foo"] = f 176 # self.con.create_function("reftest", 0, getfunc()) 177 self.con.create_function("reftest", 0, f) 178 cur = self.con.cursor() 179 cur.execute("select reftest()") 180 181 def CheckFuncReturnText(self): 182 cur = self.con.cursor() 183 cur.execute("select returntext()") 184 val = cur.fetchone()[0] 185 self.assertEqual(type(val), str) 186 self.assertEqual(val, "foo") 187 188 def CheckFuncReturnUnicode(self): 189 cur = self.con.cursor() 190 cur.execute("select returnunicode()") 191 val = cur.fetchone()[0] 192 self.assertEqual(type(val), str) 193 self.assertEqual(val, "bar") 194 195 def CheckFuncReturnInt(self): 196 cur = self.con.cursor() 197 cur.execute("select returnint()") 198 val = cur.fetchone()[0] 199 self.assertEqual(type(val), int) 200 self.assertEqual(val, 42) 201 202 def CheckFuncReturnFloat(self): 203 cur = self.con.cursor() 204 cur.execute("select returnfloat()") 205 val = cur.fetchone()[0] 206 self.assertEqual(type(val), float) 207 if val < 3.139 or val > 3.141: 208 self.fail("wrong value") 209 210 def CheckFuncReturnNull(self): 211 cur = self.con.cursor() 212 cur.execute("select returnnull()") 213 val = cur.fetchone()[0] 214 self.assertEqual(type(val), type(None)) 215 self.assertEqual(val, None) 216 217 def CheckFuncReturnBlob(self): 218 cur = self.con.cursor() 219 cur.execute("select returnblob()") 220 val = cur.fetchone()[0] 221 self.assertEqual(type(val), bytes) 222 self.assertEqual(val, b"blob") 223 224 def CheckFuncReturnLongLong(self): 225 cur = self.con.cursor() 226 cur.execute("select returnlonglong()") 227 val = cur.fetchone()[0] 228 self.assertEqual(val, 1<<31) 229 230 def CheckFuncException(self): 231 cur = self.con.cursor() 232 with self.assertRaises(sqlite.OperationalError) as cm: 233 cur.execute("select raiseexception()") 234 cur.fetchone() 235 self.assertEqual(str(cm.exception), 'user-defined function raised exception') 236 237 def CheckParamString(self): 238 cur = self.con.cursor() 239 cur.execute("select isstring(?)", ("foo",)) 240 val = cur.fetchone()[0] 241 self.assertEqual(val, 1) 242 243 def CheckParamInt(self): 244 cur = self.con.cursor() 245 cur.execute("select isint(?)", (42,)) 246 val = cur.fetchone()[0] 247 self.assertEqual(val, 1) 248 249 def CheckParamFloat(self): 250 cur = self.con.cursor() 251 cur.execute("select isfloat(?)", (3.14,)) 252 val = cur.fetchone()[0] 253 self.assertEqual(val, 1) 254 255 def CheckParamNone(self): 256 cur = self.con.cursor() 257 cur.execute("select isnone(?)", (None,)) 258 val = cur.fetchone()[0] 259 self.assertEqual(val, 1) 260 261 def CheckParamBlob(self): 262 cur = self.con.cursor() 263 cur.execute("select isblob(?)", (memoryview(b"blob"),)) 264 val = cur.fetchone()[0] 265 self.assertEqual(val, 1) 266 267 def CheckParamLongLong(self): 268 cur = self.con.cursor() 269 cur.execute("select islonglong(?)", (1<<42,)) 270 val = cur.fetchone()[0] 271 self.assertEqual(val, 1) 272 273 def CheckAnyArguments(self): 274 cur = self.con.cursor() 275 cur.execute("select spam(?, ?)", (1, 2)) 276 val = cur.fetchone()[0] 277 self.assertEqual(val, 2) 278 279 # Regarding deterministic functions: 280 # 281 # Between 3.8.3 and 3.15.0, deterministic functions were only used to 282 # optimize inner loops, so for those versions we can only test if the 283 # sqlite machinery has factored out a call or not. From 3.15.0 and onward, 284 # deterministic functions were permitted in WHERE clauses of partial 285 # indices, which allows testing based on syntax, iso. the query optimizer. 286 @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") 287 def CheckFuncNonDeterministic(self): 288 mock = unittest.mock.Mock(return_value=None) 289 self.con.create_function("nondeterministic", 0, mock, deterministic=False) 290 if sqlite.sqlite_version_info < (3, 15, 0): 291 self.con.execute("select nondeterministic() = nondeterministic()") 292 self.assertEqual(mock.call_count, 2) 293 else: 294 with self.assertRaises(sqlite.OperationalError): 295 self.con.execute("create index t on test(t) where nondeterministic() is not null") 296 297 @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") 298 def CheckFuncDeterministic(self): 299 mock = unittest.mock.Mock(return_value=None) 300 self.con.create_function("deterministic", 0, mock, deterministic=True) 301 if sqlite.sqlite_version_info < (3, 15, 0): 302 self.con.execute("select deterministic() = deterministic()") 303 self.assertEqual(mock.call_count, 1) 304 else: 305 try: 306 self.con.execute("create index t on test(t) where deterministic() is not null") 307 except sqlite.OperationalError: 308 self.fail("Unexpected failure while creating partial index") 309 310 @unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed") 311 def CheckFuncDeterministicNotSupported(self): 312 with self.assertRaises(sqlite.NotSupportedError): 313 self.con.create_function("deterministic", 0, int, deterministic=True) 314 315 def CheckFuncDeterministicKeywordOnly(self): 316 with self.assertRaises(TypeError): 317 self.con.create_function("deterministic", 0, int, True) 318 319 320class AggregateTests(unittest.TestCase): 321 def setUp(self): 322 self.con = sqlite.connect(":memory:") 323 cur = self.con.cursor() 324 cur.execute(""" 325 create table test( 326 t text, 327 i integer, 328 f float, 329 n, 330 b blob 331 ) 332 """) 333 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", 334 ("foo", 5, 3.14, None, memoryview(b"blob"),)) 335 336 self.con.create_aggregate("nostep", 1, AggrNoStep) 337 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) 338 self.con.create_aggregate("excInit", 1, AggrExceptionInInit) 339 self.con.create_aggregate("excStep", 1, AggrExceptionInStep) 340 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) 341 self.con.create_aggregate("checkType", 2, AggrCheckType) 342 self.con.create_aggregate("checkTypes", -1, AggrCheckTypes) 343 self.con.create_aggregate("mysum", 1, AggrSum) 344 345 def tearDown(self): 346 #self.cur.close() 347 #self.con.close() 348 pass 349 350 def CheckAggrErrorOnCreate(self): 351 with self.assertRaises(sqlite.OperationalError): 352 self.con.create_function("bla", -100, AggrSum) 353 354 def CheckAggrNoStep(self): 355 cur = self.con.cursor() 356 with self.assertRaises(AttributeError) as cm: 357 cur.execute("select nostep(t) from test") 358 self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'") 359 360 def CheckAggrNoFinalize(self): 361 cur = self.con.cursor() 362 with self.assertRaises(sqlite.OperationalError) as cm: 363 cur.execute("select nofinalize(t) from test") 364 val = cur.fetchone()[0] 365 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") 366 367 def CheckAggrExceptionInInit(self): 368 cur = self.con.cursor() 369 with self.assertRaises(sqlite.OperationalError) as cm: 370 cur.execute("select excInit(t) from test") 371 val = cur.fetchone()[0] 372 self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") 373 374 def CheckAggrExceptionInStep(self): 375 cur = self.con.cursor() 376 with self.assertRaises(sqlite.OperationalError) as cm: 377 cur.execute("select excStep(t) from test") 378 val = cur.fetchone()[0] 379 self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") 380 381 def CheckAggrExceptionInFinalize(self): 382 cur = self.con.cursor() 383 with self.assertRaises(sqlite.OperationalError) as cm: 384 cur.execute("select excFinalize(t) from test") 385 val = cur.fetchone()[0] 386 self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") 387 388 def CheckAggrCheckParamStr(self): 389 cur = self.con.cursor() 390 cur.execute("select checkType('str', ?)", ("foo",)) 391 val = cur.fetchone()[0] 392 self.assertEqual(val, 1) 393 394 def CheckAggrCheckParamInt(self): 395 cur = self.con.cursor() 396 cur.execute("select checkType('int', ?)", (42,)) 397 val = cur.fetchone()[0] 398 self.assertEqual(val, 1) 399 400 def CheckAggrCheckParamsInt(self): 401 cur = self.con.cursor() 402 cur.execute("select checkTypes('int', ?, ?)", (42, 24)) 403 val = cur.fetchone()[0] 404 self.assertEqual(val, 2) 405 406 def CheckAggrCheckParamFloat(self): 407 cur = self.con.cursor() 408 cur.execute("select checkType('float', ?)", (3.14,)) 409 val = cur.fetchone()[0] 410 self.assertEqual(val, 1) 411 412 def CheckAggrCheckParamNone(self): 413 cur = self.con.cursor() 414 cur.execute("select checkType('None', ?)", (None,)) 415 val = cur.fetchone()[0] 416 self.assertEqual(val, 1) 417 418 def CheckAggrCheckParamBlob(self): 419 cur = self.con.cursor() 420 cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),)) 421 val = cur.fetchone()[0] 422 self.assertEqual(val, 1) 423 424 def CheckAggrCheckAggrSum(self): 425 cur = self.con.cursor() 426 cur.execute("delete from test") 427 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) 428 cur.execute("select mysum(i) from test") 429 val = cur.fetchone()[0] 430 self.assertEqual(val, 60) 431 432class AuthorizerTests(unittest.TestCase): 433 @staticmethod 434 def authorizer_cb(action, arg1, arg2, dbname, source): 435 if action != sqlite.SQLITE_SELECT: 436 return sqlite.SQLITE_DENY 437 if arg2 == 'c2' or arg1 == 't2': 438 return sqlite.SQLITE_DENY 439 return sqlite.SQLITE_OK 440 441 def setUp(self): 442 self.con = sqlite.connect(":memory:") 443 self.con.executescript(""" 444 create table t1 (c1, c2); 445 create table t2 (c1, c2); 446 insert into t1 (c1, c2) values (1, 2); 447 insert into t2 (c1, c2) values (4, 5); 448 """) 449 450 # For our security test: 451 self.con.execute("select c2 from t2") 452 453 self.con.set_authorizer(self.authorizer_cb) 454 455 def tearDown(self): 456 pass 457 458 def test_table_access(self): 459 with self.assertRaises(sqlite.DatabaseError) as cm: 460 self.con.execute("select * from t2") 461 self.assertIn('prohibited', str(cm.exception)) 462 463 def test_column_access(self): 464 with self.assertRaises(sqlite.DatabaseError) as cm: 465 self.con.execute("select c2 from t1") 466 self.assertIn('prohibited', str(cm.exception)) 467 468class AuthorizerRaiseExceptionTests(AuthorizerTests): 469 @staticmethod 470 def authorizer_cb(action, arg1, arg2, dbname, source): 471 if action != sqlite.SQLITE_SELECT: 472 raise ValueError 473 if arg2 == 'c2' or arg1 == 't2': 474 raise ValueError 475 return sqlite.SQLITE_OK 476 477class AuthorizerIllegalTypeTests(AuthorizerTests): 478 @staticmethod 479 def authorizer_cb(action, arg1, arg2, dbname, source): 480 if action != sqlite.SQLITE_SELECT: 481 return 0.0 482 if arg2 == 'c2' or arg1 == 't2': 483 return 0.0 484 return sqlite.SQLITE_OK 485 486class AuthorizerLargeIntegerTests(AuthorizerTests): 487 @staticmethod 488 def authorizer_cb(action, arg1, arg2, dbname, source): 489 if action != sqlite.SQLITE_SELECT: 490 return 2**32 491 if arg2 == 'c2' or arg1 == 't2': 492 return 2**32 493 return sqlite.SQLITE_OK 494 495 496def suite(): 497 function_suite = unittest.makeSuite(FunctionTests, "Check") 498 aggregate_suite = unittest.makeSuite(AggregateTests, "Check") 499 authorizer_suite = unittest.makeSuite(AuthorizerTests) 500 return unittest.TestSuite(( 501 function_suite, 502 aggregate_suite, 503 authorizer_suite, 504 unittest.makeSuite(AuthorizerRaiseExceptionTests), 505 unittest.makeSuite(AuthorizerIllegalTypeTests), 506 unittest.makeSuite(AuthorizerLargeIntegerTests), 507 )) 508 509def test(): 510 runner = unittest.TextTestRunner() 511 runner.run(suite()) 512 513if __name__ == "__main__": 514 test() 515