1# -*- coding: utf-8 -*- #
2# Copyright 2019 Google LLC. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Shared utilities to access the Google Secret Manager API."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import unicode_literals
20
21from apitools.base.py import exceptions as apitools_exceptions
22from apitools.base.py import list_pager
23from googlecloudsdk.api_lib.util import apis
24
25
26def GetClient(version=None):
27  """Get the default client."""
28  return apis.GetClientInstance('secretmanager', version or
29                                apis.ResolveVersion('secretmanager'))
30
31
32def GetMessages(version=None):
33  """Get the default messages module."""
34  return apis.GetMessagesModule('secretmanager', version or
35                                apis.ResolveVersion('secretmanager'))
36
37
38def _FormatUpdateMask(update_mask):
39  return ','.join(update_mask)
40
41
42def _MakeReplicationMessage(messages, policy, locations, keys):
43  """Create a replication message from its components."""
44  replication = messages.Replication(automatic=messages.Automatic())
45  if policy == 'automatic' and keys:
46    replication = messages.Replication(
47        automatic=messages.Automatic(
48            customerManagedEncryption=messages.CustomerManagedEncryption(
49                kmsKeyName=keys[0])))
50  if policy == 'user-managed':
51    replicas = []
52    for i, location in enumerate(locations):
53      if i < len(keys):
54        replicas.append(
55            messages.Replica(
56                location=location,
57                customerManagedEncryption=messages.CustomerManagedEncryption(
58                    kmsKeyName=keys[i])))
59      else:
60        replicas.append(messages.Replica(location=locations[i]))
61
62    replication = messages.Replication(
63        userManaged=messages.UserManaged(replicas=replicas))
64  return replication
65
66
67class Client(object):
68  """Base class for all clients."""
69
70  def __init__(self, client=None, messages=None):
71    self.client = client or GetClient()
72    self.messages = messages or self.client.MESSAGES_MODULE
73
74
75class Locations(Client):
76  """High-level client for locations."""
77
78  def __init__(self, client=None, messages=None):
79    super(Locations, self).__init__(client, messages)
80    self.service = self.client.projects_locations
81
82  def Get(self, location_ref):
83    """Get the location with the given name."""
84    return self.service.Get(
85        self.messages.SecretmanagerProjectsLocationsGetRequest(
86            name=location_ref.RelativeName()))
87
88  def ListWithPager(self, project_ref, limit):
89    """List secrets returning a pager object."""
90    request = self.messages.SecretmanagerProjectsLocationsListRequest(
91        name=project_ref.RelativeName())
92
93    return list_pager.YieldFromList(
94        service=self.service,
95        request=request,
96        field='locations',
97        limit=limit,
98        batch_size_attribute='pageSize')
99
100
101class Secrets(Client):
102  """High-level client for secrets."""
103
104  def __init__(self, client=None, messages=None):
105    super(Secrets, self).__init__(client, messages)
106    self.service = self.client.projects_secrets
107
108  def Create(self,
109             secret_ref,
110             policy,
111             locations,
112             labels,
113             expire_time=None,
114             ttl=None,
115             keys=None,
116             topics=None):
117    """Create a secret."""
118    keys = keys or []
119    replication = _MakeReplicationMessage(self.messages, policy, locations,
120                                          keys)
121    topics_message_list = []
122    if topics:
123      for topic in topics:
124        topics_message_list.append(self.messages.Topic(name=topic))
125
126    return self.service.Create(
127        self.messages.SecretmanagerProjectsSecretsCreateRequest(
128            parent=secret_ref.Parent().RelativeName(),
129            secretId=secret_ref.Name(),
130            secret=self.messages.Secret(
131                labels=labels,
132                replication=replication,
133                expireTime=expire_time,
134                ttl=ttl,
135                topics=topics_message_list)))
136
137  def Delete(self, secret_ref):
138    """Delete a secret."""
139    return self.service.Delete(
140        self.messages.SecretmanagerProjectsSecretsDeleteRequest(
141            name=secret_ref.RelativeName()))
142
143  def Get(self, secret_ref):
144    """Get the secret with the given name."""
145    return self.service.Get(
146        self.messages.SecretmanagerProjectsSecretsGetRequest(
147            name=secret_ref.RelativeName()))
148
149  def GetOrNone(self, secret_ref):
150    """Attempt to get the secret, returning None if the secret does not exist."""
151    try:
152      return self.Get(secret_ref=secret_ref)
153    except apitools_exceptions.HttpNotFoundError:
154      return None
155
156  def ListWithPager(self, project_ref, limit):
157    """List secrets returning a pager object."""
158    request = self.messages.SecretmanagerProjectsSecretsListRequest(
159        parent=project_ref.RelativeName())
160
161    return list_pager.YieldFromList(
162        service=self.service,
163        request=request,
164        field='secrets',
165        limit=limit,
166        batch_size_attribute='pageSize')
167
168  def AddVersion(self, secret_ref, data):
169    """Add a new version of an existing secret."""
170    request = self.messages.SecretmanagerProjectsSecretsAddVersionRequest(
171        parent=secret_ref.RelativeName(),
172        addSecretVersionRequest=self.messages.AddSecretVersionRequest(
173            payload=self.messages.SecretPayload(data=data)))
174    return self.service.AddVersion(request)
175
176  def Update(self,
177             secret_ref,
178             labels,
179             update_mask,
180             expire_time=None,
181             ttl=None,
182             topics=None):
183    """Update a secret."""
184    topics_message_list = []
185    if topics:
186      for topic in topics:
187        topics_message_list.append(self.messages.Topic(name=topic))
188    return self.service.Patch(
189        self.messages.SecretmanagerProjectsSecretsPatchRequest(
190            name=secret_ref.RelativeName(),
191            secret=self.messages.Secret(
192                labels=labels,
193                expireTime=expire_time,
194                ttl=ttl,
195                topics=topics_message_list),
196            updateMask=_FormatUpdateMask(update_mask)))
197
198  def SetReplication(self, secret_ref, policy, locations, keys):
199    """Set the replication policy on an existing secret.."""
200    replication = _MakeReplicationMessage(self.messages, policy, locations,
201                                          keys)
202    return self.service.Patch(
203        self.messages.SecretmanagerProjectsSecretsPatchRequest(
204            name=secret_ref.RelativeName(),
205            secret=self.messages.Secret(replication=replication),
206            updateMask=_FormatUpdateMask(['replication'])))
207
208
209class SecretsLatest(Client):
210  """High-level client for latest secrets."""
211
212  def __init__(self, client=None, messages=None):
213    super(SecretsLatest, self).__init__(client, messages)
214    self.service = self.client.projects_secrets_latest
215
216  def Access(self, secret_ref):
217    """Access the latest version of a secret."""
218    return self.service.Access(
219        self.messages.SecretmanagerProjectsSecretsLatestAccessRequest(
220            name=secret_ref.RelativeName()))
221
222
223class Versions(Client):
224  """High-level client for secret versions."""
225
226  def __init__(self, client=None, messages=None):
227    super(Versions, self).__init__(client, messages)
228    self.service = self.client.projects_secrets_versions
229
230  def Access(self, version_ref):
231    """Access a specific version of a secret."""
232    return self.service.Access(
233        self.messages.SecretmanagerProjectsSecretsVersionsAccessRequest(
234            name=version_ref.RelativeName()))
235
236  def Destroy(self, version_ref):
237    """Destroy a secret version."""
238    return self.service.Destroy(
239        self.messages.SecretmanagerProjectsSecretsVersionsDestroyRequest(
240            name=version_ref.RelativeName()))
241
242  def Disable(self, version_ref):
243    """Disable a secret version."""
244    return self.service.Disable(
245        self.messages.SecretmanagerProjectsSecretsVersionsDisableRequest(
246            name=version_ref.RelativeName()))
247
248  def Enable(self, version_ref):
249    """Enable a secret version."""
250    return self.service.Enable(
251        self.messages.SecretmanagerProjectsSecretsVersionsEnableRequest(
252            name=version_ref.RelativeName()))
253
254  def Get(self, version_ref):
255    """Get the secret version with the given name."""
256    return self.service.Get(
257        self.messages.SecretmanagerProjectsSecretsVersionsGetRequest(
258            name=version_ref.RelativeName()))
259
260  def List(self, secret_ref, limit):
261    """List secrets and return an array."""
262    request = self.messages.SecretmanagerProjectsSecretsVersionsListRequest(
263        parent=secret_ref.RelativeName(), pageSize=limit)
264    return self.service.List(request)
265
266  def ListWithPager(self, secret_ref, limit):
267    """List secrets returning a pager object."""
268    request = self.messages.SecretmanagerProjectsSecretsVersionsListRequest(
269        parent=secret_ref.RelativeName(), pageSize=0)
270    return list_pager.YieldFromList(
271        service=self.service,
272        request=request,
273        field='versions',
274        limit=limit,
275        batch_size=0,
276        batch_size_attribute='pageSize')
277