1#!/usr/bin/env python3
2
3##
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements.  See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership.  The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License.  You may obtain a copy of the License at
11#
12# https://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing, software
15# distributed under the License is distributed on an "AS IS" BASIS,
16# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17# See the License for the specific language governing permissions and
18# limitations under the License.
19
20import binascii
21import datetime
22import decimal
23import io
24import itertools
25import json
26import unittest
27import warnings
28from typing import BinaryIO, Collection, Dict, List, Optional, Tuple, Union, cast
29
30import avro.io
31import avro.schema
32import avro.timezones
33from avro.utils import TypedDict
34
35
36class DefaultValueTestCaseType(TypedDict):
37    H: object
38
39
40SCHEMAS_TO_VALIDATE = tuple(
41    (json.dumps(schema), datum)
42    for schema, datum in (
43        ("null", None),
44        ("boolean", True),
45        ("string", "adsfasdf09809dsf-=adsf"),
46        ("bytes", b"12345abcd"),
47        ("int", 1234),
48        ("long", 1234),
49        ("float", 1234.0),
50        ("double", 1234.0),
51        ({"type": "fixed", "name": "Test", "size": 1}, b"B"),
52        (
53            {
54                "type": "fixed",
55                "logicalType": "decimal",
56                "name": "Test",
57                "size": 8,
58                "precision": 5,
59                "scale": 4,
60            },
61            decimal.Decimal("3.1415"),
62        ),
63        (
64            {
65                "type": "fixed",
66                "logicalType": "decimal",
67                "name": "Test",
68                "size": 8,
69                "precision": 5,
70                "scale": 4,
71            },
72            decimal.Decimal("-3.1415"),
73        ),
74        (
75            {"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 4},
76            decimal.Decimal("3.1415"),
77        ),
78        (
79            {"type": "bytes", "logicalType": "decimal", "precision": 5, "scale": 4},
80            decimal.Decimal("-3.1415"),
81        ),
82        ({"type": "enum", "name": "Test", "symbols": ["A", "B"]}, "B"),
83        ({"type": "array", "items": "long"}, [1, 3, 2]),
84        ({"type": "map", "values": "long"}, {"a": 1, "b": 3, "c": 2}),
85        (["string", "null", "long"], None),
86        ({"type": "int", "logicalType": "date"}, datetime.date(2000, 1, 1)),
87        (
88            {"type": "int", "logicalType": "time-millis"},
89            datetime.time(23, 59, 59, 999000),
90        ),
91        ({"type": "int", "logicalType": "time-millis"}, datetime.time(0, 0, 0, 000000)),
92        (
93            {"type": "long", "logicalType": "time-micros"},
94            datetime.time(23, 59, 59, 999999),
95        ),
96        (
97            {"type": "long", "logicalType": "time-micros"},
98            datetime.time(0, 0, 0, 000000),
99        ),
100        (
101            {"type": "long", "logicalType": "timestamp-millis"},
102            datetime.datetime(1000, 1, 1, 0, 0, 0, 000000, tzinfo=avro.timezones.utc),
103        ),
104        (
105            {"type": "long", "logicalType": "timestamp-millis"},
106            datetime.datetime(9999, 12, 31, 23, 59, 59, 999000, tzinfo=avro.timezones.utc),
107        ),
108        (
109            {"type": "long", "logicalType": "timestamp-millis"},
110            datetime.datetime(2000, 1, 18, 2, 2, 1, 100000, tzinfo=avro.timezones.tst),
111        ),
112        (
113            {"type": "long", "logicalType": "timestamp-micros"},
114            datetime.datetime(1000, 1, 1, 0, 0, 0, 000000, tzinfo=avro.timezones.utc),
115        ),
116        (
117            {"type": "long", "logicalType": "timestamp-micros"},
118            datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=avro.timezones.utc),
119        ),
120        (
121            {"type": "long", "logicalType": "timestamp-micros"},
122            datetime.datetime(2000, 1, 18, 2, 2, 1, 123499, tzinfo=avro.timezones.tst),
123        ),
124        (
125            {"type": "string", "logicalType": "uuid"},
126            "a4818e1c-8e59-11eb-8dcd-0242ac130003",
127        ),  # UUID1
128        (
129            {"type": "string", "logicalType": "uuid"},
130            "570feebe-2bbc-4937-98df-285944e1dbbd",
131        ),  # UUID4
132        ({"type": "string", "logicalType": "unknown-logical-type"}, "12345abcd"),
133        ({"type": "string", "logicalType": "timestamp-millis"}, "12345abcd"),
134        (
135            {
136                "type": "record",
137                "name": "Test",
138                "fields": [{"name": "f", "type": "long"}],
139            },
140            {"f": 5},
141        ),
142        (
143            {
144                "type": "record",
145                "name": "Lisp",
146                "fields": [
147                    {
148                        "name": "value",
149                        "type": [
150                            "null",
151                            "string",
152                            {
153                                "type": "record",
154                                "name": "Cons",
155                                "fields": [
156                                    {"name": "car", "type": "Lisp"},
157                                    {"name": "cdr", "type": "Lisp"},
158                                ],
159                            },
160                        ],
161                    }
162                ],
163            },
164            {"value": {"car": {"value": "head"}, "cdr": {"value": None}}},
165        ),
166    )
167)
168
169BINARY_ENCODINGS = (
170    (0, b"00"),
171    (-1, b"01"),
172    (1, b"02"),
173    (-2, b"03"),
174    (2, b"04"),
175    (-64, b"7f"),
176    (64, b"80 01"),
177    (8192, b"80 80 01"),
178    (-8193, b"81 80 01"),
179)
180
181DEFAULT_VALUE_EXAMPLES = (
182    ("null", None),
183    ("boolean", True),
184    ("string", "foo"),
185    ("bytes", "\xff\xff"),
186    ("int", 5),
187    ("long", 5),
188    ("float", 1.1),
189    ("double", 1.1),
190    ({"type": "fixed", "name": "F", "size": 2}, "\xff\xff"),
191    ({"type": "enum", "name": "F", "symbols": ["FOO", "BAR"]}, "FOO"),
192    ({"type": "array", "items": "int"}, [1, 2, 3]),
193    ({"type": "map", "values": "int"}, {"a": 1, "b": 2}),
194    (["int", "null"], 5),
195    (
196        {"type": "record", "name": "F", "fields": [{"name": "A", "type": "int"}]},
197        {"A": 5},
198    ),
199)
200
201LONG_RECORD_SCHEMA = avro.schema.parse(
202    json.dumps(
203        {
204            "type": "record",
205            "name": "Test",
206            "fields": [
207                {"name": "A", "type": "int"},
208                {"name": "B", "type": "int"},
209                {"name": "C", "type": "int"},
210                {"name": "D", "type": "int"},
211                {"name": "E", "type": "int"},
212                {"name": "F", "type": "int"},
213                {"name": "G", "type": "int"},
214            ],
215        }
216    )
217)
218
219LONG_RECORD_DATUM = {"A": 1, "B": 2, "C": 3, "D": 4, "E": 5, "F": 6, "G": 7}
220
221
222def avro_hexlify(reader: BinaryIO) -> bytes:
223    """Return the hex value, as a string, of a binary-encoded int or long."""
224    b = []
225    current_byte = reader.read(1)
226    b.append(binascii.hexlify(current_byte))
227    while (ord(current_byte) & 0x80) != 0:
228        current_byte = reader.read(1)
229        b.append(binascii.hexlify(current_byte))
230    return b" ".join(b)
231
232
233def write_datum(datum: object, writers_schema: avro.schema.Schema) -> Tuple[io.BytesIO, avro.io.BinaryEncoder, avro.io.DatumWriter]:
234    writer = io.BytesIO()
235    encoder = avro.io.BinaryEncoder(writer)
236    datum_writer = avro.io.DatumWriter(writers_schema)
237    datum_writer.write(datum, encoder)
238    return writer, encoder, datum_writer
239
240
241def read_datum(buffer: io.BytesIO, writers_schema: avro.schema.Schema, readers_schema: Optional[avro.schema.Schema] = None) -> object:
242    reader = io.BytesIO(buffer.getvalue())
243    decoder = avro.io.BinaryDecoder(reader)
244    datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
245    return datum_reader.read(decoder)
246
247
248class IoValidateTestCase(unittest.TestCase):
249    def __init__(self, test_schema: str, test_datum: object) -> None:
250        """Ignore the normal signature for unittest.TestCase because we are generating
251        many test cases from this one class. This is safe as long as the autoloader
252        ignores this class. The autoloader will ignore this class as long as it has
253        no methods starting with `test_`.
254        """
255        super().__init__("io_valid")
256        self.test_schema = avro.schema.parse(test_schema)
257        self.test_datum = test_datum
258        # Never hide repeated warnings when running this test case.
259        warnings.simplefilter("always")
260
261    def io_valid(self) -> None:
262        """
263        In these cases, the provided data should be valid with the given schema.
264        """
265        with warnings.catch_warnings(record=True) as actual_warnings:
266            self.assertTrue(
267                avro.io.validate(self.test_schema, self.test_datum),
268                f"{self.test_datum} did not validate in the schema {self.test_schema}",
269            )
270
271
272class RoundTripTestCase(unittest.TestCase):
273    def __init__(self, test_schema: str, test_datum: object) -> None:
274        """Ignore the normal signature for unittest.TestCase because we are generating
275        many test cases from this one class. This is safe as long as the autoloader
276        ignores this class. The autoloader will ignore this class as long as it has
277        no methods starting with `test_`.
278        """
279        super().__init__("io_round_trip")
280        self.test_schema = avro.schema.parse(test_schema)
281        self.test_datum = test_datum
282        # Never hide repeated warnings when running this test case.
283        warnings.simplefilter("always")
284
285    def io_round_trip(self) -> None:
286        """
287        A datum should be the same after being encoded and then decoded.
288        """
289        with warnings.catch_warnings(record=True) as actual_warnings:
290            writer, encoder, datum_writer = write_datum(self.test_datum, self.test_schema)
291            round_trip_datum = read_datum(writer, self.test_schema)
292            expected: object
293            round_trip: object
294            if isinstance(round_trip_datum, decimal.Decimal):
295                expected, round_trip, message = (
296                    str(self.test_datum),
297                    round_trip_datum.to_eng_string(),
298                    "Decimal datum changed value after encode and decode",
299                )
300            elif isinstance(round_trip_datum, datetime.datetime):
301                expected, round_trip, message = (
302                    cast(datetime.datetime, self.test_datum).astimezone(tz=avro.timezones.utc),
303                    round_trip_datum,
304                    "DateTime datum changed value after encode and decode",
305                )
306            else:
307                expected, round_trip, message = (
308                    self.test_datum,
309                    round_trip_datum,
310                    "Datum changed value after encode and decode",
311                )
312            self.assertEqual(expected, round_trip, message)
313
314
315class BinaryEncodingTestCase(unittest.TestCase):
316    def __init__(self, skip: bool, test_type: str, test_datum: object, test_hex: bytes) -> None:
317        """Ignore the normal signature for unittest.TestCase because we are generating
318        many test cases from this one class. This is safe as long as the autoloader
319        ignores this class. The autoloader will ignore this class as long as it has
320        no methods starting with `test_`.
321        """
322        super().__init__(f"check_{'skip' if skip else 'binary'}_encoding")
323        self.writers_schema = avro.schema.parse(f'"{test_type}"')
324        self.test_datum = test_datum
325        self.test_hex = test_hex
326        # Never hide repeated warnings when running this test case.
327        warnings.simplefilter("always")
328
329    def check_binary_encoding(self) -> None:
330        with warnings.catch_warnings(record=True) as actual_warnings:
331            writer, encoder, datum_writer = write_datum(self.test_datum, self.writers_schema)
332            writer.seek(0)
333            hex_val = avro_hexlify(writer)
334            self.assertEqual(
335                self.test_hex,
336                hex_val,
337                "Binary encoding did not match expected hex representation.",
338            )
339
340    def check_skip_encoding(self) -> None:
341        VALUE_TO_READ = 6253
342        with warnings.catch_warnings(record=True) as actual_warnings:
343            # write the value to skip and a known value
344            writer, encoder, datum_writer = write_datum(self.test_datum, self.writers_schema)
345            datum_writer.write(VALUE_TO_READ, encoder)
346
347            # skip the value
348            reader = io.BytesIO(writer.getvalue())
349            decoder = avro.io.BinaryDecoder(reader)
350            decoder.skip_long()
351
352            # read data from string buffer
353            datum_reader = avro.io.DatumReader(self.writers_schema)
354            read_value = datum_reader.read(decoder)
355
356            self.assertEqual(
357                read_value,
358                VALUE_TO_READ,
359                "Unexpected value after skipping a binary encoded value.",
360            )
361
362
363class SchemaPromotionTestCase(unittest.TestCase):
364    def __init__(self, write_type: str, read_type: str) -> None:
365        """Ignore the normal signature for unittest.TestCase because we are generating
366        many test cases from this one class. This is safe as long as the autoloader
367        ignores this class. The autoloader will ignore this class as long as it has
368        no methods starting with `test_`.
369        """
370        super().__init__("check_schema_promotion")
371        self.writers_schema = avro.schema.parse(f'"{write_type}"')
372        self.readers_schema = avro.schema.parse(f'"{read_type}"')
373        # Never hide repeated warnings when running this test case.
374        warnings.simplefilter("always")
375
376    def check_schema_promotion(self) -> None:
377        """Test schema promotion"""
378        # note that checking writers_schema.type in read_data
379        # allows us to handle promotion correctly
380        DATUM_TO_WRITE = 219
381        with warnings.catch_warnings(record=True) as actual_warnings:
382            writer, enc, dw = write_datum(DATUM_TO_WRITE, self.writers_schema)
383            datum_read = read_datum(writer, self.writers_schema, self.readers_schema)
384            self.assertEqual(
385                datum_read,
386                DATUM_TO_WRITE,
387                f"Datum changed between schema that were supposed to promote: writer: {self.writers_schema} reader: {self.readers_schema}.",
388            )
389
390
391class DefaultValueTestCase(unittest.TestCase):
392    def __init__(self, field_type: Collection[str], default: Union[Dict[str, int], List[int], None, float, str]) -> None:
393        """Ignore the normal signature for unittest.TestCase because we are generating
394        many test cases from this one class. This is safe as long as the autoloader
395        ignores this class. The autoloader will ignore this class as long as it has
396        no methods starting with `test_`.
397        """
398        super().__init__("check_default_value")
399        self.field_type = field_type
400        self.default = default
401        # Never hide repeated warnings when running this test case.
402        warnings.simplefilter("always")
403
404    def check_default_value(self) -> None:
405        datum_read: DefaultValueTestCaseType
406        with warnings.catch_warnings(record=True) as actual_warnings:
407            datum_to_read = cast(DefaultValueTestCaseType, {"H": self.default})
408            readers_schema = avro.schema.parse(
409                json.dumps(
410                    {
411                        "type": "record",
412                        "name": "Test",
413                        "fields": [
414                            {
415                                "name": "H",
416                                "type": self.field_type,
417                                "default": self.default,
418                            }
419                        ],
420                    }
421                )
422            )
423            writer, _, _ = write_datum(LONG_RECORD_DATUM, LONG_RECORD_SCHEMA)
424            datum_read_ = cast(DefaultValueTestCaseType, read_datum(writer, LONG_RECORD_SCHEMA, readers_schema))
425            datum_read = {"H": cast(bytes, datum_read_["H"]).decode()} if isinstance(datum_read_["H"], bytes) else datum_read_
426            self.assertEqual(datum_to_read, datum_read)
427
428
429class TestMisc(unittest.TestCase):
430    def test_decimal_bytes_small_scale(self) -> None:
431        """Avro should raise an AvroTypeException when attempting to write a decimal with a larger exponent than the schema's scale."""
432        datum = decimal.Decimal("3.1415")
433        _, _, exp = datum.as_tuple()
434        scale = -1 * exp - 1
435        schema = avro.schema.parse(
436            json.dumps(
437                {
438                    "type": "bytes",
439                    "logicalType": "decimal",
440                    "precision": 5,
441                    "scale": scale,
442                }
443            )
444        )
445        self.assertRaises(avro.errors.AvroOutOfScaleException, write_datum, datum, schema)
446
447    def test_decimal_fixed_small_scale(self) -> None:
448        """Avro should raise an AvroTypeException when attempting to write a decimal with a larger exponent than the schema's scale."""
449        datum = decimal.Decimal("3.1415")
450        _, _, exp = datum.as_tuple()
451        scale = -1 * exp - 1
452        schema = avro.schema.parse(
453            json.dumps(
454                {
455                    "type": "fixed",
456                    "logicalType": "decimal",
457                    "name": "Test",
458                    "size": 8,
459                    "precision": 5,
460                    "scale": scale,
461                }
462            )
463        )
464        self.assertRaises(avro.errors.AvroOutOfScaleException, write_datum, datum, schema)
465
466    def test_unknown_symbol(self) -> None:
467        datum_to_write = "FOO"
468        writers_schema = avro.schema.parse(json.dumps({"type": "enum", "name": "Test", "symbols": ["FOO", "BAR"]}))
469        readers_schema = avro.schema.parse(json.dumps({"type": "enum", "name": "Test", "symbols": ["BAR", "BAZ"]}))
470
471        writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
472        reader = io.BytesIO(writer.getvalue())
473        decoder = avro.io.BinaryDecoder(reader)
474        datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
475        self.assertRaises(avro.errors.SchemaResolutionException, datum_reader.read, decoder)
476
477    def test_no_default_value(self) -> None:
478        writers_schema = LONG_RECORD_SCHEMA
479        datum_to_write = LONG_RECORD_DATUM
480
481        readers_schema = avro.schema.parse(
482            json.dumps(
483                {
484                    "type": "record",
485                    "name": "Test",
486                    "fields": [{"name": "H", "type": "int"}],
487                }
488            )
489        )
490
491        writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
492        reader = io.BytesIO(writer.getvalue())
493        decoder = avro.io.BinaryDecoder(reader)
494        datum_reader = avro.io.DatumReader(writers_schema, readers_schema)
495        self.assertRaises(avro.errors.SchemaResolutionException, datum_reader.read, decoder)
496
497    def test_projection(self) -> None:
498        writers_schema = LONG_RECORD_SCHEMA
499        datum_to_write = LONG_RECORD_DATUM
500
501        readers_schema = avro.schema.parse(
502            json.dumps(
503                {
504                    "type": "record",
505                    "name": "Test",
506                    "fields": [
507                        {"name": "E", "type": "int"},
508                        {"name": "F", "type": "int"},
509                    ],
510                }
511            )
512        )
513        datum_to_read = {"E": 5, "F": 6}
514
515        writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
516        datum_read = read_datum(writer, writers_schema, readers_schema)
517        self.assertEqual(datum_to_read, datum_read)
518
519    def test_field_order(self) -> None:
520        writers_schema = LONG_RECORD_SCHEMA
521        datum_to_write = LONG_RECORD_DATUM
522
523        readers_schema = avro.schema.parse(
524            json.dumps(
525                {
526                    "type": "record",
527                    "name": "Test",
528                    "fields": [
529                        {"name": "F", "type": "int"},
530                        {"name": "E", "type": "int"},
531                    ],
532                }
533            )
534        )
535        datum_to_read = {"E": 5, "F": 6}
536
537        writer, encoder, datum_writer = write_datum(datum_to_write, writers_schema)
538        datum_read = read_datum(writer, writers_schema, readers_schema)
539        self.assertEqual(datum_to_read, datum_read)
540
541    def test_type_exception_int(self) -> None:
542        writers_schema = avro.schema.parse(
543            json.dumps(
544                {
545                    "type": "record",
546                    "name": "Test",
547                    "fields": [
548                        {"name": "F", "type": "int"},
549                        {"name": "E", "type": "int"},
550                    ],
551                }
552            )
553        )
554        datum_to_write = {"E": 5, "F": "Bad"}
555        with self.assertRaises(avro.errors.AvroTypeException) as exc:
556            write_datum(datum_to_write, writers_schema)
557        assert str(exc.exception) == 'The datum "Bad" provided for "F" is not an example of the schema "int"'
558
559    def test_type_exception_long(self) -> None:
560        writers_schema = avro.schema.parse(json.dumps({"type": "record", "name": "Test", "fields": [{"name": "foo", "type": "long"}]}))
561        datum_to_write = {"foo": 5.0}
562
563        with self.assertRaises(avro.errors.AvroTypeException) as exc:
564            write_datum(datum_to_write, writers_schema)
565        assert str(exc.exception) == 'The datum "5.0" provided for "foo" is not an example of the schema "long"'
566
567    def test_type_exception_record(self) -> None:
568        writers_schema = avro.schema.parse(json.dumps({"type": "record", "name": "Test", "fields": [{"name": "foo", "type": "long"}]}))
569        datum_to_write = ("foo", 5.0)
570
571        with self.assertRaisesRegex(avro.errors.AvroTypeException, r"The datum \".*\" provided for \".*\" is not an example of the schema [\s\S]*"):
572            write_datum(datum_to_write, writers_schema)
573
574
575def load_tests(loader: unittest.TestLoader, default_tests: None, pattern: None) -> unittest.TestSuite:
576    """Generate test cases across many test schema."""
577    suite = unittest.TestSuite()
578    suite.addTests(loader.loadTestsFromTestCase(TestMisc))
579    suite.addTests(IoValidateTestCase(schema_str, datum) for schema_str, datum in SCHEMAS_TO_VALIDATE)
580    suite.addTests(RoundTripTestCase(schema_str, datum) for schema_str, datum in SCHEMAS_TO_VALIDATE)
581    for skip in False, True:
582        for type_ in "int", "long":
583            suite.addTests(BinaryEncodingTestCase(skip, type_, datum, hex_) for datum, hex_ in BINARY_ENCODINGS)
584    suite.addTests(
585        SchemaPromotionTestCase(write_type, read_type) for write_type, read_type in itertools.combinations(("int", "long", "float", "double"), 2)
586    )
587    suite.addTests(DefaultValueTestCase(field_type, default) for field_type, default in DEFAULT_VALUE_EXAMPLES)
588    return suite
589
590
591if __name__ == "__main__":  # pragma: no coverage
592    unittest.main()
593