1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Tests for L{twisted.conch.scripts.ckeygen}.
6"""
7
8import getpass
9import subprocess
10import sys
11from io import StringIO
12
13from twisted.conch.test.keydata import (
14    privateECDSA_openssh,
15    privateEd25519_openssh_new,
16    privateRSA_openssh,
17    privateRSA_openssh_encrypted,
18    publicRSA_openssh,
19)
20from twisted.python.filepath import FilePath
21from twisted.python.reflect import requireModule
22from twisted.trial.unittest import TestCase
23
24if requireModule("cryptography") and requireModule("pyasn1"):
25    from twisted.conch.scripts.ckeygen import (
26        _saveKey,
27        changePassPhrase,
28        displayPublicKey,
29        enumrepresentation,
30        printFingerprint,
31    )
32    from twisted.conch.ssh.keys import (
33        BadFingerPrintFormat,
34        BadKeyError,
35        FingerprintFormats,
36        Key,
37    )
38else:
39    skip = "cryptography and pyasn1 required for twisted.conch.scripts.ckeygen"
40
41
42def makeGetpass(*passphrases):
43    """
44    Return a callable to patch C{getpass.getpass}.  Yields a passphrase each
45    time called. Use case is to provide an old, then new passphrase(s) as if
46    requested interactively.
47
48    @param passphrases: The list of passphrases returned, one per each call.
49
50    @return: A callable to patch C{getpass.getpass}.
51    """
52    passphrases = iter(passphrases)
53
54    def fakeGetpass(_):
55        return next(passphrases)
56
57    return fakeGetpass
58
59
60class KeyGenTests(TestCase):
61    """
62    Tests for various functions used to implement the I{ckeygen} script.
63    """
64
65    def setUp(self):
66        """
67        Patch C{sys.stdout} so tests can make assertions about what's printed.
68        """
69        self.stdout = StringIO()
70        self.patch(sys, "stdout", self.stdout)
71
72    def _testrun(self, keyType, keySize=None, privateKeySubtype=None):
73        filename = self.mktemp()
74        args = ["ckeygen", "-t", keyType, "-f", filename, "--no-passphrase"]
75        if keySize is not None:
76            args.extend(["-b", keySize])
77        if privateKeySubtype is not None:
78            args.extend(["--private-key-subtype", privateKeySubtype])
79        subprocess.call(args)
80        privKey = Key.fromFile(filename)
81        pubKey = Key.fromFile(filename + ".pub")
82        if keyType == "ecdsa":
83            self.assertEqual(privKey.type(), "EC")
84        elif keyType == "ed25519":
85            self.assertEqual(privKey.type(), "Ed25519")
86        else:
87            self.assertEqual(privKey.type(), keyType.upper())
88        self.assertTrue(pubKey.isPublic())
89
90    def test_keygeneration(self):
91        self._testrun("ecdsa", "384")
92        self._testrun("ecdsa", "384", privateKeySubtype="v1")
93        self._testrun("ecdsa")
94        self._testrun("ecdsa", privateKeySubtype="v1")
95        self._testrun("ed25519")
96        self._testrun("dsa", "2048")
97        self._testrun("dsa", "2048", privateKeySubtype="v1")
98        self._testrun("dsa")
99        self._testrun("dsa", privateKeySubtype="v1")
100        self._testrun("rsa", "2048")
101        self._testrun("rsa", "2048", privateKeySubtype="v1")
102        self._testrun("rsa")
103        self._testrun("rsa", privateKeySubtype="v1")
104
105    def test_runBadKeytype(self):
106        filename = self.mktemp()
107        with self.assertRaises(subprocess.CalledProcessError):
108            subprocess.check_call(["ckeygen", "-t", "foo", "-f", filename])
109
110    def test_enumrepresentation(self):
111        """
112        L{enumrepresentation} takes a dictionary as input and returns a
113        dictionary with its attributes changed to enum representation.
114        """
115        options = enumrepresentation({"format": "md5-hex"})
116        self.assertIs(options["format"], FingerprintFormats.MD5_HEX)
117
118    def test_enumrepresentationsha256(self):
119        """
120        Test for format L{FingerprintFormats.SHA256-BASE64}.
121        """
122        options = enumrepresentation({"format": "sha256-base64"})
123        self.assertIs(options["format"], FingerprintFormats.SHA256_BASE64)
124
125    def test_enumrepresentationBadFormat(self):
126        """
127        Test for unsupported fingerprint format
128        """
129        with self.assertRaises(BadFingerPrintFormat) as em:
130            enumrepresentation({"format": "sha-base64"})
131        self.assertEqual(
132            "Unsupported fingerprint format: sha-base64", em.exception.args[0]
133        )
134
135    def test_printFingerprint(self):
136        """
137        L{printFingerprint} writes a line to standard out giving the number of
138        bits of the key, its fingerprint, and the basename of the file from it
139        was read.
140        """
141        filename = self.mktemp()
142        FilePath(filename).setContent(publicRSA_openssh)
143        printFingerprint({"filename": filename, "format": "md5-hex"})
144        self.assertEqual(
145            self.stdout.getvalue(),
146            "2048 85:25:04:32:58:55:96:9f:57:ee:fb:a8:1a:ea:69:da temp\n",
147        )
148
149    def test_printFingerprintsha256(self):
150        """
151        L{printFigerprint} will print key fingerprint in
152        L{FingerprintFormats.SHA256-BASE64} format if explicitly specified.
153        """
154        filename = self.mktemp()
155        FilePath(filename).setContent(publicRSA_openssh)
156        printFingerprint({"filename": filename, "format": "sha256-base64"})
157        self.assertEqual(
158            self.stdout.getvalue(),
159            "2048 FBTCOoknq0mHy+kpfnY9tDdcAJuWtCpuQMaV3EsvbUI= temp\n",
160        )
161
162    def test_printFingerprintBadFingerPrintFormat(self):
163        """
164        L{printFigerprint} raises C{keys.BadFingerprintFormat} when unsupported
165        formats are requested.
166        """
167        filename = self.mktemp()
168        FilePath(filename).setContent(publicRSA_openssh)
169        with self.assertRaises(BadFingerPrintFormat) as em:
170            printFingerprint({"filename": filename, "format": "sha-base64"})
171        self.assertEqual(
172            "Unsupported fingerprint format: sha-base64", em.exception.args[0]
173        )
174
175    def test_saveKey(self):
176        """
177        L{_saveKey} writes the private and public parts of a key to two
178        different files and writes a report of this to standard out.
179        """
180        base = FilePath(self.mktemp())
181        base.makedirs()
182        filename = base.child("id_rsa").path
183        key = Key.fromString(privateRSA_openssh)
184        _saveKey(key, {"filename": filename, "pass": "passphrase", "format": "md5-hex"})
185        self.assertEqual(
186            self.stdout.getvalue(),
187            "Your identification has been saved in %s\n"
188            "Your public key has been saved in %s.pub\n"
189            "The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
190            "85:25:04:32:58:55:96:9f:57:ee:fb:a8:1a:ea:69:da\n" % (filename, filename),
191        )
192        self.assertEqual(
193            key.fromString(base.child("id_rsa").getContent(), None, "passphrase"), key
194        )
195        self.assertEqual(
196            Key.fromString(base.child("id_rsa.pub").getContent()), key.public()
197        )
198
199    def test_saveKeyECDSA(self):
200        """
201        L{_saveKey} writes the private and public parts of a key to two
202        different files and writes a report of this to standard out.
203        Test with ECDSA key.
204        """
205        base = FilePath(self.mktemp())
206        base.makedirs()
207        filename = base.child("id_ecdsa").path
208        key = Key.fromString(privateECDSA_openssh)
209        _saveKey(key, {"filename": filename, "pass": "passphrase", "format": "md5-hex"})
210        self.assertEqual(
211            self.stdout.getvalue(),
212            "Your identification has been saved in %s\n"
213            "Your public key has been saved in %s.pub\n"
214            "The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
215            "1e:ab:83:a6:f2:04:22:99:7c:64:14:d2:ab:fa:f5:16\n" % (filename, filename),
216        )
217        self.assertEqual(
218            key.fromString(base.child("id_ecdsa").getContent(), None, "passphrase"), key
219        )
220        self.assertEqual(
221            Key.fromString(base.child("id_ecdsa.pub").getContent()), key.public()
222        )
223
224    def test_saveKeyEd25519(self):
225        """
226        L{_saveKey} writes the private and public parts of a key to two
227        different files and writes a report of this to standard out.
228        Test with Ed25519 key.
229        """
230        base = FilePath(self.mktemp())
231        base.makedirs()
232        filename = base.child("id_ed25519").path
233        key = Key.fromString(privateEd25519_openssh_new)
234        _saveKey(key, {"filename": filename, "pass": "passphrase", "format": "md5-hex"})
235        self.assertEqual(
236            self.stdout.getvalue(),
237            "Your identification has been saved in %s\n"
238            "Your public key has been saved in %s.pub\n"
239            "The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
240            "ab:ee:c8:ed:e5:01:1b:45:b7:8d:b2:f0:8f:61:1c:14\n" % (filename, filename),
241        )
242        self.assertEqual(
243            key.fromString(base.child("id_ed25519").getContent(), None, "passphrase"),
244            key,
245        )
246        self.assertEqual(
247            Key.fromString(base.child("id_ed25519.pub").getContent()), key.public()
248        )
249
250    def test_saveKeysha256(self):
251        """
252        L{_saveKey} will generate key fingerprint in
253        L{FingerprintFormats.SHA256-BASE64} format if explicitly specified.
254        """
255        base = FilePath(self.mktemp())
256        base.makedirs()
257        filename = base.child("id_rsa").path
258        key = Key.fromString(privateRSA_openssh)
259        _saveKey(
260            key, {"filename": filename, "pass": "passphrase", "format": "sha256-base64"}
261        )
262        self.assertEqual(
263            self.stdout.getvalue(),
264            "Your identification has been saved in %s\n"
265            "Your public key has been saved in %s.pub\n"
266            "The key fingerprint in <FingerprintFormats=SHA256_BASE64> is:\n"
267            "FBTCOoknq0mHy+kpfnY9tDdcAJuWtCpuQMaV3EsvbUI=\n" % (filename, filename),
268        )
269        self.assertEqual(
270            key.fromString(base.child("id_rsa").getContent(), None, "passphrase"), key
271        )
272        self.assertEqual(
273            Key.fromString(base.child("id_rsa.pub").getContent()), key.public()
274        )
275
276    def test_saveKeyBadFingerPrintformat(self):
277        """
278        L{_saveKey} raises C{keys.BadFingerprintFormat} when unsupported
279        formats are requested.
280        """
281        base = FilePath(self.mktemp())
282        base.makedirs()
283        filename = base.child("id_rsa").path
284        key = Key.fromString(privateRSA_openssh)
285        with self.assertRaises(BadFingerPrintFormat) as em:
286            _saveKey(
287                key,
288                {"filename": filename, "pass": "passphrase", "format": "sha-base64"},
289            )
290        self.assertEqual(
291            "Unsupported fingerprint format: sha-base64", em.exception.args[0]
292        )
293
294    def test_saveKeyEmptyPassphrase(self):
295        """
296        L{_saveKey} will choose an empty string for the passphrase if
297        no-passphrase is C{True}.
298        """
299        base = FilePath(self.mktemp())
300        base.makedirs()
301        filename = base.child("id_rsa").path
302        key = Key.fromString(privateRSA_openssh)
303        _saveKey(
304            key, {"filename": filename, "no-passphrase": True, "format": "md5-hex"}
305        )
306        self.assertEqual(
307            key.fromString(base.child("id_rsa").getContent(), None, b""), key
308        )
309
310    def test_saveKeyECDSAEmptyPassphrase(self):
311        """
312        L{_saveKey} will choose an empty string for the passphrase if
313        no-passphrase is C{True}.
314        """
315        base = FilePath(self.mktemp())
316        base.makedirs()
317        filename = base.child("id_ecdsa").path
318        key = Key.fromString(privateECDSA_openssh)
319        _saveKey(
320            key, {"filename": filename, "no-passphrase": True, "format": "md5-hex"}
321        )
322        self.assertEqual(key.fromString(base.child("id_ecdsa").getContent(), None), key)
323
324    def test_saveKeyEd25519EmptyPassphrase(self):
325        """
326        L{_saveKey} will choose an empty string for the passphrase if
327        no-passphrase is C{True}.
328        """
329        base = FilePath(self.mktemp())
330        base.makedirs()
331        filename = base.child("id_ed25519").path
332        key = Key.fromString(privateEd25519_openssh_new)
333        _saveKey(
334            key, {"filename": filename, "no-passphrase": True, "format": "md5-hex"}
335        )
336        self.assertEqual(
337            key.fromString(base.child("id_ed25519").getContent(), None), key
338        )
339
340    def test_saveKeyNoFilename(self):
341        """
342        When no path is specified, it will ask for the path used to store the
343        key.
344        """
345        base = FilePath(self.mktemp())
346        base.makedirs()
347        keyPath = base.child("custom_key").path
348
349        import twisted.conch.scripts.ckeygen
350
351        self.patch(twisted.conch.scripts.ckeygen, "_inputSaveFile", lambda _: keyPath)
352        key = Key.fromString(privateRSA_openssh)
353        _saveKey(key, {"filename": None, "no-passphrase": True, "format": "md5-hex"})
354
355        persistedKeyContent = base.child("custom_key").getContent()
356        persistedKey = key.fromString(persistedKeyContent, None, b"")
357        self.assertEqual(key, persistedKey)
358
359    def test_saveKeySubtypeV1(self):
360        """
361        L{_saveKey} can be told to write the new private key file in OpenSSH
362        v1 format.
363        """
364        base = FilePath(self.mktemp())
365        base.makedirs()
366        filename = base.child("id_rsa").path
367        key = Key.fromString(privateRSA_openssh)
368        _saveKey(
369            key,
370            {
371                "filename": filename,
372                "pass": "passphrase",
373                "format": "md5-hex",
374                "private-key-subtype": "v1",
375            },
376        )
377        self.assertEqual(
378            self.stdout.getvalue(),
379            "Your identification has been saved in %s\n"
380            "Your public key has been saved in %s.pub\n"
381            "The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
382            "85:25:04:32:58:55:96:9f:57:ee:fb:a8:1a:ea:69:da\n" % (filename, filename),
383        )
384        privateKeyContent = base.child("id_rsa").getContent()
385        self.assertEqual(key.fromString(privateKeyContent, None, "passphrase"), key)
386        self.assertTrue(
387            privateKeyContent.startswith(b"-----BEGIN OPENSSH PRIVATE KEY-----\n")
388        )
389        self.assertEqual(
390            Key.fromString(base.child("id_rsa.pub").getContent()), key.public()
391        )
392
393    def test_displayPublicKey(self):
394        """
395        L{displayPublicKey} prints out the public key associated with a given
396        private key.
397        """
398        filename = self.mktemp()
399        pubKey = Key.fromString(publicRSA_openssh)
400        FilePath(filename).setContent(privateRSA_openssh)
401        displayPublicKey({"filename": filename})
402        displayed = self.stdout.getvalue().strip("\n")
403        if isinstance(displayed, str):
404            displayed = displayed.encode("ascii")
405        self.assertEqual(displayed, pubKey.toString("openssh"))
406
407    def test_displayPublicKeyEncrypted(self):
408        """
409        L{displayPublicKey} prints out the public key associated with a given
410        private key using the given passphrase when it's encrypted.
411        """
412        filename = self.mktemp()
413        pubKey = Key.fromString(publicRSA_openssh)
414        FilePath(filename).setContent(privateRSA_openssh_encrypted)
415        displayPublicKey({"filename": filename, "pass": "encrypted"})
416        displayed = self.stdout.getvalue().strip("\n")
417        if isinstance(displayed, str):
418            displayed = displayed.encode("ascii")
419        self.assertEqual(displayed, pubKey.toString("openssh"))
420
421    def test_displayPublicKeyEncryptedPassphrasePrompt(self):
422        """
423        L{displayPublicKey} prints out the public key associated with a given
424        private key, asking for the passphrase when it's encrypted.
425        """
426        filename = self.mktemp()
427        pubKey = Key.fromString(publicRSA_openssh)
428        FilePath(filename).setContent(privateRSA_openssh_encrypted)
429        self.patch(getpass, "getpass", lambda x: "encrypted")
430        displayPublicKey({"filename": filename})
431        displayed = self.stdout.getvalue().strip("\n")
432        if isinstance(displayed, str):
433            displayed = displayed.encode("ascii")
434        self.assertEqual(displayed, pubKey.toString("openssh"))
435
436    def test_displayPublicKeyWrongPassphrase(self):
437        """
438        L{displayPublicKey} fails with a L{BadKeyError} when trying to decrypt
439        an encrypted key with the wrong password.
440        """
441        filename = self.mktemp()
442        FilePath(filename).setContent(privateRSA_openssh_encrypted)
443        self.assertRaises(
444            BadKeyError, displayPublicKey, {"filename": filename, "pass": "wrong"}
445        )
446
447    def test_changePassphrase(self):
448        """
449        L{changePassPhrase} allows a user to change the passphrase of a
450        private key interactively.
451        """
452        oldNewConfirm = makeGetpass("encrypted", "newpass", "newpass")
453        self.patch(getpass, "getpass", oldNewConfirm)
454
455        filename = self.mktemp()
456        FilePath(filename).setContent(privateRSA_openssh_encrypted)
457
458        changePassPhrase({"filename": filename})
459        self.assertEqual(
460            self.stdout.getvalue().strip("\n"),
461            "Your identification has been saved with the new passphrase.",
462        )
463        self.assertNotEqual(
464            privateRSA_openssh_encrypted, FilePath(filename).getContent()
465        )
466
467    def test_changePassphraseWithOld(self):
468        """
469        L{changePassPhrase} allows a user to change the passphrase of a
470        private key, providing the old passphrase and prompting for new one.
471        """
472        newConfirm = makeGetpass("newpass", "newpass")
473        self.patch(getpass, "getpass", newConfirm)
474
475        filename = self.mktemp()
476        FilePath(filename).setContent(privateRSA_openssh_encrypted)
477
478        changePassPhrase({"filename": filename, "pass": "encrypted"})
479        self.assertEqual(
480            self.stdout.getvalue().strip("\n"),
481            "Your identification has been saved with the new passphrase.",
482        )
483        self.assertNotEqual(
484            privateRSA_openssh_encrypted, FilePath(filename).getContent()
485        )
486
487    def test_changePassphraseWithBoth(self):
488        """
489        L{changePassPhrase} allows a user to change the passphrase of a private
490        key by providing both old and new passphrases without prompting.
491        """
492        filename = self.mktemp()
493        FilePath(filename).setContent(privateRSA_openssh_encrypted)
494
495        changePassPhrase(
496            {"filename": filename, "pass": "encrypted", "newpass": "newencrypt"}
497        )
498        self.assertEqual(
499            self.stdout.getvalue().strip("\n"),
500            "Your identification has been saved with the new passphrase.",
501        )
502        self.assertNotEqual(
503            privateRSA_openssh_encrypted, FilePath(filename).getContent()
504        )
505
506    def test_changePassphraseWrongPassphrase(self):
507        """
508        L{changePassPhrase} exits if passed an invalid old passphrase when
509        trying to change the passphrase of a private key.
510        """
511        filename = self.mktemp()
512        FilePath(filename).setContent(privateRSA_openssh_encrypted)
513        error = self.assertRaises(
514            SystemExit, changePassPhrase, {"filename": filename, "pass": "wrong"}
515        )
516        self.assertEqual(
517            "Could not change passphrase: old passphrase error", str(error)
518        )
519        self.assertEqual(privateRSA_openssh_encrypted, FilePath(filename).getContent())
520
521    def test_changePassphraseEmptyGetPass(self):
522        """
523        L{changePassPhrase} exits if no passphrase is specified for the
524        C{getpass} call and the key is encrypted.
525        """
526        self.patch(getpass, "getpass", makeGetpass(""))
527        filename = self.mktemp()
528        FilePath(filename).setContent(privateRSA_openssh_encrypted)
529        error = self.assertRaises(SystemExit, changePassPhrase, {"filename": filename})
530        self.assertEqual(
531            "Could not change passphrase: Passphrase must be provided "
532            "for an encrypted key",
533            str(error),
534        )
535        self.assertEqual(privateRSA_openssh_encrypted, FilePath(filename).getContent())
536
537    def test_changePassphraseBadKey(self):
538        """
539        L{changePassPhrase} exits if the file specified points to an invalid
540        key.
541        """
542        filename = self.mktemp()
543        FilePath(filename).setContent(b"foobar")
544        error = self.assertRaises(SystemExit, changePassPhrase, {"filename": filename})
545
546        expected = "Could not change passphrase: cannot " "guess the type of b'foobar'"
547        self.assertEqual(expected, str(error))
548        self.assertEqual(b"foobar", FilePath(filename).getContent())
549
550    def test_changePassphraseCreateError(self):
551        """
552        L{changePassPhrase} doesn't modify the key file if an unexpected error
553        happens when trying to create the key with the new passphrase.
554        """
555        filename = self.mktemp()
556        FilePath(filename).setContent(privateRSA_openssh)
557
558        def toString(*args, **kwargs):
559            raise RuntimeError("oops")
560
561        self.patch(Key, "toString", toString)
562
563        error = self.assertRaises(
564            SystemExit,
565            changePassPhrase,
566            {"filename": filename, "newpass": "newencrypt"},
567        )
568
569        self.assertEqual("Could not change passphrase: oops", str(error))
570
571        self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
572
573    def test_changePassphraseEmptyStringError(self):
574        """
575        L{changePassPhrase} doesn't modify the key file if C{toString} returns
576        an empty string.
577        """
578        filename = self.mktemp()
579        FilePath(filename).setContent(privateRSA_openssh)
580
581        def toString(*args, **kwargs):
582            return ""
583
584        self.patch(Key, "toString", toString)
585
586        error = self.assertRaises(
587            SystemExit,
588            changePassPhrase,
589            {"filename": filename, "newpass": "newencrypt"},
590        )
591
592        expected = "Could not change passphrase: cannot guess the type of b''"
593        self.assertEqual(expected, str(error))
594
595        self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
596
597    def test_changePassphrasePublicKey(self):
598        """
599        L{changePassPhrase} exits when trying to change the passphrase on a
600        public key, and doesn't change the file.
601        """
602        filename = self.mktemp()
603        FilePath(filename).setContent(publicRSA_openssh)
604        error = self.assertRaises(
605            SystemExit, changePassPhrase, {"filename": filename, "newpass": "pass"}
606        )
607        self.assertEqual("Could not change passphrase: key not encrypted", str(error))
608        self.assertEqual(publicRSA_openssh, FilePath(filename).getContent())
609
610    def test_changePassphraseSubtypeV1(self):
611        """
612        L{changePassPhrase} can be told to write the new private key file in
613        OpenSSH v1 format.
614        """
615        oldNewConfirm = makeGetpass("encrypted", "newpass", "newpass")
616        self.patch(getpass, "getpass", oldNewConfirm)
617
618        filename = self.mktemp()
619        FilePath(filename).setContent(privateRSA_openssh_encrypted)
620
621        changePassPhrase({"filename": filename, "private-key-subtype": "v1"})
622        self.assertEqual(
623            self.stdout.getvalue().strip("\n"),
624            "Your identification has been saved with the new passphrase.",
625        )
626        privateKeyContent = FilePath(filename).getContent()
627        self.assertNotEqual(privateRSA_openssh_encrypted, privateKeyContent)
628        self.assertTrue(
629            privateKeyContent.startswith(b"-----BEGIN OPENSSH PRIVATE KEY-----\n")
630        )
631