1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10	"context"
11	"strings"
12
13	"github.com/pkg/errors"
14	"go.mongodb.org/mongo-driver/bson"
15	"go.mongodb.org/mongo-driver/bson/primitive"
16	"go.mongodb.org/mongo-driver/mongo/options"
17	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
18	"go.mongodb.org/mongo-driver/x/mongo/driver"
19	cryptOpts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
20)
21
22// ClientEncryption is used to create data keys and explicitly encrypt and decrypt BSON values.
23type ClientEncryption struct {
24	crypt          *driver.Crypt
25	keyVaultClient *Client
26	keyVaultColl   *Collection
27}
28
29// NewClientEncryption creates a new ClientEncryption instance configured with the given options.
30func NewClientEncryption(keyVaultClient *Client, opts ...*options.ClientEncryptionOptions) (*ClientEncryption, error) {
31	if keyVaultClient == nil {
32		return nil, errors.New("keyVaultClient must not be nil")
33	}
34
35	ce := &ClientEncryption{
36		keyVaultClient: keyVaultClient,
37	}
38	ceo := options.MergeClientEncryptionOptions(opts...)
39
40	// create keyVaultColl
41	db, coll := splitNamespace(ceo.KeyVaultNamespace)
42	ce.keyVaultColl = ce.keyVaultClient.Database(db).Collection(coll, keyVaultCollOpts)
43
44	// create Crypt
45	var err error
46	kr := keyRetriever{coll: ce.keyVaultColl}
47	cir := collInfoRetriever{client: ce.keyVaultClient}
48	ce.crypt, err = driver.NewCrypt(&driver.CryptOptions{
49		KeyFn:        kr.cryptKeys,
50		CollInfoFn:   cir.cryptCollInfo,
51		KmsProviders: ceo.KmsProviders,
52	})
53	if err != nil {
54		return nil, err
55	}
56
57	return ce, nil
58}
59
60// CreateDataKey creates a new key document and inserts it into the key vault collection. Returns the _id of the
61// created document.
62func (ce *ClientEncryption) CreateDataKey(ctx context.Context, kmsProvider string, opts ...*options.DataKeyOptions) (primitive.Binary, error) {
63	// translate opts to cryptOpts.DataKeyOptions
64	dko := options.MergeDataKeyOptions(opts...)
65	co := cryptOpts.DataKey().SetKeyAltNames(dko.KeyAltNames)
66	if dko.MasterKey != nil {
67		keyDoc, err := transformBsoncoreDocument(ce.keyVaultClient.registry, dko.MasterKey)
68		if err != nil {
69			return primitive.Binary{}, err
70		}
71
72		co.SetMasterKey(keyDoc)
73	}
74
75	// create data key document
76	dataKeyDoc, err := ce.crypt.CreateDataKey(ctx, kmsProvider, co)
77	if err != nil {
78		return primitive.Binary{}, err
79	}
80
81	// insert key into key vault
82	_, err = ce.keyVaultColl.InsertOne(ctx, dataKeyDoc)
83	if err != nil {
84		return primitive.Binary{}, err
85	}
86
87	subtype, data := bson.Raw(dataKeyDoc).Lookup("_id").Binary()
88	return primitive.Binary{Subtype: subtype, Data: data}, nil
89}
90
91// Encrypt encrypts a BSON value with the given key and algorithm. Returns an encrypted value (BSON binary of subtype 6).
92func (ce *ClientEncryption) Encrypt(ctx context.Context, val bson.RawValue, opts ...*options.EncryptOptions) (primitive.Binary, error) {
93	eo := options.MergeEncryptOptions(opts...)
94	transformed := cryptOpts.ExplicitEncryption()
95	if eo.KeyID != nil {
96		transformed.SetKeyID(*eo.KeyID)
97	}
98	if eo.KeyAltName != nil {
99		transformed.SetKeyAltName(*eo.KeyAltName)
100	}
101	transformed.SetAlgorithm(eo.Algorithm)
102
103	subtype, data, err := ce.crypt.EncryptExplicit(ctx, bsoncore.Value{Type: val.Type, Data: val.Value}, transformed)
104	if err != nil {
105		return primitive.Binary{}, err
106	}
107	return primitive.Binary{Subtype: subtype, Data: data}, nil
108}
109
110// Decrypt decrypts an encrypted value (BSON binary of subtype 6) and returns the original BSON value.
111func (ce *ClientEncryption) Decrypt(ctx context.Context, val primitive.Binary) (bson.RawValue, error) {
112	decrypted, err := ce.crypt.DecryptExplicit(ctx, val.Subtype, val.Data)
113	if err != nil {
114		return bson.RawValue{}, err
115	}
116
117	return bson.RawValue{Type: decrypted.Type, Value: decrypted.Data}, nil
118}
119
120// Close cleans up any resources associated with the ClientEncryption instance. This includes disconnecting the
121// key-vault Client instance.
122func (ce *ClientEncryption) Close(ctx context.Context) error {
123	ce.crypt.Close()
124	return ce.keyVaultClient.Disconnect(ctx)
125}
126
127// splitNamespace takes a namespace in the form "database.collection" and returns (database name, collection name)
128func splitNamespace(ns string) (string, string) {
129	firstDot := strings.Index(ns, ".")
130	if firstDot == -1 {
131		return "", ns
132	}
133
134	return ns[:firstDot], ns[firstDot+1:]
135}
136