1//go:build linux || windows
2// +build linux windows
3
4/*
5 *
6 * Copyright 2018 gRPC authors.
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 *     http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 */
21
22package alts
23
24import (
25	"context"
26	"strings"
27	"testing"
28	"time"
29
30	"google.golang.org/grpc/codes"
31	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
32	"google.golang.org/grpc/peer"
33	"google.golang.org/grpc/status"
34)
35
36const (
37	testServiceAccount1 = "service_account1"
38	testServiceAccount2 = "service_account2"
39	testServiceAccount3 = "service_account3"
40
41	defaultTestTimeout = 10 * time.Second
42)
43
44func (s) TestAuthInfoFromContext(t *testing.T) {
45	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
46	defer cancel()
47	altsAuthInfo := &fakeALTSAuthInfo{}
48	p := &peer.Peer{
49		AuthInfo: altsAuthInfo,
50	}
51	for _, tc := range []struct {
52		desc    string
53		ctx     context.Context
54		success bool
55		out     AuthInfo
56	}{
57		{
58			"working case",
59			peer.NewContext(ctx, p),
60			true,
61			altsAuthInfo,
62		},
63	} {
64		authInfo, err := AuthInfoFromContext(tc.ctx)
65		if got, want := (err == nil), tc.success; got != want {
66			t.Errorf("%v: AuthInfoFromContext(_)=(err=nil)=%v, want %v", tc.desc, got, want)
67		}
68		if got, want := authInfo, tc.out; got != want {
69			t.Errorf("%v:, AuthInfoFromContext(_)=(%v, _), want (%v, _)", tc.desc, got, want)
70		}
71	}
72}
73
74func (s) TestAuthInfoFromPeer(t *testing.T) {
75	altsAuthInfo := &fakeALTSAuthInfo{}
76	p := &peer.Peer{
77		AuthInfo: altsAuthInfo,
78	}
79	for _, tc := range []struct {
80		desc    string
81		p       *peer.Peer
82		success bool
83		out     AuthInfo
84	}{
85		{
86			"working case",
87			p,
88			true,
89			altsAuthInfo,
90		},
91	} {
92		authInfo, err := AuthInfoFromPeer(tc.p)
93		if got, want := (err == nil), tc.success; got != want {
94			t.Errorf("%v: AuthInfoFromPeer(_)=(err=nil)=%v, want %v", tc.desc, got, want)
95		}
96		if got, want := authInfo, tc.out; got != want {
97			t.Errorf("%v:, AuthInfoFromPeer(_)=(%v, _), want (%v, _)", tc.desc, got, want)
98		}
99	}
100}
101
102func (s) TestClientAuthorizationCheck(t *testing.T) {
103	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
104	defer cancel()
105	altsAuthInfo := &fakeALTSAuthInfo{testServiceAccount1}
106	p := &peer.Peer{
107		AuthInfo: altsAuthInfo,
108	}
109	for _, tc := range []struct {
110		desc                    string
111		ctx                     context.Context
112		expectedServiceAccounts []string
113		success                 bool
114		code                    codes.Code
115	}{
116		{
117			"working case",
118			peer.NewContext(ctx, p),
119			[]string{testServiceAccount1, testServiceAccount2},
120			true,
121			codes.OK, // err is nil, code is OK.
122		},
123		{
124			"working case (case ignored)",
125			peer.NewContext(ctx, p),
126			[]string{strings.ToUpper(testServiceAccount1), testServiceAccount2},
127			true,
128			codes.OK, // err is nil, code is OK.
129		},
130		{
131			"context does not have AuthInfo",
132			ctx,
133			[]string{testServiceAccount1, testServiceAccount2},
134			false,
135			codes.PermissionDenied,
136		},
137		{
138			"unauthorized client",
139			peer.NewContext(ctx, p),
140			[]string{testServiceAccount2, testServiceAccount3},
141			false,
142			codes.PermissionDenied,
143		},
144	} {
145		err := ClientAuthorizationCheck(tc.ctx, tc.expectedServiceAccounts)
146		if got, want := (err == nil), tc.success; got != want {
147			t.Errorf("%v: ClientAuthorizationCheck(_, %v)=(err=nil)=%v, want %v", tc.desc, tc.expectedServiceAccounts, got, want)
148		}
149		if got, want := status.Code(err), tc.code; got != want {
150			t.Errorf("%v: ClientAuthorizationCheck(_, %v).Code=%v, want %v", tc.desc, tc.expectedServiceAccounts, got, want)
151		}
152	}
153}
154
155type fakeALTSAuthInfo struct {
156	peerServiceAccount string
157}
158
159func (*fakeALTSAuthInfo) AuthType() string            { return "" }
160func (*fakeALTSAuthInfo) ApplicationProtocol() string { return "" }
161func (*fakeALTSAuthInfo) RecordProtocol() string      { return "" }
162func (*fakeALTSAuthInfo) SecurityLevel() altspb.SecurityLevel {
163	return altspb.SecurityLevel_SECURITY_NONE
164}
165func (f *fakeALTSAuthInfo) PeerServiceAccount() string                 { return f.peerServiceAccount }
166func (*fakeALTSAuthInfo) LocalServiceAccount() string                  { return "" }
167func (*fakeALTSAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { return nil }
168