1/*
2 *
3 * Copyright 2021 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package google
20
21import (
22	"context"
23	"net"
24
25	"google.golang.org/grpc/credentials"
26	"google.golang.org/grpc/internal"
27)
28
29const cfeClusterName = "google-cfe"
30
31// clusterTransportCreds is a combo of TLS + ALTS.
32//
33// On the client, ClientHandshake picks TLS or ALTS based on address attributes.
34// - if attributes has cluster name
35//   - if cluster name is "google_cfe", use TLS
36//   - otherwise, use ALTS
37// - else, do TLS
38//
39// On the server, ServerHandshake always does TLS.
40type clusterTransportCreds struct {
41	tls  credentials.TransportCredentials
42	alts credentials.TransportCredentials
43}
44
45func newClusterTransportCreds(tls, alts credentials.TransportCredentials) *clusterTransportCreds {
46	return &clusterTransportCreds{
47		tls:  tls,
48		alts: alts,
49	}
50}
51
52func (c *clusterTransportCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
53	chi := credentials.ClientHandshakeInfoFromContext(ctx)
54	if chi.Attributes == nil {
55		return c.tls.ClientHandshake(ctx, authority, rawConn)
56	}
57	cn, ok := internal.GetXDSHandshakeClusterName(chi.Attributes)
58	if !ok || cn == cfeClusterName {
59		return c.tls.ClientHandshake(ctx, authority, rawConn)
60	}
61	// If attributes have cluster name, and cluster name is not cfe, it's a
62	// backend address, use ALTS.
63	return c.alts.ClientHandshake(ctx, authority, rawConn)
64}
65
66func (c *clusterTransportCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
67	return c.tls.ServerHandshake(conn)
68}
69
70func (c *clusterTransportCreds) Info() credentials.ProtocolInfo {
71	// TODO: this always returns tls.Info now, because we don't have a cluster
72	// name to check when this method is called. This method doesn't affect
73	// anything important now. We may want to revisit this if it becomes more
74	// important later.
75	return c.tls.Info()
76}
77
78func (c *clusterTransportCreds) Clone() credentials.TransportCredentials {
79	return &clusterTransportCreds{
80		tls:  c.tls.Clone(),
81		alts: c.alts.Clone(),
82	}
83}
84
85func (c *clusterTransportCreds) OverrideServerName(s string) error {
86	if err := c.tls.OverrideServerName(s); err != nil {
87		return err
88	}
89	return c.alts.OverrideServerName(s)
90}
91