1// +build linux windows
2
3/*
4 *
5 * Copyright 2018 gRPC authors.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *     http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 */
20
21package alts
22
23import (
24	"context"
25	"strings"
26	"testing"
27	"time"
28
29	"google.golang.org/grpc/codes"
30	altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
31	"google.golang.org/grpc/peer"
32	"google.golang.org/grpc/status"
33)
34
35const (
36	testServiceAccount1 = "service_account1"
37	testServiceAccount2 = "service_account2"
38	testServiceAccount3 = "service_account3"
39
40	defaultTestTimeout = 10 * time.Second
41)
42
43func (s) TestAuthInfoFromContext(t *testing.T) {
44	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
45	defer cancel()
46	altsAuthInfo := &fakeALTSAuthInfo{}
47	p := &peer.Peer{
48		AuthInfo: altsAuthInfo,
49	}
50	for _, tc := range []struct {
51		desc    string
52		ctx     context.Context
53		success bool
54		out     AuthInfo
55	}{
56		{
57			"working case",
58			peer.NewContext(ctx, p),
59			true,
60			altsAuthInfo,
61		},
62	} {
63		authInfo, err := AuthInfoFromContext(tc.ctx)
64		if got, want := (err == nil), tc.success; got != want {
65			t.Errorf("%v: AuthInfoFromContext(_)=(err=nil)=%v, want %v", tc.desc, got, want)
66		}
67		if got, want := authInfo, tc.out; got != want {
68			t.Errorf("%v:, AuthInfoFromContext(_)=(%v, _), want (%v, _)", tc.desc, got, want)
69		}
70	}
71}
72
73func (s) TestAuthInfoFromPeer(t *testing.T) {
74	altsAuthInfo := &fakeALTSAuthInfo{}
75	p := &peer.Peer{
76		AuthInfo: altsAuthInfo,
77	}
78	for _, tc := range []struct {
79		desc    string
80		p       *peer.Peer
81		success bool
82		out     AuthInfo
83	}{
84		{
85			"working case",
86			p,
87			true,
88			altsAuthInfo,
89		},
90	} {
91		authInfo, err := AuthInfoFromPeer(tc.p)
92		if got, want := (err == nil), tc.success; got != want {
93			t.Errorf("%v: AuthInfoFromPeer(_)=(err=nil)=%v, want %v", tc.desc, got, want)
94		}
95		if got, want := authInfo, tc.out; got != want {
96			t.Errorf("%v:, AuthInfoFromPeer(_)=(%v, _), want (%v, _)", tc.desc, got, want)
97		}
98	}
99}
100
101func (s) TestClientAuthorizationCheck(t *testing.T) {
102	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
103	defer cancel()
104	altsAuthInfo := &fakeALTSAuthInfo{testServiceAccount1}
105	p := &peer.Peer{
106		AuthInfo: altsAuthInfo,
107	}
108	for _, tc := range []struct {
109		desc                    string
110		ctx                     context.Context
111		expectedServiceAccounts []string
112		success                 bool
113		code                    codes.Code
114	}{
115		{
116			"working case",
117			peer.NewContext(ctx, p),
118			[]string{testServiceAccount1, testServiceAccount2},
119			true,
120			codes.OK, // err is nil, code is OK.
121		},
122		{
123			"working case (case ignored)",
124			peer.NewContext(ctx, p),
125			[]string{strings.ToUpper(testServiceAccount1), testServiceAccount2},
126			true,
127			codes.OK, // err is nil, code is OK.
128		},
129		{
130			"context does not have AuthInfo",
131			ctx,
132			[]string{testServiceAccount1, testServiceAccount2},
133			false,
134			codes.PermissionDenied,
135		},
136		{
137			"unauthorized client",
138			peer.NewContext(ctx, p),
139			[]string{testServiceAccount2, testServiceAccount3},
140			false,
141			codes.PermissionDenied,
142		},
143	} {
144		err := ClientAuthorizationCheck(tc.ctx, tc.expectedServiceAccounts)
145		if got, want := (err == nil), tc.success; got != want {
146			t.Errorf("%v: ClientAuthorizationCheck(_, %v)=(err=nil)=%v, want %v", tc.desc, tc.expectedServiceAccounts, got, want)
147		}
148		if got, want := status.Code(err), tc.code; got != want {
149			t.Errorf("%v: ClientAuthorizationCheck(_, %v).Code=%v, want %v", tc.desc, tc.expectedServiceAccounts, got, want)
150		}
151	}
152}
153
154type fakeALTSAuthInfo struct {
155	peerServiceAccount string
156}
157
158func (*fakeALTSAuthInfo) AuthType() string            { return "" }
159func (*fakeALTSAuthInfo) ApplicationProtocol() string { return "" }
160func (*fakeALTSAuthInfo) RecordProtocol() string      { return "" }
161func (*fakeALTSAuthInfo) SecurityLevel() altspb.SecurityLevel {
162	return altspb.SecurityLevel_SECURITY_NONE
163}
164func (f *fakeALTSAuthInfo) PeerServiceAccount() string                 { return f.peerServiceAccount }
165func (*fakeALTSAuthInfo) LocalServiceAccount() string                  { return "" }
166func (*fakeALTSAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { return nil }
167