1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Module containing wrapper classes around meta-descriptors.
16
17This module contains dataclasses which wrap the descriptor protos
18defined in google/protobuf/descriptor.proto (which are descriptors that
19describe descriptors).
20
21These wrappers exist in order to provide useful helper methods and
22generally ease access to things in templates (in particular, documentation,
23certain aggregate views of things, etc.)
24
25Reading of underlying descriptor properties in templates *is* okay, a
26``__getattr__`` method which consistently routes in this way is provided.
27Documentation is consistently at ``{thing}.meta.doc``.
28"""
29
30import collections
31import dataclasses
32import re
33from itertools import chain
34from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping,
35                    ClassVar, Optional, Sequence, Set, Tuple, Union)
36from google.api import annotations_pb2      # type: ignore
37from google.api import client_pb2
38from google.api import field_behavior_pb2
39from google.api import resource_pb2
40from google.api_core import exceptions      # type: ignore
41from google.protobuf import descriptor_pb2  # type: ignore
42
43from gapic import utils
44from gapic.schema import metadata
45
46
47@dataclasses.dataclass(frozen=True)
48class Field:
49    """Description of a field."""
50    field_pb: descriptor_pb2.FieldDescriptorProto
51    message: Optional['MessageType'] = None
52    enum: Optional['EnumType'] = None
53    meta: metadata.Metadata = dataclasses.field(
54        default_factory=metadata.Metadata,
55    )
56    oneof: Optional[str] = None
57
58    def __getattr__(self, name):
59        return getattr(self.field_pb, name)
60
61    def __hash__(self):
62        # The only sense in which it is meaningful to say a field is equal to
63        # another field is if they are the same, i.e. they live in the same
64        # message type under the same moniker, i.e. they have the same id.
65        return id(self)
66
67    @property
68    def name(self) -> str:
69        """Used to prevent collisions with python keywords"""
70        name = self.field_pb.name
71        return name + "_" if name in utils.RESERVED_NAMES else name
72
73    @utils.cached_property
74    def ident(self) -> metadata.FieldIdentifier:
75        """Return the identifier to be used in templates."""
76        return metadata.FieldIdentifier(
77            ident=self.type.ident,
78            repeated=self.repeated,
79        )
80
81    @property
82    def is_primitive(self) -> bool:
83        """Return True if the field is a primitive, False otherwise."""
84        return isinstance(self.type, PrimitiveType)
85
86    @property
87    def map(self) -> bool:
88        """Return True if this field is a map, False otherwise."""
89        return bool(self.repeated and self.message and self.message.map)
90
91    @utils.cached_property
92    def mock_value(self) -> str:
93        visited_fields: Set["Field"] = set()
94        stack = [self]
95        answer = "{}"
96        while stack:
97            expr = stack.pop()
98            answer = answer.format(expr.inner_mock(stack, visited_fields))
99
100        return answer
101
102    def inner_mock(self, stack, visited_fields):
103        """Return a repr of a valid, usually truthy mock value."""
104        # For primitives, send a truthy value computed from the
105        # field name.
106        answer = 'None'
107        if isinstance(self.type, PrimitiveType):
108            if self.type.python_type == bool:
109                answer = 'True'
110            elif self.type.python_type == str:
111                answer = f"'{self.name}_value'"
112            elif self.type.python_type == bytes:
113                answer = f"b'{self.name}_blob'"
114            elif self.type.python_type == int:
115                answer = f'{sum([ord(i) for i in self.name])}'
116            elif self.type.python_type == float:
117                answer = f'0.{sum([ord(i) for i in self.name])}'
118            else:  # Impossible; skip coverage checks.
119                raise TypeError('Unrecognized PrimitiveType. This should '
120                                'never happen; please file an issue.')
121
122        # If this is an enum, select the first truthy value (or the zero
123        # value if nothing else exists).
124        if isinstance(self.type, EnumType):
125            # Note: The slightly-goofy [:2][-1] lets us gracefully fall
126            # back to index 0 if there is only one element.
127            mock_value = self.type.values[:2][-1]
128            answer = f'{self.type.ident}.{mock_value.name}'
129
130        # If this is another message, set one value on the message.
131        if (
132                not self.map    # Maps are handled separately
133                and isinstance(self.type, MessageType)
134                and len(self.type.fields)
135                # Nested message types need to terminate eventually
136                and self not in visited_fields
137        ):
138            sub = next(iter(self.type.fields.values()))
139            stack.append(sub)
140            visited_fields.add(self)
141            # Don't do the recursive rendering here, just set up
142            # where the nested value should go with the double {}.
143            answer = f'{self.type.ident}({sub.name}={{}})'
144
145        if self.map:
146            # Maps are a special case beacuse they're represented internally as
147            # a list of a generated type with two fields: 'key' and 'value'.
148            answer = '{{{}: {}}}'.format(
149                self.type.fields["key"].mock_value,
150                self.type.fields["value"].mock_value,
151            )
152        elif self.repeated:
153            # If this is a repeated field, then the mock answer should
154            # be a list.
155            answer = f'[{answer}]'
156
157        # Done; return the mock value.
158        return answer
159
160    @property
161    def proto_type(self) -> str:
162        """Return the proto type constant to be used in templates."""
163        return cast(str, descriptor_pb2.FieldDescriptorProto.Type.Name(
164            self.field_pb.type,
165        ))[len('TYPE_'):]
166
167    @property
168    def repeated(self) -> bool:
169        """Return True if this is a repeated field, False otherwise.
170
171        Returns:
172            bool: Whether this field is repeated.
173        """
174        return self.label == \
175            descriptor_pb2.FieldDescriptorProto.Label.Value(
176                'LABEL_REPEATED')   # type: ignore
177
178    @property
179    def required(self) -> bool:
180        """Return True if this is a required field, False otherwise.
181
182        Returns:
183            bool: Whether this field is required.
184        """
185        return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') in
186                self.options.Extensions[field_behavior_pb2.field_behavior])
187
188    @utils.cached_property
189    def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']:
190        """Return the type of this field."""
191        # If this is a message or enum, return the appropriate thing.
192        if self.type_name and self.message:
193            return self.message
194        if self.type_name and self.enum:
195            return self.enum
196
197        # This is a primitive. Return the corresponding Python type.
198        # The enum values used here are defined in:
199        #   Repository: https://github.com/google/protobuf/
200        #   Path: src/google/protobuf/descriptor.proto
201        #
202        # The values are used here because the code would be excessively
203        # verbose otherwise, and this is guaranteed never to change.
204        #
205        # 10, 11, and 14 are intentionally missing. They correspond to
206        # group (unused), message (covered above), and enum (covered above).
207        if self.field_pb.type in (1, 2):
208            return PrimitiveType.build(float)
209        if self.field_pb.type in (3, 4, 5, 6, 7, 13, 15, 16, 17, 18):
210            return PrimitiveType.build(int)
211        if self.field_pb.type == 8:
212            return PrimitiveType.build(bool)
213        if self.field_pb.type == 9:
214            return PrimitiveType.build(str)
215        if self.field_pb.type == 12:
216            return PrimitiveType.build(bytes)
217
218        # This should never happen.
219        raise TypeError(f'Unrecognized protobuf type: {self.field_pb.type}. '
220                        'This code should not be reachable; please file a bug.')
221
222    def with_context(
223            self,
224            *,
225            collisions: FrozenSet[str],
226            visited_messages: FrozenSet["MessageType"],
227    ) -> 'Field':
228        """Return a derivative of this field with the provided context.
229
230        This method is used to address naming collisions. The returned
231        ``Field`` object aliases module names to avoid naming collisions
232        in the file being written.
233        """
234        return dataclasses.replace(
235            self,
236            message=self.message.with_context(
237                collisions=collisions,
238                skip_fields=self.message in visited_messages,
239                visited_messages=visited_messages,
240            ) if self.message else None,
241            enum=self.enum.with_context(collisions=collisions)
242            if self.enum else None,
243            meta=self.meta.with_context(collisions=collisions),
244        )
245
246
247@dataclasses.dataclass(frozen=True)
248class Oneof:
249    """Description of a field."""
250    oneof_pb: descriptor_pb2.OneofDescriptorProto
251
252    def __getattr__(self, name):
253        return getattr(self.oneof_pb, name)
254
255
256@dataclasses.dataclass(frozen=True)
257class MessageType:
258    """Description of a message (defined with the ``message`` keyword)."""
259    # Class attributes
260    PATH_ARG_RE = re.compile(r'\{([a-zA-Z0-9_-]+)\}')
261
262    # Instance attributes
263    message_pb: descriptor_pb2.DescriptorProto
264    fields: Mapping[str, Field]
265    nested_enums: Mapping[str, 'EnumType']
266    nested_messages: Mapping[str, 'MessageType']
267    meta: metadata.Metadata = dataclasses.field(
268        default_factory=metadata.Metadata,
269    )
270    oneofs: Optional[Mapping[str, 'Oneof']] = None
271
272    def __getattr__(self, name):
273        return getattr(self.message_pb, name)
274
275    def __hash__(self):
276        # Identity is sufficiently unambiguous.
277        return hash(self.ident)
278
279    def oneof_fields(self, include_optional=False):
280        oneof_fields = collections.defaultdict(list)
281        for field in self.fields.values():
282            # Only include proto3 optional oneofs if explicitly looked for.
283            if field.oneof and not field.proto3_optional or include_optional:
284                oneof_fields[field.oneof].append(field)
285
286        return oneof_fields
287
288    @utils.cached_property
289    def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]:
290        answer = tuple(
291            field.type
292            for field in self.fields.values()
293            if field.message or field.enum
294        )
295
296        return answer
297
298    @utils.cached_property
299    def recursive_field_types(self) -> Sequence[
300        Union['MessageType', 'EnumType']
301    ]:
302        """Return all composite fields used in this proto's messages."""
303        types: Set[Union['MessageType', 'EnumType']] = set()
304
305        stack = [iter(self.fields.values())]
306        while stack:
307            fields_iter = stack.pop()
308            for field in fields_iter:
309                if field.message and field.type not in types:
310                    stack.append(iter(field.message.fields.values()))
311                if not field.is_primitive:
312                    types.add(field.type)
313
314        return tuple(types)
315
316    @utils.cached_property
317    def recursive_resource_fields(self) -> FrozenSet[Field]:
318        all_fields = chain(
319            self.fields.values(),
320            (field
321             for t in self.recursive_field_types if isinstance(t, MessageType)
322             for field in t.fields.values()),
323        )
324        return frozenset(
325            f
326            for f in all_fields
327            if (f.options.Extensions[resource_pb2.resource_reference].type or
328                f.options.Extensions[resource_pb2.resource_reference].child_type)
329        )
330
331    @property
332    def map(self) -> bool:
333        """Return True if the given message is a map, False otherwise."""
334        return self.message_pb.options.map_entry
335
336    @property
337    def ident(self) -> metadata.Address:
338        """Return the identifier data to be used in templates."""
339        return self.meta.address
340
341    @property
342    def resource_path(self) -> Optional[str]:
343        """If this message describes a resource, return the path to the resource.
344        If there are multiple paths, returns the first one."""
345        return next(
346            iter(self.options.Extensions[resource_pb2.resource].pattern),
347            None
348        )
349
350    @property
351    def resource_type(self) -> Optional[str]:
352        resource = self.options.Extensions[resource_pb2.resource]
353        return resource.type[resource.type.find('/') + 1:] if resource else None
354
355    @property
356    def resource_path_args(self) -> Sequence[str]:
357        return self.PATH_ARG_RE.findall(self.resource_path or '')
358
359    @utils.cached_property
360    def path_regex_str(self) -> str:
361        # The indirection here is a little confusing:
362        # we're using the resource path template as the base of a regex,
363        # with each resource ID segment being captured by a regex.
364        # E.g., the path schema
365        # kingdoms/{kingdom}/phyla/{phylum}
366        # becomes the regex
367        # ^kingdoms/(?P<kingdom>.+?)/phyla/(?P<phylum>.+?)$
368        parsing_regex_str = (
369            "^" +
370            self.PATH_ARG_RE.sub(
371                # We can't just use (?P<name>[^/]+) because segments may be
372                # separated by delimiters other than '/'.
373                # Multiple delimiter characters within one schema are allowed,
374                # e.g.
375                # as/{a}-{b}/cs/{c}%{d}_{e}
376                # This is discouraged but permitted by AIP4231
377                lambda m: "(?P<{name}>.+?)".format(name=m.groups()[0]),
378                self.resource_path or ''
379            ) +
380            "$"
381        )
382        return parsing_regex_str
383
384    def get_field(self, *field_path: str,
385                  collisions: FrozenSet[str] = frozenset()) -> Field:
386        """Return a field arbitrarily deep in this message's structure.
387
388        This method recursively traverses the message tree to return the
389        requested inner-field.
390
391        Traversing through repeated fields is not supported; a repeated field
392        may be specified if and only if it is the last field in the path.
393
394        Args:
395            field_path (Sequence[str]): The field path.
396
397        Returns:
398            ~.Field: A field object.
399
400        Raises:
401            KeyError: If a repeated field is used in the non-terminal position
402                in the path.
403        """
404        # If collisions are not explicitly specified, retrieve them
405        # from this message's address.
406        # This ensures that calls to `get_field` will return a field with
407        # the same context, regardless of the number of levels through the
408        # chain (in order to avoid infinite recursion on circular references,
409        # we only shallowly bind message references held by fields; this
410        # binds deeply in the one spot where that might be a problem).
411        collisions = collisions or self.meta.address.collisions
412
413        # Get the first field in the path.
414        cursor = self.fields[field_path[0]]
415
416        # Base case: If this is the last field in the path, return it outright.
417        if len(field_path) == 1:
418            return cursor.with_context(
419                collisions=collisions,
420                visited_messages=frozenset({self}),
421            )
422
423        # Sanity check: If cursor is a repeated field, then raise an exception.
424        # Repeated fields are only permitted in the terminal position.
425        if cursor.repeated:
426            raise KeyError(
427                f'The {cursor.name} field is repeated; unable to use '
428                '`get_field` to retrieve its children.\n'
429                'This exception usually indicates that a '
430                'google.api.method_signature annotation uses a repeated field '
431                'in the fields list in a position other than the end.',
432            )
433
434        # Sanity check: If this cursor has no message, there is a problem.
435        if not cursor.message:
436            raise KeyError(
437                f'Field {".".join(field_path)} could not be resolved from '
438                f'{cursor.name}.',
439            )
440
441        # Recursion case: Pass the remainder of the path to the sub-field's
442        # message.
443        return cursor.message.get_field(*field_path[1:], collisions=collisions)
444
445    def with_context(self, *,
446                     collisions: FrozenSet[str],
447                     skip_fields: bool = False,
448                     visited_messages: FrozenSet["MessageType"] = frozenset(),
449                     ) -> 'MessageType':
450        """Return a derivative of this message with the provided context.
451
452        This method is used to address naming collisions. The returned
453        ``MessageType`` object aliases module names to avoid naming collisions
454        in the file being written.
455
456        The ``skip_fields`` argument will omit applying the context to the
457        underlying fields. This provides for an "exit" in the case of circular
458        references.
459        """
460        visited_messages = visited_messages | {self}
461        return dataclasses.replace(
462            self,
463            fields=collections.OrderedDict(
464                (k, v.with_context(
465                    collisions=collisions,
466                    visited_messages=visited_messages
467                ))
468                for k, v in self.fields.items()
469            ) if not skip_fields else self.fields,
470            nested_enums=collections.OrderedDict(
471                (k, v.with_context(collisions=collisions))
472                for k, v in self.nested_enums.items()
473            ),
474            nested_messages=collections.OrderedDict(
475                (k, v.with_context(
476                    collisions=collisions,
477                    skip_fields=skip_fields,
478                    visited_messages=visited_messages,
479                ))
480                for k, v in self.nested_messages.items()),
481            meta=self.meta.with_context(collisions=collisions),
482        )
483
484
485@dataclasses.dataclass(frozen=True)
486class EnumValueType:
487    """Description of an enum value."""
488    enum_value_pb: descriptor_pb2.EnumValueDescriptorProto
489    meta: metadata.Metadata = dataclasses.field(
490        default_factory=metadata.Metadata,
491    )
492
493    def __getattr__(self, name):
494        return getattr(self.enum_value_pb, name)
495
496
497@dataclasses.dataclass(frozen=True)
498class EnumType:
499    """Description of an enum (defined with the ``enum`` keyword.)"""
500    enum_pb: descriptor_pb2.EnumDescriptorProto
501    values: List[EnumValueType]
502    meta: metadata.Metadata = dataclasses.field(
503        default_factory=metadata.Metadata,
504    )
505
506    def __hash__(self):
507        # Identity is sufficiently unambiguous.
508        return hash(self.ident)
509
510    def __getattr__(self, name):
511        return getattr(self.enum_pb, name)
512
513    @property
514    def resource_path(self) -> Optional[str]:
515        # This is a minor duck-typing workaround for the resource_messages
516        # property in the Service class: we need to check fields recursively
517        # to see if they're resources, and recursive_field_types includes enums
518        return None
519
520    @property
521    def ident(self) -> metadata.Address:
522        """Return the identifier data to be used in templates."""
523        return self.meta.address
524
525    def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType':
526        """Return a derivative of this enum with the provided context.
527
528        This method is used to address naming collisions. The returned
529        ``EnumType`` object aliases module names to avoid naming collisions in
530        the file being written.
531        """
532        return dataclasses.replace(
533            self,
534            meta=self.meta.with_context(collisions=collisions),
535        )
536
537
538@dataclasses.dataclass(frozen=True)
539class PythonType:
540    """Wrapper class for Python types.
541
542    This exists for interface consistency, so that methods like
543    :meth:`Field.type` can return an object and the caller can be confident
544    that a ``name`` property will be present.
545    """
546    meta: metadata.Metadata
547
548    def __eq__(self, other):
549        return self.meta == other.meta
550
551    def __ne__(self, other):
552        return not self == other
553
554    @utils.cached_property
555    def ident(self) -> metadata.Address:
556        """Return the identifier to be used in templates."""
557        return self.meta.address
558
559    @property
560    def name(self) -> str:
561        return self.ident.name
562
563    @property
564    def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]:
565        return tuple()
566
567
568@dataclasses.dataclass(frozen=True)
569class PrimitiveType(PythonType):
570    """A representation of a Python primitive type."""
571    python_type: Optional[type]
572
573    @classmethod
574    def build(cls, primitive_type: Optional[type]):
575        """Return a PrimitiveType object for the given Python primitive type.
576
577        Args:
578            primitive_type (cls): A Python primitive type, such as
579                :class:`int` or :class:`str`. Despite not being a type,
580                ``None`` is also accepted here.
581
582        Returns:
583            ~.PrimitiveType: The instantiated PrimitiveType object.
584        """
585        # Primitives have no import, and no module to reference, so the
586        # address just uses the name of the class (e.g. "int", "str").
587        return cls(meta=metadata.Metadata(address=metadata.Address(
588            name='None' if primitive_type is None else primitive_type.__name__,
589        )), python_type=primitive_type)
590
591    def __eq__(self, other):
592        # If we are sent the actual Python type (not the PrimitiveType object),
593        # claim to be equal to that.
594        if not hasattr(other, 'meta'):
595            return self.python_type is other
596        return super().__eq__(other)
597
598
599@dataclasses.dataclass(frozen=True)
600class OperationInfo:
601    """Representation of long-running operation info."""
602    response_type: MessageType
603    metadata_type: MessageType
604
605    def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo':
606        """Return a derivative of this OperationInfo with the provided context.
607
608          This method is used to address naming collisions. The returned
609          ``OperationInfo`` object aliases module names to avoid naming collisions
610          in the file being written.
611          """
612        return dataclasses.replace(
613            self,
614            response_type=self.response_type.with_context(
615                collisions=collisions
616            ),
617            metadata_type=self.metadata_type.with_context(
618                collisions=collisions
619            ),
620        )
621
622
623@dataclasses.dataclass(frozen=True)
624class RetryInfo:
625    """Representation of the method's retry behavior."""
626    max_attempts: int
627    initial_backoff: float
628    max_backoff: float
629    backoff_multiplier: float
630    retryable_exceptions: FrozenSet[exceptions.GoogleAPICallError]
631
632
633@dataclasses.dataclass(frozen=True)
634class Method:
635    """Description of a method (defined with the ``rpc`` keyword)."""
636    method_pb: descriptor_pb2.MethodDescriptorProto
637    input: MessageType
638    output: MessageType
639    lro: Optional[OperationInfo] = dataclasses.field(default=None)
640    retry: Optional[RetryInfo] = dataclasses.field(default=None)
641    timeout: Optional[float] = None
642    meta: metadata.Metadata = dataclasses.field(
643        default_factory=metadata.Metadata,
644    )
645
646    def __getattr__(self, name):
647        return getattr(self.method_pb, name)
648
649    @utils.cached_property
650    def client_output(self):
651        return self._client_output(enable_asyncio=False)
652
653    @utils.cached_property
654    def client_output_async(self):
655        return self._client_output(enable_asyncio=True)
656
657    def flattened_oneof_fields(self, include_optional=False):
658        oneof_fields = collections.defaultdict(list)
659        for field in self.flattened_fields.values():
660            # Only include proto3 optional oneofs if explicitly looked for.
661            if field.oneof and not field.proto3_optional or include_optional:
662                oneof_fields[field.oneof].append(field)
663
664        return oneof_fields
665
666    def _client_output(self, enable_asyncio: bool):
667        """Return the output from the client layer.
668
669        This takes into account transformations made by the outer GAPIC
670        client to transform the output from the transport.
671
672        Returns:
673            Union[~.MessageType, ~.PythonType]:
674                A description of the return type.
675        """
676        # Void messages ultimately return None.
677        if self.void:
678            return PrimitiveType.build(None)
679
680        # If this method is an LRO, return a PythonType instance representing
681        # that.
682        if self.lro:
683            return PythonType(meta=metadata.Metadata(
684                address=metadata.Address(
685                    name='AsyncOperation' if enable_asyncio else 'Operation',
686                    module='operation_async' if enable_asyncio else 'operation',
687                    package=('google', 'api_core'),
688                    collisions=self.lro.response_type.ident.collisions,
689                ),
690                documentation=utils.doc(
691                    'An object representing a long-running operation. \n\n'
692                    'The result type for the operation will be '
693                    ':class:`{ident}`: {doc}'.format(
694                        doc=self.lro.response_type.meta.doc,
695                        ident=self.lro.response_type.ident.sphinx,
696                    ),
697                ),
698            ))
699
700        # If this method is paginated, return that method's pager class.
701        if self.paged_result_field:
702            return PythonType(meta=metadata.Metadata(
703                address=metadata.Address(
704                    name=f'{self.name}AsyncPager' if enable_asyncio else f'{self.name}Pager',
705                    package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + (
706                        'services',
707                        utils.to_snake_case(self.ident.parent[-1]),
708                    ),
709                    module='pagers',
710                    collisions=self.input.ident.collisions,
711                ),
712                documentation=utils.doc(
713                    f'{self.output.meta.doc}\n\n'
714                    'Iterating over this object will yield results and '
715                    'resolve additional pages automatically.',
716                ),
717            ))
718
719        # Return the usual output.
720        return self.output
721
722    @property
723    def field_headers(self) -> Sequence[str]:
724        """Return the field headers defined for this method."""
725        http = self.options.Extensions[annotations_pb2.http]
726
727        pattern = re.compile(r'\{([a-z][\w\d_.]+)=')
728
729        potential_verbs = [
730            http.get,
731            http.put,
732            http.post,
733            http.delete,
734            http.patch,
735            http.custom.path,
736        ]
737
738        return next((tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ())
739
740    @utils.cached_property
741    def flattened_fields(self) -> Mapping[str, Field]:
742        """Return the signature defined for this method."""
743        cross_pkg_request = self.input.ident.package != self.ident.package
744
745        def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]:
746            for f in sig.split(','):
747                if not f:
748                    # Special case for an empty signature
749                    continue
750                name = f.strip()
751                field = self.input.get_field(*name.split('.'))
752                if cross_pkg_request and not field.is_primitive:
753                    # This is not a proto-plus wrapped message type,
754                    # and setting a non-primitive field directly is verboten.
755                    continue
756
757                yield name, field
758
759        signatures = self.options.Extensions[client_pb2.method_signature]
760        answer: Dict[str, Field] = collections.OrderedDict(
761            name_and_field
762            for sig in signatures
763            for name_and_field in filter_fields(sig)
764        )
765
766        return answer
767
768    @utils.cached_property
769    def flattened_field_to_key(self):
770        return {field.name: key for key, field in self.flattened_fields.items()}
771
772    @utils.cached_property
773    def legacy_flattened_fields(self) -> Mapping[str, Field]:
774        """Return the legacy flattening interface: top level fields only,
775        required fields first"""
776        required, optional = utils.partition(lambda f: f.required,
777                                             self.input.fields.values())
778        return collections.OrderedDict((f.name, f)
779                                       for f in chain(required, optional))
780
781    @property
782    def grpc_stub_type(self) -> str:
783        """Return the type of gRPC stub to use."""
784        return '{client}_{server}'.format(
785            client='stream' if self.client_streaming else 'unary',
786            server='stream' if self.server_streaming else 'unary',
787        )
788
789    @utils.cached_property
790    def idempotent(self) -> bool:
791        """Return True if we know this method is idempotent, False otherwise.
792
793        Note: We are intentionally conservative here. It is far less bad
794        to falsely believe an idempotent method is non-idempotent than
795        the converse.
796        """
797        return bool(self.options.Extensions[annotations_pb2.http].get)
798
799    @property
800    def ident(self) -> metadata.Address:
801        """Return the identifier data to be used in templates."""
802        return self.meta.address
803
804    @utils.cached_property
805    def paged_result_field(self) -> Optional[Field]:
806        """Return the response pagination field if the method is paginated."""
807        # If the request field lacks any of the expected pagination fields,
808        # then the method is not paginated.
809        for page_field in ((self.input, int, 'page_size'),
810                           (self.input, str, 'page_token'),
811                           (self.output, str, 'next_page_token')):
812            field = page_field[0].fields.get(page_field[2], None)
813            if not field or field.type != page_field[1]:
814                return None
815
816        # Return the first repeated field.
817        for field in self.output.fields.values():
818            if field.repeated:
819                return field
820
821        # We found no repeated fields. Return None.
822        return None
823
824    @utils.cached_property
825    def ref_types(self) -> Sequence[Union[MessageType, EnumType]]:
826        return self._ref_types(True)
827
828    @utils.cached_property
829    def flat_ref_types(self) -> Sequence[Union[MessageType, EnumType]]:
830        return self._ref_types(False)
831
832    def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]:
833        """Return types referenced by this method."""
834        # Begin with the input (request) and output (response) messages.
835        answer: List[Union[MessageType, EnumType]] = [self.input]
836        types: Iterable[Union[MessageType, EnumType]] = (
837            self.input.recursive_field_types if recursive
838            else (
839                f.type
840                for f in self.flattened_fields.values()
841                if f.message or f.enum
842            )
843        )
844        answer.extend(types)
845
846        if not self.void:
847            answer.append(self.client_output)
848            answer.extend(self.client_output.field_types)
849            answer.append(self.client_output_async)
850            answer.extend(self.client_output_async.field_types)
851
852        # If this method has LRO, it is possible (albeit unlikely) that
853        # the LRO messages reside in a different module.
854        if self.lro:
855            answer.append(self.lro.response_type)
856            answer.append(self.lro.metadata_type)
857
858        # If this message paginates its responses, it is possible
859        # that the individual result messages reside in a different module.
860        if self.paged_result_field and self.paged_result_field.message:
861            answer.append(self.paged_result_field.message)
862
863        # Done; return the answer.
864        return tuple(answer)
865
866    @property
867    def void(self) -> bool:
868        """Return True if this method has no return value, False otherwise."""
869        return self.output.ident.proto == 'google.protobuf.Empty'
870
871    def with_context(self, *, collisions: FrozenSet[str]) -> 'Method':
872        """Return a derivative of this method with the provided context.
873
874        This method is used to address naming collisions. The returned
875        ``Method`` object aliases module names to avoid naming collisions
876        in the file being written.
877        """
878        maybe_lro = self.lro.with_context(
879            collisions=collisions
880        ) if self.lro else None
881
882        return dataclasses.replace(
883            self,
884            lro=maybe_lro,
885            input=self.input.with_context(collisions=collisions),
886            output=self.output.with_context(collisions=collisions),
887            meta=self.meta.with_context(collisions=collisions),
888        )
889
890
891@dataclasses.dataclass(frozen=True)
892class CommonResource:
893    type_name: str
894    pattern: str
895
896    @classmethod
897    def build(cls, resource: resource_pb2.ResourceDescriptor):
898        return cls(
899            type_name=resource.type,
900            pattern=next(iter(resource.pattern))
901        )
902
903    @utils.cached_property
904    def message_type(self):
905        message_pb = descriptor_pb2.DescriptorProto()
906        res_pb = message_pb.options.Extensions[resource_pb2.resource]
907        res_pb.type = self.type_name
908        res_pb.pattern.append(self.pattern)
909
910        return MessageType(
911            message_pb=message_pb,
912            fields={},
913            nested_enums={},
914            nested_messages={},
915        )
916
917
918@dataclasses.dataclass(frozen=True)
919class Service:
920    """Description of a service (defined with the ``service`` keyword)."""
921    service_pb: descriptor_pb2.ServiceDescriptorProto
922    methods: Mapping[str, Method]
923    # N.B.: visible_resources is intended to be a read-only view
924    # whose backing store is owned by the API.
925    # This is represented by a types.MappingProxyType instance.
926    visible_resources: Mapping[str, MessageType]
927    meta: metadata.Metadata = dataclasses.field(
928        default_factory=metadata.Metadata,
929    )
930
931    common_resources: ClassVar[Mapping[str, CommonResource]] = dataclasses.field(
932        default={
933            "cloudresourcemanager.googleapis.com/Project": CommonResource(
934                "cloudresourcemanager.googleapis.com/Project",
935                "projects/{project}",
936            ),
937            "cloudresourcemanager.googleapis.com/Organization": CommonResource(
938                "cloudresourcemanager.googleapis.com/Organization",
939                "organizations/{organization}",
940            ),
941            "cloudresourcemanager.googleapis.com/Folder": CommonResource(
942                "cloudresourcemanager.googleapis.com/Folder",
943                "folders/{folder}",
944            ),
945            "cloudbilling.googleapis.com/BillingAccount": CommonResource(
946                "cloudbilling.googleapis.com/BillingAccount",
947                "billingAccounts/{billing_account}",
948            ),
949            "locations.googleapis.com/Location": CommonResource(
950                "locations.googleapis.com/Location",
951                "projects/{project}/locations/{location}",
952            ),
953        },
954        init=False,
955        compare=False,
956    )
957
958    def __getattr__(self, name):
959        return getattr(self.service_pb, name)
960
961    @property
962    def client_name(self) -> str:
963        """Returns the name of the generated client class"""
964        return self.name + "Client"
965
966    @property
967    def async_client_name(self) -> str:
968        """Returns the name of the generated AsyncIO client class"""
969        return self.name + "AsyncClient"
970
971    @property
972    def transport_name(self):
973        return self.name + "Transport"
974
975    @property
976    def grpc_transport_name(self):
977        return self.name + "GrpcTransport"
978
979    @property
980    def grpc_asyncio_transport_name(self):
981        return self.name + "GrpcAsyncIOTransport"
982
983    @property
984    def has_lro(self) -> bool:
985        """Return whether the service has a long-running method."""
986        return any([m.lro for m in self.methods.values()])
987
988    @property
989    def has_pagers(self) -> bool:
990        """Return whether the service has paged methods."""
991        return any(m.paged_result_field for m in self.methods.values())
992
993    @property
994    def host(self) -> str:
995        """Return the hostname for this service, if specified.
996
997        Returns:
998            str: The hostname, with no protocol and no trailing ``/``.
999        """
1000        if self.options.Extensions[client_pb2.default_host]:
1001            return self.options.Extensions[client_pb2.default_host]
1002        return ''
1003
1004    @property
1005    def oauth_scopes(self) -> Sequence[str]:
1006        """Return a sequence of oauth scopes, if applicable.
1007
1008        Returns:
1009            Sequence[str]: A sequence of OAuth scopes.
1010        """
1011        # Return the OAuth scopes, split on comma.
1012        return tuple(
1013            i.strip()
1014            for i in self.options.Extensions[client_pb2.oauth_scopes].split(',')
1015            if i
1016        )
1017
1018    @property
1019    def module_name(self) -> str:
1020        """Return the appropriate module name for this service.
1021
1022        Returns:
1023            str: The service name, in snake case.
1024        """
1025        return utils.to_snake_case(self.name)
1026
1027    @utils.cached_property
1028    def names(self) -> FrozenSet[str]:
1029        """Return a set of names used in this service.
1030
1031        This is used for detecting naming collisions in the module names
1032        used for imports.
1033        """
1034        # Put together a set of the service and method names.
1035        answer = {self.name, self.client_name, self.async_client_name}
1036        answer.update(
1037            utils.to_snake_case(i.name) for i in self.methods.values()
1038        )
1039
1040        # Identify any import module names where the same module name is used
1041        # from distinct packages.
1042        modules: Dict[str, Set[str]] = collections.defaultdict(set)
1043        for m in self.methods.values():
1044            for t in m.ref_types:
1045                modules[t.ident.module].add(t.ident.package)
1046
1047        answer.update(
1048            module_name
1049            for module_name, packages in modules.items()
1050            if len(packages) > 1
1051        )
1052
1053        # Done; return the answer.
1054        return frozenset(answer)
1055
1056    @utils.cached_property
1057    def resource_messages(self) -> FrozenSet[MessageType]:
1058        """Returns all the resource message types used in all
1059        request and response fields in the service."""
1060        def gen_resources(message):
1061            if message.resource_path:
1062                yield message
1063
1064            for type_ in message.recursive_field_types:
1065                if type_.resource_path:
1066                    yield type_
1067
1068        def gen_indirect_resources_used(message):
1069            for field in message.recursive_resource_fields:
1070                resource = field.options.Extensions[
1071                    resource_pb2.resource_reference]
1072                resource_type = resource.type or resource.child_type
1073                # The resource may not be visible if the resource type is one of
1074                # the common_resources (see the class var in class definition)
1075                # or if it's something unhelpful like '*'.
1076                resource = self.visible_resources.get(resource_type)
1077                if resource:
1078                    yield resource
1079
1080        return frozenset(
1081            msg
1082            for method in self.methods.values()
1083            for msg in chain(
1084                gen_resources(method.input),
1085                gen_resources(
1086                    method.lro.response_type if method.lro else method.output
1087                ),
1088                gen_indirect_resources_used(method.input),
1089                gen_indirect_resources_used(
1090                    method.lro.response_type if method.lro else method.output
1091                ),
1092            )
1093        )
1094
1095    @utils.cached_property
1096    def any_client_streaming(self) -> bool:
1097        return any(m.client_streaming for m in self.methods.values())
1098
1099    @utils.cached_property
1100    def any_server_streaming(self) -> bool:
1101        return any(m.server_streaming for m in self.methods.values())
1102
1103    def with_context(self, *, collisions: FrozenSet[str]) -> 'Service':
1104        """Return a derivative of this service with the provided context.
1105
1106        This method is used to address naming collisions. The returned
1107        ``Service`` object aliases module names to avoid naming collisions
1108        in the file being written.
1109        """
1110        return dataclasses.replace(
1111            self,
1112            methods=collections.OrderedDict(
1113                (k, v.with_context(
1114                    # A methodd's flattened fields create additional names
1115                    # that may conflict with module imports.
1116                    collisions=collisions | frozenset(v.flattened_fields.keys()))
1117                 )
1118                for k, v in self.methods.items()
1119            ),
1120            meta=self.meta.with_context(collisions=collisions),
1121        )
1122