1#!/usr/bin/env python3 2 3## 4# Licensed to the Apache Software Foundation (ASF) under one 5# or more contributor license agreements. See the NOTICE file 6# distributed with this work for additional information 7# regarding copyright ownership. The ASF licenses this file 8# to you under the Apache License, Version 2.0 (the 9# "License"); you may not use this file except in compliance 10# with the License. You may obtain a copy of the License at 11# 12# https://www.apache.org/licenses/LICENSE-2.0 13# 14# Unless required by applicable law or agreed to in writing, software 15# distributed under the License is distributed on an "AS IS" BASIS, 16# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17# See the License for the specific language governing permissions and 18# limitations under the License. 19 20import binascii 21import datetime 22import decimal 23import io 24import itertools 25import json 26import unittest 27import warnings 28from typing import BinaryIO, Collection, Dict, List, Optional, Tuple, Union, cast 29 30import avro.io 31import avro.schema 32import avro.timezones 33from avro.utils import TypedDict 34 35 36class DefaultValueTestCaseType(TypedDict): 37 H: object 38 39 40SCHEMAS_TO_VALIDATE = tuple( 41 (json.dumps(schema), datum) 42 for schema, datum in ( 43 ("null", None), 44 ("boolean", True), 45 ("string", "adsfasdf09809dsf-=adsf"), 46 ("bytes", b"12345abcd"), 47 ("int", 1234), 48 ("long", 1234), 49 ("float", 1234.0), 50 ("double", 1234.0), 51 ({"type": "fixed", "name": "Test", "size": 1}, b"B"), 52 ( 53 { 54 "type": "fixed", 55 "logicalType": "decimal", 56 "name": "Test", 57 "size": 8, 58 "precision": 5, 59 "scale": 4, 60 }, 61 decimal.Decimal("3.1415"), 62 ), 63 ( 64 { 65 "type": "fixed", 66 "logicalType": "decimal", 67 "name": "Test", 68 "size": 8, 69 "precision": 5, 70 "scale": 4, 71 }, 72 decimal.Decimal("-3.1415"), 73 ), 74 ( 75 {"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 4}, 76 decimal.Decimal("3.1415"), 77 ), 78 ( 79 {"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 4}, 80 decimal.Decimal("-3.1415"), 81 ), 82 ({"type": "enum", "name": "Test", "symbols": ["A", "B"]}, "B"), 83 ({"type": "array", "items": "long"}, [1, 3, 2]), 84 ({"type": "map", "values": "long"}, {"a": 1, "b": 3, "c": 2}), 85 (["string", "null", "long"], None), 86 ({"type": "int", "logicalType": "date"}, datetime.date(2000, 1, 1)), 87 ( 88 {"type": "int", "logicalType": "time-millis"}, 89 datetime.time(23, 59, 59, 999000), 90 ), 91 ({"type": "int", "logicalType": "time-millis"}, datetime.time(0, 0, 0, 000000)), 92 ( 93 {"type": "long", "logicalType": "time-micros"}, 94 datetime.time(23, 59, 59, 999999), 95 ), 96 ( 97 {"type": "long", "logicalType": "time-micros"}, 98 datetime.time(0, 0, 0, 000000), 99 ), 100 ( 101 {"type": "long", "logicalType": "timestamp-millis"}, 102 datetime.datetime(1000, 1, 1, 0, 0, 0, 000000, tzinfo=avro.timezones.utc), 103 ), 104 ( 105 {"type": "long", "logicalType": "timestamp-millis"}, 106 datetime.datetime(9999, 12, 31, 23, 59, 59, 999000, tzinfo=avro.timezones.utc), 107 ), 108 ( 109 {"type": "long", "logicalType": "timestamp-millis"}, 110 datetime.datetime(2000, 1, 18, 2, 2, 1, 100000, tzinfo=avro.timezones.tst), 111 ), 112 ( 113 {"type": "long", "logicalType": "timestamp-micros"}, 114 datetime.datetime(1000, 1, 1, 0, 0, 0, 000000, tzinfo=avro.timezones.utc), 115 ), 116 ( 117 {"type": "long", "logicalType": "timestamp-micros"}, 118 datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=avro.timezones.utc), 119 ), 120 ( 121 {"type": "long", "logicalType": "timestamp-micros"}, 122 datetime.datetime(2000, 1, 18, 2, 2, 1, 123499, tzinfo=avro.timezones.tst), 123 ), 124 ( 125 {"type": "string", "logicalType": "uuid"}, 126 "a4818e1c-8e59-11eb-8dcd-0242ac130003", 127 ), # UUID1 128 ( 129 {"type": "string", "logicalType": "uuid"}, 130 "570feebe-2bbc-4937-98df-285944e1dbbd", 131 ), # UUID4 132 ({"type": "string", "logicalType": "unknown-logical-type"}, "12345abcd"), 133 ({"type": "string", "logicalType": "timestamp-millis"}, "12345abcd"), 134 ( 135 { 136 "type": "record", 137 "name": "Test", 138 "fields": [{"name": "f", "type": "long"}], 139 }, 140 {"f": 5}, 141 ), 142 ( 143 { 144 "type": "record", 145 "name": "Lisp", 146 "fields": [ 147 { 148 "name": "value", 149 "type": [ 150 "null", 151 "string", 152 { 153 "type": "record", 154 "name": "Cons", 155 "fields": [ 156 {"name": "car", "type": "Lisp"}, 157 {"name": "cdr", "type": "Lisp"}, 158 ], 159 }, 160 ], 161 } 162 ], 163 }, 164 {"value": {"car": {"value": "head"}, "cdr": {"value": None}}}, 165 ), 166 ) 167) 168 169BINARY_ENCODINGS = ( 170 (0, b"00"), 171 (-1, b"01"), 172 (1, b"02"), 173 (-2, b"03"), 174 (2, b"04"), 175 (-64, b"7f"), 176 (64, b"80 01"), 177 (8192, b"80 80 01"), 178 (-8193, b"81 80 01"), 179) 180 181DEFAULT_VALUE_EXAMPLES = ( 182 ("null", None), 183 ("boolean", True), 184 ("string", "foo"), 185 ("bytes", "\xff\xff"), 186 ("int", 5), 187 ("long", 5), 188 ("float", 1.1), 189 ("double", 1.1), 190 ({"type": "fixed", "name": "F", "size": 2}, "\xff\xff"), 191 ({"type": "enum", "name": "F", "symbols": ["FOO", "BAR"]}, "FOO"), 192 ({"type": "array", "items": "int"}, [1, 2, 3]), 193 ({"type": "map", "values": "int"}, {"a": 1, "b": 2}), 194 (["int", "null"], 5), 195 ( 196 {"type": "record", "name": "F", "fields": [{"name": "A", "type": "int"}]}, 197 {"A": 5}, 198 ), 199) 200 201LONG_RECORD_SCHEMA = avro.schema.parse( 202 json.dumps( 203 { 204 "type": "record", 205 "name": "Test", 206 "fields": [ 207 {"name": "A", "type": "int"}, 208 {"name": "B", "type": "int"}, 209 {"name": "C", "type": "int"}, 210 {"name": "D", "type": "int"}, 211 {"name": "E", "type": "int"}, 212 {"name": "F", "type": "int"}, 213 {"name": "G", "type": "int"}, 214 ], 215 } 216 ) 217) 218 219LONG_RECORD_DATUM = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 5, "F": 6, "G": 7} 220 221 222def avro_hexlify(reader: BinaryIO) -> bytes: 223 """Return the hex value, as a string, of a binary-encoded int or long.""" 224 b = [] 225 current_byte = reader.read(1) 226 b.append(binascii.hexlify(current_byte)) 227 while (ord(current_byte) & 0x80) != 0: 228 current_byte = reader.read(1) 229 b.append(binascii.hexlify(current_byte)) 230 return b" ".join(b) 231 232 233def write_datum(datum: object, writers_schema: avro.schema.Schema) -> Tuple[io.BytesIO, avro.io.BinaryEncoder, avro.io.DatumWriter]: 234 writer = io.BytesIO() 235 encoder = avro.io.BinaryEncoder(writer) 236 datum_writer = avro.io.DatumWriter(writers_schema) 237 datum_writer.write(datum, encoder) 238 return writer, encoder, datum_writer 239 240 241def read_datum(buffer: io.BytesIO, writers_schema: avro.schema.Schema, readers_schema: Optional[avro.schema.Schema] = None) -> object: 242 reader = io.BytesIO(buffer.getvalue()) 243 decoder = avro.io.BinaryDecoder(reader) 244 datum_reader = avro.io.DatumReader(writers_schema, readers_schema) 245 return datum_reader.read(decoder) 246 247 248class IoValidateTestCase(unittest.TestCase): 249 def __init__(self, test_schema: str, test_datum: object) -> None: 250 """Ignore the normal signature for unittest.TestCase because we are generating 251 many test cases from this one class. This is safe as long as the autoloader 252 ignores this class. The autoloader will ignore this class as long as it has 253 no methods starting with `test_`. 254 """ 255 super().__init__("io_valid") 256 self.test_schema = avro.schema.parse(test_schema) 257 self.test_datum = test_datum 258 # Never hide repeated warnings when running this test case. 259 warnings.simplefilter("always") 260 261 def io_valid(self) -> None: 262 """ 263 In these cases, the provided data should be valid with the given schema. 264 """ 265 with warnings.catch_warnings(record=True) as actual_warnings: 266 self.assertTrue( 267 avro.io.validate(self.test_schema, self.test_datum), 268 f"{self.test_datum} did not validate in the schema {self.test_schema}", 269 ) 270 271 272class RoundTripTestCase(unittest.TestCase): 273 def __init__(self, test_schema: str, test_datum: object) -> None: 274 """Ignore the normal signature for unittest.TestCase because we are generating 275 many test cases from this one class. This is safe as long as the autoloader 276 ignores this class. The autoloader will ignore this class as long as it has 277 no methods starting with `test_`. 278 """ 279 super().__init__("io_round_trip") 280 self.test_schema = avro.schema.parse(test_schema) 281 self.test_datum = test_datum 282 # Never hide repeated warnings when running this test case. 283 warnings.simplefilter("always") 284 285 def io_round_trip(self) -> None: 286 """ 287 A datum should be the same after being encoded and then decoded. 288 """ 289 with warnings.catch_warnings(record=True) as actual_warnings: 290 writer, encoder, datum_writer = write_datum(self.test_datum, self.test_schema) 291 round_trip_datum = read_datum(writer, self.test_schema) 292 expected: object 293 round_trip: object 294 if isinstance(round_trip_datum, decimal.Decimal): 295 expected, round_trip, message = ( 296 str(self.test_datum), 297 round_trip_datum.to_eng_string(), 298 "Decimal datum changed value after encode and decode", 299 ) 300 elif isinstance(round_trip_datum, datetime.datetime): 301 expected, round_trip, message = ( 302 cast(datetime.datetime, self.test_datum).astimezone(tz=avro.timezones.utc), 303 round_trip_datum, 304 "DateTime datum changed value after encode and decode", 305 ) 306 else: 307 expected, round_trip, message = ( 308 self.test_datum, 309 round_trip_datum, 310 "Datum changed value after encode and decode", 311 ) 312 self.assertEqual(expected, round_trip, message) 313 314 315class BinaryEncodingTestCase(unittest.TestCase): 316 def __init__(self, skip: bool, test_type: str, test_datum: object, test_hex: bytes) -> None: 317 """Ignore the normal signature for unittest.TestCase because we are generating 318 many test cases from this one class. This is safe as long as the autoloader 319 ignores this class. The autoloader will ignore this class as long as it has 320 no methods starting with `test_`. 321 """ 322 super().__init__(f"check_{'skip' if skip else 'binary'}_encoding") 323 self.writers_schema = avro.schema.parse(f'"{test_type}"') 324 self.test_datum = test_datum 325 self.test_hex = test_hex 326 # Never hide repeated warnings when running this test case. 327 warnings.simplefilter("always") 328 329 def check_binary_encoding(self) -> None: 330 with warnings.catch_warnings(record=True) as actual_warnings: 331 writer, encoder, datum_writer = write_datum(self.test_datum, self.writers_schema) 332 writer.seek(0) 333 hex_val = avro_hexlify(writer) 334 self.assertEqual( 335 self.test_hex, 336 hex_val, 337 "Binary encoding did not match expected hex representation.", 338 ) 339 340 def check_skip_encoding(self) -> None: 341 VALUE_TO_READ = 6253 342 with warnings.catch_warnings(record=True) as actual_warnings: 343 # write the value to skip and a known value 344 writer, encoder, datum_writer = write_datum(self.test_datum, self.writers_schema) 345 datum_writer.write(VALUE_TO_READ, encoder) 346 347 # skip the value 348 reader = io.BytesIO(writer.getvalue()) 349 decoder = avro.io.BinaryDecoder(reader) 350 decoder.skip_long() 351 352 # read data from string buffer 353 datum_reader = avro.io.DatumReader(self.writers_schema) 354 read_value = datum_reader.read(decoder) 355 356 self.assertEqual( 357 read_value, 358 VALUE_TO_READ, 359 "Unexpected value after skipping a binary encoded value.", 360 ) 361 362 363class SchemaPromotionTestCase(unittest.TestCase): 364 def __init__(self, write_type: str, read_type: str) -> None: 365 """Ignore the normal signature for unittest.TestCase because we are generating 366 many test cases from this one class. This is safe as long as the autoloader 367 ignores this class. The autoloader will ignore this class as long as it has 368 no methods starting with `test_`. 369 """ 370 super().__init__("check_schema_promotion") 371 self.writers_schema = avro.schema.parse(f'"{write_type}"') 372 self.readers_schema = avro.schema.parse(f'"{read_type}"') 373 # Never hide repeated warnings when running this test case. 374 warnings.simplefilter("always") 375 376 def check_schema_promotion(self) -> None: 377 """Test schema promotion""" 378 # note that checking writers_schema.type in read_data 379 # allows us to handle promotion correctly 380 DATUM_TO_WRITE = 219 381 with warnings.catch_warnings(record=True) as actual_warnings: 382 writer, enc, dw = write_datum(DATUM_TO_WRITE, self.writers_schema) 383 datum_read = read_datum(writer, self.writers_schema, self.readers_schema) 384 self.assertEqual( 385 datum_read, 386 DATUM_TO_WRITE, 387 f"Datum changed between schema that were supposed to promote: writer: {self.writers_schema} reader: {self.readers_schema}.", 388 ) 389 390 391class DefaultValueTestCase(unittest.TestCase): 392 def __init__(self, field_type: Collection[str], default: Union[Dict[str, int], List[int], None, float, str]) -> None: 393 """Ignore the normal signature for unittest.TestCase because we are generating 394 many test cases from this one class. This is safe as long as the autoloader 395 ignores this class. The autoloader will ignore this class as long as it has 396 no methods starting with `test_`. 397 """ 398 super().__init__("check_default_value") 399 self.field_type = field_type 400 self.default = default 401 # Never hide repeated warnings when running this test case. 402 warnings.simplefilter("always") 403 404 def check_default_value(self) -> None: 405 datum_read: DefaultValueTestCaseType 406 with warnings.catch_warnings(record=True) as actual_warnings: 407 datum_to_read = cast(DefaultValueTestCaseType, {"H": self.default}) 408 readers_schema = avro.schema.parse( 409 json.dumps( 410 { 411 "type": "record", 412 "name": "Test", 413 "fields": [ 414 { 415 "name": "H", 416 "type": self.field_type, 417 "default": self.default, 418 } 419 ], 420 } 421 ) 422 ) 423 writer, _, _ = write_datum(LONG_RECORD_DATUM, LONG_RECORD_SCHEMA) 424 datum_read_ = cast(DefaultValueTestCaseType, read_datum(writer, LONG_RECORD_SCHEMA, readers_schema)) 425 datum_read = {"H": cast(bytes, datum_read_["H"]).decode()} if isinstance(datum_read_["H"], bytes) else datum_read_ 426 self.assertEqual(datum_to_read, datum_read) 427 428 429class TestMisc(unittest.TestCase): 430 def test_decimal_bytes_small_scale(self) -> None: 431 """Avro should raise an AvroTypeException when attempting to write a decimal with a larger exponent than the schema's scale.""" 432 datum = decimal.Decimal("3.1415") 433 _, _, exp = datum.as_tuple() 434 scale = -1 * exp - 1 435 schema = avro.schema.parse( 436 json.dumps( 437 { 438 "type": "bytes", 439 "logicalType": "decimal", 440 "precision": 5, 441 "scale": scale, 442 } 443 ) 444 ) 445 self.assertRaises(avro.errors.AvroOutOfScaleException, write_datum, datum, schema) 446 447 def test_decimal_fixed_small_scale(self) -> None: 448 """Avro should raise an AvroTypeException when attempting to write a decimal with a larger exponent than the schema's scale.""" 449 datum = decimal.Decimal("3.1415") 450 _, _, exp = datum.as_tuple() 451 scale = -1 * exp - 1 452 schema = avro.schema.parse( 453 json.dumps( 454 { 455 "type": "fixed", 456 "logicalType": "decimal", 457 "name": "Test", 458 "size": 8, 459 "precision": 5, 460 "scale": scale, 461 } 462 ) 463 ) 464 self.assertRaises(avro.errors.AvroOutOfScaleException, write_datum, datum, schema) 465 466 def test_unknown_symbol(self) -> None: 467 datum_to_write = "FOO" 468 writers_schema = avro.schema.parse(json.dumps({"type": "enum", "name": "Test", "symbols": ["FOO", "BAR"]})) 469 readers_schema = avro.schema.parse(json.dumps({"type": "enum", "name": "Test", "symbols": ["BAR", "BAZ"]})) 470 471 writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema) 472 reader = io.BytesIO(writer.getvalue()) 473 decoder = avro.io.BinaryDecoder(reader) 474 datum_reader = avro.io.DatumReader(writers_schema, readers_schema) 475 self.assertRaises(avro.errors.SchemaResolutionException, datum_reader.read, decoder) 476 477 def test_no_default_value(self) -> None: 478 writers_schema = LONG_RECORD_SCHEMA 479 datum_to_write = LONG_RECORD_DATUM 480 481 readers_schema = avro.schema.parse( 482 json.dumps( 483 { 484 "type": "record", 485 "name": "Test", 486 "fields": [{"name": "H", "type": "int"}], 487 } 488 ) 489 ) 490 491 writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema) 492 reader = io.BytesIO(writer.getvalue()) 493 decoder = avro.io.BinaryDecoder(reader) 494 datum_reader = avro.io.DatumReader(writers_schema, readers_schema) 495 self.assertRaises(avro.errors.SchemaResolutionException, datum_reader.read, decoder) 496 497 def test_projection(self) -> None: 498 writers_schema = LONG_RECORD_SCHEMA 499 datum_to_write = LONG_RECORD_DATUM 500 501 readers_schema = avro.schema.parse( 502 json.dumps( 503 { 504 "type": "record", 505 "name": "Test", 506 "fields": [ 507 {"name": "E", "type": "int"}, 508 {"name": "F", "type": "int"}, 509 ], 510 } 511 ) 512 ) 513 datum_to_read = {"E": 5, "F": 6} 514 515 writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema) 516 datum_read = read_datum(writer, writers_schema, readers_schema) 517 self.assertEqual(datum_to_read, datum_read) 518 519 def test_field_order(self) -> None: 520 writers_schema = LONG_RECORD_SCHEMA 521 datum_to_write = LONG_RECORD_DATUM 522 523 readers_schema = avro.schema.parse( 524 json.dumps( 525 { 526 "type": "record", 527 "name": "Test", 528 "fields": [ 529 {"name": "F", "type": "int"}, 530 {"name": "E", "type": "int"}, 531 ], 532 } 533 ) 534 ) 535 datum_to_read = {"E": 5, "F": 6} 536 537 writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema) 538 datum_read = read_datum(writer, writers_schema, readers_schema) 539 self.assertEqual(datum_to_read, datum_read) 540 541 def test_type_exception_int(self) -> None: 542 writers_schema = avro.schema.parse( 543 json.dumps( 544 { 545 "type": "record", 546 "name": "Test", 547 "fields": [ 548 {"name": "F", "type": "int"}, 549 {"name": "E", "type": "int"}, 550 ], 551 } 552 ) 553 ) 554 datum_to_write = {"E": 5, "F": "Bad"} 555 with self.assertRaises(avro.errors.AvroTypeException) as exc: 556 write_datum(datum_to_write, writers_schema) 557 assert str(exc.exception) == 'The datum "Bad" provided for "F" is not an example of the schema "int"' 558 559 def test_type_exception_long(self) -> None: 560 writers_schema = avro.schema.parse(json.dumps({"type": "record", "name": "Test", "fields": [{"name": "foo", "type": "long"}]})) 561 datum_to_write = {"foo": 5.0} 562 563 with self.assertRaises(avro.errors.AvroTypeException) as exc: 564 write_datum(datum_to_write, writers_schema) 565 assert str(exc.exception) == 'The datum "5.0" provided for "foo" is not an example of the schema "long"' 566 567 def test_type_exception_record(self) -> None: 568 writers_schema = avro.schema.parse(json.dumps({"type": "record", "name": "Test", "fields": [{"name": "foo", "type": "long"}]})) 569 datum_to_write = ("foo", 5.0) 570 571 with self.assertRaisesRegex(avro.errors.AvroTypeException, r"The datum \".*\" provided for \".*\" is not an example of the schema [\s\S]*"): 572 write_datum(datum_to_write, writers_schema) 573 574 575def load_tests(loader: unittest.TestLoader, default_tests: None, pattern: None) -> unittest.TestSuite: 576 """Generate test cases across many test schema.""" 577 suite = unittest.TestSuite() 578 suite.addTests(loader.loadTestsFromTestCase(TestMisc)) 579 suite.addTests(IoValidateTestCase(schema_str, datum) for schema_str, datum in SCHEMAS_TO_VALIDATE) 580 suite.addTests(RoundTripTestCase(schema_str, datum) for schema_str, datum in SCHEMAS_TO_VALIDATE) 581 for skip in False, True: 582 for type_ in "int", "long": 583 suite.addTests(BinaryEncodingTestCase(skip, type_, datum, hex_) for datum, hex_ in BINARY_ENCODINGS) 584 suite.addTests( 585 SchemaPromotionTestCase(write_type, read_type) for write_type, read_type in itertools.combinations(("int", "long", "float", "double"), 2) 586 ) 587 suite.addTests(DefaultValueTestCase(field_type, default) for field_type, default in DEFAULT_VALUE_EXAMPLES) 588 return suite 589 590 591if __name__ == "__main__": # pragma: no coverage 592 unittest.main() 593