1# Copyright (c) 2015 The Johns Hopkins University/Applied Physics Laboratory
2# All Rights Reserved.
3#
4#    Licensed under the Apache License, Version 2.0 (the "License"); you may
5#    not use this file except in compliance with the License. You may obtain
6#    a copy of the License at
7#
8#         http://www.apache.org/licenses/LICENSE-2.0
9#
10#    Unless required by applicable law or agreed to in writing, software
11#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13#    License for the specific language governing permissions and limitations
14#    under the License.
15
16"""
17Test cases for the mock key manager.
18"""
19
20from cryptography.hazmat import backends
21from cryptography.hazmat.primitives.asymmetric import padding
22from cryptography.hazmat.primitives import hashes
23from cryptography.hazmat.primitives import serialization
24from oslo_context import context
25
26from castellan.common import exception
27from castellan.common.objects import symmetric_key as sym_key
28from castellan.tests.unit.key_manager import mock_key_manager as mock_key_mgr
29from castellan.tests.unit.key_manager import test_key_manager as test_key_mgr
30
31
32def get_cryptography_private_key(private_key):
33    crypto_private_key = serialization.load_der_private_key(
34        bytes(private_key.get_encoded()),
35        password=None,
36        backend=backends.default_backend())
37    return crypto_private_key
38
39
40def get_cryptography_public_key(public_key):
41    crypto_public_key = serialization.load_der_public_key(
42        bytes(public_key.get_encoded()),
43        backend=backends.default_backend())
44    return crypto_public_key
45
46
47class MockKeyManagerTestCase(test_key_mgr.KeyManagerTestCase):
48
49    def _create_key_manager(self):
50        return mock_key_mgr.MockKeyManager()
51
52    def setUp(self):
53        super(MockKeyManagerTestCase, self).setUp()
54
55        self.context = context.RequestContext('fake', 'fake')
56
57    def cleanUp(self):
58        super(MockKeyManagerTestCase, self).cleanUp()
59
60        self.key_mgr.keys = {}
61
62    def test_create_key(self):
63        key_id_1 = self.key_mgr.create_key(self.context)
64        key_id_2 = self.key_mgr.create_key(self.context)
65        # ensure that the UUIDs are unique
66        self.assertNotEqual(key_id_1, key_id_2)
67
68    def test_create_key_with_length(self):
69        for length in [64, 128, 256]:
70            key_id = self.key_mgr.create_key(self.context, length=length)
71            key = self.key_mgr.get(self.context, key_id)
72            self.assertEqual(length / 8, len(key.get_encoded()))
73            self.assertIsNotNone(key.id)
74
75    def test_create_key_with_name(self):
76        name = 'my key'
77        key_id = self.key_mgr.create_key(self.context, name=name)
78        key = self.key_mgr.get(self.context, key_id)
79        self.assertEqual(name, key.name)
80        self.assertIsNotNone(key.id)
81
82    def test_create_key_with_algorithm(self):
83        algorithm = 'DES'
84        key_id = self.key_mgr.create_key(self.context, algorithm=algorithm)
85        key = self.key_mgr.get(self.context, key_id)
86        self.assertEqual(algorithm, key.algorithm)
87        self.assertIsNotNone(key.id)
88
89    def test_create_key_null_context(self):
90        self.assertRaises(exception.Forbidden,
91                          self.key_mgr.create_key, None)
92
93    def test_create_key_pair(self):
94        for length in [2048, 3072, 4096]:
95            name = str(length) + ' key'
96            private_key_uuid, public_key_uuid = self.key_mgr.create_key_pair(
97                self.context, 'RSA', length, name=name)
98
99            private_key = self.key_mgr.get(self.context, private_key_uuid)
100            self.assertIsNotNone(private_key.id)
101            public_key = self.key_mgr.get(self.context, public_key_uuid)
102            self.assertIsNotNone(public_key.id)
103
104            crypto_private_key = get_cryptography_private_key(private_key)
105            crypto_public_key = get_cryptography_public_key(public_key)
106
107            self.assertEqual(name, private_key.name)
108            self.assertEqual(name, public_key.name)
109
110            self.assertEqual(length, crypto_private_key.key_size)
111            self.assertEqual(length, crypto_public_key.key_size)
112
113    def test_create_key_pair_encryption(self):
114        private_key_uuid, public_key_uuid = self.key_mgr.create_key_pair(
115            self.context, 'RSA', 2048)
116
117        private_key = self.key_mgr.get(self.context, private_key_uuid)
118        public_key = self.key_mgr.get(self.context, public_key_uuid)
119
120        crypto_private_key = get_cryptography_private_key(private_key)
121        crypto_public_key = get_cryptography_public_key(public_key)
122
123        message = b'secret plaintext'
124        ciphertext = crypto_public_key.encrypt(
125            message,
126            padding.OAEP(
127                mgf=padding.MGF1(algorithm=hashes.SHA1()),
128                algorithm=hashes.SHA1(),
129                label=None))
130        plaintext = crypto_private_key.decrypt(
131            ciphertext,
132            padding.OAEP(
133                mgf=padding.MGF1(algorithm=hashes.SHA1()),
134                algorithm=hashes.SHA1(),
135                label=None))
136
137        self.assertEqual(message, plaintext)
138
139    def test_create_key_pair_null_context(self):
140        self.assertRaises(exception.Forbidden,
141                          self.key_mgr.create_key_pair, None, 'RSA', 2048)
142
143    def test_create_key_pair_invalid_algorithm(self):
144        self.assertRaises(ValueError,
145                          self.key_mgr.create_key_pair,
146                          self.context, 'DSA', 2048)
147
148    def test_create_key_pair_invalid_length(self):
149        self.assertRaises(ValueError,
150                          self.key_mgr.create_key_pair,
151                          self.context, 'RSA', 10)
152
153    def test_store_and_get_key(self):
154        secret_key = bytes(b'0' * 64)
155        _key = sym_key.SymmetricKey('AES', 64 * 8, secret_key)
156        key_id = self.key_mgr.store(self.context, _key)
157
158        actual_key = self.key_mgr.get(self.context, key_id)
159        self.assertEqual(_key, actual_key)
160
161        self.assertIsNotNone(actual_key.id)
162
163    def test_store_key_and_get_metadata(self):
164        secret_key = bytes(b'0' * 64)
165        _key = sym_key.SymmetricKey('AES', 64 * 8, secret_key)
166        key_id = self.key_mgr.store(self.context, _key)
167
168        actual_key = self.key_mgr.get(self.context,
169                                      key_id,
170                                      metadata_only=True)
171        self.assertIsNone(actual_key.get_encoded())
172        self.assertTrue(actual_key.is_metadata_only())
173
174        self.assertIsNotNone(actual_key.id)
175
176    def test_store_key_and_get_metadata_and_get_key(self):
177        secret_key = bytes(b'0' * 64)
178        _key = sym_key.SymmetricKey('AES', 64 * 8, secret_key)
179        key_id = self.key_mgr.store(self.context, _key)
180
181        actual_key = self.key_mgr.get(self.context,
182                                      key_id,
183                                      metadata_only=True)
184        self.assertIsNone(actual_key.get_encoded())
185        self.assertTrue(actual_key.is_metadata_only())
186
187        actual_key = self.key_mgr.get(self.context,
188                                      key_id,
189                                      metadata_only=False)
190        self.assertIsNotNone(actual_key.get_encoded())
191        self.assertFalse(actual_key.is_metadata_only())
192
193        self.assertIsNotNone(actual_key.id)
194
195    def test_store_null_context(self):
196        self.assertRaises(exception.Forbidden,
197                          self.key_mgr.store, None, None)
198
199    def test_get_null_context(self):
200        self.assertRaises(exception.Forbidden,
201                          self.key_mgr.get, None, None)
202
203    def test_get_unknown_key(self):
204        self.assertRaises(KeyError, self.key_mgr.get, self.context, None)
205
206    def test_delete_key(self):
207        key_id = self.key_mgr.create_key(self.context)
208        self.key_mgr.delete(self.context, key_id)
209
210        self.assertRaises(KeyError, self.key_mgr.get, self.context,
211                          key_id)
212
213    def test_delete_null_context(self):
214        self.assertRaises(exception.Forbidden,
215                          self.key_mgr.delete, None, None)
216
217    def test_delete_unknown_key(self):
218        self.assertRaises(KeyError, self.key_mgr.delete, self.context,
219                          None)
220
221    def test_list_null_context(self):
222        self.assertRaises(exception.Forbidden, self.key_mgr.list, None)
223
224    def test_list_keys(self):
225        key1 = sym_key.SymmetricKey('AES', 64 * 8, bytes(b'0' * 64))
226        self.key_mgr.store(self.context, key1)
227        key2 = sym_key.SymmetricKey('AES', 32 * 8, bytes(b'0' * 32))
228        self.key_mgr.store(self.context, key2)
229
230        keys = self.key_mgr.list(self.context)
231        self.assertEqual(2, len(keys))
232        self.assertTrue(key1 in keys)
233        self.assertTrue(key2 in keys)
234
235        for key in keys:
236            self.assertIsNotNone(key.id)
237
238    def test_list_keys_metadata_only(self):
239        key1 = sym_key.SymmetricKey('AES', 64 * 8, bytes(b'0' * 64))
240        self.key_mgr.store(self.context, key1)
241        key2 = sym_key.SymmetricKey('AES', 32 * 8, bytes(b'0' * 32))
242        self.key_mgr.store(self.context, key2)
243
244        keys = self.key_mgr.list(self.context, metadata_only=True)
245        self.assertEqual(2, len(keys))
246        bit_length_list = [key1.bit_length, key2.bit_length]
247        for key in keys:
248            self.assertTrue(key.is_metadata_only())
249            self.assertTrue(key.bit_length in bit_length_list)
250
251        for key in keys:
252            self.assertIsNotNone(key.id)
253