1#!/usr/bin/env python
2"""Protoc Plugin to generate mypy stubs. Loosely based on @zbarsky's go implementation"""
3import os
4
5import sys
6from collections import defaultdict
7from contextlib import contextmanager
8from functools import wraps
9from typing import (
10    Any,
11    Callable,
12    Dict,
13    Iterable,
14    Iterator,
15    List,
16    Optional,
17    Set,
18    Sequence,
19    Tuple,
20)
21
22import google.protobuf.descriptor_pb2 as d
23from google.protobuf.compiler import plugin_pb2 as plugin_pb2
24from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
25from google.protobuf.internal.well_known_types import WKTBASES
26from . import extensions_pb2
27
28__version__ = "2.10"
29
30# SourceCodeLocation is defined by `message Location` here
31# https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
32SourceCodeLocation = List[int]
33
34# So phabricator doesn't think mypy_protobuf.py is generated
35GENERATED = "@ge" + "nerated"
36HEADER = """\"\"\"
37{} by mypy-protobuf.  Do not edit manually!
38isort:skip_file
39\"\"\"
40""".format(
41    GENERATED
42)
43
44# See https://github.com/dropbox/mypy-protobuf/issues/73 for details
45PYTHON_RESERVED = {
46    "False",
47    "None",
48    "True",
49    "and",
50    "as",
51    "async",
52    "await",
53    "assert",
54    "break",
55    "class",
56    "continue",
57    "def",
58    "del",
59    "elif",
60    "else",
61    "except",
62    "finally",
63    "for",
64    "from",
65    "global",
66    "if",
67    "import",
68    "in",
69    "is",
70    "lambda",
71    "nonlocal",
72    "not",
73    "or",
74    "pass",
75    "raise",
76    "return",
77    "try",
78    "while",
79    "with",
80    "yield",
81}
82
83PROTO_ENUM_RESERVED = {
84    "Name",
85    "Value",
86    "keys",
87    "values",
88    "items",
89}
90
91
92def _mangle_global_identifier(name: str) -> str:
93    """
94    Module level identifiers are mangled and aliased so that they can be disambiguated
95    from fields/enum variants with the same name within the file.
96
97    Eg:
98    Enum variant `Name` or message field `Name` might conflict with a top level
99    message or enum named `Name`, so mangle it with a global___ prefix for
100    internal references. Note that this doesn't affect inner enums/messages
101    because they get fuly qualified when referenced within a file"""
102    return "global___{}".format(name)
103
104
105class Descriptors(object):
106    def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None:
107        files = {f.name: f for f in request.proto_file}
108        to_generate = {n: files[n] for n in request.file_to_generate}
109        self.files: Dict[str, d.FileDescriptorProto] = files
110        self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate
111        self.messages: Dict[str, d.DescriptorProto] = {}
112        self.message_to_fd: Dict[str, d.FileDescriptorProto] = {}
113
114        def _add_enums(
115            enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]",
116            prefix: str,
117            _fd: d.FileDescriptorProto,
118        ) -> None:
119            for enum in enums:
120                self.message_to_fd[prefix + enum.name] = _fd
121                self.message_to_fd[prefix + enum.name + ".V"] = _fd
122
123        def _add_messages(
124            messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]",
125            prefix: str,
126            _fd: d.FileDescriptorProto,
127        ) -> None:
128            for message in messages:
129                self.messages[prefix + message.name] = message
130                self.message_to_fd[prefix + message.name] = _fd
131                sub_prefix = prefix + message.name + "."
132                _add_messages(message.nested_type, sub_prefix, _fd)
133                _add_enums(message.enum_type, sub_prefix, _fd)
134
135        for fd in request.proto_file:
136            start_prefix = "." + fd.package + "." if fd.package else "."
137            _add_messages(fd.message_type, start_prefix, fd)
138            _add_enums(fd.enum_type, start_prefix, fd)
139
140
141class PkgWriter(object):
142    """Writes a single pyi file"""
143
144    def __init__(
145        self,
146        fd: d.FileDescriptorProto,
147        descriptors: Descriptors,
148        readable_stubs: bool,
149        relax_strict_optional_primitives: bool,
150        grpc: bool,
151    ) -> None:
152        self.fd = fd
153        self.descriptors = descriptors
154        self.readable_stubs = readable_stubs
155        self.relax_strict_optional_primitives = relax_strict_optional_primitives
156        self.grpc = grpc
157        self.lines: List[str] = []
158        self.indent = ""
159
160        # Set of {x}, where {x} corresponds to to `import {x}`
161        self.imports: Set[str] = set()
162        # dictionary of x->(y,z) for `from {x} import {y} as {z}`
163        # if {z} is None, then it shortens to `from {x} import {y}`
164        self.from_imports: Dict[str, Set[Tuple[str, Optional[str]]]] = defaultdict(set)
165
166        # Comments
167        self.source_code_info_by_scl = {
168            tuple(location.path): location for location in fd.source_code_info.location
169        }
170
171    def _import(self, path: str, name: str) -> str:
172        """Imports a stdlib path and returns a handle to it
173        eg. self._import("typing", "Optional") -> "Optional"
174        """
175        imp = path.replace("/", ".")
176        if self.readable_stubs:
177            self.from_imports[imp].add((name, None))
178            return name
179        else:
180            self.imports.add(imp)
181            return imp + "." + name
182
183    def _import_message(self, name: str) -> str:
184        """Import a referenced message and return a handle"""
185        message_fd = self.descriptors.message_to_fd[name]
186        assert message_fd.name.endswith(".proto")
187
188        # Strip off package name
189        if message_fd.package:
190            assert name.startswith("." + message_fd.package + ".")
191            name = name[len("." + message_fd.package + ".") :]
192        else:
193            assert name.startswith(".")
194            name = name[1:]
195
196        # Use prepended "_r_" to disambiguate message names that alias python reserved keywords
197        split = name.split(".")
198        for i, part in enumerate(split):
199            if part in PYTHON_RESERVED:
200                split[i] = "_r_" + part
201        name = ".".join(split)
202
203        # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files
204        if not self.grpc and message_fd.name == self.fd.name:
205            return name if self.readable_stubs else _mangle_global_identifier(name)
206
207        # Not in file. Must import
208        # Python generated code ignores proto packages, so the only relevant factor is
209        # whether it is in the file or not.
210        import_name = self._import(
211            message_fd.name[:-6].replace("-", "_") + "_pb2", split[0]
212        )
213
214        remains = ".".join(split[1:])
215        if not remains:
216            return import_name
217
218        # remains could either be a direct import of a nested enum or message
219        # from another package.
220        return import_name + "." + remains
221
222    def _builtin(self, name: str) -> str:
223        return self._import("builtins", name)
224
225    @contextmanager
226    def _indent(self) -> Iterator[None]:
227        self.indent = self.indent + "    "
228        yield
229        self.indent = self.indent[:-4]
230
231    def _write_line(self, line: str, *args: Any) -> None:
232        line = line.format(*args)
233        if line == "":
234            self.lines.append(line)
235        else:
236            self.lines.append(self.indent + line)
237
238    def _break_text(self, text_block: str) -> List[str]:
239        if text_block == "":
240            return []
241        return [
242            "{}".format(l[1:] if l.startswith(" ") else l)
243            for l in text_block.rstrip().split("\n")
244        ]
245
246    def _has_comments(self, scl: SourceCodeLocation) -> bool:
247        sci_loc = self.source_code_info_by_scl.get(tuple(scl))
248        return sci_loc is not None and bool(
249            sci_loc.leading_detached_comments
250            or sci_loc.leading_comments
251            or sci_loc.trailing_comments
252        )
253
254    def _write_comments(self, scl: SourceCodeLocation) -> bool:
255        """Return true if any comments were written"""
256        if not self._has_comments(scl):
257            return False
258
259        sci_loc = self.source_code_info_by_scl.get(tuple(scl))
260        assert sci_loc is not None
261
262        lines = []
263        for leading_detached_comment in sci_loc.leading_detached_comments:
264            lines.extend(self._break_text(leading_detached_comment))
265            lines.append("")
266        if sci_loc.leading_comments is not None:
267            lines.extend(self._break_text(sci_loc.leading_comments))
268        # Trailing comments also go in the header - to make sure it gets into the docstring
269        if sci_loc.trailing_comments is not None:
270            lines.extend(self._break_text(sci_loc.trailing_comments))
271
272        if len(lines) == 1:
273            self._write_line('"""{}"""', lines[0])
274        else:
275            for i, line in enumerate(lines):
276                if i == 0:
277                    self._write_line('"""{}', line)
278                else:
279                    self._write_line("{}", line)
280            self._write_line('"""')
281
282        return True
283
284    def write_enum_values(
285        self,
286        values: Iterable[Tuple[int, d.EnumValueDescriptorProto]],
287        value_type: str,
288        scl_prefix: SourceCodeLocation,
289    ) -> None:
290        for i, val in values:
291            if val.name in PYTHON_RESERVED:
292                continue
293
294            scl = scl_prefix + [i]
295            self._write_line(
296                "{} = {}({})",
297                val.name,
298                value_type,
299                val.number,
300            )
301            if self._write_comments(scl):
302                self._write_line("")  # Extra newline to separate
303
304    def write_module_attributes(self) -> None:
305        l = self._write_line
306        l(
307            "DESCRIPTOR: {} = ...",
308            self._import("google.protobuf.descriptor", "FileDescriptor"),
309        )
310        l("")
311
312    def write_enums(
313        self,
314        enums: Iterable[d.EnumDescriptorProto],
315        prefix: str,
316        scl_prefix: SourceCodeLocation,
317    ) -> None:
318        l = self._write_line
319        for i, enum in enumerate(enums):
320            class_name = (
321                enum.name if enum.name not in PYTHON_RESERVED else "_r_" + enum.name
322            )
323            value_type_fq = prefix + class_name + ".V"
324
325            l(
326                "class {}({}, metaclass={}):",
327                class_name,
328                "_" + enum.name,
329                "_" + enum.name + "EnumTypeWrapper",
330            )
331            with self._indent():
332                scl = scl_prefix + [i]
333                self._write_comments(scl)
334                l("pass")
335            l("class {}:", "_" + enum.name)
336            with self._indent():
337                l(
338                    "V = {}('V', {})",
339                    self._import("typing", "NewType"),
340                    self._builtin("int"),
341                )
342            l(
343                "class {}({}[{}], {}):",
344                "_" + enum.name + "EnumTypeWrapper",
345                self._import(
346                    "google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"
347                ),
348                "_" + enum.name + ".V",
349                self._builtin("type"),
350            )
351            with self._indent():
352                l(
353                    "DESCRIPTOR: {} = ...",
354                    self._import("google.protobuf.descriptor", "EnumDescriptor"),
355                )
356                self.write_enum_values(
357                    [
358                        (i, v)
359                        for i, v in enumerate(enum.value)
360                        if v.name not in PROTO_ENUM_RESERVED
361                    ],
362                    value_type_fq,
363                    scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
364                )
365            l("")
366
367            self.write_enum_values(
368                enumerate(enum.value),
369                value_type_fq,
370                scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
371            )
372            if prefix == "" and not self.readable_stubs:
373                l("{} = {}", _mangle_global_identifier(class_name), class_name)
374                l("")
375            l("")
376
377    def write_messages(
378        self,
379        messages: Iterable[d.DescriptorProto],
380        prefix: str,
381        scl_prefix: SourceCodeLocation,
382    ) -> None:
383        l = self._write_line
384
385        for i, desc in enumerate(messages):
386            qualified_name = prefix + desc.name
387
388            # Reproduce some hardcoded logic from the protobuf implementation - where
389            # some specific "well_known_types" generated protos to have additional
390            # base classes
391            addl_base = u""
392            if self.fd.package + "." + desc.name in WKTBASES:
393                # chop off the .proto - and import the well known type
394                # eg `from google.protobuf.duration import Duration`
395                well_known_type = WKTBASES[self.fd.package + "." + desc.name]
396                addl_base = ", " + self._import(
397                    "google.protobuf.internal.well_known_types",
398                    well_known_type.__name__,
399                )
400
401            class_name = (
402                desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name
403            )
404            message_class = self._import("google.protobuf.message", "Message")
405            l("class {}({}{}):", class_name, message_class, addl_base)
406            with self._indent():
407                scl = scl_prefix + [i]
408                self._write_comments(scl)
409
410                l(
411                    "DESCRIPTOR: {} = ...",
412                    self._import("google.protobuf.descriptor", "Descriptor"),
413                )
414
415                # Nested enums/messages
416                self.write_enums(
417                    desc.enum_type,
418                    qualified_name + ".",
419                    scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER],
420                )
421                self.write_messages(
422                    desc.nested_type,
423                    qualified_name + ".",
424                    scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER],
425                )
426
427                # integer constants  for field numbers
428                for f in desc.field:
429                    l("{}_FIELD_NUMBER: {}", f.name.upper(), self._builtin("int"))
430
431                for idx, field in enumerate(desc.field):
432                    if field.name in PYTHON_RESERVED:
433                        continue
434
435                    if (
436                        is_scalar(field)
437                        and field.label != d.FieldDescriptorProto.LABEL_REPEATED
438                    ):
439                        # Scalar non repeated fields are r/w
440                        l("{}: {} = ...", field.name, self.python_type(field))
441                        if self._write_comments(
442                            scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
443                        ):
444                            l("")
445                    else:
446                        # r/o Getters for non-scalar fields and scalar-repeated fields
447                        scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
448                        l("@property")
449                        l(
450                            "def {}(self) -> {}:{}",
451                            field.name,
452                            self.python_type(field),
453                            " ..." if not self._has_comments(scl_field) else "",
454                        )
455                        if self._has_comments(scl_field):
456                            with self._indent():
457                                self._write_comments(scl_field)
458                                l("pass")
459
460                self.write_extensions(
461                    desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER]
462                )
463
464                # Constructor
465                self_arg = (
466                    "self_" if any(f.name == "self" for f in desc.field) else "self"
467                )
468                l("def __init__({},", self_arg)
469                with self._indent():
470                    constructor_fields = [
471                        f for f in desc.field if f.name not in PYTHON_RESERVED
472                    ]
473                    if len(constructor_fields) > 0:
474                        # Only positional args allowed
475                        # See https://github.com/dropbox/mypy-protobuf/issues/71
476                        l("*,")
477                    for field in constructor_fields:
478                        if (
479                            self.fd.syntax == "proto3"
480                            and is_scalar(field)
481                            and field.label != d.FieldDescriptorProto.LABEL_REPEATED
482                            and not self.relax_strict_optional_primitives
483                        ):
484                            l(
485                                "{} : {} = ...,",
486                                field.name,
487                                self.python_type(field, generic_container=True),
488                            )
489                        else:
490                            l(
491                                "{} : {}[{}] = ...,",
492                                field.name,
493                                self._import("typing", "Optional"),
494                                self.python_type(field, generic_container=True),
495                            )
496                    l(") -> None: ...")
497
498                self.write_stringly_typed_fields(desc)
499
500            if prefix == "" and not self.readable_stubs:
501                l("{} = {}", _mangle_global_identifier(class_name), class_name)
502            l("")
503
504    def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None:
505        """Type the stringly-typed methods as a Union[Literal, Literal ...]"""
506        l = self._write_line
507        # HasField, ClearField, WhichOneof accepts both bytes/unicode
508        # HasField only supports singular. ClearField supports repeated as well
509        # In proto3, HasField only supports message fields and optional fields
510        # HasField always supports oneof fields
511        hf_fields = [
512            f.name
513            for f in desc.field
514            if f.HasField("oneof_index")
515            or (
516                f.label != d.FieldDescriptorProto.LABEL_REPEATED
517                and (
518                    self.fd.syntax != "proto3"
519                    or f.type == d.FieldDescriptorProto.TYPE_MESSAGE
520                    or f.proto3_optional
521                )
522            )
523        ]
524        cf_fields = [f.name for f in desc.field]
525        wo_fields = {
526            oneof.name: [
527                f.name
528                for f in desc.field
529                if f.HasField("oneof_index") and f.oneof_index == idx
530            ]
531            for idx, oneof in enumerate(desc.oneof_decl)
532        }
533
534        hf_fields.extend(wo_fields.keys())
535        cf_fields.extend(wo_fields.keys())
536
537        hf_fields_text = ",".join(
538            sorted('u"{}",b"{}"'.format(name, name) for name in hf_fields)
539        )
540        cf_fields_text = ",".join(
541            sorted('u"{}",b"{}"'.format(name, name) for name in cf_fields)
542        )
543
544        if not hf_fields and not cf_fields and not wo_fields:
545            return
546
547        if hf_fields:
548            l(
549                "def HasField(self, field_name: {}[{}]) -> {}: ...",
550                self._import("typing_extensions", "Literal"),
551                hf_fields_text,
552                self._builtin("bool"),
553            )
554        if cf_fields:
555            l(
556                "def ClearField(self, field_name: {}[{}]) -> None: ...",
557                self._import("typing_extensions", "Literal"),
558                cf_fields_text,
559            )
560
561        for wo_field, members in sorted(wo_fields.items()):
562            if len(wo_fields) > 1:
563                l("@{}", self._import("typing", "overload"))
564            l(
565                "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}[{}]]: ...",
566                self._import("typing_extensions", "Literal"),
567                # Accepts both unicode and bytes in both py2 and py3
568                'u"{}",b"{}"'.format(wo_field, wo_field),
569                self._import("typing", "Optional"),
570                self._import("typing_extensions", "Literal"),
571                # Returns `str` in both py2 and py3 (bytes in py2, unicode in py3)
572                ",".join('"{}"'.format(m) for m in members),
573            )
574
575    def write_extensions(
576        self,
577        extensions: Sequence[d.FieldDescriptorProto],
578        scl_prefix: SourceCodeLocation,
579    ) -> None:
580        l = self._write_line
581        for i, ext in enumerate(extensions):
582            scl = scl_prefix + [i]
583
584            l(
585                "{}: {}[{}, {}] = ...",
586                ext.name,
587                self._import(
588                    "google.protobuf.internal.extension_dict",
589                    "_ExtensionFieldDescriptor",
590                ),
591                self._import_message(ext.extendee),
592                self.python_type(ext),
593            )
594            self._write_comments(scl)
595            l("")
596
597    def write_methods(
598        self,
599        service: d.ServiceDescriptorProto,
600        is_abstract: bool,
601        scl_prefix: SourceCodeLocation,
602    ) -> None:
603        l = self._write_line
604        methods = [
605            (i, m)
606            for i, m in enumerate(service.method)
607            if m.name not in PYTHON_RESERVED
608        ]
609        if not methods:
610            l("pass")
611        for i, method in methods:
612            if is_abstract:
613                l("@{}", self._import("abc", "abstractmethod"))
614            l("def {}(self,", method.name)
615            with self._indent():
616                l(
617                    "rpc_controller: {},",
618                    self._import("google.protobuf.service", "RpcController"),
619                )
620                l("request: {},", self._import_message(method.input_type))
621                l(
622                    "done: {}[{}[[{}], None]],",
623                    self._import("typing", "Optional"),
624                    self._import("typing", "Callable"),
625                    self._import_message(method.output_type),
626                )
627
628            scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
629            l(
630                ") -> {}[{}]:{}",
631                self._import("concurrent.futures", "Future"),
632                self._import_message(method.output_type),
633                " ..." if not self._has_comments(scl_method) else "",
634            )
635            if self._has_comments(scl_method):
636                with self._indent():
637                    self._write_comments(scl_method)
638                    l("pass")
639
640    def write_services(
641        self,
642        services: Iterable[d.ServiceDescriptorProto],
643        scl_prefix: SourceCodeLocation,
644    ) -> None:
645        l = self._write_line
646        for i, service in enumerate(services):
647            scl = scl_prefix + [i]
648            class_name = (
649                service.name
650                if service.name not in PYTHON_RESERVED
651                else "_r_" + service.name
652            )
653            # The service definition interface
654            l(
655                "class {}({}, metaclass={}):",
656                class_name,
657                self._import("google.protobuf.service", "Service"),
658                self._import("abc", "ABCMeta"),
659            )
660            with self._indent():
661                self._write_comments(scl)
662                self.write_methods(service, is_abstract=True, scl_prefix=scl)
663
664            # The stub client
665            l("class {}({}):", service.name + "_Stub", class_name)
666            with self._indent():
667                self._write_comments(scl)
668                l(
669                    "def __init__(self, rpc_channel: {}) -> None: ...",
670                    self._import("google.protobuf.service", "RpcChannel"),
671                )
672                self.write_methods(service, is_abstract=False, scl_prefix=scl)
673
674    def _import_casttype(self, casttype: str) -> str:
675        split = casttype.split(".")
676        assert (
677            len(split) == 2
678        ), "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile"
679        pkg = split[0].replace("/", ".")
680        return self._import(pkg, split[1])
681
682    def _map_key_value_types(
683        self,
684        map_field: d.FieldDescriptorProto,
685        key_field: d.FieldDescriptorProto,
686        value_field: d.FieldDescriptorProto,
687    ) -> Tuple[str, str]:
688        key_casttype = map_field.options.Extensions[extensions_pb2.keytype]
689        ktype = (
690            self._import_casttype(key_casttype)
691            if key_casttype
692            else self.python_type(key_field)
693        )
694        value_casttype = map_field.options.Extensions[extensions_pb2.valuetype]
695        vtype = (
696            self._import_casttype(value_casttype)
697            if value_casttype
698            else self.python_type(value_field)
699        )
700        return ktype, vtype
701
702    def _callable_type(self, method: d.MethodDescriptorProto) -> str:
703        if method.client_streaming:
704            if method.server_streaming:
705                return self._import("grpc", "StreamStreamMultiCallable")
706            else:
707                return self._import("grpc", "StreamUnaryMultiCallable")
708        else:
709            if method.server_streaming:
710                return self._import("grpc", "UnaryStreamMultiCallable")
711            else:
712                return self._import("grpc", "UnaryUnaryMultiCallable")
713
714    def _input_type(
715        self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True
716    ) -> str:
717        result = self._import_message(method.input_type)
718        if use_stream_iterator and method.client_streaming:
719            result = "{}[{}]".format(self._import("typing", "Iterator"), result)
720        return result
721
722    def _output_type(
723        self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True
724    ) -> str:
725        result = self._import_message(method.output_type)
726        if use_stream_iterator and method.server_streaming:
727            result = "{}[{}]".format(self._import("typing", "Iterator"), result)
728        return result
729
730    def write_grpc_methods(
731        self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation
732    ) -> None:
733        l = self._write_line
734        methods = [
735            (i, m)
736            for i, m in enumerate(service.method)
737            if m.name not in PYTHON_RESERVED
738        ]
739        if not methods:
740            l("pass")
741            l("")
742        for i, method in methods:
743            scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
744
745            l("@{}", self._import("abc", "abstractmethod"))
746            l("def {}(self,", method.name)
747            with self._indent():
748                l("request: {},", self._input_type(method))
749                l("context: {},", self._import("grpc", "ServicerContext"))
750            l(
751                ") -> {}:{}",
752                self._output_type(method),
753                " ..." if not self._has_comments(scl) else "",
754            ),
755            if self._has_comments(scl):
756                with self._indent():
757                    self._write_comments(scl)
758                    l("pass")
759            l("")
760
761    def write_grpc_stub_methods(
762        self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation
763    ) -> None:
764        l = self._write_line
765        methods = [
766            (i, m)
767            for i, m in enumerate(service.method)
768            if m.name not in PYTHON_RESERVED
769        ]
770        if not methods:
771            l("pass")
772            l("")
773        for i, method in methods:
774            scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
775
776            l("{}: {}[", method.name, self._callable_type(method))
777            with self._indent():
778                l("{},", self._input_type(method, False))
779                l("{}] = ...", self._output_type(method, False))
780            self._write_comments(scl)
781            l("")
782
783    def write_grpc_services(
784        self,
785        services: Iterable[d.ServiceDescriptorProto],
786        scl_prefix: SourceCodeLocation,
787    ) -> None:
788        l = self._write_line
789        for i, service in enumerate(services):
790            if service.name in PYTHON_RESERVED:
791                continue
792
793            scl = scl_prefix + [i]
794
795            # The stub client
796            l("class {}Stub:", service.name)
797            with self._indent():
798                self._write_comments(scl)
799                l(
800                    "def __init__(self, channel: {}) -> None: ...",
801                    self._import("grpc", "Channel"),
802                )
803                self.write_grpc_stub_methods(service, scl)
804            l("")
805
806            # The service definition interface
807            l(
808                "class {}Servicer(metaclass={}):",
809                service.name,
810                self._import("abc", "ABCMeta"),
811            )
812            with self._indent():
813                self._write_comments(scl)
814                self.write_grpc_methods(service, scl)
815            l("")
816            l(
817                "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
818                service.name,
819                service.name,
820                self._import("grpc", "Server"),
821            )
822            l("")
823
824    def python_type(
825        self, field: d.FieldDescriptorProto, generic_container: bool = False
826    ) -> str:
827        """
828        generic_container
829          if set, type the field with generic interfaces. Eg.
830          - Iterable[int] rather than RepeatedScalarFieldContainer[int]
831          - Mapping[k, v] rather than MessageMap[k, v]
832          Can be useful for input types (eg constructor)
833        """
834        casttype = field.options.Extensions[extensions_pb2.casttype]
835        if casttype:
836            return self._import_casttype(casttype)
837
838        mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = {
839            d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"),
840            d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"),
841            d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"),
842            d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"),
843            d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"),
844            d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"),
845            d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"),
846            d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"),
847            d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"),
848            d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"),
849            d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"),
850            d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"),
851            d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"),
852            d.FieldDescriptorProto.TYPE_STRING: lambda: self._import("typing", "Text"),
853            d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"),
854            d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(
855                field.type_name + ".V"
856            ),
857            d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(
858                field.type_name
859            ),
860            d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(
861                field.type_name
862            ),
863        }
864
865        assert field.type in mapping, "Unrecognized type: " + repr(field.type)
866        field_type = mapping[field.type]()
867
868        # For non-repeated fields, we're done!
869        if field.label != d.FieldDescriptorProto.LABEL_REPEATED:
870            return field_type
871
872        # Scalar repeated fields go in RepeatedScalarFieldContainer
873        if is_scalar(field):
874            container = (
875                self._import("typing", "Iterable")
876                if generic_container
877                else self._import(
878                    "google.protobuf.internal.containers",
879                    "RepeatedScalarFieldContainer",
880                )
881            )
882            return "{}[{}]".format(container, field_type)
883
884        # non-scalar repeated map fields go in ScalarMap/MessageMap
885        msg = self.descriptors.messages[field.type_name]
886        if msg.options.map_entry:
887            # map generates a special Entry wrapper message
888            if generic_container:
889                container = self._import("typing", "Mapping")
890            elif is_scalar(msg.field[1]):
891                container = self._import(
892                    "google.protobuf.internal.containers", "ScalarMap"
893                )
894            else:
895                container = self._import(
896                    "google.protobuf.internal.containers", "MessageMap"
897                )
898            ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1])
899            return "{}[{}, {}]".format(container, ktype, vtype)
900
901        # non-scalar repetated fields go in RepeatedCompositeFieldContainer
902        container = (
903            self._import("typing", "Iterable")
904            if generic_container
905            else self._import(
906                "google.protobuf.internal.containers",
907                "RepeatedCompositeFieldContainer",
908            )
909        )
910        return "{}[{}]".format(container, field_type)
911
912    def write(self) -> str:
913        for reexport_idx in self.fd.public_dependency:
914            reexport_file = self.fd.dependency[reexport_idx]
915            reexport_fd = self.descriptors.files[reexport_file]
916            reexport_imp = (
917                reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2"
918            )
919            names = (
920                [m.name for m in reexport_fd.message_type]
921                + [m.name for m in reexport_fd.enum_type]
922                + [v.name for m in reexport_fd.enum_type for v in m.value]
923                + [m.name for m in reexport_fd.extension]
924            )
925            if reexport_fd.options.py_generic_services:
926                names.extend(m.name for m in reexport_fd.service)
927
928            if names:
929                # n,n to force a reexport (from x import y as y)
930                self.from_imports[reexport_imp].update((n, n) for n in names)
931
932        import_lines = []
933        for pkg in sorted(self.imports):
934            import_lines.append(u"import {}".format(pkg))
935
936        for pkg, items in sorted(self.from_imports.items()):
937            import_lines.append(u"from {} import (".format(pkg))
938            for (name, reexport_name) in sorted(items):
939                if reexport_name is None:
940                    import_lines.append(u"    {},".format(name))
941                else:
942                    import_lines.append(u"    {} as {},".format(name, reexport_name))
943            import_lines.append(u")\n")
944        import_lines.append("")
945
946        return "\n".join(import_lines + self.lines)
947
948
949def is_scalar(fd: d.FieldDescriptorProto) -> bool:
950    return not (
951        fd.type == d.FieldDescriptorProto.TYPE_MESSAGE
952        or fd.type == d.FieldDescriptorProto.TYPE_GROUP
953    )
954
955
956def generate_mypy_stubs(
957    descriptors: Descriptors,
958    response: plugin_pb2.CodeGeneratorResponse,
959    quiet: bool,
960    readable_stubs: bool,
961    relax_strict_optional_primitives: bool,
962) -> None:
963    for name, fd in descriptors.to_generate.items():
964        pkg_writer = PkgWriter(
965            fd,
966            descriptors,
967            readable_stubs,
968            relax_strict_optional_primitives,
969            grpc=False,
970        )
971
972        pkg_writer.write_module_attributes()
973        pkg_writer.write_enums(
974            fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER]
975        )
976        pkg_writer.write_messages(
977            fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER]
978        )
979        pkg_writer.write_extensions(
980            fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER]
981        )
982        if fd.options.py_generic_services:
983            pkg_writer.write_services(
984                fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]
985            )
986
987        assert name == fd.name
988        assert fd.name.endswith(".proto")
989        output = response.file.add()
990        output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi"
991        output.content = HEADER + pkg_writer.write()
992        if not quiet:
993            print("Writing mypy to", output.name, file=sys.stderr)
994
995
996def generate_mypy_grpc_stubs(
997    descriptors: Descriptors,
998    response: plugin_pb2.CodeGeneratorResponse,
999    quiet: bool,
1000    readable_stubs: bool,
1001    relax_strict_optional_primitives: bool,
1002) -> None:
1003    for name, fd in descriptors.to_generate.items():
1004        pkg_writer = PkgWriter(
1005            fd,
1006            descriptors,
1007            readable_stubs,
1008            relax_strict_optional_primitives,
1009            grpc=True,
1010        )
1011        pkg_writer.write_grpc_services(
1012            fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]
1013        )
1014
1015        assert name == fd.name
1016        assert fd.name.endswith(".proto")
1017        output = response.file.add()
1018        output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi"
1019        output.content = HEADER + pkg_writer.write()
1020        if not quiet:
1021            print("Writing mypy to", output.name, file=sys.stderr)
1022
1023
1024@contextmanager
1025def code_generation() -> Iterator[
1026    Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],
1027]:
1028    if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"):
1029        print("mypy-protobuf " + __version__)
1030        sys.exit(0)
1031
1032    # Read request message from stdin
1033    data = sys.stdin.buffer.read()
1034
1035    # Parse request
1036    request = plugin_pb2.CodeGeneratorRequest()
1037    request.ParseFromString(data)
1038
1039    # Create response
1040    response = plugin_pb2.CodeGeneratorResponse()
1041
1042    # Declare support for optional proto3 fields
1043    response.supported_features |= (
1044        plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
1045    )
1046
1047    yield request, response
1048
1049    # Serialise response message
1050    output = response.SerializeToString()
1051
1052    # Write to stdout
1053    sys.stdout.buffer.write(output)
1054
1055
1056def main() -> None:
1057    # Generate mypy
1058    with code_generation() as (request, response):
1059        generate_mypy_stubs(
1060            Descriptors(request),
1061            response,
1062            "quiet" in request.parameter,
1063            "readable_stubs" in request.parameter,
1064            "relax_strict_optional_primitives" in request.parameter,
1065        )
1066
1067
1068def grpc() -> None:
1069    # Generate grpc mypy
1070    with code_generation() as (request, response):
1071        generate_mypy_grpc_stubs(
1072            Descriptors(request),
1073            response,
1074            "quiet" in request.parameter,
1075            "readable_stubs" in request.parameter,
1076            "relax_strict_optional_primitives" in request.parameter,
1077        )
1078
1079
1080if __name__ == "__main__":
1081    main()
1082