1/* 2 * 3 * Copyright 2021 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19package google 20 21import ( 22 "context" 23 "net" 24 "testing" 25 26 "google.golang.org/grpc/credentials" 27 "google.golang.org/grpc/internal" 28 icredentials "google.golang.org/grpc/internal/credentials" 29 "google.golang.org/grpc/resolver" 30) 31 32type testCreds struct { 33 credentials.TransportCredentials 34 typ string 35} 36 37func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 38 return nil, &testAuthInfo{typ: c.typ}, nil 39} 40 41func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { 42 return nil, &testAuthInfo{typ: c.typ}, nil 43} 44 45type testAuthInfo struct { 46 typ string 47} 48 49func (t *testAuthInfo) AuthType() string { 50 return t.typ 51} 52 53var ( 54 testTLS = &testCreds{typ: "tls"} 55 testALTS = &testCreds{typ: "alts"} 56) 57 58func overrideNewCredsFuncs() func() { 59 oldNewTLS := newTLS 60 newTLS = func() credentials.TransportCredentials { 61 return testTLS 62 } 63 oldNewALTS := newALTS 64 newALTS = func() credentials.TransportCredentials { 65 return testALTS 66 } 67 return func() { 68 newTLS = oldNewTLS 69 newALTS = oldNewALTS 70 } 71} 72 73// TestClientHandshakeBasedOnClusterName that by default (without switching 74// modes), ClientHandshake does either tls or alts base on the cluster name in 75// attributes. 76func TestClientHandshakeBasedOnClusterName(t *testing.T) { 77 defer overrideNewCredsFuncs()() 78 for bundleTyp, tc := range map[string]credentials.Bundle{ 79 "defaultCreds": NewDefaultCredentials(), 80 "computeCreds": NewComputeEngineCredentials(), 81 } { 82 tests := []struct { 83 name string 84 ctx context.Context 85 wantTyp string 86 }{ 87 { 88 name: "no cluster name", 89 ctx: context.Background(), 90 wantTyp: "tls", 91 }, 92 { 93 name: "with non-CFE cluster name", 94 ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ 95 Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes, 96 }), 97 // non-CFE backends should use alts. 98 wantTyp: "alts", 99 }, 100 { 101 name: "with CFE cluster name", 102 ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ 103 Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, cfeClusterName).Attributes, 104 }), 105 // CFE should use tls. 106 wantTyp: "tls", 107 }, 108 } 109 for _, tt := range tests { 110 t.Run(bundleTyp+" "+tt.name, func(t *testing.T) { 111 _, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil) 112 if err != nil { 113 t.Fatalf("ClientHandshake failed: %v", err) 114 } 115 if gotType := info.AuthType(); gotType != tt.wantTyp { 116 t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp) 117 } 118 119 _, infoServer, err := tc.TransportCredentials().ServerHandshake(nil) 120 if err != nil { 121 t.Fatalf("ClientHandshake failed: %v", err) 122 } 123 // ServerHandshake should always do TLS. 124 if gotType := infoServer.AuthType(); gotType != "tls" { 125 t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls") 126 } 127 }) 128 } 129 } 130} 131