1// Copyright 2019 DeepMap, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14package testutil
15
16// This is a set of fluent request builders for tests, which help us to
17// simplify constructing and unmarshaling test objects. For example, to post
18// a body and return a response, you would do something like:
19//
20//   var body RequestBody
21//   var response ResponseBody
22//   t is *testing.T, from a unit test
23//   e is *echo.Echo
24//   response := NewRequest().Post("/path").WithJsonBody(body).Go(t, e)
25//   err := response.UnmarshalBodyToObject(&response)
26import (
27	"bytes"
28	"encoding/json"
29	"fmt"
30	"io"
31	"net/http"
32	"net/http/httptest"
33	"strings"
34	"testing"
35
36	"github.com/labstack/echo/v4"
37)
38
39func NewRequest() *RequestBuilder {
40	return &RequestBuilder{
41		Headers: make(map[string]string),
42	}
43}
44
45// This structure caches request settings as we build up the request.
46type RequestBuilder struct {
47	Method  string
48	Path    string
49	Headers map[string]string
50	Body    []byte
51	Error   error
52	Cookies []*http.Cookie
53}
54
55// Path operations
56func (r *RequestBuilder) WithMethod(method string, path string) *RequestBuilder {
57	r.Method = method
58	r.Path = path
59	return r
60}
61
62func (r *RequestBuilder) Get(path string) *RequestBuilder {
63	return r.WithMethod("GET", path)
64}
65
66func (r *RequestBuilder) Post(path string) *RequestBuilder {
67	return r.WithMethod("POST", path)
68}
69
70func (r *RequestBuilder) Put(path string) *RequestBuilder {
71	return r.WithMethod("PUT", path)
72}
73
74func (r *RequestBuilder) Patch(path string) *RequestBuilder {
75	return r.WithMethod("PATCH", path)
76}
77
78func (r *RequestBuilder) Delete(path string) *RequestBuilder {
79	return r.WithMethod("DELETE", path)
80}
81
82// Header operations
83func (r *RequestBuilder) WithHeader(header, value string) *RequestBuilder {
84	r.Headers[header] = value
85	return r
86}
87
88func (r *RequestBuilder) WithHost(value string) *RequestBuilder {
89	return r.WithHeader("Host", value)
90}
91
92func (r *RequestBuilder) WithContentType(value string) *RequestBuilder {
93	return r.WithHeader("Content-Type", value)
94}
95
96func (r *RequestBuilder) WithJsonContentType() *RequestBuilder {
97	return r.WithContentType("application/json")
98}
99
100func (r *RequestBuilder) WithAccept(value string) *RequestBuilder {
101	return r.WithHeader("Accept", value)
102}
103
104func (r *RequestBuilder) WithAcceptJson() *RequestBuilder {
105	return r.WithAccept("application/json")
106}
107
108// Request body operations
109
110func (r *RequestBuilder) WithBody(body []byte) *RequestBuilder {
111	r.Body = body
112	return r
113}
114
115// This function takes an object as input, marshals it to JSON, and sends it
116// as the body with Content-Type: application/json
117func (r *RequestBuilder) WithJsonBody(obj interface{}) *RequestBuilder {
118	var err error
119	r.Body, err = json.Marshal(obj)
120	if err != nil {
121		r.Error = fmt.Errorf("failed to marshal json object: %s", err)
122	}
123	return r.WithJsonContentType()
124}
125
126// Cookie operations
127func (r *RequestBuilder) WithCookie(c *http.Cookie) *RequestBuilder {
128	r.Cookies = append(r.Cookies, c)
129	return r
130}
131
132func (r *RequestBuilder) WithCookieNameValue(name, value string) *RequestBuilder {
133	return r.WithCookie(&http.Cookie{Name: name, Value: value})
134}
135
136// GoWithHTTPHandler performs the request, it takes a pointer to a testing context
137// to print messages, and a http handler for request handling.
138func (r *RequestBuilder) GoWithHTTPHandler(t *testing.T, handler http.Handler) *CompletedRequest {
139	if r.Error != nil {
140		// Fail the test if we had an error
141		t.Errorf("error constructing request: %s", r.Error)
142		return nil
143	}
144	var bodyReader io.Reader
145	if r.Body != nil {
146		bodyReader = bytes.NewReader(r.Body)
147	}
148
149	req := httptest.NewRequest(r.Method, r.Path, bodyReader)
150	for h, v := range r.Headers {
151		req.Header.Add(h, v)
152	}
153	if host, ok := r.Headers["Host"]; ok {
154		req.Host = host
155	}
156	for _, c := range r.Cookies {
157		req.AddCookie(c)
158	}
159
160	rec := httptest.NewRecorder()
161	handler.ServeHTTP(rec, req)
162
163	return &CompletedRequest{
164		Recorder: rec,
165	}
166}
167
168// Go performs the request, it takes a pointer to a testing context
169// to print messages, and a pointer to an echo context for request handling.
170func (r *RequestBuilder) Go(t *testing.T, e *echo.Echo) *CompletedRequest {
171	return r.GoWithHTTPHandler(t, e)
172}
173
174// This is the result of calling Go() on the request builder. We're wrapping the
175// ResponseRecorder with some nice helper functions.
176type CompletedRequest struct {
177	Recorder *httptest.ResponseRecorder
178
179	// When set to true, decoders will be more strict. In the default JSON
180	// recorder, unknown fields will cause errors.
181	Strict bool
182}
183
184func (c *CompletedRequest) DisallowUnknownFields() {
185	c.Strict = true
186}
187
188// This function takes a destination object as input, and unmarshals the object
189// in the response based on the Content-Type header.
190func (c *CompletedRequest) UnmarshalBodyToObject(obj interface{}) error {
191	ctype := c.Recorder.Header().Get("Content-Type")
192
193	// Content type can have an annotation after ;
194	contentParts := strings.Split(ctype, ";")
195	content := strings.TrimSpace(contentParts[0])
196	handler := getHandler(content)
197	if handler == nil {
198		return fmt.Errorf("unhandled content: %s", content)
199	}
200
201	return handler(ctype, c.Recorder.Body, obj, c.Strict)
202}
203
204// This function assumes that the response contains JSON and unmarshals it
205// into the specified object.
206func (c *CompletedRequest) UnmarshalJsonToObject(obj interface{}) error {
207	return json.Unmarshal(c.Recorder.Body.Bytes(), obj)
208}
209
210// Shortcut for response code
211func (c *CompletedRequest) Code() int {
212	return c.Recorder.Code
213}
214