1# -*- coding: utf-8 -*- #
2# Copyright 2020 Google LLC. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Flags defination for gcloud aiplatform."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import unicode_literals
20
21import argparse
22import sys
23import textwrap
24
25from googlecloudsdk.api_lib.util import apis
26
27from googlecloudsdk.calliope import actions
28from googlecloudsdk.calliope import arg_parsers
29from googlecloudsdk.calliope import base
30from googlecloudsdk.calliope.concepts import concepts
31from googlecloudsdk.calliope.concepts import deps
32from googlecloudsdk.command_lib.ai import constants
33from googlecloudsdk.command_lib.ai import errors
34from googlecloudsdk.command_lib.ai import region_util
35from googlecloudsdk.command_lib.iam import iam_util as core_iam_util
36from googlecloudsdk.command_lib.kms import resource_args as kms_resource_args
37from googlecloudsdk.command_lib.util.apis import arg_utils
38from googlecloudsdk.command_lib.util.concepts import concept_parsers
39from googlecloudsdk.core import properties
40
41CUSTOM_JOB_NAME = base.Argument('name', help=('Custom job\'s name to query.'))
42CUSTOM_JOB_DISPLAY_NAME = base.Argument(
43    '--display-name',
44    required=True,
45    help=('Display name of the custom job to create.'))
46AIPLATFORM_REGION = base.Argument(
47    '--region',
48    help=(
49        'Region of the AI Platform service to use. If not specified, the value '
50        'of the `ai/region` config property is used. If that property '
51        'is not configured, then you will be prompted to select a region. When '
52        'you specify this flag, its value is stored in the `ai/region` '
53        'config property.'),
54    action=actions.StoreProperty(properties.VALUES.ai.region))
55PYTHON_PACKGE_URIS = base.Argument(
56    '--python-package-uris',
57    metavar='PYTHON_PACKAGE_URIS',
58    type=arg_parsers.ArgList(),
59    help='The common python package uris that will be used by python image. '
60    'e.g. --python-package-uri=path1,path2'
61    'If customizing the python package is needed, please use config instead.')
62
63CUSTOM_JOB_CONFIG = base.Argument(
64    '--config',
65    help="""
66Path to the job configuration file. This file should be a YAML document containing a CustomJobSpec.
67If an option is specified both in the configuration file **and** via command line arguments, the command line arguments
68override the configuration file. Note that keys with underscore are invalid.
69
70Example(YAML):
71
72  workerPoolSpecs:
73    machineSpec:
74      machineType: n1-highmem-2
75    replicaCount: 1
76    containerSpec:
77      imageUri: gcr.io/ucaip-test/ucaip-training-test
78      args:
79      - port=8500
80      command:
81      - start""")
82
83WORKER_POOL_SPEC = base.Argument(
84    '--worker-pool-spec',
85    action='append',
86    type=arg_parsers.ArgDict(
87        spec={
88            'replica-count': int,
89            'machine-type': str,
90            'container-image-uri': str,
91            'python-image-uri': str,
92            'python-module': str,
93        },
94        required_keys=['machine-type']),
95    metavar='WORKER_POOL_SPEC',
96    help="""
97Define the worker pool configuration used by the custom job. You can specify multiple
98worker pool specs in order to create a custom job with multiple worker pools.
99
100The spec can contain the following fields, which are listed with corresponding
101fields in the WorkerPoolSpec API message:
102
103*machine-type*::: (Required): machineSpec.machineType
104*replica-count*::: replicaCount
105*container-image-uri*::: containerSpec.imageUri
106*python-image-uri*::: pythonPackageSpec.executorImageUri
107*python-module*::: pythonPackageSpec.pythonModule
108
109For example:
110`--worker-pool-spec=replica-count=1,machine-type=n1-highmem-2,container-image-uri=gcr.io/ucaip-test/ucaip-training-test`
111""")
112
113HPTUNING_JOB_DISPLAY_NAME = base.Argument(
114    '--display-name',
115    required=True,
116    help=('Display name of the hyperparameter tuning job to create.'))
117
118HPTUNING_MAX_TRIAL_COUNT = base.Argument(
119    '--max-trial-count',
120    type=int,
121    default=1,
122    help=('Desired total number of trials. The default value is 1.'))
123
124HPTUNING_PARALLEL_TRIAL_COUNT = base.Argument(
125    '--parallel-trial-count',
126    type=int,
127    default=1,
128    help=(
129        'Desired number of Trials to run in parallel. The default value is 1.'))
130
131HPTUNING_JOB_CONFIG = base.Argument(
132    '--config',
133    required=True,
134    help="""
135Path to the job configuration file. This file should be a YAML document containing a HyperparameterTuningSpec.
136If an option is specified both in the configuration file **and** via command line arguments, the command line arguments
137override the configuration file.
138
139Example(YAML):
140
141  displayName: TestHpTuningJob
142  maxTrialCount: 1
143  parallelTrialCount: 1
144  studySpec:
145    metrics:
146    - metricId: x
147      goal: MINIMIZE
148    parameters:
149    - parameterId: z
150      integerValueSpec:
151        minValue: 1
152        maxValue: 100
153    algorithm: RANDOM_SEARCH
154  trialJobSpec:
155    workerPoolSpecs:
156    - machineSpec:
157        machineType: n1-standard-4
158      replicaCount: 1
159      containerSpec:
160        imageUri: gcr.io/ucaip-test/ucaip-training-test
161""")
162
163_POLLING_INTERVAL_FLAG = base.Argument(
164    '--polling-interval',
165    type=arg_parsers.BoundedInt(1, sys.maxsize, unlimited=True),
166    default=60,
167    help=('Number of seconds to wait between efforts to fetch the latest '
168          'log messages.'))
169
170_ALLOW_MULTILINE_LOGS = base.Argument(
171    '--allow-multiline-logs',
172    action='store_true',
173    default=False,
174    help='Output multiline log messages as single records.')
175
176_TASK_NAME = base.Argument(
177    '--task-name',
178    required=False,
179    default=None,
180    help='If set, display only the logs for this particular task.')
181
182_CUSTOM_JOB_COMMAND = base.Argument(
183    '--command',
184    type=arg_parsers.ArgList(),
185    metavar='COMMAND',
186    action=arg_parsers.UpdateAction,
187    help="""\
188Command to be invoked when containers are started.
189It overrides the entrypoint instruction in Dockerfile when provided.
190""")
191_CUSTOM_JOB_ARGS = base.Argument(
192    '--args',
193    metavar='ARG',
194    type=arg_parsers.ArgList(),
195    action=arg_parsers.UpdateAction,
196    help="""\
197Comma-separated arguments passed to containers or python tasks.
198""")
199
200_NETWORK = base.Argument(
201    '--network',
202    help=textwrap.dedent("""\
203      Full name of the Google Compute Engine network to which the Job
204      is peered with. Private services access must already have been configured.
205      If unspecified, the Job is not peered with any network.
206      """))
207
208_TRAINING_SERVICE_ACCOUNT = base.Argument(
209    '--service-account',
210    type=core_iam_util.GetIamAccountFormatValidator(),
211    required=False,
212    help=textwrap.dedent("""\
213      The email address of a service account to use when running the
214      training appplication. You must have the `iam.serviceAccounts.actAs`
215      permission for the specified service account.
216      """))
217
218
219def AddCreateCustomJobFlags(parser):
220  """Adds flags related to create a custom job."""
221  AddRegionResourceArg(parser, 'to create a custom job')
222  CUSTOM_JOB_DISPLAY_NAME.AddToParser(parser)
223  PYTHON_PACKGE_URIS.AddToParser(parser)
224  _CUSTOM_JOB_ARGS.AddToParser(parser)
225  _CUSTOM_JOB_COMMAND.AddToParser(parser)
226  _TRAINING_SERVICE_ACCOUNT.AddToParser(parser)
227  _NETWORK.AddToParser(parser)
228  AddKmsKeyResourceArg(parser, 'custom job')
229  worker_pool_spec_group = base.ArgumentGroup(
230      help='Worker pool specification.', required=True)
231  worker_pool_spec_group.AddArgument(CUSTOM_JOB_CONFIG)
232  worker_pool_spec_group.AddArgument(WORKER_POOL_SPEC)
233  worker_pool_spec_group.AddToParser(parser)
234
235
236def AddStreamLogsFlags(parser):
237  _POLLING_INTERVAL_FLAG.AddToParser(parser)
238  _TASK_NAME.AddToParser(parser)
239  _ALLOW_MULTILINE_LOGS.AddToParser(parser)
240
241
242def GetModelIdArg(required=True):
243  return base.Argument(
244      '--model', help='Id of the uploaded model.', required=required)
245
246
247def GetDeployedModelId(required=True):
248  return base.Argument(
249      '--deployed-model-id',
250      help='Id of the deployed model.',
251      required=required)
252
253
254def GetIndexIdArg(required=True, helper_text='ID of the index.'):
255  return base.Argument('--index', help=helper_text, required=required)
256
257
258def GetDeployedIndexId(required=True):
259  return base.Argument(
260      '--deployed-index-id',
261      help='Id of the deployed index.',
262      required=required)
263
264
265def GetDisplayNameArg(noun, required=True):
266  return base.Argument(
267      '--display-name',
268      required=required,
269      help='Display name of the {noun}.'.format(noun=noun))
270
271
272def GetDescriptionArg(noun):
273  return base.Argument(
274      '--description',
275      required=False,
276      default=None,
277      help='Description of the {noun}.'.format(noun=noun))
278
279
280def AddPredictInstanceArg(parser, required=True):
281  """Add arguments for different types of predict instances."""
282  base.Argument(
283      '--json-request',
284      required=required,
285      help="""\
286      Path to a local file containing the body of a JSON request.
287
288      An example of a JSON request:
289
290          {
291            "instances": [
292              {"x": [1, 2], "y": [3, 4]},
293              {"x": [-1, -2], "y": [-3, -4]}
294            ]
295          }
296
297      This flag accepts "-" for stdin.
298      """).AddToParser(parser)
299
300
301def GetTrafficSplitArg():
302  """Add arguments for traffic split."""
303  return base.Argument(
304      '--traffic-split',
305      metavar='DEPLOYED_MODEL_ID=VALUE',
306      type=arg_parsers.ArgDict(value_type=int),
307      action=arg_parsers.UpdateAction,
308      help=('List of paris of deployed model id and value to set as traffic '
309            'split.'))
310
311
312def AddTrafficSplitGroupArgs(parser):
313  """Add arguments for traffic split."""
314  group = parser.add_mutually_exclusive_group(required=False)
315  group.add_argument(
316      '--traffic-split',
317      metavar='DEPLOYED_MODEL_ID=VALUE',
318      type=arg_parsers.ArgDict(value_type=int),
319      action=arg_parsers.UpdateAction,
320      help=('List of paris of deployed model id and value to set as traffic '
321            'split.'))
322
323  group.add_argument(
324      '--clear-traffic-split',
325      action='store_true',
326      help=('Clears the traffic split map. If the map is empty, the endpoint '
327            'is to not accept any traffic at the moment.'))
328
329
330def AddPredictionResourcesArgs(parser, version):
331  """Add arguments for prediction resources."""
332  base.Argument(
333      '--min-replica-count',
334      type=arg_parsers.BoundedInt(1, sys.maxsize, unlimited=True),
335      help=("""\
336Minimum number of machine replicas the deployed model will be always deployed
337on. If specified, the value must be equal to or larger than 1.
338
339If not specified and the uploaded models use dedicated resources, the default
340value is 1.
341""")).AddToParser(parser)
342
343  base.Argument(
344      '--max-replica-count',
345      type=int,
346      help=('Maximum number of machine replicas the deployed model will be '
347            'always deployed on.')).AddToParser(parser)
348
349  base.Argument(
350      '--autoscaling-metric-specs',
351      metavar='METRIC-NAME=TARGET',
352      type=arg_parsers.ArgDict(key_type=str, value_type=int),
353      action=arg_parsers.UpdateAction,
354      help="""\
355Metric specifications that overrides a resource utilization metric's target
356value. At most one entry is allowed per metric.
357
358*METRIC-NAME*::: Resource metric name. Choices are {}.
359
360*TARGET*::: Target resource utilization in percentage (1% - 100%) for the
361given metric. If the value is set to 60, the target resource utilization is 60%.
362
363For example:
364`--autoscaling-metric-specs=cpu-usage=70`
365""".format(', '.join([
366    "'{}'".format(c)
367    for c in sorted(constants.OP_AUTOSCALING_METRIC_NAME_MAPPER.keys())]
368                     ))).AddToParser(parser)
369
370  base.Argument(
371      '--machine-type',
372      help="""\
373The machine resources to be used for each node of this deployment.
374For available machine types, see
375https://cloud.google.com/ai-platform-unified/docs/predictions/machine-types.
376""").AddToParser(parser)
377
378  base.Argument(
379      '--accelerator',
380      type=arg_parsers.ArgDict(
381          spec={
382              'type': str,
383              'count': int,
384          }, required_keys=['type']),
385      help="""\
386Manage the accelerator config for GPU serving. When deploying a model with
387Compute Engine Machine Types, a GPU accelerator may also
388be selected.
389
390*type*::: The type of the accelerator. Choices are {}.
391
392*count*::: The number of accelerators to attach to each machine running the job.
393 This is usually 1. If not specified, the default value is 1.
394
395For example:
396`--accelerator=type=nvidia-tesla-k80,count=1`""".format(', '.join([
397    "'{}'".format(c) for c in GetAcceleratorTypeMapper(version).choices
398  ]))).AddToParser(parser)
399
400
401def AddAutomaticResourcesArgs(parser, resource_type):
402  """Add arguments for automatic deployment resources."""
403  base.Argument(
404      '--min-replica-count',
405      type=arg_parsers.BoundedInt(1, sys.maxsize, unlimited=True),
406      help=("""\
407Minimum number of machine replicas the {} will be always deployed
408on. If specified, the value must be equal to or larger than 1.
409""".format(resource_type))).AddToParser(parser)
410
411  base.Argument(
412      '--max-replica-count',
413      type=int,
414      help=('Maximum number of machine replicas the {} will be '
415            'always deployed on.'.format(resource_type))).AddToParser(parser)
416
417
418def GetEnableAccessLoggingArg():
419  return base.Argument(
420      '--enable-access-logging',
421      action='store_true',
422      default=False,
423      required=False,
424      help="""\
425If true, online prediction access logs are sent to Cloud Logging.
426
427These logs are standard server access logs, containing information like
428timestamp and latency for each prediction request.
429""")
430
431
432def GetEnableContainerLoggingArg():
433  return base.Argument(
434      '--enable-container-logging',
435      action='store_true',
436      default=False,
437      required=False,
438      help="""\
439If true, the container of the deployed model instances will send `stderr` and
440`stdout` streams to Cloud Logging.
441
442Currently, only supported for custom-trained Models and AutoML Tables Models.
443""")
444
445
446def GetServiceAccountArg():
447  return base.Argument(
448      '--service-account',
449      required=False,
450      help="""\
451Service account that the deployed model's container runs as. Specify the
452email address of the service account. If this service account is not
453specified, the container runs as a service account that doesn't have access
454to the resource project.
455""")
456
457
458def RegionAttributeConfig():
459  return concepts.ResourceParameterAttributeConfig(
460      name='region',
461      help_text='Cloud region for the {resource}.',
462      fallthroughs=[
463          deps.PropertyFallthrough(properties.VALUES.ai.region),
464          deps.Fallthrough(function=region_util.PromptForRegion, hint='region')
465      ])
466
467
468def GetRegionResourceSpec():
469  return concepts.ResourceSpec(
470      'aiplatform.projects.locations',
471      resource_name='region',
472      locationsId=RegionAttributeConfig(),
473      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG)
474
475
476def GetModelResourceSpec(resource_name='model'):
477  return concepts.ResourceSpec(
478      'aiplatform.projects.locations.models',
479      resource_name=resource_name,
480      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
481      locationsId=RegionAttributeConfig(),
482      disable_auto_completers=False)
483
484
485def AddRegionResourceArg(parser, verb):
486  """Add a resource argument for a cloud AI Platform region.
487
488  NOTE: Must be used only if it's the only resource arg in the command.
489
490  Args:
491    parser: the parser for the command.
492    verb: str, the verb to describe the resource, such as 'to update'.
493  """
494  concept_parsers.ConceptParser.ForResource(
495      '--region',
496      GetRegionResourceSpec(),
497      'Cloud region {}.'.format(verb),
498      required=True).AddToParser(parser)
499
500
501def GetDefaultOperationResourceSpec():
502  return concepts.ResourceSpec(
503      constants.DEFAULT_OPERATION_COLLECTION,
504      resource_name='operation',
505      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
506      locationsId=RegionAttributeConfig(),
507      disable_auto_completers=False)
508
509
510def AddOperationResourceArg(parser):
511  """Add a resource argument for a cloud AI Platform operation."""
512  resource_name = 'operation'
513  concept_parsers.ConceptParser.ForResource(
514      resource_name,
515      GetDefaultOperationResourceSpec(),
516      'The ID of the operation.',
517      required=True).AddToParser(parser)
518
519
520def AddModelResourceArg(parser, verb):
521  """Add a resource argument for a cloud AI Platform model.
522
523  NOTE: Must be used only if it's the only resource arg in the command.
524
525  Args:
526    parser: the parser for the command.
527    verb: str, the verb to describe the resource, such as 'to update'.
528  """
529  name = 'model'
530  concept_parsers.ConceptParser.ForResource(
531      name, GetModelResourceSpec(), 'Model {}.'.format(verb),
532      required=True).AddToParser(parser)
533
534
535def AddUploadModelFlags(parser):
536  """Adds flags for UploadModel."""
537  AddRegionResourceArg(parser, 'to upload model')
538  base.Argument(
539      '--display-name', required=True,
540      help=('Display name of the model.')).AddToParser(parser)
541  base.Argument(
542      '--description', required=False,
543      help=('Description of the model.')).AddToParser(parser)
544  base.Argument(
545      '--container-image-uri',
546      required=True,
547      help=("""\
548URI of the Model serving container file in the Container Registry
549(e.g. gcr.io/myproject/server:latest).
550""")).AddToParser(parser)
551  base.Argument(
552      '--artifact-uri',
553      help=("""\
554Path to the directory containing the Model artifact and any of its
555supporting files.
556""")).AddToParser(parser)
557  parser.add_argument(
558      '--container-env-vars',
559      metavar='KEY=VALUE',
560      type=arg_parsers.ArgDict(),
561      action=arg_parsers.UpdateAction,
562      help='List of key-value pairs to set as environment variables.')
563  parser.add_argument(
564      '--container-command',
565      type=arg_parsers.ArgList(),
566      metavar='COMMAND',
567      action=arg_parsers.UpdateAction,
568      help="""\
569Entrypoint for the container image. If not specified, the container
570image's default entrypoint is run.
571""")
572  parser.add_argument(
573      '--container-args',
574      metavar='ARG',
575      type=arg_parsers.ArgList(),
576      action=arg_parsers.UpdateAction,
577      help="""\
578Comma-separated arguments passed to the command run by the container
579image. If not specified and no `--command` is provided, the container
580image's default command is used.
581""")
582  parser.add_argument(
583      '--container-ports',
584      metavar='PORT',
585      type=arg_parsers.ArgList(element_type=arg_parsers.BoundedInt(1, 65535)),
586      action=arg_parsers.UpdateAction,
587      help="""\
588Container ports to receive requests at. Must be a number between 1 and 65535,
589inclusive.
590""")
591  parser.add_argument(
592      '--container-predict-route',
593      help='HTTP path to send prediction requests to inside the container.')
594  parser.add_argument(
595      '--container-health-route',
596      help='HTTP path to send health checks to inside the container.')
597  # For Explanation.
598  parser.add_argument(
599      '--explanation-method',
600      help='Method used for explanation. Accepted values are `integrated-gradients`, `xrai` and `sampled-shapley`.'
601  )
602  parser.add_argument(
603      '--explanation-metadata-file',
604      help='Path to a local JSON file that contains the metadata describing the Model\'s input and output for explanation.'
605  )
606  parser.add_argument(
607      '--explanation-step-count',
608      type=int,
609      help='Number of steps to approximate the path integral for explanation.')
610  parser.add_argument(
611      '--explanation-path-count',
612      type=int,
613      help='Number of feature permutations to consider when approximating the Shapley values for explanation.'
614  )
615  parser.add_argument(
616      '--smooth-grad-noisy-sample-count',
617      type=int,
618      help='Number of gradient samples used for approximation at explanation. Only applicable to explanation method `integrated-gradients` or `xrai`.'
619  )
620  parser.add_argument(
621      '--smooth-grad-noise-sigma',
622      type=float,
623      help='Single float value used to add noise to all the features for explanation. Only applicable to explanation method `integrated-gradients` or `xrai`.'
624  )
625  parser.add_argument(
626      '--smooth-grad-noise-sigma-by-feature',
627      metavar='KEY=VALUE',
628      type=arg_parsers.ArgDict(),
629      action=arg_parsers.UpdateAction,
630      help='Noise sigma by features for explanation. Noise sigma represents the standard deviation of the gaussian kernel that will be used to add noise to interpolated inputs prior to computing gradients. Only applicable to explanation method `integrated-gradients` or `xrai`.'
631  )
632
633
634def GetMetadataFilePathArg(noun, required=False):
635  return base.Argument(
636      '--metadata-file',
637      required=required,
638      help='Path to a local JSON file that contains the additional metadata information about the {noun}.'
639      .format(noun=noun))
640
641
642def GetMetadataSchemaUriArg(noun):
643  return base.Argument(
644      '--metadata-schema-uri',
645      required=False,
646      help='Points to a YAML file stored on Google Cloud Storage describing additional information about {noun}.'
647      .format(noun=noun))
648
649
650def AddIndexResourceArg(parser, verb):
651  """Add a resource argument for a cloud AI Platform index.
652
653  NOTE: Must be used only if it's the only resource arg in the command.
654
655  Args:
656    parser: the parser for the command.
657    verb: str, the verb to describe the resource, such as 'to update'.
658  """
659  concept_parsers.ConceptParser.ForResource(
660      'index', GetIndexResourceSpec(), 'Index {}.'.format(verb),
661      required=True).AddToParser(parser)
662
663
664def GetIndexResourceSpec(resource_name='index'):
665  return concepts.ResourceSpec(
666      constants.INDEXES_COLLECTION,
667      resource_name=resource_name,
668      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
669      locationsId=RegionAttributeConfig(),
670      disable_auto_completers=False)
671
672
673def GetEndpointId():
674  return base.Argument('name', help='The endpoint\'s id.')
675
676
677def GetEndpointResourceSpec(resource_name='endpoint'):
678  return concepts.ResourceSpec(
679      constants.ENDPOINTS_COLLECTION,
680      resource_name=resource_name,
681      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
682      locationsId=RegionAttributeConfig(),
683      disable_auto_completers=False)
684
685
686def AddEndpointResourceArg(parser, verb):
687  """Add a resource argument for a Cloud AI Platform endpoint.
688
689  NOTE: Must be used only if it's the only resource arg in the command.
690
691  Args:
692    parser: the parser for the command.
693    verb: str, the verb to describe the resource, such as 'to update'.
694  """
695  concept_parsers.ConceptParser.ForResource(
696      'endpoint',
697      GetEndpointResourceSpec(),
698      'The endpoint {}.'.format(verb),
699      required=True).AddToParser(parser)
700
701
702def AddIndexEndpointResourceArg(parser, verb):
703  """Add a resource argument for a cloud AI Platform index endpoint.
704
705  NOTE: Must be used only if it's the only resource arg in the command.
706
707  Args:
708    parser: the parser for the command.
709    verb: str, the verb to describe the resource, such as 'to update'.
710  """
711  concept_parsers.ConceptParser.ForResource(
712      'index_endpoint',
713      GetIndexEndpointResourceSpec(),
714      'The index endpoint {}.'.format(verb),
715      required=True).AddToParser(parser)
716
717
718def GetIndexEndpointResourceSpec(resource_name='index_endpoint'):
719  return concepts.ResourceSpec(
720      constants.INDEX_ENDPOINTS_COLLECTION,
721      resource_name=resource_name,
722      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
723      locationsId=RegionAttributeConfig(),
724      disable_auto_completers=False)
725
726
727# TODO(b/357812579): Consider switch to use resource arg.
728def GetNetworkArg(required=True):
729  """Add arguments for VPC network."""
730  return base.Argument(
731      '--network',
732      required=required,
733      help="""
734      The Google Compute Engine network name to which the IndexEndpoint should be peered.
735      """)
736
737
738def GetTensorboardResourceSpec(resource_name='tensorboard'):
739  return concepts.ResourceSpec(
740      constants.TENSORBOARDS_COLLECTION,
741      resource_name=resource_name,
742      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
743      locationsId=RegionAttributeConfig(),
744      disable_auto_completers=False)
745
746
747def AddTensorboardResourceArg(parser, verb):
748  """Add a resource argument for a Cloud AI Platform Tensorboard.
749
750  NOTE: Must be used only if it's the only resource arg in the command.
751
752  Args:
753    parser: the parser for the command.
754    verb: str, the verb to describe the resource, such as 'to update'.
755  """
756  concept_parsers.ConceptParser.ForResource(
757      'tensorboard',
758      GetTensorboardResourceSpec(),
759      'The tensorboard {}.'.format(verb),
760      required=True).AddToParser(parser)
761
762
763def ParseAcceleratorFlag(accelerator, version):
764  """Validates and returns a accelerator config message object."""
765  if accelerator is None:
766    return None
767  types = list(c for c in GetAcceleratorTypeMapper(version).choices)
768  raw_type = accelerator.get('type', None)
769  if raw_type not in types:
770    raise errors.ArgumentError("""\
771The type of the accelerator can only be one of the following: {}.
772""".format(', '.join(["'{}'".format(c) for c in types])))
773  accelerator_count = accelerator.get('count', 1)
774  if accelerator_count <= 0:
775    raise errors.ArgumentError("""\
776The count of the accelerator must be greater than 0.
777""")
778  if version == constants.ALPHA_VERSION:
779    accelerator_msg = (
780        apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME,
781                               constants.AI_PLATFORM_API_VERSION[version])
782        .GoogleCloudAiplatformV1alpha1MachineSpec)
783  else:
784    accelerator_msg = (
785        apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME,
786                               constants.AI_PLATFORM_API_VERSION[version])
787        .GoogleCloudAiplatformV1beta1MachineSpec)
788  accelerator_type = arg_utils.ChoiceToEnum(
789      raw_type, accelerator_msg.AcceleratorTypeValueValuesEnum)
790  return accelerator_msg(
791      acceleratorCount=accelerator_count, acceleratorType=accelerator_type)
792
793
794def GetAcceleratorTypeMapper(version):
795  """Get a mapper for accelerator type to enum value."""
796  if version == constants.ALPHA_VERSION:
797    return arg_utils.ChoiceEnumMapper(
798        'generic-accelerator',
799        apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME,
800                               constants.AI_PLATFORM_API_VERSION[version])
801        .GoogleCloudAiplatformV1beta1MachineSpec.AcceleratorTypeValueValuesEnum,
802        help_str='The available types of accelerators.',
803        include_filter=lambda x: x.startswith('NVIDIA'),
804        required=False)
805  return arg_utils.ChoiceEnumMapper(
806      'generic-accelerator',
807      apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME,
808                             constants.AI_PLATFORM_API_VERSION[version])
809      .GoogleCloudAiplatformV1beta1MachineSpec.AcceleratorTypeValueValuesEnum,
810      help_str='The available types of accelerators.',
811      include_filter=lambda x: x.startswith('NVIDIA'),
812      required=False)
813
814
815def AddCreateHpTuningJobFlags(parser, algorithm_enum):
816  """Add arguments for creating hp tuning job."""
817  AddRegionResourceArg(parser, 'to upload model')
818  HPTUNING_JOB_DISPLAY_NAME.AddToParser(parser)
819  HPTUNING_JOB_CONFIG.AddToParser(parser)
820  HPTUNING_MAX_TRIAL_COUNT.AddToParser(parser)
821  HPTUNING_PARALLEL_TRIAL_COUNT.AddToParser(parser)
822  _TRAINING_SERVICE_ACCOUNT.AddToParser(parser)
823  _NETWORK.AddToParser(parser)
824  AddKmsKeyResourceArg(parser, 'hyperparameter tuning job')
825
826  arg_utils.ChoiceEnumMapper(
827      '--algorithm',
828      algorithm_enum,
829      help_str='Search algorithm specified for the given study. '
830  ).choice_arg.AddToParser(parser)
831
832
833def GetCustomJobResourceSpec(resource_name='custom_job'):
834  return concepts.ResourceSpec(
835      constants.CUSTOM_JOB_COLLECTION,
836      resource_name=resource_name,
837      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
838      locationsId=RegionAttributeConfig(),
839      disable_auto_completers=False)
840
841
842def AddCustomJobResourceArg(parser, verb):
843  """Add a resource argument for a Cloud AI Platform custom job.
844
845  NOTE: Must be used only if it's the only resource arg in the command.
846
847  Args:
848    parser: the parser for the command.
849    verb: str, the verb to describe the resource, such as 'to update'.
850  """
851  concept_parsers.ConceptParser.ForResource(
852      'custom_job',
853      GetCustomJobResourceSpec(),
854      'The custom job {}.'.format(verb),
855      required=True).AddToParser(parser)
856
857
858def GetHptuningJobResourceSpec(resource_name='hptuning_job'):
859  return concepts.ResourceSpec(
860      constants.HPTUNING_JOB_COLLECTION,
861      resource_name=resource_name,
862      projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
863      locationsId=RegionAttributeConfig(),
864      disable_auto_completers=False)
865
866
867def AddHptuningJobResourceArg(parser, verb):
868  """Add a resource argument for a Cloud AI Platform hyperparameter tuning  job.
869
870  NOTE: Must be used only if it's the only resource arg in the command.
871
872  Args:
873    parser: the parser for the command.
874    verb: str, the verb to describe the resource, such as 'to update'.
875  """
876  concept_parsers.ConceptParser.ForResource(
877      'hptuning_job',
878      GetHptuningJobResourceSpec(),
879      'The hyperparameter tuning job {}.'.format(verb),
880      required=True).AddToParser(parser)
881
882
883def AddKmsKeyResourceArg(parser, resource):
884  """Add the --kms-key resource arg to the given parser."""
885  permission_info = ("The 'AI Platform Service Agent' service account must hold"
886                     " permission 'Cloud KMS CryptoKey Encrypter/Decrypter'")
887  kms_resource_args.AddKmsKeyResourceArg(
888      parser, resource, permission_info=permission_info)
889
890
891def AddLocalRunCustomJobFlags(parser):
892  """Add local-run related flags to the parser."""
893
894  # Flags for entry point of the training application
895  application_group = parser.add_mutually_exclusive_group(required=True)
896  application_group.add_argument(
897      '--python-module',
898      metavar='PYTHON_MODULE',
899      help=textwrap.dedent("""
900      Name of the python module to execute, in 'trainer.train' or 'train'
901      format. Its path should be relative to the `work_dir`.
902      """))
903  application_group.add_argument(
904      '--script',
905      metavar='SCRIPT',
906      help=textwrap.dedent("""
907      The relative path of the file to execute. Accepets a Python file,
908      IPYNB file, or arbitrary bash script. This path should be relative to the
909      `work_dir`.
910      """))
911
912  # Flags for working directory.
913  parser.add_argument(
914      '--work-dir',
915      metavar='WORK_DIR',
916      help=textwrap.dedent("""
917      Path of the working directory where the python-module or script exists.
918      If not specified, it use the directory where you run the this command.
919
920      Only the contents of this directory will be accessible to the built
921      container image.
922      """))
923
924  # Flags for extra directory
925  parser.add_argument(
926      '--extra-dirs',
927      metavar='EXTRA_DIR',
928      type=arg_parsers.ArgList(),
929      help=textwrap.dedent("""
930      Extra directories under the working directory to include, besides the one
931      that contains the main executable.
932
933      By default, only the parent directory of the main script or python module
934      is copied to the container.
935      For example, if the module is "training.task" or the script is
936      "training/task.py", the whole "training" directory, including its
937      sub-directories, will always be copied to the container. You may specify
938      this flag to also copy other directories if necessary.
939
940      Note: if no parent is specified in 'python_module' or 'scirpt', the whole
941      working directory is copied, then you don't need to specify this flag.
942      """))
943
944  # Flags for base container image
945  parser.add_argument(
946      '--base-image',
947      metavar='BASE_IMAGE',
948      required=True,
949      help=textwrap.dedent("""
950      URI or ID of the container image in either the Container Registry or local
951      that will run the application.
952      See https://cloud.google.com/ai-platform-unified/docs/training/pre-built-containers
953      for available pre-built container images provided by AI Platform for training.
954      """))
955
956  # Flags for extra requirements.
957  parser.add_argument(
958      '--requirements',
959      metavar='REQUIREMENTS',
960      type=arg_parsers.ArgList(),
961      help=textwrap.dedent("""
962      Python dependencies from PyPI to be used when running the application.
963      If this is not specified, and there is no "setup.py" or "requirements.txt"
964      in the working directory, your application will only have access to what
965      exists in the base image with on other dependencies.
966
967      Example:
968      'tensorflow-cpu, pandas==1.2.0, matplotlib>=3.0.2'
969      """))
970
971  # Flags for extra dependency .
972  parser.add_argument(
973      '--extra-packages',
974      metavar='PACKAGE',
975      type=arg_parsers.ArgList(),
976      help=textwrap.dedent("""
977      Local paths to Python archives used as training dependencies in the image
978      container.
979      These can be absolute or relative paths. However, they have to be under
980      the work_dir; Otherwise, this tool will not be able to acces it.
981
982      Example:
983      'dep1.tar.gz, ./downloads/dep2.whl'
984      """))
985
986  # Flags for the output image
987  parser.add_argument(
988      '--output-image-uri',
989      metavar='OUTPUT_IMAGE',
990      help=textwrap.dedent("""
991      Uri of the custom container image to be built with the your application
992      packed in.
993      """))
994
995  # Flaga for GPU support
996  parser.add_argument(
997      '--gpu', action='store_true', default=False, help='Enable to use GPU.')
998
999  # Flags for docker run
1000  parser.add_argument(
1001      '--docker-run-options',
1002      metavar='DOCKER_RUN_OPTIONS',
1003      type=arg_parsers.ArgList(),
1004      help=textwrap.dedent("""
1005      Custom Docker run options to pass to image during execution.
1006      For example, '--no-healthcheck, -a stdin'.
1007
1008      See https://docs.docker.com/engine/reference/commandline/run/#options for
1009      more details.
1010      """))
1011
1012  # User custom flags.
1013  parser.add_argument(
1014      'args',
1015      nargs=argparse.REMAINDER,
1016      default=[],
1017      help="""Additional user arguments to be forwarded to your application.""",
1018      example=('$ {command} --script=my_run.sh --base-image=gcr.io/my/image '
1019               '-- --my-arg bar --enable_foo'))
1020