1// +build linux windows
2
3/*
4 *
5 * Copyright 2018 gRPC authors.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *     http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 */
20
21package alts
22
23import (
24	"context"
25	"io"
26	"os"
27	"strings"
28	"testing"
29	"time"
30
31	"google.golang.org/grpc/codes"
32	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
33	"google.golang.org/grpc/peer"
34	"google.golang.org/grpc/status"
35)
36
37const (
38	testServiceAccount1 = "service_account1"
39	testServiceAccount2 = "service_account2"
40	testServiceAccount3 = "service_account3"
41
42	defaultTestTimeout = 10 * time.Second
43)
44
45func setupManufacturerReader(testOS string, reader func() (io.Reader, error)) func() {
46	tmpOS := runningOS
47	tmpReader := manufacturerReader
48
49	// Set test OS and reader function.
50	runningOS = testOS
51	manufacturerReader = reader
52	return func() {
53		runningOS = tmpOS
54		manufacturerReader = tmpReader
55	}
56
57}
58
59func setup(testOS string, testReader io.Reader) func() {
60	reader := func() (io.Reader, error) {
61		return testReader, nil
62	}
63	return setupManufacturerReader(testOS, reader)
64}
65
66func setupError(testOS string, err error) func() {
67	reader := func() (io.Reader, error) {
68		return nil, err
69	}
70	return setupManufacturerReader(testOS, reader)
71}
72
73func (s) TestIsRunningOnGCP(t *testing.T) {
74	for _, tc := range []struct {
75		description string
76		testOS      string
77		testReader  io.Reader
78		out         bool
79	}{
80		// Linux tests.
81		{"linux: not a GCP platform", "linux", strings.NewReader("not GCP"), false},
82		{"Linux: GCP platform (Google)", "linux", strings.NewReader("Google"), true},
83		{"Linux: GCP platform (Google Compute Engine)", "linux", strings.NewReader("Google Compute Engine"), true},
84		{"Linux: GCP platform (Google Compute Engine) with extra spaces", "linux", strings.NewReader("  Google Compute Engine        "), true},
85		// Windows tests.
86		{"windows: not a GCP platform", "windows", strings.NewReader("not GCP"), false},
87		{"windows: GCP platform (Google)", "windows", strings.NewReader("Google"), true},
88		{"windows: GCP platform (Google) with extra spaces", "windows", strings.NewReader("  Google     "), true},
89	} {
90		reverseFunc := setup(tc.testOS, tc.testReader)
91		if got, want := isRunningOnGCP(), tc.out; got != want {
92			t.Errorf("%v: isRunningOnGCP()=%v, want %v", tc.description, got, want)
93		}
94		reverseFunc()
95	}
96}
97
98func (s) TestIsRunningOnGCPNoProductNameFile(t *testing.T) {
99	reverseFunc := setupError("linux", os.ErrNotExist)
100	if isRunningOnGCP() {
101		t.Errorf("ErrNotExist: isRunningOnGCP()=true, want false")
102	}
103	reverseFunc()
104}
105
106func (s) TestAuthInfoFromContext(t *testing.T) {
107	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
108	defer cancel()
109	altsAuthInfo := &fakeALTSAuthInfo{}
110	p := &peer.Peer{
111		AuthInfo: altsAuthInfo,
112	}
113	for _, tc := range []struct {
114		desc    string
115		ctx     context.Context
116		success bool
117		out     AuthInfo
118	}{
119		{
120			"working case",
121			peer.NewContext(ctx, p),
122			true,
123			altsAuthInfo,
124		},
125	} {
126		authInfo, err := AuthInfoFromContext(tc.ctx)
127		if got, want := (err == nil), tc.success; got != want {
128			t.Errorf("%v: AuthInfoFromContext(_)=(err=nil)=%v, want %v", tc.desc, got, want)
129		}
130		if got, want := authInfo, tc.out; got != want {
131			t.Errorf("%v:, AuthInfoFromContext(_)=(%v, _), want (%v, _)", tc.desc, got, want)
132		}
133	}
134}
135
136func (s) TestAuthInfoFromPeer(t *testing.T) {
137	altsAuthInfo := &fakeALTSAuthInfo{}
138	p := &peer.Peer{
139		AuthInfo: altsAuthInfo,
140	}
141	for _, tc := range []struct {
142		desc    string
143		p       *peer.Peer
144		success bool
145		out     AuthInfo
146	}{
147		{
148			"working case",
149			p,
150			true,
151			altsAuthInfo,
152		},
153	} {
154		authInfo, err := AuthInfoFromPeer(tc.p)
155		if got, want := (err == nil), tc.success; got != want {
156			t.Errorf("%v: AuthInfoFromPeer(_)=(err=nil)=%v, want %v", tc.desc, got, want)
157		}
158		if got, want := authInfo, tc.out; got != want {
159			t.Errorf("%v:, AuthInfoFromPeer(_)=(%v, _), want (%v, _)", tc.desc, got, want)
160		}
161	}
162}
163
164func (s) TestClientAuthorizationCheck(t *testing.T) {
165	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
166	defer cancel()
167	altsAuthInfo := &fakeALTSAuthInfo{testServiceAccount1}
168	p := &peer.Peer{
169		AuthInfo: altsAuthInfo,
170	}
171	for _, tc := range []struct {
172		desc                    string
173		ctx                     context.Context
174		expectedServiceAccounts []string
175		success                 bool
176		code                    codes.Code
177	}{
178		{
179			"working case",
180			peer.NewContext(ctx, p),
181			[]string{testServiceAccount1, testServiceAccount2},
182			true,
183			codes.OK, // err is nil, code is OK.
184		},
185		{
186			"working case (case ignored)",
187			peer.NewContext(ctx, p),
188			[]string{strings.ToUpper(testServiceAccount1), testServiceAccount2},
189			true,
190			codes.OK, // err is nil, code is OK.
191		},
192		{
193			"context does not have AuthInfo",
194			ctx,
195			[]string{testServiceAccount1, testServiceAccount2},
196			false,
197			codes.PermissionDenied,
198		},
199		{
200			"unauthorized client",
201			peer.NewContext(ctx, p),
202			[]string{testServiceAccount2, testServiceAccount3},
203			false,
204			codes.PermissionDenied,
205		},
206	} {
207		err := ClientAuthorizationCheck(tc.ctx, tc.expectedServiceAccounts)
208		if got, want := (err == nil), tc.success; got != want {
209			t.Errorf("%v: ClientAuthorizationCheck(_, %v)=(err=nil)=%v, want %v", tc.desc, tc.expectedServiceAccounts, got, want)
210		}
211		if got, want := status.Code(err), tc.code; got != want {
212			t.Errorf("%v: ClientAuthorizationCheck(_, %v).Code=%v, want %v", tc.desc, tc.expectedServiceAccounts, got, want)
213		}
214	}
215}
216
217type fakeALTSAuthInfo struct {
218	peerServiceAccount string
219}
220
221func (*fakeALTSAuthInfo) AuthType() string            { return "" }
222func (*fakeALTSAuthInfo) ApplicationProtocol() string { return "" }
223func (*fakeALTSAuthInfo) RecordProtocol() string      { return "" }
224func (*fakeALTSAuthInfo) SecurityLevel() altspb.SecurityLevel {
225	return altspb.SecurityLevel_SECURITY_NONE
226}
227func (f *fakeALTSAuthInfo) PeerServiceAccount() string                 { return f.peerServiceAccount }
228func (*fakeALTSAuthInfo) LocalServiceAccount() string                  { return "" }
229func (*fakeALTSAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { return nil }
230