1/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package credentials
20
21import (
22	"context"
23	"crypto/tls"
24	"crypto/x509"
25	"fmt"
26	"io/ioutil"
27	"net"
28
29	"google.golang.org/grpc/credentials/internal"
30)
31
32// TLSInfo contains the auth information for a TLS authenticated connection.
33// It implements the AuthInfo interface.
34type TLSInfo struct {
35	State tls.ConnectionState
36	CommonAuthInfo
37}
38
39// AuthType returns the type of TLSInfo as a string.
40func (t TLSInfo) AuthType() string {
41	return "tls"
42}
43
44// GetSecurityValue returns security info requested by channelz.
45func (t TLSInfo) GetSecurityValue() ChannelzSecurityValue {
46	v := &TLSChannelzSecurityValue{
47		StandardName: cipherSuiteLookup[t.State.CipherSuite],
48	}
49	// Currently there's no way to get LocalCertificate info from tls package.
50	if len(t.State.PeerCertificates) > 0 {
51		v.RemoteCertificate = t.State.PeerCertificates[0].Raw
52	}
53	return v
54}
55
56// tlsCreds is the credentials required for authenticating a connection using TLS.
57type tlsCreds struct {
58	// TLS configuration
59	config *tls.Config
60}
61
62func (c tlsCreds) Info() ProtocolInfo {
63	return ProtocolInfo{
64		SecurityProtocol: "tls",
65		SecurityVersion:  "1.2",
66		ServerName:       c.config.ServerName,
67	}
68}
69
70func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
71	// use local cfg to avoid clobbering ServerName if using multiple endpoints
72	cfg := cloneTLSConfig(c.config)
73	if cfg.ServerName == "" {
74		serverName, _, err := net.SplitHostPort(authority)
75		if err != nil {
76			// If the authority had no host port or if the authority cannot be parsed, use it as-is.
77			serverName = authority
78		}
79		cfg.ServerName = serverName
80	}
81	conn := tls.Client(rawConn, cfg)
82	errChannel := make(chan error, 1)
83	go func() {
84		errChannel <- conn.Handshake()
85		close(errChannel)
86	}()
87	select {
88	case err := <-errChannel:
89		if err != nil {
90			conn.Close()
91			return nil, nil, err
92		}
93	case <-ctx.Done():
94		conn.Close()
95		return nil, nil, ctx.Err()
96	}
97	return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState(), CommonAuthInfo{PrivacyAndIntegrity}}, nil
98}
99
100func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
101	conn := tls.Server(rawConn, c.config)
102	if err := conn.Handshake(); err != nil {
103		conn.Close()
104		return nil, nil, err
105	}
106	return internal.WrapSyscallConn(rawConn, conn), TLSInfo{conn.ConnectionState(), CommonAuthInfo{PrivacyAndIntegrity}}, nil
107}
108
109func (c *tlsCreds) Clone() TransportCredentials {
110	return NewTLS(c.config)
111}
112
113func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
114	c.config.ServerName = serverNameOverride
115	return nil
116}
117
118const alpnProtoStrH2 = "h2"
119
120func appendH2ToNextProtos(ps []string) []string {
121	for _, p := range ps {
122		if p == alpnProtoStrH2 {
123			return ps
124		}
125	}
126	ret := make([]string, 0, len(ps)+1)
127	ret = append(ret, ps...)
128	return append(ret, alpnProtoStrH2)
129}
130
131// NewTLS uses c to construct a TransportCredentials based on TLS.
132func NewTLS(c *tls.Config) TransportCredentials {
133	tc := &tlsCreds{cloneTLSConfig(c)}
134	tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos)
135	return tc
136}
137
138// NewClientTLSFromCert constructs TLS credentials from the input certificate for client.
139// serverNameOverride is for testing only. If set to a non empty string,
140// it will override the virtual host name of authority (e.g. :authority header field) in requests.
141func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials {
142	return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp})
143}
144
145// NewClientTLSFromFile constructs TLS credentials from the input certificate file for client.
146// serverNameOverride is for testing only. If set to a non empty string,
147// it will override the virtual host name of authority (e.g. :authority header field) in requests.
148func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) {
149	b, err := ioutil.ReadFile(certFile)
150	if err != nil {
151		return nil, err
152	}
153	cp := x509.NewCertPool()
154	if !cp.AppendCertsFromPEM(b) {
155		return nil, fmt.Errorf("credentials: failed to append certificates")
156	}
157	return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
158}
159
160// NewServerTLSFromCert constructs TLS credentials from the input certificate for server.
161func NewServerTLSFromCert(cert *tls.Certificate) TransportCredentials {
162	return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
163}
164
165// NewServerTLSFromFile constructs TLS credentials from the input certificate file and key
166// file for server.
167func NewServerTLSFromFile(certFile, keyFile string) (TransportCredentials, error) {
168	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
169	if err != nil {
170		return nil, err
171	}
172	return NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}), nil
173}
174
175// TLSChannelzSecurityValue defines the struct that TLS protocol should return
176// from GetSecurityValue(), containing security info like cipher and certificate used.
177//
178// This API is EXPERIMENTAL.
179type TLSChannelzSecurityValue struct {
180	ChannelzSecurityValue
181	StandardName      string
182	LocalCertificate  []byte
183	RemoteCertificate []byte
184}
185
186var cipherSuiteLookup = map[uint16]string{
187	tls.TLS_RSA_WITH_RC4_128_SHA:                "TLS_RSA_WITH_RC4_128_SHA",
188	tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA:           "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
189	tls.TLS_RSA_WITH_AES_128_CBC_SHA:            "TLS_RSA_WITH_AES_128_CBC_SHA",
190	tls.TLS_RSA_WITH_AES_256_CBC_SHA:            "TLS_RSA_WITH_AES_256_CBC_SHA",
191	tls.TLS_RSA_WITH_AES_128_GCM_SHA256:         "TLS_RSA_WITH_AES_128_GCM_SHA256",
192	tls.TLS_RSA_WITH_AES_256_GCM_SHA384:         "TLS_RSA_WITH_AES_256_GCM_SHA384",
193	tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA:        "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA",
194	tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:    "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
195	tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:    "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
196	tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA:          "TLS_ECDHE_RSA_WITH_RC4_128_SHA",
197	tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA:     "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA",
198	tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:      "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
199	tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:      "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
200	tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:   "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
201	tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
202	tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384:   "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
203	tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
204	tls.TLS_FALLBACK_SCSV:                       "TLS_FALLBACK_SCSV",
205	tls.TLS_RSA_WITH_AES_128_CBC_SHA256:         "TLS_RSA_WITH_AES_128_CBC_SHA256",
206	tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256",
207	tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:   "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
208	tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305:    "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305",
209	tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:  "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
210}
211
212// cloneTLSConfig returns a shallow clone of the exported
213// fields of cfg, ignoring the unexported sync.Once, which
214// contains a mutex and must not be copied.
215//
216// If cfg is nil, a new zero tls.Config is returned.
217//
218// TODO: inline this function if possible.
219func cloneTLSConfig(cfg *tls.Config) *tls.Config {
220	if cfg == nil {
221		return &tls.Config{}
222	}
223
224	return cfg.Clone()
225}
226