1// Copyright 2019 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package credentials implements gRPC credential interface with etcd specific logic.
16// e.g., client handshake with custom authority parameter
17package credentials
18
19import (
20	"context"
21	"crypto/tls"
22	"net"
23	"sync"
24
25	"go.etcd.io/etcd/clientv3/balancer/resolver/endpoint"
26	"go.etcd.io/etcd/etcdserver/api/v3rpc/rpctypes"
27	grpccredentials "google.golang.org/grpc/credentials"
28)
29
30// Config defines gRPC credential configuration.
31type Config struct {
32	TLSConfig *tls.Config
33}
34
35// Bundle defines gRPC credential interface.
36type Bundle interface {
37	grpccredentials.Bundle
38	UpdateAuthToken(token string)
39}
40
41// NewBundle constructs a new gRPC credential bundle.
42func NewBundle(cfg Config) Bundle {
43	return &bundle{
44		tc: newTransportCredential(cfg.TLSConfig),
45		rc: newPerRPCCredential(),
46	}
47}
48
49// bundle implements "grpccredentials.Bundle" interface.
50type bundle struct {
51	tc *transportCredential
52	rc *perRPCCredential
53}
54
55func (b *bundle) TransportCredentials() grpccredentials.TransportCredentials {
56	return b.tc
57}
58
59func (b *bundle) PerRPCCredentials() grpccredentials.PerRPCCredentials {
60	return b.rc
61}
62
63func (b *bundle) NewWithMode(mode string) (grpccredentials.Bundle, error) {
64	// no-op
65	return nil, nil
66}
67
68// transportCredential implements "grpccredentials.TransportCredentials" interface.
69// transportCredential wraps TransportCredentials to track which
70// addresses are dialed for which endpoints, and then sets the authority when checking the endpoint's cert to the
71// hostname or IP of the dialed endpoint.
72// This is a workaround of a gRPC load balancer issue. gRPC uses the dialed target's service name as the authority when
73// checking all endpoint certs, which does not work for etcd servers using their hostname or IP as the Subject Alternative Name
74// in their TLS certs.
75// To enable, include both WithTransportCredentials(creds) and WithContextDialer(creds.Dialer)
76// when dialing.
77type transportCredential struct {
78	gtc grpccredentials.TransportCredentials
79	mu  sync.Mutex
80	// addrToEndpoint maps from the connection addresses that are dialed to the hostname or IP of the
81	// endpoint provided to the dialer when dialing
82	addrToEndpoint map[string]string
83}
84
85func newTransportCredential(cfg *tls.Config) *transportCredential {
86	return &transportCredential{
87		gtc:            grpccredentials.NewTLS(cfg),
88		addrToEndpoint: map[string]string{},
89	}
90}
91
92func (tc *transportCredential) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, grpccredentials.AuthInfo, error) {
93	// Set the authority when checking the endpoint's cert to the hostname or IP of the dialed endpoint
94	tc.mu.Lock()
95	dialEp, ok := tc.addrToEndpoint[rawConn.RemoteAddr().String()]
96	tc.mu.Unlock()
97	if ok {
98		_, host, _ := endpoint.ParseEndpoint(dialEp)
99		authority = host
100	}
101	return tc.gtc.ClientHandshake(ctx, authority, rawConn)
102}
103
104// return true if given string is an IP.
105func isIP(ep string) bool {
106	return net.ParseIP(ep) != nil
107}
108
109func (tc *transportCredential) ServerHandshake(rawConn net.Conn) (net.Conn, grpccredentials.AuthInfo, error) {
110	return tc.gtc.ServerHandshake(rawConn)
111}
112
113func (tc *transportCredential) Info() grpccredentials.ProtocolInfo {
114	return tc.gtc.Info()
115}
116
117func (tc *transportCredential) Clone() grpccredentials.TransportCredentials {
118	copy := map[string]string{}
119	tc.mu.Lock()
120	for k, v := range tc.addrToEndpoint {
121		copy[k] = v
122	}
123	tc.mu.Unlock()
124	return &transportCredential{
125		gtc:            tc.gtc.Clone(),
126		addrToEndpoint: copy,
127	}
128}
129
130func (tc *transportCredential) OverrideServerName(serverNameOverride string) error {
131	return tc.gtc.OverrideServerName(serverNameOverride)
132}
133
134func (tc *transportCredential) Dialer(ctx context.Context, dialEp string) (net.Conn, error) {
135	// Keep track of which addresses are dialed for which endpoints
136	conn, err := endpoint.Dialer(ctx, dialEp)
137	if conn != nil {
138		tc.mu.Lock()
139		tc.addrToEndpoint[conn.RemoteAddr().String()] = dialEp
140		tc.mu.Unlock()
141	}
142	return conn, err
143}
144
145// perRPCCredential implements "grpccredentials.PerRPCCredentials" interface.
146type perRPCCredential struct {
147	authToken   string
148	authTokenMu sync.RWMutex
149}
150
151func newPerRPCCredential() *perRPCCredential { return &perRPCCredential{} }
152
153func (rc *perRPCCredential) RequireTransportSecurity() bool { return false }
154
155func (rc *perRPCCredential) GetRequestMetadata(ctx context.Context, s ...string) (map[string]string, error) {
156	rc.authTokenMu.RLock()
157	authToken := rc.authToken
158	rc.authTokenMu.RUnlock()
159	return map[string]string{rpctypes.TokenFieldNameGRPC: authToken}, nil
160}
161
162func (b *bundle) UpdateAuthToken(token string) {
163	if b.rc == nil {
164		return
165	}
166	b.rc.UpdateAuthToken(token)
167}
168
169func (rc *perRPCCredential) UpdateAuthToken(token string) {
170	rc.authTokenMu.Lock()
171	rc.authToken = token
172	rc.authTokenMu.Unlock()
173}
174