1/*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package alts implements the ALTS credential support by gRPC library, which
20// encapsulates all the state needed by a client to authenticate with a server
21// using ALTS and make various assertions, e.g., about the client's identity,
22// role, or whether it is authorized to make a particular call.
23// This package is experimental.
24package alts
25
26import (
27	"context"
28	"errors"
29	"fmt"
30	"net"
31	"sync"
32	"time"
33
34	"google.golang.org/grpc/credentials"
35	core "google.golang.org/grpc/credentials/alts/internal"
36	"google.golang.org/grpc/credentials/alts/internal/handshaker"
37	"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
38	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
39	"google.golang.org/grpc/grpclog"
40)
41
42const (
43	// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
44	// handshaker service address in the hypervisor.
45	hypervisorHandshakerServiceAddress = "metadata.google.internal.:8080"
46	// defaultTimeout specifies the server handshake timeout.
47	defaultTimeout = 30.0 * time.Second
48	// The following constants specify the minimum and maximum acceptable
49	// protocol versions.
50	protocolVersionMaxMajor = 2
51	protocolVersionMaxMinor = 1
52	protocolVersionMinMajor = 2
53	protocolVersionMinMinor = 1
54)
55
56var (
57	once          sync.Once
58	maxRPCVersion = &altspb.RpcProtocolVersions_Version{
59		Major: protocolVersionMaxMajor,
60		Minor: protocolVersionMaxMinor,
61	}
62	minRPCVersion = &altspb.RpcProtocolVersions_Version{
63		Major: protocolVersionMinMajor,
64		Minor: protocolVersionMinMinor,
65	}
66	// ErrUntrustedPlatform is returned from ClientHandshake and
67	// ServerHandshake is running on a platform where the trustworthiness of
68	// the handshaker service is not guaranteed.
69	ErrUntrustedPlatform = errors.New("ALTS: untrusted platform. ALTS is only supported on GCP")
70	logger               = grpclog.Component("alts")
71)
72
73// AuthInfo exposes security information from the ALTS handshake to the
74// application. This interface is to be implemented by ALTS. Users should not
75// need a brand new implementation of this interface. For situations like
76// testing, any new implementation should embed this interface. This allows
77// ALTS to add new methods to this interface.
78type AuthInfo interface {
79	// ApplicationProtocol returns application protocol negotiated for the
80	// ALTS connection.
81	ApplicationProtocol() string
82	// RecordProtocol returns the record protocol negotiated for the ALTS
83	// connection.
84	RecordProtocol() string
85	// SecurityLevel returns the security level of the created ALTS secure
86	// channel.
87	SecurityLevel() altspb.SecurityLevel
88	// PeerServiceAccount returns the peer service account.
89	PeerServiceAccount() string
90	// LocalServiceAccount returns the local service account.
91	LocalServiceAccount() string
92	// PeerRPCVersions returns the RPC version supported by the peer.
93	PeerRPCVersions() *altspb.RpcProtocolVersions
94}
95
96// ClientOptions contains the client-side options of an ALTS channel. These
97// options will be passed to the underlying ALTS handshaker.
98type ClientOptions struct {
99	// TargetServiceAccounts contains a list of expected target service
100	// accounts.
101	TargetServiceAccounts []string
102	// HandshakerServiceAddress represents the ALTS handshaker gRPC service
103	// address to connect to.
104	HandshakerServiceAddress string
105}
106
107// DefaultClientOptions creates a new ClientOptions object with the default
108// values.
109func DefaultClientOptions() *ClientOptions {
110	return &ClientOptions{
111		HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
112	}
113}
114
115// ServerOptions contains the server-side options of an ALTS channel. These
116// options will be passed to the underlying ALTS handshaker.
117type ServerOptions struct {
118	// HandshakerServiceAddress represents the ALTS handshaker gRPC service
119	// address to connect to.
120	HandshakerServiceAddress string
121}
122
123// DefaultServerOptions creates a new ServerOptions object with the default
124// values.
125func DefaultServerOptions() *ServerOptions {
126	return &ServerOptions{
127		HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
128	}
129}
130
131// altsTC is the credentials required for authenticating a connection using ALTS.
132// It implements credentials.TransportCredentials interface.
133type altsTC struct {
134	info      *credentials.ProtocolInfo
135	side      core.Side
136	accounts  []string
137	hsAddress string
138}
139
140// NewClientCreds constructs a client-side ALTS TransportCredentials object.
141func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
142	return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
143}
144
145// NewServerCreds constructs a server-side ALTS TransportCredentials object.
146func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
147	return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
148}
149
150func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
151	once.Do(func() {
152		vmOnGCP = isRunningOnGCP()
153	})
154
155	if hsAddress == "" {
156		hsAddress = hypervisorHandshakerServiceAddress
157	}
158	return &altsTC{
159		info: &credentials.ProtocolInfo{
160			SecurityProtocol: "alts",
161			SecurityVersion:  "1.0",
162		},
163		side:      side,
164		accounts:  accounts,
165		hsAddress: hsAddress,
166	}
167}
168
169// ClientHandshake implements the client side handshake protocol.
170func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
171	if !vmOnGCP {
172		return nil, nil, ErrUntrustedPlatform
173	}
174
175	// Connecting to ALTS handshaker service.
176	hsConn, err := service.Dial(g.hsAddress)
177	if err != nil {
178		return nil, nil, err
179	}
180	// Do not close hsConn since it is shared with other handshakes.
181
182	// Possible context leak:
183	// The cancel function for the child context we create will only be
184	// called a non-nil error is returned.
185	var cancel context.CancelFunc
186	ctx, cancel = context.WithCancel(ctx)
187	defer func() {
188		if err != nil {
189			cancel()
190		}
191	}()
192
193	opts := handshaker.DefaultClientHandshakerOptions()
194	opts.TargetName = addr
195	opts.TargetServiceAccounts = g.accounts
196	opts.RPCVersions = &altspb.RpcProtocolVersions{
197		MaxRpcVersion: maxRPCVersion,
198		MinRpcVersion: minRPCVersion,
199	}
200	chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
201	if err != nil {
202		return nil, nil, err
203	}
204	defer func() {
205		if err != nil {
206			chs.Close()
207		}
208	}()
209	secConn, authInfo, err := chs.ClientHandshake(ctx)
210	if err != nil {
211		return nil, nil, err
212	}
213	altsAuthInfo, ok := authInfo.(AuthInfo)
214	if !ok {
215		return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
216	}
217	match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
218	if !match {
219		return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
220	}
221	return secConn, authInfo, nil
222}
223
224// ServerHandshake implements the server side ALTS handshaker.
225func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
226	if !vmOnGCP {
227		return nil, nil, ErrUntrustedPlatform
228	}
229	// Connecting to ALTS handshaker service.
230	hsConn, err := service.Dial(g.hsAddress)
231	if err != nil {
232		return nil, nil, err
233	}
234	// Do not close hsConn since it's shared with other handshakes.
235
236	ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
237	defer cancel()
238	opts := handshaker.DefaultServerHandshakerOptions()
239	opts.RPCVersions = &altspb.RpcProtocolVersions{
240		MaxRpcVersion: maxRPCVersion,
241		MinRpcVersion: minRPCVersion,
242	}
243	shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
244	if err != nil {
245		return nil, nil, err
246	}
247	defer func() {
248		if err != nil {
249			shs.Close()
250		}
251	}()
252	secConn, authInfo, err := shs.ServerHandshake(ctx)
253	if err != nil {
254		return nil, nil, err
255	}
256	altsAuthInfo, ok := authInfo.(AuthInfo)
257	if !ok {
258		return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
259	}
260	match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
261	if !match {
262		return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
263	}
264	return secConn, authInfo, nil
265}
266
267func (g *altsTC) Info() credentials.ProtocolInfo {
268	return *g.info
269}
270
271func (g *altsTC) Clone() credentials.TransportCredentials {
272	info := *g.info
273	var accounts []string
274	if g.accounts != nil {
275		accounts = make([]string, len(g.accounts))
276		copy(accounts, g.accounts)
277	}
278	return &altsTC{
279		info:      &info,
280		side:      g.side,
281		hsAddress: g.hsAddress,
282		accounts:  accounts,
283	}
284}
285
286func (g *altsTC) OverrideServerName(serverNameOverride string) error {
287	g.info.ServerName = serverNameOverride
288	return nil
289}
290
291// compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2.
292func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
293	switch {
294	case v1.GetMajor() > v2.GetMajor(),
295		v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
296		return 1
297	case v1.GetMajor() < v2.GetMajor(),
298		v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
299		return -1
300	}
301	return 0
302}
303
304// checkRPCVersions performs a version check between local and peer rpc protocol
305// versions. This function returns true if the check passes which means both
306// parties agreed on a common rpc protocol to use, and false otherwise. The
307// function also returns the highest common RPC protocol version both parties
308// agreed on.
309func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
310	if local == nil || peer == nil {
311		logger.Error("invalid checkRPCVersions argument, either local or peer is nil.")
312		return false, nil
313	}
314
315	// maxCommonVersion is MIN(local.max, peer.max).
316	maxCommonVersion := local.GetMaxRpcVersion()
317	if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
318		maxCommonVersion = peer.GetMaxRpcVersion()
319	}
320
321	// minCommonVersion is MAX(local.min, peer.min).
322	minCommonVersion := peer.GetMinRpcVersion()
323	if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
324		minCommonVersion = local.GetMinRpcVersion()
325	}
326
327	if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
328		return false, nil
329	}
330	return true, maxCommonVersion
331}
332