1package gcs 2 3import ( 4 "context" 5 "crypto/md5" 6 "errors" 7 "fmt" 8 "io/ioutil" 9 "os" 10 "sort" 11 "strconv" 12 "strings" 13 "time" 14 15 metrics "github.com/armon/go-metrics" 16 "github.com/hashicorp/errwrap" 17 log "github.com/hashicorp/go-hclog" 18 multierror "github.com/hashicorp/go-multierror" 19 "github.com/hashicorp/vault/sdk/helper/useragent" 20 "github.com/hashicorp/vault/sdk/physical" 21 22 "cloud.google.com/go/storage" 23 "google.golang.org/api/iterator" 24 "google.golang.org/api/option" 25) 26 27// Verify Backend satisfies the correct interfaces 28var _ physical.Backend = (*Backend)(nil) 29 30const ( 31 // envBucket is the name of the environment variable to search for the 32 // storage bucket name. 33 envBucket = "GOOGLE_STORAGE_BUCKET" 34 35 // envChunkSize is the environment variable to serach for the chunk size for 36 // requests. 37 envChunkSize = "GOOGLE_STORAGE_CHUNK_SIZE" 38 39 // envHAEnabled is the name of the environment variable to search for the 40 // boolean indicating if HA is enabled. 41 envHAEnabled = "GOOGLE_STORAGE_HA_ENABLED" 42 43 // defaultChunkSize is the number of bytes the writer will attempt to write in 44 // a single request. 45 defaultChunkSize = "8192" 46 47 // objectDelimiter is the string to use to delimit objects. 48 objectDelimiter = "/" 49) 50 51var ( 52 // metricDelete is the key for the metric for measuring a Delete call. 53 metricDelete = []string{"gcs", "delete"} 54 55 // metricGet is the key for the metric for measuring a Get call. 56 metricGet = []string{"gcs", "get"} 57 58 // metricList is the key for the metric for measuring a List call. 59 metricList = []string{"gcs", "list"} 60 61 // metricPut is the key for the metric for measuring a Put call. 62 metricPut = []string{"gcs", "put"} 63) 64 65// Backend implements physical.Backend and describes the steps necessary to 66// persist data in Google Cloud Storage. 67type Backend struct { 68 // bucket is the name of the bucket to use for data storage and retrieval. 69 bucket string 70 71 // chunkSize is the chunk size to use for requests. 72 chunkSize int 73 74 // client is the underlying API client for talking to gcs. 75 client *storage.Client 76 77 // haEnabled indicates if HA is enabled. 78 haEnabled bool 79 80 // logger and permitPool are internal constructs 81 logger log.Logger 82 permitPool *physical.PermitPool 83} 84 85// NewBackend constructs a Google Cloud Storage backend with the given 86// configuration. This uses the official Golang Cloud SDK and therefore supports 87// specifying credentials via envvars, credential files, etc. from environment 88// variables or a service account file 89func NewBackend(c map[string]string, logger log.Logger) (physical.Backend, error) { 90 logger.Debug("configuring backend") 91 92 // Bucket name 93 bucket := os.Getenv(envBucket) 94 if bucket == "" { 95 bucket = c["bucket"] 96 } 97 if bucket == "" { 98 return nil, errors.New("missing bucket name") 99 } 100 101 // Chunk size 102 chunkSizeStr := os.Getenv(envChunkSize) 103 if chunkSizeStr == "" { 104 chunkSizeStr = c["chunk_size"] 105 } 106 if chunkSizeStr == "" { 107 chunkSizeStr = defaultChunkSize 108 } 109 chunkSize, err := strconv.Atoi(chunkSizeStr) 110 if err != nil { 111 return nil, errwrap.Wrapf("failed to parse chunk_size: {{err}}", err) 112 } 113 114 // Values are specified as kb, but the API expects them as bytes. 115 chunkSize = chunkSize * 1024 116 117 // HA configuration 118 haEnabled := false 119 haEnabledStr := os.Getenv(envHAEnabled) 120 if haEnabledStr == "" { 121 haEnabledStr = c["ha_enabled"] 122 } 123 if haEnabledStr != "" { 124 var err error 125 haEnabled, err = strconv.ParseBool(haEnabledStr) 126 if err != nil { 127 return nil, errwrap.Wrapf("failed to parse HA enabled: {{err}}", err) 128 } 129 } 130 131 // Max parallel 132 maxParallel, err := extractInt(c["max_parallel"]) 133 if err != nil { 134 return nil, errwrap.Wrapf("failed to parse max_parallel: {{err}}", err) 135 } 136 137 logger.Debug("configuration", 138 "bucket", bucket, 139 "chunk_size", chunkSize, 140 "ha_enabled", haEnabled, 141 "max_parallel", maxParallel, 142 ) 143 logger.Debug("creating client") 144 145 // Client 146 opts := []option.ClientOption{option.WithUserAgent(useragent.String())} 147 if credentialsFile := c["credentials_file"]; credentialsFile != "" { 148 logger.Warn("specifying credentials_file as an option is " + 149 "deprecated. Please use the GOOGLE_APPLICATION_CREDENTIALS environment " + 150 "variable or instance credentials instead.") 151 opts = append(opts, option.WithCredentialsFile(credentialsFile)) 152 } 153 154 ctx := context.Background() 155 client, err := storage.NewClient(ctx, opts...) 156 if err != nil { 157 return nil, errwrap.Wrapf("failed to create storage client: {{err}}", err) 158 } 159 160 return &Backend{ 161 bucket: bucket, 162 haEnabled: haEnabled, 163 chunkSize: chunkSize, 164 client: client, 165 permitPool: physical.NewPermitPool(maxParallel), 166 logger: logger, 167 }, nil 168} 169 170// Put is used to insert or update an entry 171func (b *Backend) Put(ctx context.Context, entry *physical.Entry) (retErr error) { 172 defer metrics.MeasureSince(metricPut, time.Now()) 173 174 // Pooling 175 b.permitPool.Acquire() 176 defer b.permitPool.Release() 177 178 // Insert 179 w := b.client.Bucket(b.bucket).Object(entry.Key).NewWriter(ctx) 180 w.ChunkSize = b.chunkSize 181 md5Array := md5.Sum(entry.Value) 182 w.MD5 = md5Array[:] 183 defer func() { 184 closeErr := w.Close() 185 if closeErr != nil { 186 retErr = multierror.Append(retErr, errwrap.Wrapf("error closing connection: {{err}}", closeErr)) 187 } 188 }() 189 190 if _, err := w.Write(entry.Value); err != nil { 191 return errwrap.Wrapf("failed to put data: {{err}}", err) 192 } 193 return nil 194} 195 196// Get fetches an entry. If no entry exists, this function returns nil. 197func (b *Backend) Get(ctx context.Context, key string) (retEntry *physical.Entry, retErr error) { 198 defer metrics.MeasureSince(metricGet, time.Now()) 199 200 // Pooling 201 b.permitPool.Acquire() 202 defer b.permitPool.Release() 203 204 // Read 205 r, err := b.client.Bucket(b.bucket).Object(key).NewReader(ctx) 206 if err == storage.ErrObjectNotExist { 207 return nil, nil 208 } 209 if err != nil { 210 return nil, errwrap.Wrapf(fmt.Sprintf("failed to read value for %q: {{err}}", key), err) 211 } 212 213 defer func() { 214 closeErr := r.Close() 215 if closeErr != nil { 216 retErr = multierror.Append(retErr, errwrap.Wrapf("error closing connection: {{err}}", closeErr)) 217 } 218 }() 219 220 value, err := ioutil.ReadAll(r) 221 if err != nil { 222 return nil, errwrap.Wrapf("failed to read value into a string: {{err}}", err) 223 } 224 225 return &physical.Entry{ 226 Key: key, 227 Value: value, 228 }, nil 229} 230 231// Delete deletes an entry with the given key 232func (b *Backend) Delete(ctx context.Context, key string) error { 233 defer metrics.MeasureSince(metricDelete, time.Now()) 234 235 // Pooling 236 b.permitPool.Acquire() 237 defer b.permitPool.Release() 238 239 // Delete 240 err := b.client.Bucket(b.bucket).Object(key).Delete(ctx) 241 if err != nil && err != storage.ErrObjectNotExist { 242 return errwrap.Wrapf(fmt.Sprintf("failed to delete key %q: {{err}}", key), err) 243 } 244 return nil 245} 246 247// List is used to list all the keys under a given 248// prefix, up to the next prefix. 249func (b *Backend) List(ctx context.Context, prefix string) ([]string, error) { 250 defer metrics.MeasureSince(metricList, time.Now()) 251 252 // Pooling 253 b.permitPool.Acquire() 254 defer b.permitPool.Release() 255 256 iter := b.client.Bucket(b.bucket).Objects(ctx, &storage.Query{ 257 Prefix: prefix, 258 Delimiter: objectDelimiter, 259 Versions: false, 260 }) 261 262 keys := []string{} 263 264 for { 265 objAttrs, err := iter.Next() 266 if err == iterator.Done { 267 break 268 } 269 if err != nil { 270 return nil, errwrap.Wrapf("failed to read object: {{err}}", err) 271 } 272 273 var path string 274 if objAttrs.Prefix != "" { 275 // "subdirectory" 276 path = objAttrs.Prefix 277 } else { 278 // file 279 path = objAttrs.Name 280 } 281 282 // get relative file/dir just like "basename" 283 key := strings.TrimPrefix(path, prefix) 284 keys = append(keys, key) 285 } 286 287 sort.Strings(keys) 288 289 return keys, nil 290} 291 292// extractInt is a helper function that takes a string and converts that string 293// to an int, but accounts for the empty string. 294func extractInt(s string) (int, error) { 295 if s == "" { 296 return 0, nil 297 } 298 return strconv.Atoi(s) 299} 300