1/*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package alts implements the ALTS credential support by gRPC library, which
20// encapsulates all the state needed by a client to authenticate with a server
21// using ALTS and make various assertions, e.g., about the client's identity,
22// role, or whether it is authorized to make a particular call.
23// This package is experimental.
24package alts
25
26import (
27	"context"
28	"errors"
29	"fmt"
30	"net"
31	"sync"
32	"time"
33
34	"google.golang.org/grpc/credentials"
35	core "google.golang.org/grpc/credentials/alts/internal"
36	"google.golang.org/grpc/credentials/alts/internal/handshaker"
37	"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
38	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
39	"google.golang.org/grpc/grpclog"
40	"google.golang.org/grpc/internal/googlecloud"
41)
42
43const (
44	// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
45	// handshaker service address in the hypervisor.
46	hypervisorHandshakerServiceAddress = "metadata.google.internal.:8080"
47	// defaultTimeout specifies the server handshake timeout.
48	defaultTimeout = 30.0 * time.Second
49	// The following constants specify the minimum and maximum acceptable
50	// protocol versions.
51	protocolVersionMaxMajor = 2
52	protocolVersionMaxMinor = 1
53	protocolVersionMinMajor = 2
54	protocolVersionMinMinor = 1
55)
56
57var (
58	vmOnGCP       bool
59	once          sync.Once
60	maxRPCVersion = &altspb.RpcProtocolVersions_Version{
61		Major: protocolVersionMaxMajor,
62		Minor: protocolVersionMaxMinor,
63	}
64	minRPCVersion = &altspb.RpcProtocolVersions_Version{
65		Major: protocolVersionMinMajor,
66		Minor: protocolVersionMinMinor,
67	}
68	// ErrUntrustedPlatform is returned from ClientHandshake and
69	// ServerHandshake is running on a platform where the trustworthiness of
70	// the handshaker service is not guaranteed.
71	ErrUntrustedPlatform = errors.New("ALTS: untrusted platform. ALTS is only supported on GCP")
72	logger               = grpclog.Component("alts")
73)
74
75// AuthInfo exposes security information from the ALTS handshake to the
76// application. This interface is to be implemented by ALTS. Users should not
77// need a brand new implementation of this interface. For situations like
78// testing, any new implementation should embed this interface. This allows
79// ALTS to add new methods to this interface.
80type AuthInfo interface {
81	// ApplicationProtocol returns application protocol negotiated for the
82	// ALTS connection.
83	ApplicationProtocol() string
84	// RecordProtocol returns the record protocol negotiated for the ALTS
85	// connection.
86	RecordProtocol() string
87	// SecurityLevel returns the security level of the created ALTS secure
88	// channel.
89	SecurityLevel() altspb.SecurityLevel
90	// PeerServiceAccount returns the peer service account.
91	PeerServiceAccount() string
92	// LocalServiceAccount returns the local service account.
93	LocalServiceAccount() string
94	// PeerRPCVersions returns the RPC version supported by the peer.
95	PeerRPCVersions() *altspb.RpcProtocolVersions
96}
97
98// ClientOptions contains the client-side options of an ALTS channel. These
99// options will be passed to the underlying ALTS handshaker.
100type ClientOptions struct {
101	// TargetServiceAccounts contains a list of expected target service
102	// accounts.
103	TargetServiceAccounts []string
104	// HandshakerServiceAddress represents the ALTS handshaker gRPC service
105	// address to connect to.
106	HandshakerServiceAddress string
107}
108
109// DefaultClientOptions creates a new ClientOptions object with the default
110// values.
111func DefaultClientOptions() *ClientOptions {
112	return &ClientOptions{
113		HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
114	}
115}
116
117// ServerOptions contains the server-side options of an ALTS channel. These
118// options will be passed to the underlying ALTS handshaker.
119type ServerOptions struct {
120	// HandshakerServiceAddress represents the ALTS handshaker gRPC service
121	// address to connect to.
122	HandshakerServiceAddress string
123}
124
125// DefaultServerOptions creates a new ServerOptions object with the default
126// values.
127func DefaultServerOptions() *ServerOptions {
128	return &ServerOptions{
129		HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
130	}
131}
132
133// altsTC is the credentials required for authenticating a connection using ALTS.
134// It implements credentials.TransportCredentials interface.
135type altsTC struct {
136	info      *credentials.ProtocolInfo
137	side      core.Side
138	accounts  []string
139	hsAddress string
140}
141
142// NewClientCreds constructs a client-side ALTS TransportCredentials object.
143func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
144	return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
145}
146
147// NewServerCreds constructs a server-side ALTS TransportCredentials object.
148func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
149	return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
150}
151
152func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
153	once.Do(func() {
154		vmOnGCP = googlecloud.OnGCE()
155	})
156	if hsAddress == "" {
157		hsAddress = hypervisorHandshakerServiceAddress
158	}
159	return &altsTC{
160		info: &credentials.ProtocolInfo{
161			SecurityProtocol: "alts",
162			SecurityVersion:  "1.0",
163		},
164		side:      side,
165		accounts:  accounts,
166		hsAddress: hsAddress,
167	}
168}
169
170// ClientHandshake implements the client side handshake protocol.
171func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
172	if !vmOnGCP {
173		return nil, nil, ErrUntrustedPlatform
174	}
175
176	// Connecting to ALTS handshaker service.
177	hsConn, err := service.Dial(g.hsAddress)
178	if err != nil {
179		return nil, nil, err
180	}
181	// Do not close hsConn since it is shared with other handshakes.
182
183	// Possible context leak:
184	// The cancel function for the child context we create will only be
185	// called a non-nil error is returned.
186	var cancel context.CancelFunc
187	ctx, cancel = context.WithCancel(ctx)
188	defer func() {
189		if err != nil {
190			cancel()
191		}
192	}()
193
194	opts := handshaker.DefaultClientHandshakerOptions()
195	opts.TargetName = addr
196	opts.TargetServiceAccounts = g.accounts
197	opts.RPCVersions = &altspb.RpcProtocolVersions{
198		MaxRpcVersion: maxRPCVersion,
199		MinRpcVersion: minRPCVersion,
200	}
201	chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
202	if err != nil {
203		return nil, nil, err
204	}
205	defer func() {
206		if err != nil {
207			chs.Close()
208		}
209	}()
210	secConn, authInfo, err := chs.ClientHandshake(ctx)
211	if err != nil {
212		return nil, nil, err
213	}
214	altsAuthInfo, ok := authInfo.(AuthInfo)
215	if !ok {
216		return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
217	}
218	match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
219	if !match {
220		return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
221	}
222	return secConn, authInfo, nil
223}
224
225// ServerHandshake implements the server side ALTS handshaker.
226func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
227	if !vmOnGCP {
228		return nil, nil, ErrUntrustedPlatform
229	}
230	// Connecting to ALTS handshaker service.
231	hsConn, err := service.Dial(g.hsAddress)
232	if err != nil {
233		return nil, nil, err
234	}
235	// Do not close hsConn since it's shared with other handshakes.
236
237	ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
238	defer cancel()
239	opts := handshaker.DefaultServerHandshakerOptions()
240	opts.RPCVersions = &altspb.RpcProtocolVersions{
241		MaxRpcVersion: maxRPCVersion,
242		MinRpcVersion: minRPCVersion,
243	}
244	shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
245	if err != nil {
246		return nil, nil, err
247	}
248	defer func() {
249		if err != nil {
250			shs.Close()
251		}
252	}()
253	secConn, authInfo, err := shs.ServerHandshake(ctx)
254	if err != nil {
255		return nil, nil, err
256	}
257	altsAuthInfo, ok := authInfo.(AuthInfo)
258	if !ok {
259		return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
260	}
261	match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
262	if !match {
263		return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
264	}
265	return secConn, authInfo, nil
266}
267
268func (g *altsTC) Info() credentials.ProtocolInfo {
269	return *g.info
270}
271
272func (g *altsTC) Clone() credentials.TransportCredentials {
273	info := *g.info
274	var accounts []string
275	if g.accounts != nil {
276		accounts = make([]string, len(g.accounts))
277		copy(accounts, g.accounts)
278	}
279	return &altsTC{
280		info:      &info,
281		side:      g.side,
282		hsAddress: g.hsAddress,
283		accounts:  accounts,
284	}
285}
286
287func (g *altsTC) OverrideServerName(serverNameOverride string) error {
288	g.info.ServerName = serverNameOverride
289	return nil
290}
291
292// compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2.
293func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
294	switch {
295	case v1.GetMajor() > v2.GetMajor(),
296		v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
297		return 1
298	case v1.GetMajor() < v2.GetMajor(),
299		v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
300		return -1
301	}
302	return 0
303}
304
305// checkRPCVersions performs a version check between local and peer rpc protocol
306// versions. This function returns true if the check passes which means both
307// parties agreed on a common rpc protocol to use, and false otherwise. The
308// function also returns the highest common RPC protocol version both parties
309// agreed on.
310func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
311	if local == nil || peer == nil {
312		logger.Error("invalid checkRPCVersions argument, either local or peer is nil.")
313		return false, nil
314	}
315
316	// maxCommonVersion is MIN(local.max, peer.max).
317	maxCommonVersion := local.GetMaxRpcVersion()
318	if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
319		maxCommonVersion = peer.GetMaxRpcVersion()
320	}
321
322	// minCommonVersion is MAX(local.min, peer.min).
323	minCommonVersion := peer.GetMinRpcVersion()
324	if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
325		minCommonVersion = local.GetMinRpcVersion()
326	}
327
328	if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
329		return false, nil
330	}
331	return true, maxCommonVersion
332}
333