1// Copyright 2015 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package martiantest provides helper utilities for testing
16// modifiers.
17package martiantest
18
19import (
20	"net/http"
21	"sync/atomic"
22)
23
24// Modifier keeps track of the number of requests and responses it has modified
25// and can be configured to return errors or run custom functions.
26type Modifier struct {
27	reqcount int32 // atomic
28	rescount int32 // atomic
29	reqerr   error
30	reserr   error
31	reqfunc  func(*http.Request)
32	resfunc  func(*http.Response)
33}
34
35// NewModifier returns a new test modifier.
36func NewModifier() *Modifier {
37	return &Modifier{}
38}
39
40// RequestCount returns the number of requests modified.
41func (m *Modifier) RequestCount() int32 {
42	return atomic.LoadInt32(&m.reqcount)
43}
44
45// ResponseCount returns the number of responses modified.
46func (m *Modifier) ResponseCount() int32 {
47	return atomic.LoadInt32(&m.rescount)
48}
49
50// RequestModified returns whether a request has been modified.
51func (m *Modifier) RequestModified() bool {
52	return m.RequestCount() != 0
53}
54
55// ResponseModified returns whether a response has been modified.
56func (m *Modifier) ResponseModified() bool {
57	return m.ResponseCount() != 0
58}
59
60// RequestError overrides the error returned by ModifyRequest.
61func (m *Modifier) RequestError(err error) {
62	m.reqerr = err
63}
64
65// ResponseError overrides the error returned by ModifyResponse.
66func (m *Modifier) ResponseError(err error) {
67	m.reserr = err
68}
69
70// RequestFunc is a function to run during ModifyRequest.
71func (m *Modifier) RequestFunc(reqfunc func(req *http.Request)) {
72	m.reqfunc = reqfunc
73}
74
75// ResponseFunc is a function to run during ModifyResponse.
76func (m *Modifier) ResponseFunc(resfunc func(res *http.Response)) {
77	m.resfunc = resfunc
78}
79
80// ModifyRequest increases the count of requests seen and runs reqfunc if configured.
81func (m *Modifier) ModifyRequest(req *http.Request) error {
82	atomic.AddInt32(&m.reqcount, 1)
83
84	if m.reqfunc != nil {
85		m.reqfunc(req)
86	}
87
88	return m.reqerr
89}
90
91// ModifyResponse increases the count of responses seen and runs resfunc if configured.
92func (m *Modifier) ModifyResponse(res *http.Response) error {
93	atomic.AddInt32(&m.rescount, 1)
94
95	if m.resfunc != nil {
96		m.resfunc(res)
97	}
98
99	return m.reserr
100}
101
102// Reset resets the request and response counts, the custom
103// functions, and the modifier errors.
104func (m *Modifier) Reset() {
105	atomic.StoreInt32(&m.reqcount, 0)
106	atomic.StoreInt32(&m.rescount, 0)
107
108	m.reqfunc = nil
109	m.resfunc = nil
110
111	m.reqerr = nil
112	m.reserr = nil
113}
114
115// Matcher is a stubbed matcher used in tests.
116type Matcher struct {
117	resval bool
118	reqval bool
119}
120
121// NewMatcher returns a pointer to martiantest.Matcher with the return values
122// for MatchRequest and MatchResponse intiailized to true.
123func NewMatcher() *Matcher {
124	return &Matcher{resval: true, reqval: true}
125}
126
127// ResponseEvaluatesTo sets the value returned by MatchResponse.
128func (tm *Matcher) ResponseEvaluatesTo(value bool) {
129	tm.resval = value
130}
131
132// RequestEvaluatesTo sets the value returned by MatchRequest.
133func (tm *Matcher) RequestEvaluatesTo(value bool) {
134	tm.reqval = value
135}
136
137// MatchRequest returns the stubbed value in tm.reqval.
138func (tm *Matcher) MatchRequest(*http.Request) bool {
139	return tm.reqval
140}
141
142// MatchResponse returns the stubbed value in tm.resval.
143func (tm *Matcher) MatchResponse(*http.Response) bool {
144	return tm.resval
145}
146