1package policychecker_test 2 3import ( 4 "bytes" 5 "errors" 6 "io/ioutil" 7 "net/http" 8 "net/http/httptest" 9 10 "github.com/concourse/concourse/atc/api/accessor" 11 "github.com/concourse/concourse/atc/api/accessor/accessorfakes" 12 "github.com/concourse/concourse/atc/api/policychecker" 13 "github.com/concourse/concourse/atc/policy" 14 "github.com/concourse/concourse/atc/policy/policyfakes" 15 16 . "github.com/onsi/ginkgo" 17 . "github.com/onsi/gomega" 18) 19 20var _ = Describe("PolicyChecker", func() { 21 var ( 22 policyFilter policy.Filter 23 fakeAccess *accessorfakes.FakeAccess 24 fakeRequest *http.Request 25 result policy.PolicyCheckOutput 26 checkErr error 27 ) 28 29 BeforeEach(func() { 30 fakeAccess = new(accessorfakes.FakeAccess) 31 fakePolicyAgent = new(policyfakes.FakeAgent) 32 fakePolicyAgentFactory.NewAgentReturns(fakePolicyAgent, nil) 33 34 policyFilter = policy.Filter{ 35 ActionsToSkip: []string{}, 36 Actions: []string{}, 37 HttpMethods: []string{}, 38 } 39 }) 40 41 JustBeforeEach(func() { 42 policyCheck, err := policy.Initialize(testLogger, "some-cluster", "some-version", policyFilter) 43 Expect(err).ToNot(HaveOccurred()) 44 Expect(policyCheck).ToNot(BeNil()) 45 result, checkErr = policychecker.NewApiPolicyChecker(policyCheck).Check("some-action", fakeAccess, fakeRequest) 46 }) 47 48 Context("when system action", func() { 49 BeforeEach(func() { 50 fakeAccess.IsSystemReturns(true) 51 }) 52 It("should pass", func() { 53 Expect(checkErr).ToNot(HaveOccurred()) 54 Expect(result.Allowed).To(BeTrue()) 55 }) 56 It("Agent should not be called", func() { 57 Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0)) 58 }) 59 }) 60 61 Context("when not system action", func() { 62 BeforeEach(func() { 63 fakeAccess.IsSystemReturns(false) 64 }) 65 66 Context("when the action should be skipped", func() { 67 BeforeEach(func() { 68 policyFilter.ActionsToSkip = []string{"some-action"} 69 }) 70 It("should pass", func() { 71 Expect(checkErr).ToNot(HaveOccurred()) 72 Expect(result.Allowed).To(BeTrue()) 73 }) 74 It("Agent should not be called", func() { 75 Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0)) 76 }) 77 }) 78 79 Context("when the http method no need to check", func() { 80 BeforeEach(func() { 81 fakeRequest = httptest.NewRequest("GET", "/something", nil) 82 policyFilter.HttpMethods = []string{"PUT"} 83 }) 84 It("should pass", func() { 85 Expect(checkErr).ToNot(HaveOccurred()) 86 Expect(result.Allowed).To(BeTrue()) 87 }) 88 It("Agent should not be called", func() { 89 Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0)) 90 }) 91 }) 92 93 Context("when not in action list", func() { 94 BeforeEach(func() { 95 fakeRequest = httptest.NewRequest("PUT", "/something", nil) 96 policyFilter.HttpMethods = []string{} 97 policyFilter.Actions = []string{} 98 }) 99 It("should pass", func() { 100 Expect(checkErr).ToNot(HaveOccurred()) 101 Expect(result.Allowed).To(BeTrue()) 102 }) 103 It("Agent should not be called", func() { 104 Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0)) 105 }) 106 }) 107 108 Context("when the http method needs to check", func() { 109 BeforeEach(func() { 110 fakeRequest = httptest.NewRequest("PUT", "/something", nil) 111 policyFilter.HttpMethods = []string{"PUT"} 112 }) 113 114 Context("when request body is a bad json", func() { 115 BeforeEach(func() { 116 body := bytes.NewBuffer([]byte("hello")) 117 fakeRequest = httptest.NewRequest("PUT", "/something", body) 118 fakeRequest.Header.Add("Content-type", "application/json") 119 }) 120 121 It("should error", func() { 122 Expect(checkErr).To(HaveOccurred()) 123 Expect(checkErr.Error()).To(Equal(`invalid character 'h' looking for beginning of value`)) 124 Expect(result.Allowed).To(BeFalse()) 125 }) 126 It("Agent should not be called", func() { 127 Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0)) 128 }) 129 }) 130 131 Context("when request body is a bad yaml", func() { 132 BeforeEach(func() { 133 body := bytes.NewBuffer([]byte("a:\nb")) 134 fakeRequest = httptest.NewRequest("PUT", "/something", body) 135 fakeRequest.Header.Add("Content-type", "application/x-yaml") 136 }) 137 138 It("should error", func() { 139 Expect(checkErr).To(HaveOccurred()) 140 Expect(checkErr.Error()).To(Equal(`error converting YAML to JSON: yaml: line 3: could not find expected ':'`)) 141 Expect(result.Allowed).To(BeFalse()) 142 }) 143 It("Agent should not be called", func() { 144 Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0)) 145 }) 146 }) 147 148 Context("when every is ok", func() { 149 BeforeEach(func() { 150 fakeAccess.TeamRolesReturns(map[string][]string{ 151 "some-team": []string{"some-role"}, 152 }) 153 fakeAccess.ClaimsReturns(accessor.Claims{UserName: "some-user"}) 154 body := bytes.NewBuffer([]byte("a: b")) 155 fakeRequest = httptest.NewRequest("PUT", "/something?:team_name=some-team&:pipeline_name=some-pipeline", body) 156 fakeRequest.Header.Add("Content-type", "application/x-yaml") 157 fakeRequest.ParseForm() 158 }) 159 160 It("should not error", func() { 161 Expect(checkErr).ToNot(HaveOccurred()) 162 }) 163 It("Agent should be called", func() { 164 Expect(fakePolicyAgent.CheckCallCount()).To(Equal(1)) 165 }) 166 It("Agent should take correct input", func() { 167 Expect(fakePolicyAgent.CheckArgsForCall(0)).To(Equal(policy.PolicyCheckInput{ 168 Service: "concourse", 169 ClusterName: "some-cluster", 170 ClusterVersion: "some-version", 171 HttpMethod: "PUT", 172 Action: "some-action", 173 User: "some-user", 174 Team: "some-team", 175 Roles: []string{"some-role"}, 176 Pipeline: "some-pipeline", 177 Data: map[string]interface{}{"a": "b"}, 178 })) 179 }) 180 181 It("request body should still be readable", func() { 182 body, err := ioutil.ReadAll(fakeRequest.Body) 183 Expect(err).ToNot(HaveOccurred()) 184 Expect(body).To(Equal([]byte("a: b"))) 185 }) 186 187 Context("when Agent says pass", func() { 188 BeforeEach(func() { 189 fakePolicyAgent.CheckReturns(policy.PassedPolicyCheck(), nil) 190 }) 191 192 It("it should pass", func() { 193 Expect(checkErr).ToNot(HaveOccurred()) 194 Expect(result.Allowed).To(BeTrue()) 195 }) 196 }) 197 198 Context("when Agent says not-pass", func() { 199 BeforeEach(func() { 200 fakePolicyAgent.CheckReturns(policy.PolicyCheckOutput{ 201 Allowed: false, 202 Reasons: []string{"a policy says you can't do that"}, 203 }, nil) 204 }) 205 206 It("should not pass", func() { 207 Expect(checkErr).ToNot(HaveOccurred()) 208 Expect(result.Allowed).To(BeFalse()) 209 Expect(result.Reasons).To(ConsistOf("a policy says you can't do that")) 210 }) 211 }) 212 213 Context("when Agent says error", func() { 214 BeforeEach(func() { 215 fakePolicyAgent.CheckReturns(policy.FailedPolicyCheck(), errors.New("some-error")) 216 }) 217 218 It("should not pass", func() { 219 Expect(checkErr).To(HaveOccurred()) 220 Expect(checkErr.Error()).To(Equal("some-error")) 221 Expect(result.Allowed).To(BeFalse()) 222 }) 223 }) 224 }) 225 }) 226 }) 227}) 228