1#!/usr/bin/env python 2"""Protoc Plugin to generate mypy stubs. Loosely based on @zbarsky's go implementation""" 3import os 4 5import sys 6from collections import defaultdict 7from contextlib import contextmanager 8from functools import wraps 9from typing import ( 10 Any, 11 Callable, 12 Dict, 13 Iterable, 14 Iterator, 15 List, 16 Optional, 17 Set, 18 Sequence, 19 Tuple, 20) 21 22import google.protobuf.descriptor_pb2 as d 23from google.protobuf.compiler import plugin_pb2 as plugin_pb2 24from google.protobuf.internal.containers import RepeatedCompositeFieldContainer 25from google.protobuf.internal.well_known_types import WKTBASES 26from . import extensions_pb2 27 28__version__ = "2.10" 29 30# SourceCodeLocation is defined by `message Location` here 31# https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto 32SourceCodeLocation = List[int] 33 34# So phabricator doesn't think mypy_protobuf.py is generated 35GENERATED = "@ge" + "nerated" 36HEADER = """\"\"\" 37{} by mypy-protobuf. Do not edit manually! 38isort:skip_file 39\"\"\" 40""".format( 41 GENERATED 42) 43 44# See https://github.com/dropbox/mypy-protobuf/issues/73 for details 45PYTHON_RESERVED = { 46 "False", 47 "None", 48 "True", 49 "and", 50 "as", 51 "async", 52 "await", 53 "assert", 54 "break", 55 "class", 56 "continue", 57 "def", 58 "del", 59 "elif", 60 "else", 61 "except", 62 "finally", 63 "for", 64 "from", 65 "global", 66 "if", 67 "import", 68 "in", 69 "is", 70 "lambda", 71 "nonlocal", 72 "not", 73 "or", 74 "pass", 75 "raise", 76 "return", 77 "try", 78 "while", 79 "with", 80 "yield", 81} 82 83PROTO_ENUM_RESERVED = { 84 "Name", 85 "Value", 86 "keys", 87 "values", 88 "items", 89} 90 91 92def _mangle_global_identifier(name: str) -> str: 93 """ 94 Module level identifiers are mangled and aliased so that they can be disambiguated 95 from fields/enum variants with the same name within the file. 96 97 Eg: 98 Enum variant `Name` or message field `Name` might conflict with a top level 99 message or enum named `Name`, so mangle it with a global___ prefix for 100 internal references. Note that this doesn't affect inner enums/messages 101 because they get fuly qualified when referenced within a file""" 102 return "global___{}".format(name) 103 104 105class Descriptors(object): 106 def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None: 107 files = {f.name: f for f in request.proto_file} 108 to_generate = {n: files[n] for n in request.file_to_generate} 109 self.files: Dict[str, d.FileDescriptorProto] = files 110 self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate 111 self.messages: Dict[str, d.DescriptorProto] = {} 112 self.message_to_fd: Dict[str, d.FileDescriptorProto] = {} 113 114 def _add_enums( 115 enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]", 116 prefix: str, 117 _fd: d.FileDescriptorProto, 118 ) -> None: 119 for enum in enums: 120 self.message_to_fd[prefix + enum.name] = _fd 121 self.message_to_fd[prefix + enum.name + ".V"] = _fd 122 123 def _add_messages( 124 messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]", 125 prefix: str, 126 _fd: d.FileDescriptorProto, 127 ) -> None: 128 for message in messages: 129 self.messages[prefix + message.name] = message 130 self.message_to_fd[prefix + message.name] = _fd 131 sub_prefix = prefix + message.name + "." 132 _add_messages(message.nested_type, sub_prefix, _fd) 133 _add_enums(message.enum_type, sub_prefix, _fd) 134 135 for fd in request.proto_file: 136 start_prefix = "." + fd.package + "." if fd.package else "." 137 _add_messages(fd.message_type, start_prefix, fd) 138 _add_enums(fd.enum_type, start_prefix, fd) 139 140 141class PkgWriter(object): 142 """Writes a single pyi file""" 143 144 def __init__( 145 self, 146 fd: d.FileDescriptorProto, 147 descriptors: Descriptors, 148 readable_stubs: bool, 149 relax_strict_optional_primitives: bool, 150 grpc: bool, 151 ) -> None: 152 self.fd = fd 153 self.descriptors = descriptors 154 self.readable_stubs = readable_stubs 155 self.relax_strict_optional_primitives = relax_strict_optional_primitives 156 self.grpc = grpc 157 self.lines: List[str] = [] 158 self.indent = "" 159 160 # Set of {x}, where {x} corresponds to to `import {x}` 161 self.imports: Set[str] = set() 162 # dictionary of x->(y,z) for `from {x} import {y} as {z}` 163 # if {z} is None, then it shortens to `from {x} import {y}` 164 self.from_imports: Dict[str, Set[Tuple[str, Optional[str]]]] = defaultdict(set) 165 166 # Comments 167 self.source_code_info_by_scl = { 168 tuple(location.path): location for location in fd.source_code_info.location 169 } 170 171 def _import(self, path: str, name: str) -> str: 172 """Imports a stdlib path and returns a handle to it 173 eg. self._import("typing", "Optional") -> "Optional" 174 """ 175 imp = path.replace("/", ".") 176 if self.readable_stubs: 177 self.from_imports[imp].add((name, None)) 178 return name 179 else: 180 self.imports.add(imp) 181 return imp + "." + name 182 183 def _import_message(self, name: str) -> str: 184 """Import a referenced message and return a handle""" 185 message_fd = self.descriptors.message_to_fd[name] 186 assert message_fd.name.endswith(".proto") 187 188 # Strip off package name 189 if message_fd.package: 190 assert name.startswith("." + message_fd.package + ".") 191 name = name[len("." + message_fd.package + ".") :] 192 else: 193 assert name.startswith(".") 194 name = name[1:] 195 196 # Use prepended "_r_" to disambiguate message names that alias python reserved keywords 197 split = name.split(".") 198 for i, part in enumerate(split): 199 if part in PYTHON_RESERVED: 200 split[i] = "_r_" + part 201 name = ".".join(split) 202 203 # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files 204 if not self.grpc and message_fd.name == self.fd.name: 205 return name if self.readable_stubs else _mangle_global_identifier(name) 206 207 # Not in file. Must import 208 # Python generated code ignores proto packages, so the only relevant factor is 209 # whether it is in the file or not. 210 import_name = self._import( 211 message_fd.name[:-6].replace("-", "_") + "_pb2", split[0] 212 ) 213 214 remains = ".".join(split[1:]) 215 if not remains: 216 return import_name 217 218 # remains could either be a direct import of a nested enum or message 219 # from another package. 220 return import_name + "." + remains 221 222 def _builtin(self, name: str) -> str: 223 return self._import("builtins", name) 224 225 @contextmanager 226 def _indent(self) -> Iterator[None]: 227 self.indent = self.indent + " " 228 yield 229 self.indent = self.indent[:-4] 230 231 def _write_line(self, line: str, *args: Any) -> None: 232 line = line.format(*args) 233 if line == "": 234 self.lines.append(line) 235 else: 236 self.lines.append(self.indent + line) 237 238 def _break_text(self, text_block: str) -> List[str]: 239 if text_block == "": 240 return [] 241 return [ 242 "{}".format(l[1:] if l.startswith(" ") else l) 243 for l in text_block.rstrip().split("\n") 244 ] 245 246 def _has_comments(self, scl: SourceCodeLocation) -> bool: 247 sci_loc = self.source_code_info_by_scl.get(tuple(scl)) 248 return sci_loc is not None and bool( 249 sci_loc.leading_detached_comments 250 or sci_loc.leading_comments 251 or sci_loc.trailing_comments 252 ) 253 254 def _write_comments(self, scl: SourceCodeLocation) -> bool: 255 """Return true if any comments were written""" 256 if not self._has_comments(scl): 257 return False 258 259 sci_loc = self.source_code_info_by_scl.get(tuple(scl)) 260 assert sci_loc is not None 261 262 lines = [] 263 for leading_detached_comment in sci_loc.leading_detached_comments: 264 lines.extend(self._break_text(leading_detached_comment)) 265 lines.append("") 266 if sci_loc.leading_comments is not None: 267 lines.extend(self._break_text(sci_loc.leading_comments)) 268 # Trailing comments also go in the header - to make sure it gets into the docstring 269 if sci_loc.trailing_comments is not None: 270 lines.extend(self._break_text(sci_loc.trailing_comments)) 271 272 if len(lines) == 1: 273 self._write_line('"""{}"""', lines[0]) 274 else: 275 for i, line in enumerate(lines): 276 if i == 0: 277 self._write_line('"""{}', line) 278 else: 279 self._write_line("{}", line) 280 self._write_line('"""') 281 282 return True 283 284 def write_enum_values( 285 self, 286 values: Iterable[Tuple[int, d.EnumValueDescriptorProto]], 287 value_type: str, 288 scl_prefix: SourceCodeLocation, 289 ) -> None: 290 for i, val in values: 291 if val.name in PYTHON_RESERVED: 292 continue 293 294 scl = scl_prefix + [i] 295 self._write_line( 296 "{} = {}({})", 297 val.name, 298 value_type, 299 val.number, 300 ) 301 if self._write_comments(scl): 302 self._write_line("") # Extra newline to separate 303 304 def write_module_attributes(self) -> None: 305 l = self._write_line 306 l( 307 "DESCRIPTOR: {} = ...", 308 self._import("google.protobuf.descriptor", "FileDescriptor"), 309 ) 310 l("") 311 312 def write_enums( 313 self, 314 enums: Iterable[d.EnumDescriptorProto], 315 prefix: str, 316 scl_prefix: SourceCodeLocation, 317 ) -> None: 318 l = self._write_line 319 for i, enum in enumerate(enums): 320 class_name = ( 321 enum.name if enum.name not in PYTHON_RESERVED else "_r_" + enum.name 322 ) 323 value_type_fq = prefix + class_name + ".V" 324 325 l( 326 "class {}({}, metaclass={}):", 327 class_name, 328 "_" + enum.name, 329 "_" + enum.name + "EnumTypeWrapper", 330 ) 331 with self._indent(): 332 scl = scl_prefix + [i] 333 self._write_comments(scl) 334 l("pass") 335 l("class {}:", "_" + enum.name) 336 with self._indent(): 337 l( 338 "V = {}('V', {})", 339 self._import("typing", "NewType"), 340 self._builtin("int"), 341 ) 342 l( 343 "class {}({}[{}], {}):", 344 "_" + enum.name + "EnumTypeWrapper", 345 self._import( 346 "google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper" 347 ), 348 "_" + enum.name + ".V", 349 self._builtin("type"), 350 ) 351 with self._indent(): 352 l( 353 "DESCRIPTOR: {} = ...", 354 self._import("google.protobuf.descriptor", "EnumDescriptor"), 355 ) 356 self.write_enum_values( 357 [ 358 (i, v) 359 for i, v in enumerate(enum.value) 360 if v.name not in PROTO_ENUM_RESERVED 361 ], 362 value_type_fq, 363 scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER], 364 ) 365 l("") 366 367 self.write_enum_values( 368 enumerate(enum.value), 369 value_type_fq, 370 scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER], 371 ) 372 if prefix == "" and not self.readable_stubs: 373 l("{} = {}", _mangle_global_identifier(class_name), class_name) 374 l("") 375 l("") 376 377 def write_messages( 378 self, 379 messages: Iterable[d.DescriptorProto], 380 prefix: str, 381 scl_prefix: SourceCodeLocation, 382 ) -> None: 383 l = self._write_line 384 385 for i, desc in enumerate(messages): 386 qualified_name = prefix + desc.name 387 388 # Reproduce some hardcoded logic from the protobuf implementation - where 389 # some specific "well_known_types" generated protos to have additional 390 # base classes 391 addl_base = u"" 392 if self.fd.package + "." + desc.name in WKTBASES: 393 # chop off the .proto - and import the well known type 394 # eg `from google.protobuf.duration import Duration` 395 well_known_type = WKTBASES[self.fd.package + "." + desc.name] 396 addl_base = ", " + self._import( 397 "google.protobuf.internal.well_known_types", 398 well_known_type.__name__, 399 ) 400 401 class_name = ( 402 desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name 403 ) 404 message_class = self._import("google.protobuf.message", "Message") 405 l("class {}({}{}):", class_name, message_class, addl_base) 406 with self._indent(): 407 scl = scl_prefix + [i] 408 self._write_comments(scl) 409 410 l( 411 "DESCRIPTOR: {} = ...", 412 self._import("google.protobuf.descriptor", "Descriptor"), 413 ) 414 415 # Nested enums/messages 416 self.write_enums( 417 desc.enum_type, 418 qualified_name + ".", 419 scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER], 420 ) 421 self.write_messages( 422 desc.nested_type, 423 qualified_name + ".", 424 scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER], 425 ) 426 427 # integer constants for field numbers 428 for f in desc.field: 429 l("{}_FIELD_NUMBER: {}", f.name.upper(), self._builtin("int")) 430 431 for idx, field in enumerate(desc.field): 432 if field.name in PYTHON_RESERVED: 433 continue 434 435 if ( 436 is_scalar(field) 437 and field.label != d.FieldDescriptorProto.LABEL_REPEATED 438 ): 439 # Scalar non repeated fields are r/w 440 l("{}: {} = ...", field.name, self.python_type(field)) 441 if self._write_comments( 442 scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx] 443 ): 444 l("") 445 else: 446 # r/o Getters for non-scalar fields and scalar-repeated fields 447 scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx] 448 l("@property") 449 l( 450 "def {}(self) -> {}:{}", 451 field.name, 452 self.python_type(field), 453 " ..." if not self._has_comments(scl_field) else "", 454 ) 455 if self._has_comments(scl_field): 456 with self._indent(): 457 self._write_comments(scl_field) 458 l("pass") 459 460 self.write_extensions( 461 desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER] 462 ) 463 464 # Constructor 465 self_arg = ( 466 "self_" if any(f.name == "self" for f in desc.field) else "self" 467 ) 468 l("def __init__({},", self_arg) 469 with self._indent(): 470 constructor_fields = [ 471 f for f in desc.field if f.name not in PYTHON_RESERVED 472 ] 473 if len(constructor_fields) > 0: 474 # Only positional args allowed 475 # See https://github.com/dropbox/mypy-protobuf/issues/71 476 l("*,") 477 for field in constructor_fields: 478 if ( 479 self.fd.syntax == "proto3" 480 and is_scalar(field) 481 and field.label != d.FieldDescriptorProto.LABEL_REPEATED 482 and not self.relax_strict_optional_primitives 483 ): 484 l( 485 "{} : {} = ...,", 486 field.name, 487 self.python_type(field, generic_container=True), 488 ) 489 else: 490 l( 491 "{} : {}[{}] = ...,", 492 field.name, 493 self._import("typing", "Optional"), 494 self.python_type(field, generic_container=True), 495 ) 496 l(") -> None: ...") 497 498 self.write_stringly_typed_fields(desc) 499 500 if prefix == "" and not self.readable_stubs: 501 l("{} = {}", _mangle_global_identifier(class_name), class_name) 502 l("") 503 504 def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None: 505 """Type the stringly-typed methods as a Union[Literal, Literal ...]""" 506 l = self._write_line 507 # HasField, ClearField, WhichOneof accepts both bytes/unicode 508 # HasField only supports singular. ClearField supports repeated as well 509 # In proto3, HasField only supports message fields and optional fields 510 # HasField always supports oneof fields 511 hf_fields = [ 512 f.name 513 for f in desc.field 514 if f.HasField("oneof_index") 515 or ( 516 f.label != d.FieldDescriptorProto.LABEL_REPEATED 517 and ( 518 self.fd.syntax != "proto3" 519 or f.type == d.FieldDescriptorProto.TYPE_MESSAGE 520 or f.proto3_optional 521 ) 522 ) 523 ] 524 cf_fields = [f.name for f in desc.field] 525 wo_fields = { 526 oneof.name: [ 527 f.name 528 for f in desc.field 529 if f.HasField("oneof_index") and f.oneof_index == idx 530 ] 531 for idx, oneof in enumerate(desc.oneof_decl) 532 } 533 534 hf_fields.extend(wo_fields.keys()) 535 cf_fields.extend(wo_fields.keys()) 536 537 hf_fields_text = ",".join( 538 sorted('u"{}",b"{}"'.format(name, name) for name in hf_fields) 539 ) 540 cf_fields_text = ",".join( 541 sorted('u"{}",b"{}"'.format(name, name) for name in cf_fields) 542 ) 543 544 if not hf_fields and not cf_fields and not wo_fields: 545 return 546 547 if hf_fields: 548 l( 549 "def HasField(self, field_name: {}[{}]) -> {}: ...", 550 self._import("typing_extensions", "Literal"), 551 hf_fields_text, 552 self._builtin("bool"), 553 ) 554 if cf_fields: 555 l( 556 "def ClearField(self, field_name: {}[{}]) -> None: ...", 557 self._import("typing_extensions", "Literal"), 558 cf_fields_text, 559 ) 560 561 for wo_field, members in sorted(wo_fields.items()): 562 if len(wo_fields) > 1: 563 l("@{}", self._import("typing", "overload")) 564 l( 565 "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}[{}]]: ...", 566 self._import("typing_extensions", "Literal"), 567 # Accepts both unicode and bytes in both py2 and py3 568 'u"{}",b"{}"'.format(wo_field, wo_field), 569 self._import("typing", "Optional"), 570 self._import("typing_extensions", "Literal"), 571 # Returns `str` in both py2 and py3 (bytes in py2, unicode in py3) 572 ",".join('"{}"'.format(m) for m in members), 573 ) 574 575 def write_extensions( 576 self, 577 extensions: Sequence[d.FieldDescriptorProto], 578 scl_prefix: SourceCodeLocation, 579 ) -> None: 580 l = self._write_line 581 for i, ext in enumerate(extensions): 582 scl = scl_prefix + [i] 583 584 l( 585 "{}: {}[{}, {}] = ...", 586 ext.name, 587 self._import( 588 "google.protobuf.internal.extension_dict", 589 "_ExtensionFieldDescriptor", 590 ), 591 self._import_message(ext.extendee), 592 self.python_type(ext), 593 ) 594 self._write_comments(scl) 595 l("") 596 597 def write_methods( 598 self, 599 service: d.ServiceDescriptorProto, 600 is_abstract: bool, 601 scl_prefix: SourceCodeLocation, 602 ) -> None: 603 l = self._write_line 604 methods = [ 605 (i, m) 606 for i, m in enumerate(service.method) 607 if m.name not in PYTHON_RESERVED 608 ] 609 if not methods: 610 l("pass") 611 for i, method in methods: 612 if is_abstract: 613 l("@{}", self._import("abc", "abstractmethod")) 614 l("def {}(self,", method.name) 615 with self._indent(): 616 l( 617 "rpc_controller: {},", 618 self._import("google.protobuf.service", "RpcController"), 619 ) 620 l("request: {},", self._import_message(method.input_type)) 621 l( 622 "done: {}[{}[[{}], None]],", 623 self._import("typing", "Optional"), 624 self._import("typing", "Callable"), 625 self._import_message(method.output_type), 626 ) 627 628 scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i] 629 l( 630 ") -> {}[{}]:{}", 631 self._import("concurrent.futures", "Future"), 632 self._import_message(method.output_type), 633 " ..." if not self._has_comments(scl_method) else "", 634 ) 635 if self._has_comments(scl_method): 636 with self._indent(): 637 self._write_comments(scl_method) 638 l("pass") 639 640 def write_services( 641 self, 642 services: Iterable[d.ServiceDescriptorProto], 643 scl_prefix: SourceCodeLocation, 644 ) -> None: 645 l = self._write_line 646 for i, service in enumerate(services): 647 scl = scl_prefix + [i] 648 class_name = ( 649 service.name 650 if service.name not in PYTHON_RESERVED 651 else "_r_" + service.name 652 ) 653 # The service definition interface 654 l( 655 "class {}({}, metaclass={}):", 656 class_name, 657 self._import("google.protobuf.service", "Service"), 658 self._import("abc", "ABCMeta"), 659 ) 660 with self._indent(): 661 self._write_comments(scl) 662 self.write_methods(service, is_abstract=True, scl_prefix=scl) 663 664 # The stub client 665 l("class {}({}):", service.name + "_Stub", class_name) 666 with self._indent(): 667 self._write_comments(scl) 668 l( 669 "def __init__(self, rpc_channel: {}) -> None: ...", 670 self._import("google.protobuf.service", "RpcChannel"), 671 ) 672 self.write_methods(service, is_abstract=False, scl_prefix=scl) 673 674 def _import_casttype(self, casttype: str) -> str: 675 split = casttype.split(".") 676 assert ( 677 len(split) == 2 678 ), "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile" 679 pkg = split[0].replace("/", ".") 680 return self._import(pkg, split[1]) 681 682 def _map_key_value_types( 683 self, 684 map_field: d.FieldDescriptorProto, 685 key_field: d.FieldDescriptorProto, 686 value_field: d.FieldDescriptorProto, 687 ) -> Tuple[str, str]: 688 key_casttype = map_field.options.Extensions[extensions_pb2.keytype] 689 ktype = ( 690 self._import_casttype(key_casttype) 691 if key_casttype 692 else self.python_type(key_field) 693 ) 694 value_casttype = map_field.options.Extensions[extensions_pb2.valuetype] 695 vtype = ( 696 self._import_casttype(value_casttype) 697 if value_casttype 698 else self.python_type(value_field) 699 ) 700 return ktype, vtype 701 702 def _callable_type(self, method: d.MethodDescriptorProto) -> str: 703 if method.client_streaming: 704 if method.server_streaming: 705 return self._import("grpc", "StreamStreamMultiCallable") 706 else: 707 return self._import("grpc", "StreamUnaryMultiCallable") 708 else: 709 if method.server_streaming: 710 return self._import("grpc", "UnaryStreamMultiCallable") 711 else: 712 return self._import("grpc", "UnaryUnaryMultiCallable") 713 714 def _input_type( 715 self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True 716 ) -> str: 717 result = self._import_message(method.input_type) 718 if use_stream_iterator and method.client_streaming: 719 result = "{}[{}]".format(self._import("typing", "Iterator"), result) 720 return result 721 722 def _output_type( 723 self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True 724 ) -> str: 725 result = self._import_message(method.output_type) 726 if use_stream_iterator and method.server_streaming: 727 result = "{}[{}]".format(self._import("typing", "Iterator"), result) 728 return result 729 730 def write_grpc_methods( 731 self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation 732 ) -> None: 733 l = self._write_line 734 methods = [ 735 (i, m) 736 for i, m in enumerate(service.method) 737 if m.name not in PYTHON_RESERVED 738 ] 739 if not methods: 740 l("pass") 741 l("") 742 for i, method in methods: 743 scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i] 744 745 l("@{}", self._import("abc", "abstractmethod")) 746 l("def {}(self,", method.name) 747 with self._indent(): 748 l("request: {},", self._input_type(method)) 749 l("context: {},", self._import("grpc", "ServicerContext")) 750 l( 751 ") -> {}:{}", 752 self._output_type(method), 753 " ..." if not self._has_comments(scl) else "", 754 ), 755 if self._has_comments(scl): 756 with self._indent(): 757 self._write_comments(scl) 758 l("pass") 759 l("") 760 761 def write_grpc_stub_methods( 762 self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation 763 ) -> None: 764 l = self._write_line 765 methods = [ 766 (i, m) 767 for i, m in enumerate(service.method) 768 if m.name not in PYTHON_RESERVED 769 ] 770 if not methods: 771 l("pass") 772 l("") 773 for i, method in methods: 774 scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i] 775 776 l("{}: {}[", method.name, self._callable_type(method)) 777 with self._indent(): 778 l("{},", self._input_type(method, False)) 779 l("{}] = ...", self._output_type(method, False)) 780 self._write_comments(scl) 781 l("") 782 783 def write_grpc_services( 784 self, 785 services: Iterable[d.ServiceDescriptorProto], 786 scl_prefix: SourceCodeLocation, 787 ) -> None: 788 l = self._write_line 789 for i, service in enumerate(services): 790 if service.name in PYTHON_RESERVED: 791 continue 792 793 scl = scl_prefix + [i] 794 795 # The stub client 796 l("class {}Stub:", service.name) 797 with self._indent(): 798 self._write_comments(scl) 799 l( 800 "def __init__(self, channel: {}) -> None: ...", 801 self._import("grpc", "Channel"), 802 ) 803 self.write_grpc_stub_methods(service, scl) 804 l("") 805 806 # The service definition interface 807 l( 808 "class {}Servicer(metaclass={}):", 809 service.name, 810 self._import("abc", "ABCMeta"), 811 ) 812 with self._indent(): 813 self._write_comments(scl) 814 self.write_grpc_methods(service, scl) 815 l("") 816 l( 817 "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...", 818 service.name, 819 service.name, 820 self._import("grpc", "Server"), 821 ) 822 l("") 823 824 def python_type( 825 self, field: d.FieldDescriptorProto, generic_container: bool = False 826 ) -> str: 827 """ 828 generic_container 829 if set, type the field with generic interfaces. Eg. 830 - Iterable[int] rather than RepeatedScalarFieldContainer[int] 831 - Mapping[k, v] rather than MessageMap[k, v] 832 Can be useful for input types (eg constructor) 833 """ 834 casttype = field.options.Extensions[extensions_pb2.casttype] 835 if casttype: 836 return self._import_casttype(casttype) 837 838 mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = { 839 d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"), 840 d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"), 841 d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"), 842 d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"), 843 d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"), 844 d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"), 845 d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"), 846 d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"), 847 d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"), 848 d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"), 849 d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"), 850 d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"), 851 d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"), 852 d.FieldDescriptorProto.TYPE_STRING: lambda: self._import("typing", "Text"), 853 d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"), 854 d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message( 855 field.type_name + ".V" 856 ), 857 d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message( 858 field.type_name 859 ), 860 d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message( 861 field.type_name 862 ), 863 } 864 865 assert field.type in mapping, "Unrecognized type: " + repr(field.type) 866 field_type = mapping[field.type]() 867 868 # For non-repeated fields, we're done! 869 if field.label != d.FieldDescriptorProto.LABEL_REPEATED: 870 return field_type 871 872 # Scalar repeated fields go in RepeatedScalarFieldContainer 873 if is_scalar(field): 874 container = ( 875 self._import("typing", "Iterable") 876 if generic_container 877 else self._import( 878 "google.protobuf.internal.containers", 879 "RepeatedScalarFieldContainer", 880 ) 881 ) 882 return "{}[{}]".format(container, field_type) 883 884 # non-scalar repeated map fields go in ScalarMap/MessageMap 885 msg = self.descriptors.messages[field.type_name] 886 if msg.options.map_entry: 887 # map generates a special Entry wrapper message 888 if generic_container: 889 container = self._import("typing", "Mapping") 890 elif is_scalar(msg.field[1]): 891 container = self._import( 892 "google.protobuf.internal.containers", "ScalarMap" 893 ) 894 else: 895 container = self._import( 896 "google.protobuf.internal.containers", "MessageMap" 897 ) 898 ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1]) 899 return "{}[{}, {}]".format(container, ktype, vtype) 900 901 # non-scalar repetated fields go in RepeatedCompositeFieldContainer 902 container = ( 903 self._import("typing", "Iterable") 904 if generic_container 905 else self._import( 906 "google.protobuf.internal.containers", 907 "RepeatedCompositeFieldContainer", 908 ) 909 ) 910 return "{}[{}]".format(container, field_type) 911 912 def write(self) -> str: 913 for reexport_idx in self.fd.public_dependency: 914 reexport_file = self.fd.dependency[reexport_idx] 915 reexport_fd = self.descriptors.files[reexport_file] 916 reexport_imp = ( 917 reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2" 918 ) 919 names = ( 920 [m.name for m in reexport_fd.message_type] 921 + [m.name for m in reexport_fd.enum_type] 922 + [v.name for m in reexport_fd.enum_type for v in m.value] 923 + [m.name for m in reexport_fd.extension] 924 ) 925 if reexport_fd.options.py_generic_services: 926 names.extend(m.name for m in reexport_fd.service) 927 928 if names: 929 # n,n to force a reexport (from x import y as y) 930 self.from_imports[reexport_imp].update((n, n) for n in names) 931 932 import_lines = [] 933 for pkg in sorted(self.imports): 934 import_lines.append(u"import {}".format(pkg)) 935 936 for pkg, items in sorted(self.from_imports.items()): 937 import_lines.append(u"from {} import (".format(pkg)) 938 for (name, reexport_name) in sorted(items): 939 if reexport_name is None: 940 import_lines.append(u" {},".format(name)) 941 else: 942 import_lines.append(u" {} as {},".format(name, reexport_name)) 943 import_lines.append(u")\n") 944 import_lines.append("") 945 946 return "\n".join(import_lines + self.lines) 947 948 949def is_scalar(fd: d.FieldDescriptorProto) -> bool: 950 return not ( 951 fd.type == d.FieldDescriptorProto.TYPE_MESSAGE 952 or fd.type == d.FieldDescriptorProto.TYPE_GROUP 953 ) 954 955 956def generate_mypy_stubs( 957 descriptors: Descriptors, 958 response: plugin_pb2.CodeGeneratorResponse, 959 quiet: bool, 960 readable_stubs: bool, 961 relax_strict_optional_primitives: bool, 962) -> None: 963 for name, fd in descriptors.to_generate.items(): 964 pkg_writer = PkgWriter( 965 fd, 966 descriptors, 967 readable_stubs, 968 relax_strict_optional_primitives, 969 grpc=False, 970 ) 971 972 pkg_writer.write_module_attributes() 973 pkg_writer.write_enums( 974 fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER] 975 ) 976 pkg_writer.write_messages( 977 fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER] 978 ) 979 pkg_writer.write_extensions( 980 fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER] 981 ) 982 if fd.options.py_generic_services: 983 pkg_writer.write_services( 984 fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER] 985 ) 986 987 assert name == fd.name 988 assert fd.name.endswith(".proto") 989 output = response.file.add() 990 output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi" 991 output.content = HEADER + pkg_writer.write() 992 if not quiet: 993 print("Writing mypy to", output.name, file=sys.stderr) 994 995 996def generate_mypy_grpc_stubs( 997 descriptors: Descriptors, 998 response: plugin_pb2.CodeGeneratorResponse, 999 quiet: bool, 1000 readable_stubs: bool, 1001 relax_strict_optional_primitives: bool, 1002) -> None: 1003 for name, fd in descriptors.to_generate.items(): 1004 pkg_writer = PkgWriter( 1005 fd, 1006 descriptors, 1007 readable_stubs, 1008 relax_strict_optional_primitives, 1009 grpc=True, 1010 ) 1011 pkg_writer.write_grpc_services( 1012 fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER] 1013 ) 1014 1015 assert name == fd.name 1016 assert fd.name.endswith(".proto") 1017 output = response.file.add() 1018 output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi" 1019 output.content = HEADER + pkg_writer.write() 1020 if not quiet: 1021 print("Writing mypy to", output.name, file=sys.stderr) 1022 1023 1024@contextmanager 1025def code_generation() -> Iterator[ 1026 Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse], 1027]: 1028 if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"): 1029 print("mypy-protobuf " + __version__) 1030 sys.exit(0) 1031 1032 # Read request message from stdin 1033 data = sys.stdin.buffer.read() 1034 1035 # Parse request 1036 request = plugin_pb2.CodeGeneratorRequest() 1037 request.ParseFromString(data) 1038 1039 # Create response 1040 response = plugin_pb2.CodeGeneratorResponse() 1041 1042 # Declare support for optional proto3 fields 1043 response.supported_features |= ( 1044 plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL 1045 ) 1046 1047 yield request, response 1048 1049 # Serialise response message 1050 output = response.SerializeToString() 1051 1052 # Write to stdout 1053 sys.stdout.buffer.write(output) 1054 1055 1056def main() -> None: 1057 # Generate mypy 1058 with code_generation() as (request, response): 1059 generate_mypy_stubs( 1060 Descriptors(request), 1061 response, 1062 "quiet" in request.parameter, 1063 "readable_stubs" in request.parameter, 1064 "relax_strict_optional_primitives" in request.parameter, 1065 ) 1066 1067 1068def grpc() -> None: 1069 # Generate grpc mypy 1070 with code_generation() as (request, response): 1071 generate_mypy_grpc_stubs( 1072 Descriptors(request), 1073 response, 1074 "quiet" in request.parameter, 1075 "readable_stubs" in request.parameter, 1076 "relax_strict_optional_primitives" in request.parameter, 1077 ) 1078 1079 1080if __name__ == "__main__": 1081 main() 1082