1package policychecker_test
2
3import (
4	"bytes"
5	"errors"
6	"io/ioutil"
7	"net/http"
8	"net/http/httptest"
9
10	"github.com/concourse/concourse/atc/api/accessor"
11	"github.com/concourse/concourse/atc/api/accessor/accessorfakes"
12	"github.com/concourse/concourse/atc/api/policychecker"
13	"github.com/concourse/concourse/atc/policy"
14	"github.com/concourse/concourse/atc/policy/policyfakes"
15
16	. "github.com/onsi/ginkgo"
17	. "github.com/onsi/gomega"
18)
19
20var _ = Describe("PolicyChecker", func() {
21	var (
22		policyFilter policy.Filter
23		fakeAccess   *accessorfakes.FakeAccess
24		fakeRequest  *http.Request
25		result       policy.PolicyCheckOutput
26		checkErr     error
27	)
28
29	BeforeEach(func() {
30		fakeAccess = new(accessorfakes.FakeAccess)
31		fakePolicyAgent = new(policyfakes.FakeAgent)
32		fakePolicyAgentFactory.NewAgentReturns(fakePolicyAgent, nil)
33
34		policyFilter = policy.Filter{
35			ActionsToSkip: []string{},
36			Actions:       []string{},
37			HttpMethods:   []string{},
38		}
39	})
40
41	JustBeforeEach(func() {
42		policyCheck, err := policy.Initialize(testLogger, "some-cluster", "some-version", policyFilter)
43		Expect(err).ToNot(HaveOccurred())
44		Expect(policyCheck).ToNot(BeNil())
45		result, checkErr = policychecker.NewApiPolicyChecker(policyCheck).Check("some-action", fakeAccess, fakeRequest)
46	})
47
48	Context("when system action", func() {
49		BeforeEach(func() {
50			fakeAccess.IsSystemReturns(true)
51		})
52		It("should pass", func() {
53			Expect(checkErr).ToNot(HaveOccurred())
54			Expect(result.Allowed).To(BeTrue())
55		})
56		It("Agent should not be called", func() {
57			Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
58		})
59	})
60
61	Context("when not system action", func() {
62		BeforeEach(func() {
63			fakeAccess.IsSystemReturns(false)
64		})
65
66		Context("when the action should be skipped", func() {
67			BeforeEach(func() {
68				policyFilter.ActionsToSkip = []string{"some-action"}
69			})
70			It("should pass", func() {
71				Expect(checkErr).ToNot(HaveOccurred())
72				Expect(result.Allowed).To(BeTrue())
73			})
74			It("Agent should not be called", func() {
75				Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
76			})
77		})
78
79		Context("when the http method no need to check", func() {
80			BeforeEach(func() {
81				fakeRequest = httptest.NewRequest("GET", "/something", nil)
82				policyFilter.HttpMethods = []string{"PUT"}
83			})
84			It("should pass", func() {
85				Expect(checkErr).ToNot(HaveOccurred())
86				Expect(result.Allowed).To(BeTrue())
87			})
88			It("Agent should not be called", func() {
89				Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
90			})
91		})
92
93		Context("when not in action list", func() {
94			BeforeEach(func() {
95				fakeRequest = httptest.NewRequest("PUT", "/something", nil)
96				policyFilter.HttpMethods = []string{}
97				policyFilter.Actions = []string{}
98			})
99			It("should pass", func() {
100				Expect(checkErr).ToNot(HaveOccurred())
101				Expect(result.Allowed).To(BeTrue())
102			})
103			It("Agent should not be called", func() {
104				Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
105			})
106		})
107
108		Context("when the http method needs to check", func() {
109			BeforeEach(func() {
110				fakeRequest = httptest.NewRequest("PUT", "/something", nil)
111				policyFilter.HttpMethods = []string{"PUT"}
112			})
113
114			Context("when request body is a bad json", func() {
115				BeforeEach(func() {
116					body := bytes.NewBuffer([]byte("hello"))
117					fakeRequest = httptest.NewRequest("PUT", "/something", body)
118					fakeRequest.Header.Add("Content-type", "application/json")
119				})
120
121				It("should error", func() {
122					Expect(checkErr).To(HaveOccurred())
123					Expect(checkErr.Error()).To(Equal(`invalid character 'h' looking for beginning of value`))
124					Expect(result.Allowed).To(BeFalse())
125				})
126				It("Agent should not be called", func() {
127					Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
128				})
129			})
130
131			Context("when request body is a bad yaml", func() {
132				BeforeEach(func() {
133					body := bytes.NewBuffer([]byte("a:\nb"))
134					fakeRequest = httptest.NewRequest("PUT", "/something", body)
135					fakeRequest.Header.Add("Content-type", "application/x-yaml")
136				})
137
138				It("should error", func() {
139					Expect(checkErr).To(HaveOccurred())
140					Expect(checkErr.Error()).To(Equal(`error converting YAML to JSON: yaml: line 3: could not find expected ':'`))
141					Expect(result.Allowed).To(BeFalse())
142				})
143				It("Agent should not be called", func() {
144					Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
145				})
146			})
147
148			Context("when every is ok", func() {
149				BeforeEach(func() {
150					fakeAccess.TeamRolesReturns(map[string][]string{
151						"some-team": []string{"some-role"},
152					})
153					fakeAccess.ClaimsReturns(accessor.Claims{UserName: "some-user"})
154					body := bytes.NewBuffer([]byte("a: b"))
155					fakeRequest = httptest.NewRequest("PUT", "/something?:team_name=some-team&:pipeline_name=some-pipeline", body)
156					fakeRequest.Header.Add("Content-type", "application/x-yaml")
157					fakeRequest.ParseForm()
158				})
159
160				It("should not error", func() {
161					Expect(checkErr).ToNot(HaveOccurred())
162				})
163				It("Agent should be called", func() {
164					Expect(fakePolicyAgent.CheckCallCount()).To(Equal(1))
165				})
166				It("Agent should take correct input", func() {
167					Expect(fakePolicyAgent.CheckArgsForCall(0)).To(Equal(policy.PolicyCheckInput{
168						Service:        "concourse",
169						ClusterName:    "some-cluster",
170						ClusterVersion: "some-version",
171						HttpMethod:     "PUT",
172						Action:         "some-action",
173						User:           "some-user",
174						Team:           "some-team",
175						Roles:          []string{"some-role"},
176						Pipeline:       "some-pipeline",
177						Data:           map[string]interface{}{"a": "b"},
178					}))
179				})
180
181				It("request body should still be readable", func() {
182					body, err := ioutil.ReadAll(fakeRequest.Body)
183					Expect(err).ToNot(HaveOccurred())
184					Expect(body).To(Equal([]byte("a: b")))
185				})
186
187				Context("when Agent says pass", func() {
188					BeforeEach(func() {
189						fakePolicyAgent.CheckReturns(policy.PassedPolicyCheck(), nil)
190					})
191
192					It("it should pass", func() {
193						Expect(checkErr).ToNot(HaveOccurred())
194						Expect(result.Allowed).To(BeTrue())
195					})
196				})
197
198				Context("when Agent says not-pass", func() {
199					BeforeEach(func() {
200						fakePolicyAgent.CheckReturns(policy.PolicyCheckOutput{
201							Allowed: false,
202							Reasons: []string{"a policy says you can't do that"},
203						}, nil)
204					})
205
206					It("should not pass", func() {
207						Expect(checkErr).ToNot(HaveOccurred())
208						Expect(result.Allowed).To(BeFalse())
209						Expect(result.Reasons).To(ConsistOf("a policy says you can't do that"))
210					})
211				})
212
213				Context("when Agent says error", func() {
214					BeforeEach(func() {
215						fakePolicyAgent.CheckReturns(policy.FailedPolicyCheck(), errors.New("some-error"))
216					})
217
218					It("should not pass", func() {
219						Expect(checkErr).To(HaveOccurred())
220						Expect(checkErr.Error()).To(Equal("some-error"))
221						Expect(result.Allowed).To(BeFalse())
222					})
223				})
224			})
225		})
226	})
227})
228