1package gcs
2
3import (
4	"context"
5	"crypto/md5"
6	"errors"
7	"fmt"
8	"io/ioutil"
9	"os"
10	"sort"
11	"strconv"
12	"strings"
13	"time"
14
15	metrics "github.com/armon/go-metrics"
16	"github.com/hashicorp/errwrap"
17	log "github.com/hashicorp/go-hclog"
18	multierror "github.com/hashicorp/go-multierror"
19	"github.com/hashicorp/vault/sdk/helper/useragent"
20	"github.com/hashicorp/vault/sdk/physical"
21
22	"cloud.google.com/go/storage"
23	"google.golang.org/api/iterator"
24	"google.golang.org/api/option"
25)
26
27// Verify Backend satisfies the correct interfaces
28var _ physical.Backend = (*Backend)(nil)
29
30const (
31	// envBucket is the name of the environment variable to search for the
32	// storage bucket name.
33	envBucket = "GOOGLE_STORAGE_BUCKET"
34
35	// envChunkSize is the environment variable to serach for the chunk size for
36	// requests.
37	envChunkSize = "GOOGLE_STORAGE_CHUNK_SIZE"
38
39	// envHAEnabled is the name of the environment variable to search for the
40	// boolean indicating if HA is enabled.
41	envHAEnabled = "GOOGLE_STORAGE_HA_ENABLED"
42
43	// defaultChunkSize is the number of bytes the writer will attempt to write in
44	// a single request.
45	defaultChunkSize = "8192"
46
47	// objectDelimiter is the string to use to delimit objects.
48	objectDelimiter = "/"
49)
50
51var (
52	// metricDelete is the key for the metric for measuring a Delete call.
53	metricDelete = []string{"gcs", "delete"}
54
55	// metricGet is the key for the metric for measuring a Get call.
56	metricGet = []string{"gcs", "get"}
57
58	// metricList is the key for the metric for measuring a List call.
59	metricList = []string{"gcs", "list"}
60
61	// metricPut is the key for the metric for measuring a Put call.
62	metricPut = []string{"gcs", "put"}
63)
64
65// Backend implements physical.Backend and describes the steps necessary to
66// persist data in Google Cloud Storage.
67type Backend struct {
68	// bucket is the name of the bucket to use for data storage and retrieval.
69	bucket string
70
71	// chunkSize is the chunk size to use for requests.
72	chunkSize int
73
74	// client is the underlying API client for talking to gcs.
75	client *storage.Client
76
77	// haEnabled indicates if HA is enabled.
78	haEnabled bool
79
80	// logger and permitPool are internal constructs
81	logger     log.Logger
82	permitPool *physical.PermitPool
83}
84
85// NewBackend constructs a Google Cloud Storage backend with the given
86// configuration. This uses the official Golang Cloud SDK and therefore supports
87// specifying credentials via envvars, credential files, etc. from environment
88// variables or a service account file
89func NewBackend(c map[string]string, logger log.Logger) (physical.Backend, error) {
90	logger.Debug("configuring backend")
91
92	// Bucket name
93	bucket := os.Getenv(envBucket)
94	if bucket == "" {
95		bucket = c["bucket"]
96	}
97	if bucket == "" {
98		return nil, errors.New("missing bucket name")
99	}
100
101	// Chunk size
102	chunkSizeStr := os.Getenv(envChunkSize)
103	if chunkSizeStr == "" {
104		chunkSizeStr = c["chunk_size"]
105	}
106	if chunkSizeStr == "" {
107		chunkSizeStr = defaultChunkSize
108	}
109	chunkSize, err := strconv.Atoi(chunkSizeStr)
110	if err != nil {
111		return nil, errwrap.Wrapf("failed to parse chunk_size: {{err}}", err)
112	}
113
114	// Values are specified as kb, but the API expects them as bytes.
115	chunkSize = chunkSize * 1024
116
117	// HA configuration
118	haEnabled := false
119	haEnabledStr := os.Getenv(envHAEnabled)
120	if haEnabledStr == "" {
121		haEnabledStr = c["ha_enabled"]
122	}
123	if haEnabledStr != "" {
124		var err error
125		haEnabled, err = strconv.ParseBool(haEnabledStr)
126		if err != nil {
127			return nil, errwrap.Wrapf("failed to parse HA enabled: {{err}}", err)
128		}
129	}
130
131	// Max parallel
132	maxParallel, err := extractInt(c["max_parallel"])
133	if err != nil {
134		return nil, errwrap.Wrapf("failed to parse max_parallel: {{err}}", err)
135	}
136
137	logger.Debug("configuration",
138		"bucket", bucket,
139		"chunk_size", chunkSize,
140		"ha_enabled", haEnabled,
141		"max_parallel", maxParallel,
142	)
143	logger.Debug("creating client")
144
145	// Client
146	opts := []option.ClientOption{option.WithUserAgent(useragent.String())}
147	if credentialsFile := c["credentials_file"]; credentialsFile != "" {
148		logger.Warn("specifying credentials_file as an option is " +
149			"deprecated. Please use the GOOGLE_APPLICATION_CREDENTIALS environment " +
150			"variable or instance credentials instead.")
151		opts = append(opts, option.WithCredentialsFile(credentialsFile))
152	}
153
154	ctx := context.Background()
155	client, err := storage.NewClient(ctx, opts...)
156	if err != nil {
157		return nil, errwrap.Wrapf("failed to create storage client: {{err}}", err)
158	}
159
160	return &Backend{
161		bucket:     bucket,
162		haEnabled:  haEnabled,
163		chunkSize:  chunkSize,
164		client:     client,
165		permitPool: physical.NewPermitPool(maxParallel),
166		logger:     logger,
167	}, nil
168}
169
170// Put is used to insert or update an entry
171func (b *Backend) Put(ctx context.Context, entry *physical.Entry) (retErr error) {
172	defer metrics.MeasureSince(metricPut, time.Now())
173
174	// Pooling
175	b.permitPool.Acquire()
176	defer b.permitPool.Release()
177
178	// Insert
179	w := b.client.Bucket(b.bucket).Object(entry.Key).NewWriter(ctx)
180	w.ChunkSize = b.chunkSize
181	md5Array := md5.Sum(entry.Value)
182	w.MD5 = md5Array[:]
183	defer func() {
184		closeErr := w.Close()
185		if closeErr != nil {
186			retErr = multierror.Append(retErr, errwrap.Wrapf("error closing connection: {{err}}", closeErr))
187		}
188	}()
189
190	if _, err := w.Write(entry.Value); err != nil {
191		return errwrap.Wrapf("failed to put data: {{err}}", err)
192	}
193	return nil
194}
195
196// Get fetches an entry. If no entry exists, this function returns nil.
197func (b *Backend) Get(ctx context.Context, key string) (retEntry *physical.Entry, retErr error) {
198	defer metrics.MeasureSince(metricGet, time.Now())
199
200	// Pooling
201	b.permitPool.Acquire()
202	defer b.permitPool.Release()
203
204	// Read
205	r, err := b.client.Bucket(b.bucket).Object(key).NewReader(ctx)
206	if err == storage.ErrObjectNotExist {
207		return nil, nil
208	}
209	if err != nil {
210		return nil, errwrap.Wrapf(fmt.Sprintf("failed to read value for %q: {{err}}", key), err)
211	}
212
213	defer func() {
214		closeErr := r.Close()
215		if closeErr != nil {
216			retErr = multierror.Append(retErr, errwrap.Wrapf("error closing connection: {{err}}", closeErr))
217		}
218	}()
219
220	value, err := ioutil.ReadAll(r)
221	if err != nil {
222		return nil, errwrap.Wrapf("failed to read value into a string: {{err}}", err)
223	}
224
225	return &physical.Entry{
226		Key:   key,
227		Value: value,
228	}, nil
229}
230
231// Delete deletes an entry with the given key
232func (b *Backend) Delete(ctx context.Context, key string) error {
233	defer metrics.MeasureSince(metricDelete, time.Now())
234
235	// Pooling
236	b.permitPool.Acquire()
237	defer b.permitPool.Release()
238
239	// Delete
240	err := b.client.Bucket(b.bucket).Object(key).Delete(ctx)
241	if err != nil && err != storage.ErrObjectNotExist {
242		return errwrap.Wrapf(fmt.Sprintf("failed to delete key %q: {{err}}", key), err)
243	}
244	return nil
245}
246
247// List is used to list all the keys under a given
248// prefix, up to the next prefix.
249func (b *Backend) List(ctx context.Context, prefix string) ([]string, error) {
250	defer metrics.MeasureSince(metricList, time.Now())
251
252	// Pooling
253	b.permitPool.Acquire()
254	defer b.permitPool.Release()
255
256	iter := b.client.Bucket(b.bucket).Objects(ctx, &storage.Query{
257		Prefix:    prefix,
258		Delimiter: objectDelimiter,
259		Versions:  false,
260	})
261
262	keys := []string{}
263
264	for {
265		objAttrs, err := iter.Next()
266		if err == iterator.Done {
267			break
268		}
269		if err != nil {
270			return nil, errwrap.Wrapf("failed to read object: {{err}}", err)
271		}
272
273		var path string
274		if objAttrs.Prefix != "" {
275			// "subdirectory"
276			path = objAttrs.Prefix
277		} else {
278			// file
279			path = objAttrs.Name
280		}
281
282		// get relative file/dir just like "basename"
283		key := strings.TrimPrefix(path, prefix)
284		keys = append(keys, key)
285	}
286
287	sort.Strings(keys)
288
289	return keys, nil
290}
291
292// extractInt is a helper function that takes a string and converts that string
293// to an int, but accounts for the empty string.
294func extractInt(s string) (int, error) {
295	if s == "" {
296		return 0, nil
297	}
298	return strconv.Atoi(s)
299}
300