1# Copyright 2016 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import absolute_import
16
17import datetime
18import httplib
19import unittest2
20from expects import be_none, equal, expect, raise_error
21
22from apitools.base.py import encoding
23
24from google.api.control import caches, label_descriptor, timestamp
25from google.api.control import check_request, messages, metric_value
26
27
28class TestSign(unittest2.TestCase):
29
30    def setUp(self):
31        op = messages.Operation(
32            consumerId=_TEST_CONSUMER_ID,
33            operationName=_TEST_OP_NAME
34        )
35        self.test_check_request = messages.CheckRequest(operation=op)
36        self.test_op = op
37
38    def test_should_fail_if_operation_is_not_set(self):
39        testf = lambda: check_request.sign(messages.CheckRequest())
40        expect(testf).to(raise_error(ValueError))
41
42    def test_should_fail_on_invalid_input(self):
43        testf = lambda: check_request.sign(None)
44        expect(testf).to(raise_error(ValueError))
45        testf = lambda: check_request.sign(object())
46        expect(testf).to(raise_error(ValueError))
47
48    def test_should_fail_if_operation_has_no_operation_name(self):
49        op = messages.Operation(consumerId=_TEST_CONSUMER_ID)
50        testf = lambda: check_request.sign(messages.CheckRequest(operation=op))
51        expect(testf).to(raise_error(ValueError))
52
53    def test_should_fail_if_operation_has_no_consumer_id(self):
54        op = messages.Operation(operationName=_TEST_OP_NAME)
55        testf = lambda: check_request.sign(messages.CheckRequest(operation=op))
56        expect(testf).to(raise_error(ValueError))
57
58    def test_should_sign_a_valid_check_request(self):
59        check_request.sign(self.test_check_request)
60
61    def test_should_change_signature_when_labels_are_added(self):
62        without_labels = check_request.sign(self.test_check_request)
63        self.test_op.labels = encoding.PyValueToMessage(
64            messages.Operation.LabelsValue, {
65                'key1': 'value1',
66                'key2': 'value2'
67            })
68        with_labels = check_request.sign(self.test_check_request)
69        expect(with_labels).not_to(equal(without_labels))
70
71    def test_should_change_signature_when_metric_values_are_added(self):
72        without_mvs = check_request.sign(self.test_check_request)
73        self.test_op.metricValueSets = [
74            messages.MetricValueSet(
75                metricName='a_float',
76                metricValues=[
77                    metric_value.create(
78                        labels={
79                            'key1': 'value1',
80                            'key2': 'value2'
81                        },
82                        doubleValue=1.1,
83                    ),
84                ]
85            )
86        ]
87        with_mvs = check_request.sign(self.test_check_request)
88        expect(with_mvs).not_to(equal(without_mvs))
89
90    def test_should_change_signature_quota_properties_are_specified(self):
91        without_qprops = check_request.sign(self.test_check_request)
92        self.test_op.quotaProperties = messages.QuotaProperties()
93        with_qprops = check_request.sign(self.test_check_request)
94        expect(with_qprops).not_to(equal(without_qprops))
95
96
97class TestAggregatorCheck(unittest2.TestCase):
98    SERVICE_NAME = 'service.check'
99    FAKE_OPERATION_ID = 'service.general.check'
100
101    def setUp(self):
102        self.timer = _DateTimeTimer()
103        self.agg = check_request.Aggregator(
104            self.SERVICE_NAME, caches.CheckOptions())
105
106    def test_should_fail_if_req_is_bad(self):
107        testf = lambda: self.agg.check(object())
108        expect(testf).to(raise_error(ValueError))
109        testf = lambda: self.agg.check(None)
110        expect(testf).to(raise_error(ValueError))
111
112    def test_should_fail_if_service_name_does_not_match(self):
113        req = _make_test_request(self.SERVICE_NAME + '-will-not-match')
114        testf = lambda: self.agg.check(req)
115        expect(testf).to(raise_error(ValueError))
116
117    def test_should_fail_if_check_request_is_missing(self):
118        req = messages.ServicecontrolServicesCheckRequest(
119            serviceName=self.SERVICE_NAME)
120        testf = lambda: self.agg.check(req)
121        expect(testf).to(raise_error(ValueError))
122
123    def test_should_fail_if_operation_is_missing(self):
124        req = messages.ServicecontrolServicesCheckRequest(
125            serviceName=self.SERVICE_NAME,
126            checkRequest=messages.CheckRequest())
127        testf = lambda: self.agg.check(req)
128        expect(testf).to(raise_error(ValueError))
129
130    def test_should_return_none_initially_as_req_is_not_cached(self):
131        req = _make_test_request(self.SERVICE_NAME)
132        fake_response = messages.CheckResponse(
133            operationId=self.FAKE_OPERATION_ID)
134        agg = self.agg
135        expect(agg.check(req)).to(be_none)
136
137
138class TestAggregatorThatCannotCache(unittest2.TestCase):
139    SERVICE_NAME = 'service.no_cache'
140    FAKE_OPERATION_ID = 'service.no_cache.op_id'
141
142    def setUp(self):
143        # -ve num_entries means no cache is present
144        self.agg = check_request.Aggregator(
145            self.SERVICE_NAME,
146            caches.CheckOptions(num_entries=-1))
147
148    def test_should_not_cache_responses(self):
149        req = _make_test_request(self.SERVICE_NAME)
150        fake_response = messages.CheckResponse(
151            operationId=self.FAKE_OPERATION_ID)
152        agg = self.agg
153        expect(agg.check(req)).to(be_none)
154        agg.add_response(req, fake_response)
155        expect(agg.check(req)).to(be_none)
156        agg.clear()
157        expect(agg.check(req)).to(be_none)
158
159    def test_should_have_empty_flush_response(self):
160        expect(len(self.agg.flush())).to(equal(0))
161
162    def test_should_have_none_as_flush_interval(self):
163        expect(self.agg.flush_interval).to(be_none)
164
165
166
167class _DateTimeTimer(object):
168    def __init__(self, auto=False):
169        self.auto = auto
170        self.time = datetime.datetime.utcfromtimestamp(0)
171
172    def __call__(self):
173        if self.auto:
174            self.tick()
175        return self.time
176
177    def tick(self):
178        self.time += datetime.timedelta(seconds=1)
179
180
181class TestCachingAggregator(unittest2.TestCase):
182    SERVICE_NAME = 'service.with_cache'
183    FAKE_OPERATION_ID = 'service.with_cache.op_id'
184
185    def setUp(self):
186        self.timer = _DateTimeTimer()
187        self.expiration = datetime.timedelta(seconds=2)
188        options = caches.CheckOptions(
189            flush_interval=datetime.timedelta(seconds=1),
190            expiration=self.expiration)
191        self.agg = check_request.Aggregator(
192            self.SERVICE_NAME, options, timer=self.timer)
193
194    def test_should_have_expiration_as_flush_interval(self):
195        expect(self.agg.flush_interval).to(equal(self.expiration))
196
197    def test_should_cache_responses(self):
198        req = _make_test_request(self.SERVICE_NAME)
199        fake_response = messages.CheckResponse(
200            operationId=self.FAKE_OPERATION_ID)
201        agg = self.agg
202        expect(agg.check(req)).to(be_none)
203        agg.add_response(req, fake_response)
204        expect(agg.check(req)).to(equal(fake_response))
205
206    def test_should_not_cache_requests_with_important_operations(self):
207        req = _make_test_request(
208            self.SERVICE_NAME,
209            importance=messages.Operation.ImportanceValueValuesEnum.HIGH)
210        fake_response = messages.CheckResponse(
211            operationId=self.FAKE_OPERATION_ID)
212        agg = self.agg
213        expect(agg.check(req)).to(be_none)
214        agg.add_response(req, fake_response)
215        expect(agg.check(req)).to(be_none)
216
217    def test_signals_a_resend_on_1st_call_after_flush_interval(self):
218        req = _make_test_request(self.SERVICE_NAME)
219        fake_response = messages.CheckResponse(
220            operationId=self.FAKE_OPERATION_ID)
221        agg = self.agg
222        expect(agg.check(req)).to(be_none)
223        agg.add_response(req, fake_response)
224        expect(agg.check(req)).to(equal(fake_response))
225
226        # Now flush interval is reached, but not the response expiry
227        self.timer.tick() # now past the flush_interval
228        expect(agg.check(req)).to(be_none)  # none signals the resend
229
230        # Until expiry, the response will continue to be returned
231        expect(agg.check(req)).to(equal(fake_response))
232        expect(agg.check(req)).to(equal(fake_response))
233
234        # Once expired the cached response is no longer returned
235        # expire
236        self.timer.tick()
237        self.timer.tick() # now expired
238        expect(agg.check(req)).to(be_none)
239        expect(agg.check(req)).to(be_none)  # 2nd check is None as well
240
241    def test_signals_resend_on_1st_call_after_flush_interval_with_errors(self):
242        req = _make_test_request(self.SERVICE_NAME)
243        failure_code = messages.CheckError.CodeValueValuesEnum.NOT_FOUND
244        fake_response = messages.CheckResponse(
245            operationId=self.FAKE_OPERATION_ID, checkErrors=[
246                messages.CheckError(code=failure_code)
247            ])
248        agg = self.agg
249        expect(agg.check(req)).to(be_none)
250        agg.add_response(req, fake_response)
251        expect(agg.check(req)).to(equal(fake_response))
252
253        # Now flush interval is reached, but not the response expiry
254        self.timer.tick() # now past the flush_interval
255        expect(agg.check(req)).to(be_none)  # first response is null
256
257        # until expiry, the response will continue to be returned
258        expect(agg.check(req)).to(equal(fake_response))
259        expect(agg.check(req)).to(equal(fake_response))
260
261        # expire
262        self.timer.tick()
263        self.timer.tick() # now expired
264        expect(agg.check(req)).to(be_none)
265        expect(agg.check(req)).to(be_none) # 2nd check is None as well
266
267    def test_should_extend_expiration_on_receipt_of_a_response(self):
268        req = _make_test_request(self.SERVICE_NAME)
269        fake_response = messages.CheckResponse(
270            operationId=self.FAKE_OPERATION_ID
271        )
272        agg = self.agg
273        expect(agg.check(req)).to(be_none)
274        agg.add_response(req, fake_response)
275        expect(agg.check(req)).to(equal(fake_response))
276
277        # Now flush interval is reached, but not the response expiry
278        self.timer.tick() # now past the flush_interval
279        expect(agg.check(req)).to(be_none)  # first response is null
280
281        # until expiry, the response will continue to be returned
282        expect(agg.check(req)).to(equal(fake_response))
283        expect(agg.check(req)).to(equal(fake_response))
284
285        # add a response as the request expires
286        self.timer.tick()
287        agg.add_response(req, fake_response)
288        # it would have expired, but because the response was added it does not
289        expect(agg.check(req)).to(equal(fake_response))
290        expect(agg.check(req)).to(equal(fake_response))
291        self.timer.tick() # now past the flush interval again
292        expect(agg.check(req)).to(be_none)
293        expect(agg.check(req)).to(equal(fake_response))
294
295    def test_does_not_flush_request_that_has_not_been_updated(self):
296        req = _make_test_request(self.SERVICE_NAME)
297        fake_response = messages.CheckResponse(
298            operationId=self.FAKE_OPERATION_ID
299        )
300        agg = self.agg
301        expect(agg.check(req)).to(be_none)
302        agg.add_response(req, fake_response)
303        self.timer.tick() # now past the flush_interval
304        expect(len(agg.flush())).to(equal(0)) # nothing expired
305        self.timer.tick() # now past expiry
306        self.timer.tick() # now past expiry
307        expect(agg.check(req)).to(be_none)  # confirm nothing in cache
308        expect(agg.check(req)).to(be_none)  # confirm nothing in cache
309        expect(len(agg.flush())).to(equal(0)) # no cached check request
310
311    def test_does_flush_requests_that_have_been_updated(self):
312        req = _make_test_request(self.SERVICE_NAME)
313        fake_response = messages.CheckResponse(
314            operationId=self.FAKE_OPERATION_ID
315        )
316        agg = self.agg
317        expect(agg.check(req)).to(be_none)
318        agg.add_response(req, fake_response)
319        expect(agg.check(req)).to(equal(fake_response))
320        self.timer.tick() # now past the flush_interval
321        expect(len(agg.flush())).to(equal(0)) # nothing expired
322        self.timer.tick() # now past expiry
323        self.timer.tick() # now past expiry
324        expect(len(agg.flush())).to(equal(1)) # got the cached check request
325
326    def test_should_clear_requests(self):
327        req = _make_test_request(self.SERVICE_NAME)
328        fake_response = messages.CheckResponse(
329            operationId=self.FAKE_OPERATION_ID
330        )
331        agg = self.agg
332        expect(agg.check(req)).to(be_none)
333        agg.add_response(req, fake_response)
334        expect(agg.check(req)).to(equal(fake_response))
335        agg.clear()
336        expect(agg.check(req)).to(be_none)
337        expect(len(agg.flush())).to(equal(0))
338
339
340_TEST_CONSUMER_ID = 'testConsumerID'
341_TEST_OP_NAME = 'testOperationName'
342
343
344def _make_test_request(service_name, importance=None):
345    if importance is None:
346        importance = messages.Operation.ImportanceValueValuesEnum.LOW
347    op = messages.Operation(
348        consumerId=_TEST_CONSUMER_ID,
349        operationName=_TEST_OP_NAME,
350        importance=importance
351    )
352    check_request = messages.CheckRequest(operation=op)
353    return messages.ServicecontrolServicesCheckRequest(
354        serviceName=service_name,
355        checkRequest=check_request)
356
357
358_WANTED_USER_AGENT = label_descriptor.USER_AGENT
359_START_OF_EPOCH = timestamp.to_rfc3339(datetime.datetime(1970, 1, 1, 0, 0, 0))
360_TEST_SERVICE_NAME = 'a_service_name'
361_INFO_TESTS = [
362    (check_request.Info(
363        operation_id='an_op_id',
364        operation_name='an_op_name',
365        referer='a_referer',
366        service_name=_TEST_SERVICE_NAME),
367     messages.Operation(
368         importance=messages.Operation.ImportanceValueValuesEnum.LOW,
369         labels = encoding.PyValueToMessage(
370             messages.Operation.LabelsValue, {
371                 'servicecontrol.googleapis.com/user_agent': _WANTED_USER_AGENT,
372                 'servicecontrol.googleapis.com/referer': 'a_referer'
373             }),
374         operationId='an_op_id',
375         operationName='an_op_name',
376         startTime=_START_OF_EPOCH,
377         endTime=_START_OF_EPOCH)),
378    (check_request.Info(
379        api_key='an_api_key',
380        api_key_valid=True,
381        operation_id='an_op_id',
382        operation_name='an_op_name',
383        referer='a_referer',
384        service_name=_TEST_SERVICE_NAME),
385     messages.Operation(
386         importance=messages.Operation.ImportanceValueValuesEnum.LOW,
387         consumerId='api_key:an_api_key',
388         labels = encoding.PyValueToMessage(
389             messages.Operation.LabelsValue, {
390                 'servicecontrol.googleapis.com/user_agent': _WANTED_USER_AGENT,
391                 'servicecontrol.googleapis.com/referer': 'a_referer'
392             }),
393         operationId='an_op_id',
394         operationName='an_op_name',
395         startTime=_START_OF_EPOCH,
396         endTime=_START_OF_EPOCH)),
397    (check_request.Info(
398        api_key='an_api_key',
399        api_key_valid=False,
400        client_ip='127.0.0.1',
401        consumer_project_id='project_id',
402        operation_id='an_op_id',
403        operation_name='an_op_name',
404        referer='a_referer',
405        service_name=_TEST_SERVICE_NAME),
406     messages.Operation(
407         importance=messages.Operation.ImportanceValueValuesEnum.LOW,
408         consumerId='project:project_id',
409         labels = encoding.PyValueToMessage(
410             messages.Operation.LabelsValue, {
411                 'servicecontrol.googleapis.com/caller_ip': '127.0.0.1',
412                 'servicecontrol.googleapis.com/user_agent': _WANTED_USER_AGENT,
413                 'servicecontrol.googleapis.com/referer': 'a_referer'
414             }),
415         operationId='an_op_id',
416         operationName='an_op_name',
417         startTime=_START_OF_EPOCH,
418         endTime=_START_OF_EPOCH)),
419]
420_INCOMPLETE_INFO_TESTS = [
421    check_request.Info(
422        operation_name='an_op_name',
423        service_name=_TEST_SERVICE_NAME),
424    check_request.Info(
425        operation_id='an_op_id',
426        service_name=_TEST_SERVICE_NAME),
427    check_request.Info(
428        operation_id='an_op_id',
429        operation_name='an_op_name')
430]
431
432
433class TestInfo(unittest2.TestCase):
434
435    def test_should_construct_with_no_args(self):
436        expect(check_request.Info()).not_to(be_none)
437
438    def test_should_convert_using_as_check_request(self):
439        timer = _DateTimeTimer()
440        for info, want in _INFO_TESTS:
441            got = info.as_check_request(timer=timer)
442            expect(got.checkRequest.operation).to(equal(want))
443            expect(got.serviceName).to(equal(_TEST_SERVICE_NAME))
444
445    def test_should_fail_as_check_request_on_incomplete_info(self):
446        timer = _DateTimeTimer()
447        for info in _INCOMPLETE_INFO_TESTS:
448            testf = lambda: info.as_check_request(timer=timer)
449            expect(testf).to(raise_error(ValueError))
450
451
452class TestConvertResponse(unittest2.TestCase):
453    PROJECT_ID = 'test_convert_response'
454
455    def test_should_be_ok_with_no_errors(self):
456        code, message, _ = check_request.convert_response(
457            messages.CheckResponse(), self.PROJECT_ID)
458        expect(code).to(equal(httplib.OK))
459        expect(message).to(equal(''))
460
461    def test_should_include_project_id_in_error_text_when_needed(self):
462        resp = messages.CheckResponse(
463            checkErrors = [
464                messages.CheckError(
465                    code=messages.CheckError.CodeValueValuesEnum.PROJECT_DELETED)
466            ]
467        )
468        code, got, _ = check_request.convert_response(resp, self.PROJECT_ID)
469        want = 'Project %s has been deleted' % (self.PROJECT_ID,)
470        expect(code).to(equal(httplib.FORBIDDEN))
471        expect(got).to(equal(want))
472
473    def test_should_include_detail_in_error_text_when_needed(self):
474        detail = 'details, details, details'
475        resp = messages.CheckResponse(
476            checkErrors = [
477                messages.CheckError(
478                    code=messages.CheckError.CodeValueValuesEnum.IP_ADDRESS_BLOCKED,
479                    detail=detail)
480            ]
481        )
482        code, got, _ = check_request.convert_response(resp, self.PROJECT_ID)
483        expect(code).to(equal(httplib.FORBIDDEN))
484        expect(got).to(equal(detail))
485
486
487class _DateTimeTimer(object):
488    def __init__(self, auto=False):
489        self.auto = auto
490        self.time = datetime.datetime(1970, 1, 1)
491
492    def __call__(self):
493        if self.auto:
494            self.tick()
495        return self.time
496
497    def tick(self):
498        self.time += datetime.timedelta(seconds=1)
499