1/*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package alts implements the ALTS credential support by gRPC library, which
20// encapsulates all the state needed by a client to authenticate with a server
21// using ALTS and make various assertions, e.g., about the client's identity,
22// role, or whether it is authorized to make a particular call.
23// This package is experimental.
24package alts
25
26import (
27	"context"
28	"errors"
29	"fmt"
30	"net"
31	"sync"
32	"time"
33
34	"google.golang.org/grpc/credentials"
35	core "google.golang.org/grpc/credentials/alts/internal"
36	"google.golang.org/grpc/credentials/alts/internal/handshaker"
37	"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
38	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
39	"google.golang.org/grpc/grpclog"
40)
41
42const (
43	// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
44	// handshaker service address in the hypervisor.
45	hypervisorHandshakerServiceAddress = "metadata.google.internal:8080"
46	// defaultTimeout specifies the server handshake timeout.
47	defaultTimeout = 30.0 * time.Second
48	// The following constants specify the minimum and maximum acceptable
49	// protocol versions.
50	protocolVersionMaxMajor = 2
51	protocolVersionMaxMinor = 1
52	protocolVersionMinMajor = 2
53	protocolVersionMinMinor = 1
54)
55
56var (
57	once          sync.Once
58	maxRPCVersion = &altspb.RpcProtocolVersions_Version{
59		Major: protocolVersionMaxMajor,
60		Minor: protocolVersionMaxMinor,
61	}
62	minRPCVersion = &altspb.RpcProtocolVersions_Version{
63		Major: protocolVersionMinMajor,
64		Minor: protocolVersionMinMinor,
65	}
66	// ErrUntrustedPlatform is returned from ClientHandshake and
67	// ServerHandshake is running on a platform where the trustworthiness of
68	// the handshaker service is not guaranteed.
69	ErrUntrustedPlatform = errors.New("ALTS: untrusted platform. ALTS is only supported on GCP")
70)
71
72// AuthInfo exposes security information from the ALTS handshake to the
73// application. This interface is to be implemented by ALTS. Users should not
74// need a brand new implementation of this interface. For situations like
75// testing, any new implementation should embed this interface. This allows
76// ALTS to add new methods to this interface.
77type AuthInfo interface {
78	// ApplicationProtocol returns application protocol negotiated for the
79	// ALTS connection.
80	ApplicationProtocol() string
81	// RecordProtocol returns the record protocol negotiated for the ALTS
82	// connection.
83	RecordProtocol() string
84	// SecurityLevel returns the security level of the created ALTS secure
85	// channel.
86	SecurityLevel() altspb.SecurityLevel
87	// PeerServiceAccount returns the peer service account.
88	PeerServiceAccount() string
89	// LocalServiceAccount returns the local service account.
90	LocalServiceAccount() string
91	// PeerRPCVersions returns the RPC version supported by the peer.
92	PeerRPCVersions() *altspb.RpcProtocolVersions
93}
94
95// ClientOptions contains the client-side options of an ALTS channel. These
96// options will be passed to the underlying ALTS handshaker.
97type ClientOptions struct {
98	// TargetServiceAccounts contains a list of expected target service
99	// accounts.
100	TargetServiceAccounts []string
101	// HandshakerServiceAddress represents the ALTS handshaker gRPC service
102	// address to connect to.
103	HandshakerServiceAddress string
104}
105
106// DefaultClientOptions creates a new ClientOptions object with the default
107// values.
108func DefaultClientOptions() *ClientOptions {
109	return &ClientOptions{
110		HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
111	}
112}
113
114// ServerOptions contains the server-side options of an ALTS channel. These
115// options will be passed to the underlying ALTS handshaker.
116type ServerOptions struct {
117	// HandshakerServiceAddress represents the ALTS handshaker gRPC service
118	// address to connect to.
119	HandshakerServiceAddress string
120}
121
122// DefaultServerOptions creates a new ServerOptions object with the default
123// values.
124func DefaultServerOptions() *ServerOptions {
125	return &ServerOptions{
126		HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
127	}
128}
129
130// altsTC is the credentials required for authenticating a connection using ALTS.
131// It implements credentials.TransportCredentials interface.
132type altsTC struct {
133	info      *credentials.ProtocolInfo
134	side      core.Side
135	accounts  []string
136	hsAddress string
137}
138
139// NewClientCreds constructs a client-side ALTS TransportCredentials object.
140func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
141	return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
142}
143
144// NewServerCreds constructs a server-side ALTS TransportCredentials object.
145func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
146	return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
147}
148
149func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
150	once.Do(func() {
151		vmOnGCP = isRunningOnGCP()
152	})
153
154	if hsAddress == "" {
155		hsAddress = hypervisorHandshakerServiceAddress
156	}
157	return &altsTC{
158		info: &credentials.ProtocolInfo{
159			SecurityProtocol: "alts",
160			SecurityVersion:  "1.0",
161		},
162		side:      side,
163		accounts:  accounts,
164		hsAddress: hsAddress,
165	}
166}
167
168// ClientHandshake implements the client side handshake protocol.
169func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
170	if !vmOnGCP {
171		return nil, nil, ErrUntrustedPlatform
172	}
173
174	// Connecting to ALTS handshaker service.
175	hsConn, err := service.Dial(g.hsAddress)
176	if err != nil {
177		return nil, nil, err
178	}
179	// Do not close hsConn since it is shared with other handshakes.
180
181	// Possible context leak:
182	// The cancel function for the child context we create will only be
183	// called a non-nil error is returned.
184	var cancel context.CancelFunc
185	ctx, cancel = context.WithCancel(ctx)
186	defer func() {
187		if err != nil {
188			cancel()
189		}
190	}()
191
192	opts := handshaker.DefaultClientHandshakerOptions()
193	opts.TargetName = addr
194	opts.TargetServiceAccounts = g.accounts
195	opts.RPCVersions = &altspb.RpcProtocolVersions{
196		MaxRpcVersion: maxRPCVersion,
197		MinRpcVersion: minRPCVersion,
198	}
199	chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
200	if err != nil {
201		return nil, nil, err
202	}
203	defer func() {
204		if err != nil {
205			chs.Close()
206		}
207	}()
208	secConn, authInfo, err := chs.ClientHandshake(ctx)
209	if err != nil {
210		return nil, nil, err
211	}
212	altsAuthInfo, ok := authInfo.(AuthInfo)
213	if !ok {
214		return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
215	}
216	match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
217	if !match {
218		return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
219	}
220	return secConn, authInfo, nil
221}
222
223// ServerHandshake implements the server side ALTS handshaker.
224func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
225	if !vmOnGCP {
226		return nil, nil, ErrUntrustedPlatform
227	}
228	// Connecting to ALTS handshaker service.
229	hsConn, err := service.Dial(g.hsAddress)
230	if err != nil {
231		return nil, nil, err
232	}
233	// Do not close hsConn since it's shared with other handshakes.
234
235	ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
236	defer cancel()
237	opts := handshaker.DefaultServerHandshakerOptions()
238	opts.RPCVersions = &altspb.RpcProtocolVersions{
239		MaxRpcVersion: maxRPCVersion,
240		MinRpcVersion: minRPCVersion,
241	}
242	shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
243	if err != nil {
244		return nil, nil, err
245	}
246	defer func() {
247		if err != nil {
248			shs.Close()
249		}
250	}()
251	secConn, authInfo, err := shs.ServerHandshake(ctx)
252	if err != nil {
253		return nil, nil, err
254	}
255	altsAuthInfo, ok := authInfo.(AuthInfo)
256	if !ok {
257		return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
258	}
259	match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
260	if !match {
261		return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
262	}
263	return secConn, authInfo, nil
264}
265
266func (g *altsTC) Info() credentials.ProtocolInfo {
267	return *g.info
268}
269
270func (g *altsTC) Clone() credentials.TransportCredentials {
271	info := *g.info
272	var accounts []string
273	if g.accounts != nil {
274		accounts = make([]string, len(g.accounts))
275		copy(accounts, g.accounts)
276	}
277	return &altsTC{
278		info:      &info,
279		side:      g.side,
280		hsAddress: g.hsAddress,
281		accounts:  accounts,
282	}
283}
284
285func (g *altsTC) OverrideServerName(serverNameOverride string) error {
286	g.info.ServerName = serverNameOverride
287	return nil
288}
289
290// compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2.
291func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
292	switch {
293	case v1.GetMajor() > v2.GetMajor(),
294		v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
295		return 1
296	case v1.GetMajor() < v2.GetMajor(),
297		v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
298		return -1
299	}
300	return 0
301}
302
303// checkRPCVersions performs a version check between local and peer rpc protocol
304// versions. This function returns true if the check passes which means both
305// parties agreed on a common rpc protocol to use, and false otherwise. The
306// function also returns the highest common RPC protocol version both parties
307// agreed on.
308func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
309	if local == nil || peer == nil {
310		grpclog.Error("invalid checkRPCVersions argument, either local or peer is nil.")
311		return false, nil
312	}
313
314	// maxCommonVersion is MIN(local.max, peer.max).
315	maxCommonVersion := local.GetMaxRpcVersion()
316	if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
317		maxCommonVersion = peer.GetMaxRpcVersion()
318	}
319
320	// minCommonVersion is MAX(local.min, peer.min).
321	minCommonVersion := peer.GetMinRpcVersion()
322	if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
323		minCommonVersion = local.GetMinRpcVersion()
324	}
325
326	if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
327		return false, nil
328	}
329	return true, maxCommonVersion
330}
331