1/*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package alts
20
21import (
22	"context"
23	"io"
24	"os"
25	"strings"
26	"testing"
27
28	"google.golang.org/grpc/codes"
29	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
30	"google.golang.org/grpc/peer"
31	"google.golang.org/grpc/status"
32)
33
34const (
35	testServiceAccount1 = "service_account1"
36	testServiceAccount2 = "service_account2"
37	testServiceAccount3 = "service_account3"
38)
39
40func setupManufacturerReader(testOS string, reader func() (io.Reader, error)) func() {
41	tmpOS := runningOS
42	tmpReader := manufacturerReader
43
44	// Set test OS and reader function.
45	runningOS = testOS
46	manufacturerReader = reader
47	return func() {
48		runningOS = tmpOS
49		manufacturerReader = tmpReader
50	}
51
52}
53
54func setup(testOS string, testReader io.Reader) func() {
55	reader := func() (io.Reader, error) {
56		return testReader, nil
57	}
58	return setupManufacturerReader(testOS, reader)
59}
60
61func setupError(testOS string, err error) func() {
62	reader := func() (io.Reader, error) {
63		return nil, err
64	}
65	return setupManufacturerReader(testOS, reader)
66}
67
68func TestIsRunningOnGCP(t *testing.T) {
69	for _, tc := range []struct {
70		description string
71		testOS      string
72		testReader  io.Reader
73		out         bool
74	}{
75		// Linux tests.
76		{"linux: not a GCP platform", "linux", strings.NewReader("not GCP"), false},
77		{"Linux: GCP platform (Google)", "linux", strings.NewReader("Google"), true},
78		{"Linux: GCP platform (Google Compute Engine)", "linux", strings.NewReader("Google Compute Engine"), true},
79		{"Linux: GCP platform (Google Compute Engine) with extra spaces", "linux", strings.NewReader("  Google Compute Engine        "), true},
80		// Windows tests.
81		{"windows: not a GCP platform", "windows", strings.NewReader("not GCP"), false},
82		{"windows: GCP platform (Google)", "windows", strings.NewReader("Google"), true},
83		{"windows: GCP platform (Google) with extra spaces", "windows", strings.NewReader("  Google     "), true},
84	} {
85		reverseFunc := setup(tc.testOS, tc.testReader)
86		if got, want := isRunningOnGCP(), tc.out; got != want {
87			t.Errorf("%v: isRunningOnGCP()=%v, want %v", tc.description, got, want)
88		}
89		reverseFunc()
90	}
91}
92
93func TestIsRunningOnGCPNoProductNameFile(t *testing.T) {
94	reverseFunc := setupError("linux", os.ErrNotExist)
95	if isRunningOnGCP() {
96		t.Errorf("ErrNotExist: isRunningOnGCP()=true, want false")
97	}
98	reverseFunc()
99}
100
101func TestAuthInfoFromContext(t *testing.T) {
102	ctx := context.Background()
103	altsAuthInfo := &fakeALTSAuthInfo{}
104	p := &peer.Peer{
105		AuthInfo: altsAuthInfo,
106	}
107	for _, tc := range []struct {
108		desc    string
109		ctx     context.Context
110		success bool
111		out     AuthInfo
112	}{
113		{
114			"working case",
115			peer.NewContext(ctx, p),
116			true,
117			altsAuthInfo,
118		},
119	} {
120		authInfo, err := AuthInfoFromContext(tc.ctx)
121		if got, want := (err == nil), tc.success; got != want {
122			t.Errorf("%v: AuthInfoFromContext(_)=(err=nil)=%v, want %v", tc.desc, got, want)
123		}
124		if got, want := authInfo, tc.out; got != want {
125			t.Errorf("%v:, AuthInfoFromContext(_)=(%v, _), want (%v, _)", tc.desc, got, want)
126		}
127	}
128}
129
130func TestAuthInfoFromPeer(t *testing.T) {
131	altsAuthInfo := &fakeALTSAuthInfo{}
132	p := &peer.Peer{
133		AuthInfo: altsAuthInfo,
134	}
135	for _, tc := range []struct {
136		desc    string
137		p       *peer.Peer
138		success bool
139		out     AuthInfo
140	}{
141		{
142			"working case",
143			p,
144			true,
145			altsAuthInfo,
146		},
147	} {
148		authInfo, err := AuthInfoFromPeer(tc.p)
149		if got, want := (err == nil), tc.success; got != want {
150			t.Errorf("%v: AuthInfoFromPeer(_)=(err=nil)=%v, want %v", tc.desc, got, want)
151		}
152		if got, want := authInfo, tc.out; got != want {
153			t.Errorf("%v:, AuthInfoFromPeer(_)=(%v, _), want (%v, _)", tc.desc, got, want)
154		}
155	}
156}
157
158func TestClientAuthorizationCheck(t *testing.T) {
159	ctx := context.Background()
160	altsAuthInfo := &fakeALTSAuthInfo{testServiceAccount1}
161	p := &peer.Peer{
162		AuthInfo: altsAuthInfo,
163	}
164	for _, tc := range []struct {
165		desc                    string
166		ctx                     context.Context
167		expectedServiceAccounts []string
168		success                 bool
169		code                    codes.Code
170	}{
171		{
172			"working case",
173			peer.NewContext(ctx, p),
174			[]string{testServiceAccount1, testServiceAccount2},
175			true,
176			codes.OK, // err is nil, code is OK.
177		},
178		{
179			"context does not have AuthInfo",
180			ctx,
181			[]string{testServiceAccount1, testServiceAccount2},
182			false,
183			codes.PermissionDenied,
184		},
185		{
186			"unauthorized client",
187			peer.NewContext(ctx, p),
188			[]string{testServiceAccount2, testServiceAccount3},
189			false,
190			codes.PermissionDenied,
191		},
192	} {
193		err := ClientAuthorizationCheck(tc.ctx, tc.expectedServiceAccounts)
194		if got, want := (err == nil), tc.success; got != want {
195			t.Errorf("%v: ClientAuthorizationCheck(_, %v)=(err=nil)=%v, want %v", tc.desc, tc.expectedServiceAccounts, got, want)
196		}
197		if got, want := status.Code(err), tc.code; got != want {
198			t.Errorf("%v: ClientAuthorizationCheck(_, %v).Code=%v, want %v", tc.desc, tc.expectedServiceAccounts, got, want)
199		}
200	}
201}
202
203type fakeALTSAuthInfo struct {
204	peerServiceAccount string
205}
206
207func (*fakeALTSAuthInfo) AuthType() string            { return "" }
208func (*fakeALTSAuthInfo) ApplicationProtocol() string { return "" }
209func (*fakeALTSAuthInfo) RecordProtocol() string      { return "" }
210func (*fakeALTSAuthInfo) SecurityLevel() altspb.SecurityLevel {
211	return altspb.SecurityLevel_SECURITY_NONE
212}
213func (f *fakeALTSAuthInfo) PeerServiceAccount() string                 { return f.peerServiceAccount }
214func (*fakeALTSAuthInfo) LocalServiceAccount() string                  { return "" }
215func (*fakeALTSAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { return nil }
216