1/* 2 * 3 * Copyright 2021 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19package google 20 21import ( 22 "context" 23 "net" 24 25 "google.golang.org/grpc/credentials" 26 "google.golang.org/grpc/internal" 27) 28 29const cfeClusterName = "google-cfe" 30 31// clusterTransportCreds is a combo of TLS + ALTS. 32// 33// On the client, ClientHandshake picks TLS or ALTS based on address attributes. 34// - if attributes has cluster name 35// - if cluster name is "google_cfe", use TLS 36// - otherwise, use ALTS 37// - else, do TLS 38// 39// On the server, ServerHandshake always does TLS. 40type clusterTransportCreds struct { 41 tls credentials.TransportCredentials 42 alts credentials.TransportCredentials 43} 44 45func newClusterTransportCreds(tls, alts credentials.TransportCredentials) *clusterTransportCreds { 46 return &clusterTransportCreds{ 47 tls: tls, 48 alts: alts, 49 } 50} 51 52func (c *clusterTransportCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 53 chi := credentials.ClientHandshakeInfoFromContext(ctx) 54 if chi.Attributes == nil { 55 return c.tls.ClientHandshake(ctx, authority, rawConn) 56 } 57 cn, ok := internal.GetXDSHandshakeClusterName(chi.Attributes) 58 if !ok || cn == cfeClusterName { 59 return c.tls.ClientHandshake(ctx, authority, rawConn) 60 } 61 // If attributes have cluster name, and cluster name is not cfe, it's a 62 // backend address, use ALTS. 63 return c.alts.ClientHandshake(ctx, authority, rawConn) 64} 65 66func (c *clusterTransportCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { 67 return c.tls.ServerHandshake(conn) 68} 69 70func (c *clusterTransportCreds) Info() credentials.ProtocolInfo { 71 // TODO: this always returns tls.Info now, because we don't have a cluster 72 // name to check when this method is called. This method doesn't affect 73 // anything important now. We may want to revisit this if it becomes more 74 // important later. 75 return c.tls.Info() 76} 77 78func (c *clusterTransportCreds) Clone() credentials.TransportCredentials { 79 return &clusterTransportCreds{ 80 tls: c.tls.Clone(), 81 alts: c.alts.Clone(), 82 } 83} 84 85func (c *clusterTransportCreds) OverrideServerName(s string) error { 86 if err := c.tls.OverrideServerName(s); err != nil { 87 return err 88 } 89 return c.alts.OverrideServerName(s) 90} 91