1/*
2 *
3 * Copyright 2021 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package google
20
21import (
22	"context"
23	"net"
24	"testing"
25
26	"google.golang.org/grpc/credentials"
27	"google.golang.org/grpc/internal"
28	icredentials "google.golang.org/grpc/internal/credentials"
29	"google.golang.org/grpc/resolver"
30)
31
32type testCreds struct {
33	credentials.TransportCredentials
34	typ string
35}
36
37func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
38	return nil, &testAuthInfo{typ: c.typ}, nil
39}
40
41func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
42	return nil, &testAuthInfo{typ: c.typ}, nil
43}
44
45type testAuthInfo struct {
46	typ string
47}
48
49func (t *testAuthInfo) AuthType() string {
50	return t.typ
51}
52
53var (
54	testTLS  = &testCreds{typ: "tls"}
55	testALTS = &testCreds{typ: "alts"}
56)
57
58func overrideNewCredsFuncs() func() {
59	oldNewTLS := newTLS
60	newTLS = func() credentials.TransportCredentials {
61		return testTLS
62	}
63	oldNewALTS := newALTS
64	newALTS = func() credentials.TransportCredentials {
65		return testALTS
66	}
67	return func() {
68		newTLS = oldNewTLS
69		newALTS = oldNewALTS
70	}
71}
72
73// TestClientHandshakeBasedOnClusterName that by default (without switching
74// modes), ClientHandshake does either tls or alts base on the cluster name in
75// attributes.
76func TestClientHandshakeBasedOnClusterName(t *testing.T) {
77	defer overrideNewCredsFuncs()()
78	for bundleTyp, tc := range map[string]credentials.Bundle{
79		"defaultCreds": NewDefaultCredentials(),
80		"computeCreds": NewComputeEngineCredentials(),
81	} {
82		tests := []struct {
83			name    string
84			ctx     context.Context
85			wantTyp string
86		}{
87			{
88				name:    "no cluster name",
89				ctx:     context.Background(),
90				wantTyp: "tls",
91			},
92			{
93				name: "with non-CFE cluster name",
94				ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
95					Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
96				}),
97				// non-CFE backends should use alts.
98				wantTyp: "alts",
99			},
100			{
101				name: "with CFE cluster name",
102				ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
103					Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, cfeClusterName).Attributes,
104				}),
105				// CFE should use tls.
106				wantTyp: "tls",
107			},
108		}
109		for _, tt := range tests {
110			t.Run(bundleTyp+" "+tt.name, func(t *testing.T) {
111				_, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil)
112				if err != nil {
113					t.Fatalf("ClientHandshake failed: %v", err)
114				}
115				if gotType := info.AuthType(); gotType != tt.wantTyp {
116					t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp)
117				}
118
119				_, infoServer, err := tc.TransportCredentials().ServerHandshake(nil)
120				if err != nil {
121					t.Fatalf("ClientHandshake failed: %v", err)
122				}
123				// ServerHandshake should always do TLS.
124				if gotType := infoServer.AuthType(); gotType != "tls" {
125					t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls")
126				}
127			})
128		}
129	}
130}
131