1#!/usr/bin/env python3 2 3## 4# Licensed to the Apache Software Foundation (ASF) under one 5# or more contributor license agreements. See the NOTICE file 6# distributed with this work for additional information 7# regarding copyright ownership. The ASF licenses this file 8# to you under the Apache License, Version 2.0 (the 9# "License"); you may not use this file except in compliance 10# with the License. You may obtain a copy of the License at 11# 12# https://www.apache.org/licenses/LICENSE-2.0 13# 14# Unless required by applicable law or agreed to in writing, software 15# distributed under the License is distributed on an "AS IS" BASIS, 16# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17# See the License for the specific language governing permissions and 18# limitations under the License. 19from copy import copy 20from enum import Enum 21from typing import Container, Iterable, List, Optional, Set, cast 22 23from avro.errors import AvroRuntimeException 24from avro.schema import ( 25 ArraySchema, 26 EnumSchema, 27 Field, 28 FixedSchema, 29 MapSchema, 30 NamedSchema, 31 RecordSchema, 32 Schema, 33 UnionSchema, 34) 35 36 37class SchemaType(str, Enum): 38 ARRAY = "array" 39 BOOLEAN = "boolean" 40 BYTES = "bytes" 41 DOUBLE = "double" 42 ENUM = "enum" 43 FIXED = "fixed" 44 FLOAT = "float" 45 INT = "int" 46 LONG = "long" 47 MAP = "map" 48 NULL = "null" 49 RECORD = "record" 50 STRING = "string" 51 UNION = "union" 52 53 54class SchemaCompatibilityType(Enum): 55 compatible = "compatible" 56 incompatible = "incompatible" 57 recursion_in_progress = "recursion_in_progress" 58 59 60class SchemaIncompatibilityType(Enum): 61 name_mismatch = "name_mismatch" 62 fixed_size_mismatch = "fixed_size_mismatch" 63 missing_enum_symbols = "missing_enum_symbols" 64 reader_field_missing_default_value = "reader_field_missing_default_value" 65 type_mismatch = "type_mismatch" 66 missing_union_branch = "missing_union_branch" 67 68 69PRIMITIVE_TYPES = { 70 SchemaType.NULL, 71 SchemaType.BOOLEAN, 72 SchemaType.INT, 73 SchemaType.LONG, 74 SchemaType.FLOAT, 75 SchemaType.DOUBLE, 76 SchemaType.BYTES, 77 SchemaType.STRING, 78} 79 80 81class SchemaCompatibilityResult: 82 def __init__( 83 self, 84 compatibility: SchemaCompatibilityType = SchemaCompatibilityType.recursion_in_progress, 85 incompatibilities: List[SchemaIncompatibilityType] = None, 86 messages: Optional[Set[str]] = None, 87 locations: Optional[Set[str]] = None, 88 ): 89 self.locations = locations or {"/"} 90 self.messages = messages or set() 91 self.compatibility = compatibility 92 self.incompatibilities = incompatibilities or [] 93 94 95def merge(this: SchemaCompatibilityResult, that: SchemaCompatibilityResult) -> SchemaCompatibilityResult: 96 """ 97 Merges two {@code SchemaCompatibilityResult} into a new instance, combining the list of Incompatibilities 98 and regressing to the SchemaCompatibilityType.incompatible state if any incompatibilities are encountered. 99 :param this: SchemaCompatibilityResult 100 :param that: SchemaCompatibilityResult 101 :return: SchemaCompatibilityResult 102 """ 103 that = cast(SchemaCompatibilityResult, that) 104 merged = [*copy(this.incompatibilities), *copy(that.incompatibilities)] 105 if this.compatibility is SchemaCompatibilityType.compatible: 106 compat = that.compatibility 107 messages = that.messages 108 locations = that.locations 109 else: 110 compat = this.compatibility 111 messages = this.messages.union(that.messages) 112 locations = this.locations.union(that.locations) 113 return SchemaCompatibilityResult( 114 compatibility=compat, 115 incompatibilities=merged, 116 messages=messages, 117 locations=locations, 118 ) 119 120 121CompatibleResult = SchemaCompatibilityResult(SchemaCompatibilityType.compatible) 122 123 124class ReaderWriter: 125 def __init__(self, reader: Schema, writer: Schema) -> None: 126 self.reader, self.writer = reader, writer 127 128 def __hash__(self) -> int: 129 return id(self.reader) ^ id(self.writer) 130 131 def __eq__(self, other) -> bool: 132 if not isinstance(other, ReaderWriter): 133 return False 134 return self.reader is other.reader and self.writer is other.writer 135 136 137class ReaderWriterCompatibilityChecker: 138 ROOT_REFERENCE_TOKEN = "/" 139 140 def __init__(self): 141 self.memoize_map = {} 142 143 def get_compatibility( 144 self, 145 reader: Schema, 146 writer: Schema, 147 reference_token: str = ROOT_REFERENCE_TOKEN, 148 location: Optional[List[str]] = None, 149 ) -> SchemaCompatibilityResult: 150 if location is None: 151 location = [] 152 pair = ReaderWriter(reader, writer) 153 if pair in self.memoize_map: 154 result = cast(SchemaCompatibilityResult, self.memoize_map[pair]) 155 if result.compatibility is SchemaCompatibilityType.recursion_in_progress: 156 result = CompatibleResult 157 else: 158 self.memoize_map[pair] = SchemaCompatibilityResult() 159 result = self.calculate_compatibility(reader, writer, location + [reference_token]) 160 self.memoize_map[pair] = result 161 return result 162 163 # pylSchemaType.INT: disable=too-many-return-statements 164 def calculate_compatibility( 165 self, 166 reader: Schema, 167 writer: Schema, 168 location: List[str], 169 ) -> SchemaCompatibilityResult: 170 """ 171 Calculates the compatibility of a reader/writer schema pair. Will be positive if the reader is capable of reading 172 whatever the writer may write 173 :param reader: avro.schema.Schema 174 :param writer: avro.schema.Schema 175 :param location: List[str] 176 :return: SchemaCompatibilityResult 177 """ 178 assert reader is not None 179 assert writer is not None 180 result = CompatibleResult 181 if reader.type == writer.type: 182 if reader.type in PRIMITIVE_TYPES: 183 return result 184 if reader.type == SchemaType.ARRAY: 185 reader, writer = cast(ArraySchema, reader), cast(ArraySchema, writer) 186 return merge( 187 result, 188 self.get_compatibility(reader.items, writer.items, "items", location), 189 ) 190 if reader.type == SchemaType.MAP: 191 reader, writer = cast(MapSchema, reader), cast(MapSchema, writer) 192 return merge( 193 result, 194 self.get_compatibility(reader.values, writer.values, "values", location), 195 ) 196 if reader.type == SchemaType.FIXED: 197 reader, writer = cast(FixedSchema, reader), cast(FixedSchema, writer) 198 result = merge(result, check_schema_names(reader, writer, location)) 199 return merge(result, check_fixed_size(reader, writer, location)) 200 if reader.type == SchemaType.ENUM: 201 reader, writer = cast(EnumSchema, reader), cast(EnumSchema, writer) 202 result = merge(result, check_schema_names(reader, writer, location)) 203 return merge( 204 result, 205 check_reader_enum_contains_writer_enum(reader, writer, location), 206 ) 207 if reader.type == SchemaType.RECORD: 208 reader, writer = cast(RecordSchema, reader), cast(RecordSchema, writer) 209 result = merge(result, check_schema_names(reader, writer, location)) 210 return merge( 211 result, 212 self.check_reader_writer_record_fields(reader, writer, location), 213 ) 214 if reader.type == SchemaType.UNION: 215 reader, writer = cast(UnionSchema, reader), cast(UnionSchema, writer) 216 for i, writer_branch in enumerate(writer.schemas): 217 compat = self.get_compatibility(reader, writer_branch) 218 if compat.compatibility is SchemaCompatibilityType.incompatible: 219 result = merge( 220 result, 221 incompatible( 222 SchemaIncompatibilityType.missing_union_branch, 223 f"reader union lacking writer type: {writer_branch.type.upper()}", 224 location + [str(i)], 225 ), 226 ) 227 return result 228 raise AvroRuntimeException(f"Unknown schema type: {reader.type}") 229 if writer.type == SchemaType.UNION: 230 writer = cast(UnionSchema, writer) 231 for s in writer.schemas: 232 result = merge(result, self.get_compatibility(reader, s)) 233 return result 234 if reader.type in {SchemaType.NULL, SchemaType.BOOLEAN, SchemaType.INT}: 235 return merge(result, type_mismatch(reader, writer, location)) 236 if reader.type == SchemaType.LONG: 237 if writer.type == SchemaType.INT: 238 return result 239 return merge(result, type_mismatch(reader, writer, location)) 240 if reader.type == SchemaType.FLOAT: 241 if writer.type in {SchemaType.INT, SchemaType.LONG}: 242 return result 243 return merge(result, type_mismatch(reader, writer, location)) 244 if reader.type == SchemaType.DOUBLE: 245 if writer.type in {SchemaType.INT, SchemaType.LONG, SchemaType.FLOAT}: 246 return result 247 return merge(result, type_mismatch(reader, writer, location)) 248 if reader.type == SchemaType.BYTES: 249 if writer.type == SchemaType.STRING: 250 return result 251 return merge(result, type_mismatch(reader, writer, location)) 252 if reader.type == SchemaType.STRING: 253 if writer.type == SchemaType.BYTES: 254 return result 255 return merge(result, type_mismatch(reader, writer, location)) 256 if reader.type in { 257 SchemaType.ARRAY, 258 SchemaType.MAP, 259 SchemaType.FIXED, 260 SchemaType.ENUM, 261 SchemaType.RECORD, 262 }: 263 return merge(result, type_mismatch(reader, writer, location)) 264 if reader.type == SchemaType.UNION: 265 reader = cast(UnionSchema, reader) 266 for reader_branch in reader.schemas: 267 compat = self.get_compatibility(reader_branch, writer) 268 if compat.compatibility is SchemaCompatibilityType.compatible: 269 return result 270 # No branch in reader compatible with writer 271 message = f"reader union lacking writer type {writer.type}" 272 return merge( 273 result, 274 incompatible(SchemaIncompatibilityType.missing_union_branch, message, location), 275 ) 276 raise AvroRuntimeException(f"Unknown schema type: {reader.type}") 277 278 # pylSchemaType.INT: enable=too-many-return-statements 279 280 def check_reader_writer_record_fields(self, reader: RecordSchema, writer: RecordSchema, location: List[str]) -> SchemaCompatibilityResult: 281 result = CompatibleResult 282 for i, reader_field in enumerate(reader.fields): 283 reader_field = cast(Field, reader_field) 284 writer_field = lookup_writer_field(writer_schema=writer, reader_field=reader_field) 285 if writer_field is None: 286 if not reader_field.has_default: 287 if reader_field.type.type == SchemaType.ENUM and reader_field.type.props.get("default"): 288 result = merge( 289 result, 290 self.get_compatibility( 291 reader_field.type, 292 writer, 293 "type", 294 location + ["fields", str(i)], 295 ), 296 ) 297 else: 298 result = merge( 299 result, 300 incompatible( 301 SchemaIncompatibilityType.reader_field_missing_default_value, 302 reader_field.name, 303 location + ["fields", str(i)], 304 ), 305 ) 306 else: 307 result = merge( 308 result, 309 self.get_compatibility( 310 reader_field.type, 311 writer_field.type, 312 "type", 313 location + ["fields", str(i)], 314 ), 315 ) 316 return result 317 318 319def type_mismatch(reader: Schema, writer: Schema, location: List[str]) -> SchemaCompatibilityResult: 320 message = f"reader type: {reader.type} not compatible with writer type: {writer.type}" 321 return incompatible(SchemaIncompatibilityType.type_mismatch, message, location) 322 323 324def check_schema_names(reader: NamedSchema, writer: NamedSchema, location: List[str]) -> SchemaCompatibilityResult: 325 result = CompatibleResult 326 if not schema_name_equals(reader, writer): 327 message = f"expected: {writer.fullname}" 328 result = incompatible(SchemaIncompatibilityType.name_mismatch, message, location + ["name"]) 329 return result 330 331 332def check_fixed_size(reader: FixedSchema, writer: FixedSchema, location: List[str]) -> SchemaCompatibilityResult: 333 result = CompatibleResult 334 actual = reader.size 335 expected = writer.size 336 if actual != expected: 337 message = f"expected: {expected}, found: {actual}" 338 result = incompatible( 339 SchemaIncompatibilityType.fixed_size_mismatch, 340 message, 341 location + ["size"], 342 ) 343 return result 344 345 346def check_reader_enum_contains_writer_enum(reader: EnumSchema, writer: EnumSchema, location: List[str]) -> SchemaCompatibilityResult: 347 result = CompatibleResult 348 writer_symbols, reader_symbols = set(writer.symbols), set(reader.symbols) 349 extra_symbols = writer_symbols.difference(reader_symbols) 350 if extra_symbols: 351 default = reader.props.get("default") 352 if default and default in reader_symbols: 353 result = CompatibleResult 354 else: 355 result = incompatible( 356 SchemaIncompatibilityType.missing_enum_symbols, 357 str(extra_symbols), 358 location + ["symbols"], 359 ) 360 return result 361 362 363def incompatible(incompat_type: SchemaIncompatibilityType, message: str, location: List[str]) -> SchemaCompatibilityResult: 364 locations = "/".join(location) 365 if len(location) > 1: 366 locations = locations[1:] 367 ret = SchemaCompatibilityResult( 368 compatibility=SchemaCompatibilityType.incompatible, 369 incompatibilities=[incompat_type], 370 locations={locations}, 371 messages={message}, 372 ) 373 return ret 374 375 376def schema_name_equals(reader: NamedSchema, writer: NamedSchema) -> bool: 377 aliases = reader.props.get("aliases") 378 return (reader.name == writer.name) or (isinstance(aliases, Container) and writer.fullname in aliases) 379 380 381def lookup_writer_field(writer_schema: RecordSchema, reader_field: Field) -> Optional[Field]: 382 direct = writer_schema.fields_dict.get(reader_field.name) 383 if direct: 384 return cast(Field, direct) 385 aliases = reader_field.props.get("aliases") 386 if not isinstance(aliases, Iterable): 387 return None 388 for alias in aliases: 389 writer_field = writer_schema.fields_dict.get(alias) 390 if writer_field is not None: 391 return cast(Field, writer_field) 392 return None 393