1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Module containing wrapper classes around meta-descriptors. 16 17This module contains dataclasses which wrap the descriptor protos 18defined in google/protobuf/descriptor.proto (which are descriptors that 19describe descriptors). 20 21These wrappers exist in order to provide useful helper methods and 22generally ease access to things in templates (in particular, documentation, 23certain aggregate views of things, etc.) 24 25Reading of underlying descriptor properties in templates *is* okay, a 26``__getattr__`` method which consistently routes in this way is provided. 27Documentation is consistently at ``{thing}.meta.doc``. 28""" 29 30import collections 31import dataclasses 32import re 33from itertools import chain 34from typing import (cast, Dict, FrozenSet, Iterable, List, Mapping, 35 ClassVar, Optional, Sequence, Set, Tuple, Union) 36from google.api import annotations_pb2 # type: ignore 37from google.api import client_pb2 38from google.api import field_behavior_pb2 39from google.api import resource_pb2 40from google.api_core import exceptions # type: ignore 41from google.protobuf import descriptor_pb2 # type: ignore 42 43from gapic import utils 44from gapic.schema import metadata 45 46 47@dataclasses.dataclass(frozen=True) 48class Field: 49 """Description of a field.""" 50 field_pb: descriptor_pb2.FieldDescriptorProto 51 message: Optional['MessageType'] = None 52 enum: Optional['EnumType'] = None 53 meta: metadata.Metadata = dataclasses.field( 54 default_factory=metadata.Metadata, 55 ) 56 oneof: Optional[str] = None 57 58 def __getattr__(self, name): 59 return getattr(self.field_pb, name) 60 61 def __hash__(self): 62 # The only sense in which it is meaningful to say a field is equal to 63 # another field is if they are the same, i.e. they live in the same 64 # message type under the same moniker, i.e. they have the same id. 65 return id(self) 66 67 @property 68 def name(self) -> str: 69 """Used to prevent collisions with python keywords""" 70 name = self.field_pb.name 71 return name + "_" if name in utils.RESERVED_NAMES else name 72 73 @utils.cached_property 74 def ident(self) -> metadata.FieldIdentifier: 75 """Return the identifier to be used in templates.""" 76 return metadata.FieldIdentifier( 77 ident=self.type.ident, 78 repeated=self.repeated, 79 ) 80 81 @property 82 def is_primitive(self) -> bool: 83 """Return True if the field is a primitive, False otherwise.""" 84 return isinstance(self.type, PrimitiveType) 85 86 @property 87 def map(self) -> bool: 88 """Return True if this field is a map, False otherwise.""" 89 return bool(self.repeated and self.message and self.message.map) 90 91 @utils.cached_property 92 def mock_value(self) -> str: 93 visited_fields: Set["Field"] = set() 94 stack = [self] 95 answer = "{}" 96 while stack: 97 expr = stack.pop() 98 answer = answer.format(expr.inner_mock(stack, visited_fields)) 99 100 return answer 101 102 def inner_mock(self, stack, visited_fields): 103 """Return a repr of a valid, usually truthy mock value.""" 104 # For primitives, send a truthy value computed from the 105 # field name. 106 answer = 'None' 107 if isinstance(self.type, PrimitiveType): 108 if self.type.python_type == bool: 109 answer = 'True' 110 elif self.type.python_type == str: 111 answer = f"'{self.name}_value'" 112 elif self.type.python_type == bytes: 113 answer = f"b'{self.name}_blob'" 114 elif self.type.python_type == int: 115 answer = f'{sum([ord(i) for i in self.name])}' 116 elif self.type.python_type == float: 117 answer = f'0.{sum([ord(i) for i in self.name])}' 118 else: # Impossible; skip coverage checks. 119 raise TypeError('Unrecognized PrimitiveType. This should ' 120 'never happen; please file an issue.') 121 122 # If this is an enum, select the first truthy value (or the zero 123 # value if nothing else exists). 124 if isinstance(self.type, EnumType): 125 # Note: The slightly-goofy [:2][-1] lets us gracefully fall 126 # back to index 0 if there is only one element. 127 mock_value = self.type.values[:2][-1] 128 answer = f'{self.type.ident}.{mock_value.name}' 129 130 # If this is another message, set one value on the message. 131 if ( 132 not self.map # Maps are handled separately 133 and isinstance(self.type, MessageType) 134 and len(self.type.fields) 135 # Nested message types need to terminate eventually 136 and self not in visited_fields 137 ): 138 sub = next(iter(self.type.fields.values())) 139 stack.append(sub) 140 visited_fields.add(self) 141 # Don't do the recursive rendering here, just set up 142 # where the nested value should go with the double {}. 143 answer = f'{self.type.ident}({sub.name}={{}})' 144 145 if self.map: 146 # Maps are a special case beacuse they're represented internally as 147 # a list of a generated type with two fields: 'key' and 'value'. 148 answer = '{{{}: {}}}'.format( 149 self.type.fields["key"].mock_value, 150 self.type.fields["value"].mock_value, 151 ) 152 elif self.repeated: 153 # If this is a repeated field, then the mock answer should 154 # be a list. 155 answer = f'[{answer}]' 156 157 # Done; return the mock value. 158 return answer 159 160 @property 161 def proto_type(self) -> str: 162 """Return the proto type constant to be used in templates.""" 163 return cast(str, descriptor_pb2.FieldDescriptorProto.Type.Name( 164 self.field_pb.type, 165 ))[len('TYPE_'):] 166 167 @property 168 def repeated(self) -> bool: 169 """Return True if this is a repeated field, False otherwise. 170 171 Returns: 172 bool: Whether this field is repeated. 173 """ 174 return self.label == \ 175 descriptor_pb2.FieldDescriptorProto.Label.Value( 176 'LABEL_REPEATED') # type: ignore 177 178 @property 179 def required(self) -> bool: 180 """Return True if this is a required field, False otherwise. 181 182 Returns: 183 bool: Whether this field is required. 184 """ 185 return (field_behavior_pb2.FieldBehavior.Value('REQUIRED') in 186 self.options.Extensions[field_behavior_pb2.field_behavior]) 187 188 @utils.cached_property 189 def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']: 190 """Return the type of this field.""" 191 # If this is a message or enum, return the appropriate thing. 192 if self.type_name and self.message: 193 return self.message 194 if self.type_name and self.enum: 195 return self.enum 196 197 # This is a primitive. Return the corresponding Python type. 198 # The enum values used here are defined in: 199 # Repository: https://github.com/google/protobuf/ 200 # Path: src/google/protobuf/descriptor.proto 201 # 202 # The values are used here because the code would be excessively 203 # verbose otherwise, and this is guaranteed never to change. 204 # 205 # 10, 11, and 14 are intentionally missing. They correspond to 206 # group (unused), message (covered above), and enum (covered above). 207 if self.field_pb.type in (1, 2): 208 return PrimitiveType.build(float) 209 if self.field_pb.type in (3, 4, 5, 6, 7, 13, 15, 16, 17, 18): 210 return PrimitiveType.build(int) 211 if self.field_pb.type == 8: 212 return PrimitiveType.build(bool) 213 if self.field_pb.type == 9: 214 return PrimitiveType.build(str) 215 if self.field_pb.type == 12: 216 return PrimitiveType.build(bytes) 217 218 # This should never happen. 219 raise TypeError(f'Unrecognized protobuf type: {self.field_pb.type}. ' 220 'This code should not be reachable; please file a bug.') 221 222 def with_context( 223 self, 224 *, 225 collisions: FrozenSet[str], 226 visited_messages: FrozenSet["MessageType"], 227 ) -> 'Field': 228 """Return a derivative of this field with the provided context. 229 230 This method is used to address naming collisions. The returned 231 ``Field`` object aliases module names to avoid naming collisions 232 in the file being written. 233 """ 234 return dataclasses.replace( 235 self, 236 message=self.message.with_context( 237 collisions=collisions, 238 skip_fields=self.message in visited_messages, 239 visited_messages=visited_messages, 240 ) if self.message else None, 241 enum=self.enum.with_context(collisions=collisions) 242 if self.enum else None, 243 meta=self.meta.with_context(collisions=collisions), 244 ) 245 246 247@dataclasses.dataclass(frozen=True) 248class Oneof: 249 """Description of a field.""" 250 oneof_pb: descriptor_pb2.OneofDescriptorProto 251 252 def __getattr__(self, name): 253 return getattr(self.oneof_pb, name) 254 255 256@dataclasses.dataclass(frozen=True) 257class MessageType: 258 """Description of a message (defined with the ``message`` keyword).""" 259 # Class attributes 260 PATH_ARG_RE = re.compile(r'\{([a-zA-Z0-9_-]+)\}') 261 262 # Instance attributes 263 message_pb: descriptor_pb2.DescriptorProto 264 fields: Mapping[str, Field] 265 nested_enums: Mapping[str, 'EnumType'] 266 nested_messages: Mapping[str, 'MessageType'] 267 meta: metadata.Metadata = dataclasses.field( 268 default_factory=metadata.Metadata, 269 ) 270 oneofs: Optional[Mapping[str, 'Oneof']] = None 271 272 def __getattr__(self, name): 273 return getattr(self.message_pb, name) 274 275 def __hash__(self): 276 # Identity is sufficiently unambiguous. 277 return hash(self.ident) 278 279 def oneof_fields(self, include_optional=False): 280 oneof_fields = collections.defaultdict(list) 281 for field in self.fields.values(): 282 # Only include proto3 optional oneofs if explicitly looked for. 283 if field.oneof and not field.proto3_optional or include_optional: 284 oneof_fields[field.oneof].append(field) 285 286 return oneof_fields 287 288 @utils.cached_property 289 def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: 290 answer = tuple( 291 field.type 292 for field in self.fields.values() 293 if field.message or field.enum 294 ) 295 296 return answer 297 298 @utils.cached_property 299 def recursive_field_types(self) -> Sequence[ 300 Union['MessageType', 'EnumType'] 301 ]: 302 """Return all composite fields used in this proto's messages.""" 303 types: Set[Union['MessageType', 'EnumType']] = set() 304 305 stack = [iter(self.fields.values())] 306 while stack: 307 fields_iter = stack.pop() 308 for field in fields_iter: 309 if field.message and field.type not in types: 310 stack.append(iter(field.message.fields.values())) 311 if not field.is_primitive: 312 types.add(field.type) 313 314 return tuple(types) 315 316 @utils.cached_property 317 def recursive_resource_fields(self) -> FrozenSet[Field]: 318 all_fields = chain( 319 self.fields.values(), 320 (field 321 for t in self.recursive_field_types if isinstance(t, MessageType) 322 for field in t.fields.values()), 323 ) 324 return frozenset( 325 f 326 for f in all_fields 327 if (f.options.Extensions[resource_pb2.resource_reference].type or 328 f.options.Extensions[resource_pb2.resource_reference].child_type) 329 ) 330 331 @property 332 def map(self) -> bool: 333 """Return True if the given message is a map, False otherwise.""" 334 return self.message_pb.options.map_entry 335 336 @property 337 def ident(self) -> metadata.Address: 338 """Return the identifier data to be used in templates.""" 339 return self.meta.address 340 341 @property 342 def resource_path(self) -> Optional[str]: 343 """If this message describes a resource, return the path to the resource. 344 If there are multiple paths, returns the first one.""" 345 return next( 346 iter(self.options.Extensions[resource_pb2.resource].pattern), 347 None 348 ) 349 350 @property 351 def resource_type(self) -> Optional[str]: 352 resource = self.options.Extensions[resource_pb2.resource] 353 return resource.type[resource.type.find('/') + 1:] if resource else None 354 355 @property 356 def resource_path_args(self) -> Sequence[str]: 357 return self.PATH_ARG_RE.findall(self.resource_path or '') 358 359 @utils.cached_property 360 def path_regex_str(self) -> str: 361 # The indirection here is a little confusing: 362 # we're using the resource path template as the base of a regex, 363 # with each resource ID segment being captured by a regex. 364 # E.g., the path schema 365 # kingdoms/{kingdom}/phyla/{phylum} 366 # becomes the regex 367 # ^kingdoms/(?P<kingdom>.+?)/phyla/(?P<phylum>.+?)$ 368 parsing_regex_str = ( 369 "^" + 370 self.PATH_ARG_RE.sub( 371 # We can't just use (?P<name>[^/]+) because segments may be 372 # separated by delimiters other than '/'. 373 # Multiple delimiter characters within one schema are allowed, 374 # e.g. 375 # as/{a}-{b}/cs/{c}%{d}_{e} 376 # This is discouraged but permitted by AIP4231 377 lambda m: "(?P<{name}>.+?)".format(name=m.groups()[0]), 378 self.resource_path or '' 379 ) + 380 "$" 381 ) 382 return parsing_regex_str 383 384 def get_field(self, *field_path: str, 385 collisions: FrozenSet[str] = frozenset()) -> Field: 386 """Return a field arbitrarily deep in this message's structure. 387 388 This method recursively traverses the message tree to return the 389 requested inner-field. 390 391 Traversing through repeated fields is not supported; a repeated field 392 may be specified if and only if it is the last field in the path. 393 394 Args: 395 field_path (Sequence[str]): The field path. 396 397 Returns: 398 ~.Field: A field object. 399 400 Raises: 401 KeyError: If a repeated field is used in the non-terminal position 402 in the path. 403 """ 404 # If collisions are not explicitly specified, retrieve them 405 # from this message's address. 406 # This ensures that calls to `get_field` will return a field with 407 # the same context, regardless of the number of levels through the 408 # chain (in order to avoid infinite recursion on circular references, 409 # we only shallowly bind message references held by fields; this 410 # binds deeply in the one spot where that might be a problem). 411 collisions = collisions or self.meta.address.collisions 412 413 # Get the first field in the path. 414 cursor = self.fields[field_path[0]] 415 416 # Base case: If this is the last field in the path, return it outright. 417 if len(field_path) == 1: 418 return cursor.with_context( 419 collisions=collisions, 420 visited_messages=frozenset({self}), 421 ) 422 423 # Sanity check: If cursor is a repeated field, then raise an exception. 424 # Repeated fields are only permitted in the terminal position. 425 if cursor.repeated: 426 raise KeyError( 427 f'The {cursor.name} field is repeated; unable to use ' 428 '`get_field` to retrieve its children.\n' 429 'This exception usually indicates that a ' 430 'google.api.method_signature annotation uses a repeated field ' 431 'in the fields list in a position other than the end.', 432 ) 433 434 # Sanity check: If this cursor has no message, there is a problem. 435 if not cursor.message: 436 raise KeyError( 437 f'Field {".".join(field_path)} could not be resolved from ' 438 f'{cursor.name}.', 439 ) 440 441 # Recursion case: Pass the remainder of the path to the sub-field's 442 # message. 443 return cursor.message.get_field(*field_path[1:], collisions=collisions) 444 445 def with_context(self, *, 446 collisions: FrozenSet[str], 447 skip_fields: bool = False, 448 visited_messages: FrozenSet["MessageType"] = frozenset(), 449 ) -> 'MessageType': 450 """Return a derivative of this message with the provided context. 451 452 This method is used to address naming collisions. The returned 453 ``MessageType`` object aliases module names to avoid naming collisions 454 in the file being written. 455 456 The ``skip_fields`` argument will omit applying the context to the 457 underlying fields. This provides for an "exit" in the case of circular 458 references. 459 """ 460 visited_messages = visited_messages | {self} 461 return dataclasses.replace( 462 self, 463 fields=collections.OrderedDict( 464 (k, v.with_context( 465 collisions=collisions, 466 visited_messages=visited_messages 467 )) 468 for k, v in self.fields.items() 469 ) if not skip_fields else self.fields, 470 nested_enums=collections.OrderedDict( 471 (k, v.with_context(collisions=collisions)) 472 for k, v in self.nested_enums.items() 473 ), 474 nested_messages=collections.OrderedDict( 475 (k, v.with_context( 476 collisions=collisions, 477 skip_fields=skip_fields, 478 visited_messages=visited_messages, 479 )) 480 for k, v in self.nested_messages.items()), 481 meta=self.meta.with_context(collisions=collisions), 482 ) 483 484 485@dataclasses.dataclass(frozen=True) 486class EnumValueType: 487 """Description of an enum value.""" 488 enum_value_pb: descriptor_pb2.EnumValueDescriptorProto 489 meta: metadata.Metadata = dataclasses.field( 490 default_factory=metadata.Metadata, 491 ) 492 493 def __getattr__(self, name): 494 return getattr(self.enum_value_pb, name) 495 496 497@dataclasses.dataclass(frozen=True) 498class EnumType: 499 """Description of an enum (defined with the ``enum`` keyword.)""" 500 enum_pb: descriptor_pb2.EnumDescriptorProto 501 values: List[EnumValueType] 502 meta: metadata.Metadata = dataclasses.field( 503 default_factory=metadata.Metadata, 504 ) 505 506 def __hash__(self): 507 # Identity is sufficiently unambiguous. 508 return hash(self.ident) 509 510 def __getattr__(self, name): 511 return getattr(self.enum_pb, name) 512 513 @property 514 def resource_path(self) -> Optional[str]: 515 # This is a minor duck-typing workaround for the resource_messages 516 # property in the Service class: we need to check fields recursively 517 # to see if they're resources, and recursive_field_types includes enums 518 return None 519 520 @property 521 def ident(self) -> metadata.Address: 522 """Return the identifier data to be used in templates.""" 523 return self.meta.address 524 525 def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType': 526 """Return a derivative of this enum with the provided context. 527 528 This method is used to address naming collisions. The returned 529 ``EnumType`` object aliases module names to avoid naming collisions in 530 the file being written. 531 """ 532 return dataclasses.replace( 533 self, 534 meta=self.meta.with_context(collisions=collisions), 535 ) 536 537 538@dataclasses.dataclass(frozen=True) 539class PythonType: 540 """Wrapper class for Python types. 541 542 This exists for interface consistency, so that methods like 543 :meth:`Field.type` can return an object and the caller can be confident 544 that a ``name`` property will be present. 545 """ 546 meta: metadata.Metadata 547 548 def __eq__(self, other): 549 return self.meta == other.meta 550 551 def __ne__(self, other): 552 return not self == other 553 554 @utils.cached_property 555 def ident(self) -> metadata.Address: 556 """Return the identifier to be used in templates.""" 557 return self.meta.address 558 559 @property 560 def name(self) -> str: 561 return self.ident.name 562 563 @property 564 def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]: 565 return tuple() 566 567 568@dataclasses.dataclass(frozen=True) 569class PrimitiveType(PythonType): 570 """A representation of a Python primitive type.""" 571 python_type: Optional[type] 572 573 @classmethod 574 def build(cls, primitive_type: Optional[type]): 575 """Return a PrimitiveType object for the given Python primitive type. 576 577 Args: 578 primitive_type (cls): A Python primitive type, such as 579 :class:`int` or :class:`str`. Despite not being a type, 580 ``None`` is also accepted here. 581 582 Returns: 583 ~.PrimitiveType: The instantiated PrimitiveType object. 584 """ 585 # Primitives have no import, and no module to reference, so the 586 # address just uses the name of the class (e.g. "int", "str"). 587 return cls(meta=metadata.Metadata(address=metadata.Address( 588 name='None' if primitive_type is None else primitive_type.__name__, 589 )), python_type=primitive_type) 590 591 def __eq__(self, other): 592 # If we are sent the actual Python type (not the PrimitiveType object), 593 # claim to be equal to that. 594 if not hasattr(other, 'meta'): 595 return self.python_type is other 596 return super().__eq__(other) 597 598 599@dataclasses.dataclass(frozen=True) 600class OperationInfo: 601 """Representation of long-running operation info.""" 602 response_type: MessageType 603 metadata_type: MessageType 604 605 def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo': 606 """Return a derivative of this OperationInfo with the provided context. 607 608 This method is used to address naming collisions. The returned 609 ``OperationInfo`` object aliases module names to avoid naming collisions 610 in the file being written. 611 """ 612 return dataclasses.replace( 613 self, 614 response_type=self.response_type.with_context( 615 collisions=collisions 616 ), 617 metadata_type=self.metadata_type.with_context( 618 collisions=collisions 619 ), 620 ) 621 622 623@dataclasses.dataclass(frozen=True) 624class RetryInfo: 625 """Representation of the method's retry behavior.""" 626 max_attempts: int 627 initial_backoff: float 628 max_backoff: float 629 backoff_multiplier: float 630 retryable_exceptions: FrozenSet[exceptions.GoogleAPICallError] 631 632 633@dataclasses.dataclass(frozen=True) 634class Method: 635 """Description of a method (defined with the ``rpc`` keyword).""" 636 method_pb: descriptor_pb2.MethodDescriptorProto 637 input: MessageType 638 output: MessageType 639 lro: Optional[OperationInfo] = dataclasses.field(default=None) 640 retry: Optional[RetryInfo] = dataclasses.field(default=None) 641 timeout: Optional[float] = None 642 meta: metadata.Metadata = dataclasses.field( 643 default_factory=metadata.Metadata, 644 ) 645 646 def __getattr__(self, name): 647 return getattr(self.method_pb, name) 648 649 @utils.cached_property 650 def client_output(self): 651 return self._client_output(enable_asyncio=False) 652 653 @utils.cached_property 654 def client_output_async(self): 655 return self._client_output(enable_asyncio=True) 656 657 def flattened_oneof_fields(self, include_optional=False): 658 oneof_fields = collections.defaultdict(list) 659 for field in self.flattened_fields.values(): 660 # Only include proto3 optional oneofs if explicitly looked for. 661 if field.oneof and not field.proto3_optional or include_optional: 662 oneof_fields[field.oneof].append(field) 663 664 return oneof_fields 665 666 def _client_output(self, enable_asyncio: bool): 667 """Return the output from the client layer. 668 669 This takes into account transformations made by the outer GAPIC 670 client to transform the output from the transport. 671 672 Returns: 673 Union[~.MessageType, ~.PythonType]: 674 A description of the return type. 675 """ 676 # Void messages ultimately return None. 677 if self.void: 678 return PrimitiveType.build(None) 679 680 # If this method is an LRO, return a PythonType instance representing 681 # that. 682 if self.lro: 683 return PythonType(meta=metadata.Metadata( 684 address=metadata.Address( 685 name='AsyncOperation' if enable_asyncio else 'Operation', 686 module='operation_async' if enable_asyncio else 'operation', 687 package=('google', 'api_core'), 688 collisions=self.lro.response_type.ident.collisions, 689 ), 690 documentation=utils.doc( 691 'An object representing a long-running operation. \n\n' 692 'The result type for the operation will be ' 693 ':class:`{ident}`: {doc}'.format( 694 doc=self.lro.response_type.meta.doc, 695 ident=self.lro.response_type.ident.sphinx, 696 ), 697 ), 698 )) 699 700 # If this method is paginated, return that method's pager class. 701 if self.paged_result_field: 702 return PythonType(meta=metadata.Metadata( 703 address=metadata.Address( 704 name=f'{self.name}AsyncPager' if enable_asyncio else f'{self.name}Pager', 705 package=self.ident.api_naming.module_namespace + (self.ident.api_naming.versioned_module_name,) + self.ident.subpackage + ( 706 'services', 707 utils.to_snake_case(self.ident.parent[-1]), 708 ), 709 module='pagers', 710 collisions=self.input.ident.collisions, 711 ), 712 documentation=utils.doc( 713 f'{self.output.meta.doc}\n\n' 714 'Iterating over this object will yield results and ' 715 'resolve additional pages automatically.', 716 ), 717 )) 718 719 # Return the usual output. 720 return self.output 721 722 @property 723 def field_headers(self) -> Sequence[str]: 724 """Return the field headers defined for this method.""" 725 http = self.options.Extensions[annotations_pb2.http] 726 727 pattern = re.compile(r'\{([a-z][\w\d_.]+)=') 728 729 potential_verbs = [ 730 http.get, 731 http.put, 732 http.post, 733 http.delete, 734 http.patch, 735 http.custom.path, 736 ] 737 738 return next((tuple(pattern.findall(verb)) for verb in potential_verbs if verb), ()) 739 740 @utils.cached_property 741 def flattened_fields(self) -> Mapping[str, Field]: 742 """Return the signature defined for this method.""" 743 cross_pkg_request = self.input.ident.package != self.ident.package 744 745 def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]: 746 for f in sig.split(','): 747 if not f: 748 # Special case for an empty signature 749 continue 750 name = f.strip() 751 field = self.input.get_field(*name.split('.')) 752 if cross_pkg_request and not field.is_primitive: 753 # This is not a proto-plus wrapped message type, 754 # and setting a non-primitive field directly is verboten. 755 continue 756 757 yield name, field 758 759 signatures = self.options.Extensions[client_pb2.method_signature] 760 answer: Dict[str, Field] = collections.OrderedDict( 761 name_and_field 762 for sig in signatures 763 for name_and_field in filter_fields(sig) 764 ) 765 766 return answer 767 768 @utils.cached_property 769 def flattened_field_to_key(self): 770 return {field.name: key for key, field in self.flattened_fields.items()} 771 772 @utils.cached_property 773 def legacy_flattened_fields(self) -> Mapping[str, Field]: 774 """Return the legacy flattening interface: top level fields only, 775 required fields first""" 776 required, optional = utils.partition(lambda f: f.required, 777 self.input.fields.values()) 778 return collections.OrderedDict((f.name, f) 779 for f in chain(required, optional)) 780 781 @property 782 def grpc_stub_type(self) -> str: 783 """Return the type of gRPC stub to use.""" 784 return '{client}_{server}'.format( 785 client='stream' if self.client_streaming else 'unary', 786 server='stream' if self.server_streaming else 'unary', 787 ) 788 789 @utils.cached_property 790 def idempotent(self) -> bool: 791 """Return True if we know this method is idempotent, False otherwise. 792 793 Note: We are intentionally conservative here. It is far less bad 794 to falsely believe an idempotent method is non-idempotent than 795 the converse. 796 """ 797 return bool(self.options.Extensions[annotations_pb2.http].get) 798 799 @property 800 def ident(self) -> metadata.Address: 801 """Return the identifier data to be used in templates.""" 802 return self.meta.address 803 804 @utils.cached_property 805 def paged_result_field(self) -> Optional[Field]: 806 """Return the response pagination field if the method is paginated.""" 807 # If the request field lacks any of the expected pagination fields, 808 # then the method is not paginated. 809 for page_field in ((self.input, int, 'page_size'), 810 (self.input, str, 'page_token'), 811 (self.output, str, 'next_page_token')): 812 field = page_field[0].fields.get(page_field[2], None) 813 if not field or field.type != page_field[1]: 814 return None 815 816 # Return the first repeated field. 817 for field in self.output.fields.values(): 818 if field.repeated: 819 return field 820 821 # We found no repeated fields. Return None. 822 return None 823 824 @utils.cached_property 825 def ref_types(self) -> Sequence[Union[MessageType, EnumType]]: 826 return self._ref_types(True) 827 828 @utils.cached_property 829 def flat_ref_types(self) -> Sequence[Union[MessageType, EnumType]]: 830 return self._ref_types(False) 831 832 def _ref_types(self, recursive: bool) -> Sequence[Union[MessageType, EnumType]]: 833 """Return types referenced by this method.""" 834 # Begin with the input (request) and output (response) messages. 835 answer: List[Union[MessageType, EnumType]] = [self.input] 836 types: Iterable[Union[MessageType, EnumType]] = ( 837 self.input.recursive_field_types if recursive 838 else ( 839 f.type 840 for f in self.flattened_fields.values() 841 if f.message or f.enum 842 ) 843 ) 844 answer.extend(types) 845 846 if not self.void: 847 answer.append(self.client_output) 848 answer.extend(self.client_output.field_types) 849 answer.append(self.client_output_async) 850 answer.extend(self.client_output_async.field_types) 851 852 # If this method has LRO, it is possible (albeit unlikely) that 853 # the LRO messages reside in a different module. 854 if self.lro: 855 answer.append(self.lro.response_type) 856 answer.append(self.lro.metadata_type) 857 858 # If this message paginates its responses, it is possible 859 # that the individual result messages reside in a different module. 860 if self.paged_result_field and self.paged_result_field.message: 861 answer.append(self.paged_result_field.message) 862 863 # Done; return the answer. 864 return tuple(answer) 865 866 @property 867 def void(self) -> bool: 868 """Return True if this method has no return value, False otherwise.""" 869 return self.output.ident.proto == 'google.protobuf.Empty' 870 871 def with_context(self, *, collisions: FrozenSet[str]) -> 'Method': 872 """Return a derivative of this method with the provided context. 873 874 This method is used to address naming collisions. The returned 875 ``Method`` object aliases module names to avoid naming collisions 876 in the file being written. 877 """ 878 maybe_lro = self.lro.with_context( 879 collisions=collisions 880 ) if self.lro else None 881 882 return dataclasses.replace( 883 self, 884 lro=maybe_lro, 885 input=self.input.with_context(collisions=collisions), 886 output=self.output.with_context(collisions=collisions), 887 meta=self.meta.with_context(collisions=collisions), 888 ) 889 890 891@dataclasses.dataclass(frozen=True) 892class CommonResource: 893 type_name: str 894 pattern: str 895 896 @classmethod 897 def build(cls, resource: resource_pb2.ResourceDescriptor): 898 return cls( 899 type_name=resource.type, 900 pattern=next(iter(resource.pattern)) 901 ) 902 903 @utils.cached_property 904 def message_type(self): 905 message_pb = descriptor_pb2.DescriptorProto() 906 res_pb = message_pb.options.Extensions[resource_pb2.resource] 907 res_pb.type = self.type_name 908 res_pb.pattern.append(self.pattern) 909 910 return MessageType( 911 message_pb=message_pb, 912 fields={}, 913 nested_enums={}, 914 nested_messages={}, 915 ) 916 917 918@dataclasses.dataclass(frozen=True) 919class Service: 920 """Description of a service (defined with the ``service`` keyword).""" 921 service_pb: descriptor_pb2.ServiceDescriptorProto 922 methods: Mapping[str, Method] 923 # N.B.: visible_resources is intended to be a read-only view 924 # whose backing store is owned by the API. 925 # This is represented by a types.MappingProxyType instance. 926 visible_resources: Mapping[str, MessageType] 927 meta: metadata.Metadata = dataclasses.field( 928 default_factory=metadata.Metadata, 929 ) 930 931 common_resources: ClassVar[Mapping[str, CommonResource]] = dataclasses.field( 932 default={ 933 "cloudresourcemanager.googleapis.com/Project": CommonResource( 934 "cloudresourcemanager.googleapis.com/Project", 935 "projects/{project}", 936 ), 937 "cloudresourcemanager.googleapis.com/Organization": CommonResource( 938 "cloudresourcemanager.googleapis.com/Organization", 939 "organizations/{organization}", 940 ), 941 "cloudresourcemanager.googleapis.com/Folder": CommonResource( 942 "cloudresourcemanager.googleapis.com/Folder", 943 "folders/{folder}", 944 ), 945 "cloudbilling.googleapis.com/BillingAccount": CommonResource( 946 "cloudbilling.googleapis.com/BillingAccount", 947 "billingAccounts/{billing_account}", 948 ), 949 "locations.googleapis.com/Location": CommonResource( 950 "locations.googleapis.com/Location", 951 "projects/{project}/locations/{location}", 952 ), 953 }, 954 init=False, 955 compare=False, 956 ) 957 958 def __getattr__(self, name): 959 return getattr(self.service_pb, name) 960 961 @property 962 def client_name(self) -> str: 963 """Returns the name of the generated client class""" 964 return self.name + "Client" 965 966 @property 967 def async_client_name(self) -> str: 968 """Returns the name of the generated AsyncIO client class""" 969 return self.name + "AsyncClient" 970 971 @property 972 def transport_name(self): 973 return self.name + "Transport" 974 975 @property 976 def grpc_transport_name(self): 977 return self.name + "GrpcTransport" 978 979 @property 980 def grpc_asyncio_transport_name(self): 981 return self.name + "GrpcAsyncIOTransport" 982 983 @property 984 def has_lro(self) -> bool: 985 """Return whether the service has a long-running method.""" 986 return any([m.lro for m in self.methods.values()]) 987 988 @property 989 def has_pagers(self) -> bool: 990 """Return whether the service has paged methods.""" 991 return any(m.paged_result_field for m in self.methods.values()) 992 993 @property 994 def host(self) -> str: 995 """Return the hostname for this service, if specified. 996 997 Returns: 998 str: The hostname, with no protocol and no trailing ``/``. 999 """ 1000 if self.options.Extensions[client_pb2.default_host]: 1001 return self.options.Extensions[client_pb2.default_host] 1002 return '' 1003 1004 @property 1005 def oauth_scopes(self) -> Sequence[str]: 1006 """Return a sequence of oauth scopes, if applicable. 1007 1008 Returns: 1009 Sequence[str]: A sequence of OAuth scopes. 1010 """ 1011 # Return the OAuth scopes, split on comma. 1012 return tuple( 1013 i.strip() 1014 for i in self.options.Extensions[client_pb2.oauth_scopes].split(',') 1015 if i 1016 ) 1017 1018 @property 1019 def module_name(self) -> str: 1020 """Return the appropriate module name for this service. 1021 1022 Returns: 1023 str: The service name, in snake case. 1024 """ 1025 return utils.to_snake_case(self.name) 1026 1027 @utils.cached_property 1028 def names(self) -> FrozenSet[str]: 1029 """Return a set of names used in this service. 1030 1031 This is used for detecting naming collisions in the module names 1032 used for imports. 1033 """ 1034 # Put together a set of the service and method names. 1035 answer = {self.name, self.client_name, self.async_client_name} 1036 answer.update( 1037 utils.to_snake_case(i.name) for i in self.methods.values() 1038 ) 1039 1040 # Identify any import module names where the same module name is used 1041 # from distinct packages. 1042 modules: Dict[str, Set[str]] = collections.defaultdict(set) 1043 for m in self.methods.values(): 1044 for t in m.ref_types: 1045 modules[t.ident.module].add(t.ident.package) 1046 1047 answer.update( 1048 module_name 1049 for module_name, packages in modules.items() 1050 if len(packages) > 1 1051 ) 1052 1053 # Done; return the answer. 1054 return frozenset(answer) 1055 1056 @utils.cached_property 1057 def resource_messages(self) -> FrozenSet[MessageType]: 1058 """Returns all the resource message types used in all 1059 request and response fields in the service.""" 1060 def gen_resources(message): 1061 if message.resource_path: 1062 yield message 1063 1064 for type_ in message.recursive_field_types: 1065 if type_.resource_path: 1066 yield type_ 1067 1068 def gen_indirect_resources_used(message): 1069 for field in message.recursive_resource_fields: 1070 resource = field.options.Extensions[ 1071 resource_pb2.resource_reference] 1072 resource_type = resource.type or resource.child_type 1073 # The resource may not be visible if the resource type is one of 1074 # the common_resources (see the class var in class definition) 1075 # or if it's something unhelpful like '*'. 1076 resource = self.visible_resources.get(resource_type) 1077 if resource: 1078 yield resource 1079 1080 return frozenset( 1081 msg 1082 for method in self.methods.values() 1083 for msg in chain( 1084 gen_resources(method.input), 1085 gen_resources( 1086 method.lro.response_type if method.lro else method.output 1087 ), 1088 gen_indirect_resources_used(method.input), 1089 gen_indirect_resources_used( 1090 method.lro.response_type if method.lro else method.output 1091 ), 1092 ) 1093 ) 1094 1095 @utils.cached_property 1096 def any_client_streaming(self) -> bool: 1097 return any(m.client_streaming for m in self.methods.values()) 1098 1099 @utils.cached_property 1100 def any_server_streaming(self) -> bool: 1101 return any(m.server_streaming for m in self.methods.values()) 1102 1103 def with_context(self, *, collisions: FrozenSet[str]) -> 'Service': 1104 """Return a derivative of this service with the provided context. 1105 1106 This method is used to address naming collisions. The returned 1107 ``Service`` object aliases module names to avoid naming collisions 1108 in the file being written. 1109 """ 1110 return dataclasses.replace( 1111 self, 1112 methods=collections.OrderedDict( 1113 (k, v.with_context( 1114 # A methodd's flattened fields create additional names 1115 # that may conflict with module imports. 1116 collisions=collisions | frozenset(v.flattened_fields.keys())) 1117 ) 1118 for k, v in self.methods.items() 1119 ), 1120 meta=self.meta.with_context(collisions=collisions), 1121 ) 1122