1/*-
2 * Copyright 2015 Square Inc.
3 * Copyright 2014 CoreOS
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package pkix
19
20import (
21	"bytes"
22	"crypto"
23	"crypto/ecdsa"
24	"crypto/rand"
25	"crypto/rsa"
26	"crypto/x509"
27	"crypto/x509/pkix"
28	"encoding/asn1"
29	"encoding/pem"
30	"errors"
31	"fmt"
32	"math/big"
33	"net"
34	"strings"
35)
36
37const (
38	csrPEMBlockType = "CERTIFICATE REQUEST"
39)
40
41var (
42	csrPkixName = pkix.Name{
43		Country:            []string{},
44		Organization:       []string{},
45		OrganizationalUnit: nil,
46		Locality:           nil,
47		Province:           nil,
48		StreetAddress:      nil,
49		PostalCode:         nil,
50		SerialNumber:       "",
51		CommonName:         "",
52	}
53)
54
55// ParseAndValidateIPs parses a comma-delimited list of IP addresses into an array of IP addresses
56func ParseAndValidateIPs(ipList string) (res []net.IP, err error) {
57	// IP list can potentially be a blank string, ""
58	if len(ipList) > 0 {
59		ips := strings.Split(ipList, ",")
60		for _, ip := range ips {
61			parsedIP := net.ParseIP(ip)
62			if parsedIP == nil {
63				return nil, fmt.Errorf("Invalid IP address: %s", ip)
64			}
65			res = append(res, parsedIP)
66		}
67	}
68	return
69}
70
71// CreateCertificateSigningRequest sets up a request to create a csr file with the given parameters
72func CreateCertificateSigningRequest(key *Key, organizationalUnit string, ipList []net.IP, domainList []string, organization string, country string, province string, locality string, commonName string) (*CertificateSigningRequest, error) {
73
74	csrPkixName.CommonName = commonName
75
76	if len(organizationalUnit) > 0 {
77		csrPkixName.OrganizationalUnit = []string{organizationalUnit}
78	}
79	if len(organization) > 0 {
80		csrPkixName.Organization = []string{organization}
81	}
82	if len(country) > 0 {
83		csrPkixName.Country = []string{country}
84	}
85	if len(province) > 0 {
86		csrPkixName.Province = []string{province}
87	}
88	if len(locality) > 0 {
89		csrPkixName.Locality = []string{locality}
90	}
91	csrTemplate := &x509.CertificateRequest{
92		Subject:     csrPkixName,
93		IPAddresses: ipList,
94		DNSNames:    domainList,
95	}
96
97	csrBytes, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key.Private)
98	if err != nil {
99		return nil, err
100	}
101	return NewCertificateSigningRequestFromDER(csrBytes), nil
102}
103
104// CertificateSigningRequest is a wrapper around a x509 CertificateRequest and its DER-formatted bytes
105type CertificateSigningRequest struct {
106	// derBytes is always set for valid Certificate
107	derBytes []byte
108
109	cr *x509.CertificateRequest
110}
111
112// NewCertificateSigningRequestFromDER inits CertificateSigningRequest from DER-format bytes
113func NewCertificateSigningRequestFromDER(derBytes []byte) *CertificateSigningRequest {
114	return &CertificateSigningRequest{derBytes: derBytes}
115}
116
117// NewCertificateSigningRequestFromPEM inits CertificateSigningRequest from PEM-format bytes
118// data should contain at most one certificate
119func NewCertificateSigningRequestFromPEM(data []byte) (*CertificateSigningRequest, error) {
120	pemBlock, _ := pem.Decode(data)
121	if pemBlock == nil {
122		return nil, errors.New("cannot find the next PEM formatted block")
123	}
124	if pemBlock.Type != csrPEMBlockType || len(pemBlock.Headers) != 0 {
125		return nil, errors.New("unmatched type or headers")
126	}
127	return &CertificateSigningRequest{derBytes: pemBlock.Bytes}, nil
128}
129
130// build cr field if needed
131func (c *CertificateSigningRequest) buildPKCS10CertificateSigningRequest() error {
132	if c.cr != nil {
133		return nil
134	}
135
136	var err error
137	c.cr, err = x509.ParseCertificateRequest(c.derBytes)
138	if err != nil {
139		return err
140	}
141	return nil
142}
143
144// GetRawCertificateSigningRequest returns a copy of this certificate request as an x509.CertificateRequest.
145func (c *CertificateSigningRequest) GetRawCertificateSigningRequest() (*x509.CertificateRequest, error) {
146	if err := c.buildPKCS10CertificateSigningRequest(); err != nil {
147		return nil, err
148	}
149	return c.cr, nil
150}
151
152// CheckSignature verifies that the signature is a valid signature
153// using the public key in CertificateSigningRequest.
154func (c *CertificateSigningRequest) CheckSignature() error {
155	if err := c.buildPKCS10CertificateSigningRequest(); err != nil {
156		return err
157	}
158	return checkSignature(c.cr, c.cr.SignatureAlgorithm, c.cr.RawTBSCertificateRequest, c.cr.Signature)
159}
160
161// checkSignature verifies a signature made by the key on a CSR, such
162// as on the CSR itself.
163func checkSignature(csr *x509.CertificateRequest, algo x509.SignatureAlgorithm, signed, signature []byte) error {
164	var hashType crypto.Hash
165	switch algo {
166	case x509.SHA1WithRSA, x509.ECDSAWithSHA1:
167		hashType = crypto.SHA1
168	case x509.SHA256WithRSA, x509.ECDSAWithSHA256:
169		hashType = crypto.SHA256
170	case x509.SHA384WithRSA, x509.ECDSAWithSHA384:
171		hashType = crypto.SHA384
172	case x509.SHA512WithRSA, x509.ECDSAWithSHA512:
173		hashType = crypto.SHA512
174	default:
175		return x509.ErrUnsupportedAlgorithm
176	}
177	if !hashType.Available() {
178		return x509.ErrUnsupportedAlgorithm
179	}
180	h := hashType.New()
181	h.Write(signed)
182	digest := h.Sum(nil)
183	switch pub := csr.PublicKey.(type) {
184	case *rsa.PublicKey:
185		return rsa.VerifyPKCS1v15(pub, hashType, digest, signature)
186	case *ecdsa.PublicKey:
187		ecdsaSig := new(struct{ R, S *big.Int })
188		if _, err := asn1.Unmarshal(signature, ecdsaSig); err != nil {
189			return err
190		}
191		if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
192			return errors.New("x509: ECDSA signature contained zero or negative values")
193		}
194		if !ecdsa.Verify(pub, digest, ecdsaSig.R, ecdsaSig.S) {
195			return errors.New("x509: ECDSA verification failure")
196		}
197		return nil
198	}
199	return x509.ErrUnsupportedAlgorithm
200}
201
202// Export returns PEM-format bytes
203func (c *CertificateSigningRequest) Export() ([]byte, error) {
204	pemBlock := &pem.Block{
205		Type:    csrPEMBlockType,
206		Headers: nil,
207		Bytes:   c.derBytes,
208	}
209
210	buf := new(bytes.Buffer)
211	if err := pem.Encode(buf, pemBlock); err != nil {
212		return nil, err
213	}
214	return buf.Bytes(), nil
215}
216