1# -*- coding: utf-8 -*- #
2# Copyright 2018 Google LLC. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Utility methods for the compute node templates commands."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import unicode_literals
20
21from apitools.base.py import encoding
22from googlecloudsdk.command_lib.compute.sole_tenancy.node_templates import flags
23from googlecloudsdk.command_lib.util.apis import arg_utils
24import six
25
26
27def _ParseNodeAffinityLabels(affinity_labels, messages):
28  affinity_labels_class = messages.NodeTemplate.NodeAffinityLabelsValue
29  return encoding.DictToAdditionalPropertyMessage(
30      affinity_labels, affinity_labels_class, sort_items=True)
31
32
33def CreateNodeTemplate(node_template_ref,
34                       args,
35                       messages):
36  """Creates a Node Template message from args."""
37  node_affinity_labels = None
38  if args.node_affinity_labels:
39    node_affinity_labels = _ParseNodeAffinityLabels(args.node_affinity_labels,
40                                                    messages)
41
42  node_type_flexbility = None
43  if args.IsSpecified('node_requirements'):
44    node_type_flexbility = messages.NodeTemplateNodeTypeFlexibility(
45        cpus=six.text_type(args.node_requirements.get('vCPU', 'any')),
46        # local SSD is unique because the user may omit the local SSD constraint
47        # entirely to include the possibility of node types with no local SSD.
48        # "any" corresponds to "greater than zero".
49        localSsd=args.node_requirements.get('localSSD', None),
50        memory=args.node_requirements.get('memory', 'any'))
51
52  node_template = messages.NodeTemplate(
53      name=node_template_ref.Name(),
54      description=args.description,
55      nodeAffinityLabels=node_affinity_labels,
56      nodeType=args.node_type,
57      nodeTypeFlexibility=node_type_flexbility)
58
59  if args.IsSpecified('disk'):
60    local_disk = messages.LocalDisk(
61        diskCount=args.disk.get('count'),
62        diskSizeGb=args.disk.get('size'),
63        diskType=args.disk.get('type'))
64    node_template.disks = [local_disk]
65
66  if args.IsSpecified('cpu_overcommit_type'):
67    overcommit_type = arg_utils.ChoiceToEnum(
68        args.cpu_overcommit_type,
69        messages.NodeTemplate.CpuOvercommitTypeValueValuesEnum)
70    node_template.cpuOvercommitType = overcommit_type
71
72  node_template.accelerators = GetAccelerators(args, messages)
73
74  server_binding_flag = flags.GetServerBindingMapperFlag(messages)
75  server_binding = messages.ServerBinding(
76      type=server_binding_flag.GetEnumForChoice(args.server_binding))
77  node_template.serverBinding = server_binding
78
79  return node_template
80
81
82def GetAccelerators(args, messages):
83  """Returns list of messages with accelerators for the instance."""
84  if args.accelerator:
85    accelerator_type = args.accelerator['type']
86    accelerator_count = int(args.accelerator.get('count', 4))
87    return CreateAcceleratorConfigMessages(messages, accelerator_type,
88                                           accelerator_count)
89  return []
90
91
92def CreateAcceleratorConfigMessages(msgs, accelerator_type, accelerator_count):
93  """Returns a list of accelerator config messages.
94
95  Args:
96    msgs: tracked GCE API messages.
97    accelerator_type: reference to the accelerator type.
98    accelerator_count: number of accelerators to attach to the VM.
99
100  Returns:
101    a list of accelerator config message that specifies the type and number of
102    accelerators to attach to an instance.
103  """
104
105  accelerator_config = msgs.AcceleratorConfig(
106      acceleratorType=accelerator_type, acceleratorCount=accelerator_count)
107  return [accelerator_config]
108
109
110def ParseAcceleratorType(accelerator_type_name, resource_parser, project,
111                         region):
112  collection = 'compute.regionAcceleratorTypes'
113  params = {'project': project, 'region': region}
114  accelerator_type = resource_parser.Parse(
115      accelerator_type_name, collection=collection, params=params).SelfLink()
116  return accelerator_type
117