1#!/usr/bin/env python3
2
3##
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements.  See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership.  The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License.  You may obtain a copy of the License at
11#
12# https://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing, software
15# distributed under the License is distributed on an "AS IS" BASIS,
16# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17# See the License for the specific language governing permissions and
18# limitations under the License.
19from copy import copy
20from enum import Enum
21from typing import Container, Iterable, List, Optional, Set, cast
22
23from avro.errors import AvroRuntimeException
24from avro.schema import (
25    ArraySchema,
26    EnumSchema,
27    Field,
28    FixedSchema,
29    MapSchema,
30    NamedSchema,
31    RecordSchema,
32    Schema,
33    UnionSchema,
34)
35
36
37class SchemaType(str, Enum):
38    ARRAY = "array"
39    BOOLEAN = "boolean"
40    BYTES = "bytes"
41    DOUBLE = "double"
42    ENUM = "enum"
43    FIXED = "fixed"
44    FLOAT = "float"
45    INT = "int"
46    LONG = "long"
47    MAP = "map"
48    NULL = "null"
49    RECORD = "record"
50    STRING = "string"
51    UNION = "union"
52
53
54class SchemaCompatibilityType(Enum):
55    compatible = "compatible"
56    incompatible = "incompatible"
57    recursion_in_progress = "recursion_in_progress"
58
59
60class SchemaIncompatibilityType(Enum):
61    name_mismatch = "name_mismatch"
62    fixed_size_mismatch = "fixed_size_mismatch"
63    missing_enum_symbols = "missing_enum_symbols"
64    reader_field_missing_default_value = "reader_field_missing_default_value"
65    type_mismatch = "type_mismatch"
66    missing_union_branch = "missing_union_branch"
67
68
69PRIMITIVE_TYPES = {
70    SchemaType.NULL,
71    SchemaType.BOOLEAN,
72    SchemaType.INT,
73    SchemaType.LONG,
74    SchemaType.FLOAT,
75    SchemaType.DOUBLE,
76    SchemaType.BYTES,
77    SchemaType.STRING,
78}
79
80
81class SchemaCompatibilityResult:
82    def __init__(
83        self,
84        compatibility: SchemaCompatibilityType = SchemaCompatibilityType.recursion_in_progress,
85        incompatibilities: List[SchemaIncompatibilityType] = None,
86        messages: Optional[Set[str]] = None,
87        locations: Optional[Set[str]] = None,
88    ):
89        self.locations = locations or {"/"}
90        self.messages = messages or set()
91        self.compatibility = compatibility
92        self.incompatibilities = incompatibilities or []
93
94
95def merge(this: SchemaCompatibilityResult, that: SchemaCompatibilityResult) -> SchemaCompatibilityResult:
96    """
97    Merges two {@code SchemaCompatibilityResult} into a new instance, combining the list of Incompatibilities
98    and regressing to the SchemaCompatibilityType.incompatible state if any incompatibilities are encountered.
99    :param this: SchemaCompatibilityResult
100    :param that: SchemaCompatibilityResult
101    :return: SchemaCompatibilityResult
102    """
103    that = cast(SchemaCompatibilityResult, that)
104    merged = [*copy(this.incompatibilities), *copy(that.incompatibilities)]
105    if this.compatibility is SchemaCompatibilityType.compatible:
106        compat = that.compatibility
107        messages = that.messages
108        locations = that.locations
109    else:
110        compat = this.compatibility
111        messages = this.messages.union(that.messages)
112        locations = this.locations.union(that.locations)
113    return SchemaCompatibilityResult(
114        compatibility=compat,
115        incompatibilities=merged,
116        messages=messages,
117        locations=locations,
118    )
119
120
121CompatibleResult = SchemaCompatibilityResult(SchemaCompatibilityType.compatible)
122
123
124class ReaderWriter:
125    def __init__(self, reader: Schema, writer: Schema) -> None:
126        self.reader, self.writer = reader, writer
127
128    def __hash__(self) -> int:
129        return id(self.reader) ^ id(self.writer)
130
131    def __eq__(self, other) -> bool:
132        if not isinstance(other, ReaderWriter):
133            return False
134        return self.reader is other.reader and self.writer is other.writer
135
136
137class ReaderWriterCompatibilityChecker:
138    ROOT_REFERENCE_TOKEN = "/"
139
140    def __init__(self):
141        self.memoize_map = {}
142
143    def get_compatibility(
144        self,
145        reader: Schema,
146        writer: Schema,
147        reference_token: str = ROOT_REFERENCE_TOKEN,
148        location: Optional[List[str]] = None,
149    ) -> SchemaCompatibilityResult:
150        if location is None:
151            location = []
152        pair = ReaderWriter(reader, writer)
153        if pair in self.memoize_map:
154            result = cast(SchemaCompatibilityResult, self.memoize_map[pair])
155            if result.compatibility is SchemaCompatibilityType.recursion_in_progress:
156                result = CompatibleResult
157        else:
158            self.memoize_map[pair] = SchemaCompatibilityResult()
159            result = self.calculate_compatibility(reader, writer, location + [reference_token])
160            self.memoize_map[pair] = result
161        return result
162
163    # pylSchemaType.INT: disable=too-many-return-statements
164    def calculate_compatibility(
165        self,
166        reader: Schema,
167        writer: Schema,
168        location: List[str],
169    ) -> SchemaCompatibilityResult:
170        """
171        Calculates the compatibility of a reader/writer schema pair. Will be positive if the reader is capable of reading
172        whatever the writer may write
173        :param reader: avro.schema.Schema
174        :param writer: avro.schema.Schema
175        :param location: List[str]
176        :return: SchemaCompatibilityResult
177        """
178        assert reader is not None
179        assert writer is not None
180        result = CompatibleResult
181        if reader.type == writer.type:
182            if reader.type in PRIMITIVE_TYPES:
183                return result
184            if reader.type == SchemaType.ARRAY:
185                reader, writer = cast(ArraySchema, reader), cast(ArraySchema, writer)
186                return merge(
187                    result,
188                    self.get_compatibility(reader.items, writer.items, "items", location),
189                )
190            if reader.type == SchemaType.MAP:
191                reader, writer = cast(MapSchema, reader), cast(MapSchema, writer)
192                return merge(
193                    result,
194                    self.get_compatibility(reader.values, writer.values, "values", location),
195                )
196            if reader.type == SchemaType.FIXED:
197                reader, writer = cast(FixedSchema, reader), cast(FixedSchema, writer)
198                result = merge(result, check_schema_names(reader, writer, location))
199                return merge(result, check_fixed_size(reader, writer, location))
200            if reader.type == SchemaType.ENUM:
201                reader, writer = cast(EnumSchema, reader), cast(EnumSchema, writer)
202                result = merge(result, check_schema_names(reader, writer, location))
203                return merge(
204                    result,
205                    check_reader_enum_contains_writer_enum(reader, writer, location),
206                )
207            if reader.type == SchemaType.RECORD:
208                reader, writer = cast(RecordSchema, reader), cast(RecordSchema, writer)
209                result = merge(result, check_schema_names(reader, writer, location))
210                return merge(
211                    result,
212                    self.check_reader_writer_record_fields(reader, writer, location),
213                )
214            if reader.type == SchemaType.UNION:
215                reader, writer = cast(UnionSchema, reader), cast(UnionSchema, writer)
216                for i, writer_branch in enumerate(writer.schemas):
217                    compat = self.get_compatibility(reader, writer_branch)
218                    if compat.compatibility is SchemaCompatibilityType.incompatible:
219                        result = merge(
220                            result,
221                            incompatible(
222                                SchemaIncompatibilityType.missing_union_branch,
223                                f"reader union lacking writer type: {writer_branch.type.upper()}",
224                                location + [str(i)],
225                            ),
226                        )
227                return result
228            raise AvroRuntimeException(f"Unknown schema type: {reader.type}")
229        if writer.type == SchemaType.UNION:
230            writer = cast(UnionSchema, writer)
231            for s in writer.schemas:
232                result = merge(result, self.get_compatibility(reader, s))
233            return result
234        if reader.type in {SchemaType.NULL, SchemaType.BOOLEAN, SchemaType.INT}:
235            return merge(result, type_mismatch(reader, writer, location))
236        if reader.type == SchemaType.LONG:
237            if writer.type == SchemaType.INT:
238                return result
239            return merge(result, type_mismatch(reader, writer, location))
240        if reader.type == SchemaType.FLOAT:
241            if writer.type in {SchemaType.INT, SchemaType.LONG}:
242                return result
243            return merge(result, type_mismatch(reader, writer, location))
244        if reader.type == SchemaType.DOUBLE:
245            if writer.type in {SchemaType.INT, SchemaType.LONG, SchemaType.FLOAT}:
246                return result
247            return merge(result, type_mismatch(reader, writer, location))
248        if reader.type == SchemaType.BYTES:
249            if writer.type == SchemaType.STRING:
250                return result
251            return merge(result, type_mismatch(reader, writer, location))
252        if reader.type == SchemaType.STRING:
253            if writer.type == SchemaType.BYTES:
254                return result
255            return merge(result, type_mismatch(reader, writer, location))
256        if reader.type in {
257            SchemaType.ARRAY,
258            SchemaType.MAP,
259            SchemaType.FIXED,
260            SchemaType.ENUM,
261            SchemaType.RECORD,
262        }:
263            return merge(result, type_mismatch(reader, writer, location))
264        if reader.type == SchemaType.UNION:
265            reader = cast(UnionSchema, reader)
266            for reader_branch in reader.schemas:
267                compat = self.get_compatibility(reader_branch, writer)
268                if compat.compatibility is SchemaCompatibilityType.compatible:
269                    return result
270            # No branch in reader compatible with writer
271            message = f"reader union lacking writer type {writer.type}"
272            return merge(
273                result,
274                incompatible(SchemaIncompatibilityType.missing_union_branch, message, location),
275            )
276        raise AvroRuntimeException(f"Unknown schema type: {reader.type}")
277
278    # pylSchemaType.INT: enable=too-many-return-statements
279
280    def check_reader_writer_record_fields(self, reader: RecordSchema, writer: RecordSchema, location: List[str]) -> SchemaCompatibilityResult:
281        result = CompatibleResult
282        for i, reader_field in enumerate(reader.fields):
283            reader_field = cast(Field, reader_field)
284            writer_field = lookup_writer_field(writer_schema=writer, reader_field=reader_field)
285            if writer_field is None:
286                if not reader_field.has_default:
287                    if reader_field.type.type == SchemaType.ENUM and reader_field.type.props.get("default"):
288                        result = merge(
289                            result,
290                            self.get_compatibility(
291                                reader_field.type,
292                                writer,
293                                "type",
294                                location + ["fields", str(i)],
295                            ),
296                        )
297                    else:
298                        result = merge(
299                            result,
300                            incompatible(
301                                SchemaIncompatibilityType.reader_field_missing_default_value,
302                                reader_field.name,
303                                location + ["fields", str(i)],
304                            ),
305                        )
306            else:
307                result = merge(
308                    result,
309                    self.get_compatibility(
310                        reader_field.type,
311                        writer_field.type,
312                        "type",
313                        location + ["fields", str(i)],
314                    ),
315                )
316        return result
317
318
319def type_mismatch(reader: Schema, writer: Schema, location: List[str]) -> SchemaCompatibilityResult:
320    message = f"reader type: {reader.type} not compatible with writer type: {writer.type}"
321    return incompatible(SchemaIncompatibilityType.type_mismatch, message, location)
322
323
324def check_schema_names(reader: NamedSchema, writer: NamedSchema, location: List[str]) -> SchemaCompatibilityResult:
325    result = CompatibleResult
326    if not schema_name_equals(reader, writer):
327        message = f"expected: {writer.fullname}"
328        result = incompatible(SchemaIncompatibilityType.name_mismatch, message, location + ["name"])
329    return result
330
331
332def check_fixed_size(reader: FixedSchema, writer: FixedSchema, location: List[str]) -> SchemaCompatibilityResult:
333    result = CompatibleResult
334    actual = reader.size
335    expected = writer.size
336    if actual != expected:
337        message = f"expected: {expected}, found: {actual}"
338        result = incompatible(
339            SchemaIncompatibilityType.fixed_size_mismatch,
340            message,
341            location + ["size"],
342        )
343    return result
344
345
346def check_reader_enum_contains_writer_enum(reader: EnumSchema, writer: EnumSchema, location: List[str]) -> SchemaCompatibilityResult:
347    result = CompatibleResult
348    writer_symbols, reader_symbols = set(writer.symbols), set(reader.symbols)
349    extra_symbols = writer_symbols.difference(reader_symbols)
350    if extra_symbols:
351        default = reader.props.get("default")
352        if default and default in reader_symbols:
353            result = CompatibleResult
354        else:
355            result = incompatible(
356                SchemaIncompatibilityType.missing_enum_symbols,
357                str(extra_symbols),
358                location + ["symbols"],
359            )
360    return result
361
362
363def incompatible(incompat_type: SchemaIncompatibilityType, message: str, location: List[str]) -> SchemaCompatibilityResult:
364    locations = "/".join(location)
365    if len(location) > 1:
366        locations = locations[1:]
367    ret = SchemaCompatibilityResult(
368        compatibility=SchemaCompatibilityType.incompatible,
369        incompatibilities=[incompat_type],
370        locations={locations},
371        messages={message},
372    )
373    return ret
374
375
376def schema_name_equals(reader: NamedSchema, writer: NamedSchema) -> bool:
377    aliases = reader.props.get("aliases")
378    return (reader.name == writer.name) or (isinstance(aliases, Container) and writer.fullname in aliases)
379
380
381def lookup_writer_field(writer_schema: RecordSchema, reader_field: Field) -> Optional[Field]:
382    direct = writer_schema.fields_dict.get(reader_field.name)
383    if direct:
384        return cast(Field, direct)
385    aliases = reader_field.props.get("aliases")
386    if not isinstance(aliases, Iterable):
387        return None
388    for alias in aliases:
389        writer_field = writer_schema.fields_dict.get(alias)
390        if writer_field is not None:
391            return cast(Field, writer_field)
392    return None
393