1package sshkeys
2
3import (
4	"crypto/aes"
5	"crypto/cipher"
6	"crypto/dsa"
7	"crypto/ecdsa"
8	"crypto/rand"
9	"crypto/rsa"
10	"crypto/x509"
11	"encoding/asn1"
12	"encoding/pem"
13	"fmt"
14	"math/big"
15	mrand "math/rand"
16
17	"github.com/dchest/bcrypt_pbkdf"
18	"golang.org/x/crypto/ed25519"
19	"golang.org/x/crypto/ssh"
20)
21
22// Format of private key to use when Marshaling.
23type Format int
24
25const (
26	// FormatOpenSSHv1 encodes a private key using OpenSSH's PROTOCOL.key format: https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
27	FormatOpenSSHv1 Format = iota
28	// FormatClassicPEM encodes private keys in PEM, with a key-specific encoding, as used by OpenSSH.
29	FormatClassicPEM
30)
31
32// MarshalOptions provides the Marshal function format and encryption options.
33type MarshalOptions struct {
34	// Passphrase to encrypt private key with, if nil, the key will not be encrypted.
35	Passphrase []byte
36	// Format to encode the private key in.
37	Format Format
38}
39
40// Marshal converts a private key into an optionally encrypted format.
41func Marshal(pk interface{}, opts *MarshalOptions) ([]byte, error) {
42	switch opts.Format {
43	case FormatOpenSSHv1:
44		return marshalOpenssh(pk, opts)
45	case FormatClassicPEM:
46		return marshalPem(pk, opts)
47	default:
48		return nil, fmt.Errorf("sshkeys: invalid format %d", opts.Format)
49	}
50}
51
52func marshalPem(pk interface{}, opts *MarshalOptions) ([]byte, error) {
53	var err error
54	var plain []byte
55	var pemType string
56
57	switch key := pk.(type) {
58	case *rsa.PrivateKey:
59		pemType = "RSA PRIVATE KEY"
60		plain = x509.MarshalPKCS1PrivateKey(key)
61	case *ecdsa.PrivateKey:
62		pemType = "EC PRIVATE KEY"
63		plain, err = x509.MarshalECPrivateKey(key)
64		if err != nil {
65			return nil, err
66		}
67	case *dsa.PrivateKey:
68		pemType = "DSA PRIVATE KEY"
69		plain, err = marshalDSAPrivateKey(key)
70		if err != nil {
71			return nil, err
72		}
73	case *ed25519.PrivateKey:
74		return nil, fmt.Errorf("sshkeys: ed25519 keys must be marshaled with FormatOpenSSHv1")
75	default:
76		return nil, fmt.Errorf("sshkeys: unsupported key type %T", pk)
77	}
78
79	if len(opts.Passphrase) > 0 {
80		block, err := x509.EncryptPEMBlock(rand.Reader, pemType, plain, opts.Passphrase, x509.PEMCipherAES128)
81		if err != nil {
82			return nil, err
83		}
84		return pem.EncodeToMemory(block), nil
85	}
86
87	return pem.EncodeToMemory(&pem.Block{
88		Type:  pemType,
89		Bytes: plain,
90	}), nil
91}
92
93type dsaOpenssl struct {
94	Version int
95	P       *big.Int
96	Q       *big.Int
97	G       *big.Int
98	Pub     *big.Int
99	Priv    *big.Int
100}
101
102// https://github.com/golang/crypto/blob/master/ssh/keys.go#L793-L804
103func marshalDSAPrivateKey(pk *dsa.PrivateKey) ([]byte, error) {
104	k := dsaOpenssl{
105		Version: 0,
106		P:       pk.P,
107		Q:       pk.Q,
108		G:       pk.G,
109		Pub:     pk.Y,
110		Priv:    pk.X,
111	}
112
113	return asn1.Marshal(k)
114}
115
116const opensshv1Magic = "openssh-key-v1"
117
118type opensshHeader struct {
119	CipherName   string
120	KdfName      string
121	KdfOpts      string
122	NumKeys      uint32
123	PubKey       string
124	PrivKeyBlock string
125}
126
127type opensshKey struct {
128	Check1  uint32
129	Check2  uint32
130	Keytype string
131	Rest    []byte `ssh:"rest"`
132}
133
134type opensshRsa struct {
135	N       *big.Int
136	E       *big.Int
137	D       *big.Int
138	Iqmp    *big.Int
139	P       *big.Int
140	Q       *big.Int
141	Comment string
142	Pad     []byte `ssh:"rest"`
143}
144
145type opensshED25519 struct {
146	Pub     []byte
147	Priv    []byte
148	Comment string
149	Pad     []byte `ssh:"rest"`
150}
151
152func padBytes(data []byte, blocksize int) []byte {
153	if blocksize != 0 {
154		var i byte
155		for i = byte(1); len(data)%blocksize != 0; i++ {
156			data = append(data, i&0xFF)
157		}
158	}
159	return data
160}
161
162func marshalOpenssh(pk interface{}, opts *MarshalOptions) ([]byte, error) {
163	var blocksize int
164	var keylen int
165
166	out := opensshHeader{
167		CipherName: "none",
168		KdfName:    "none",
169		KdfOpts:    "",
170		NumKeys:    1,
171		PubKey:     "",
172	}
173
174	if len(opts.Passphrase) > 0 {
175		out.CipherName = "aes256-cbc"
176		out.KdfName = "bcrypt"
177		keylen = keySizeAES256
178		blocksize = aes.BlockSize
179	}
180
181	check := mrand.Uint32()
182	pk1 := opensshKey{
183		Check1: check,
184		Check2: check,
185	}
186
187	switch key := pk.(type) {
188	case *rsa.PrivateKey:
189		k := &opensshRsa{
190			N:       key.N,
191			E:       big.NewInt(int64(key.E)),
192			D:       key.D,
193			Iqmp:    key.Precomputed.Qinv,
194			P:       key.Primes[0],
195			Q:       key.Primes[1],
196			Comment: "",
197		}
198
199		data := ssh.Marshal(k)
200		pk1.Keytype = ssh.KeyAlgoRSA
201		pk1.Rest = data
202		publicKey, err := ssh.NewPublicKey(&key.PublicKey)
203		if err != nil {
204			return nil, err
205		}
206		out.PubKey = string(publicKey.Marshal())
207
208	case ed25519.PrivateKey:
209		k := opensshED25519{
210			Pub:  key.Public().(ed25519.PublicKey),
211			Priv: key,
212		}
213		data := ssh.Marshal(k)
214		pk1.Keytype = ssh.KeyAlgoED25519
215		pk1.Rest = data
216
217		publicKey, err := ssh.NewPublicKey(key.Public())
218		if err != nil {
219			return nil, err
220		}
221		out.PubKey = string(publicKey.Marshal())
222	default:
223		return nil, fmt.Errorf("sshkeys: unsupported key type %T", pk)
224	}
225
226	if len(opts.Passphrase) > 0 {
227		rounds := 16
228		ivlen := blocksize
229		salt := make([]byte, blocksize)
230		_, err := rand.Read(salt)
231		if err != nil {
232			return nil, err
233		}
234
235		kdfdata, err := bcrypt_pbkdf.Key(opts.Passphrase, salt, rounds, keylen+ivlen)
236		if err != nil {
237			return nil, err
238		}
239		iv := kdfdata[keylen : ivlen+keylen]
240		aeskey := kdfdata[0:keylen]
241
242		block, err := aes.NewCipher(aeskey)
243		if err != nil {
244			return nil, err
245		}
246
247		pkblock := padBytes(ssh.Marshal(pk1), blocksize)
248
249		cbc := cipher.NewCBCEncrypter(block, iv)
250		cbc.CryptBlocks(pkblock, pkblock)
251
252		out.PrivKeyBlock = string(pkblock)
253
254		var opts struct {
255			Salt   []byte
256			Rounds uint32
257		}
258
259		opts.Salt = salt
260		opts.Rounds = uint32(rounds)
261
262		out.KdfOpts = string(ssh.Marshal(&opts))
263	} else {
264		out.PrivKeyBlock = string(ssh.Marshal(pk1))
265	}
266
267	outBytes := []byte(opensshv1Magic)
268	outBytes = append(outBytes, 0)
269	outBytes = append(outBytes, ssh.Marshal(out)...)
270	block := &pem.Block{
271		Type:  "OPENSSH PRIVATE KEY",
272		Bytes: outBytes,
273	}
274	return pem.EncodeToMemory(block), nil
275}
276