1package ztest
2
3import (
4	"bytes"
5	"encoding/json"
6	"io"
7	"mime/multipart"
8	"net/http"
9	"net/http/httptest"
10	"testing"
11
12	"zgo.at/zstd/zstring"
13)
14
15// Code checks if the error code in the recoder matches the desired one, and
16// will stop the test with t.Fatal() if it doesn't.
17func Code(t *testing.T, recorder *httptest.ResponseRecorder, want int) {
18	t.Helper()
19	if recorder.Code != want {
20		t.Errorf("wrong response code\nhave: %d %s\nwant: %d %s\nbody: %v",
21			recorder.Code, http.StatusText(recorder.Code),
22			want, http.StatusText(want),
23			zstring.ElideLeft(recorder.Body.String(), 500))
24	}
25}
26
27// Default values for NewRequest()
28var (
29	DefaultHost        = "example.com"
30	DefaultContentType = "application/json"
31)
32
33// NewRequest creates a new request with some sensible defaults set.
34func NewRequest(method, target string, body io.Reader) *http.Request {
35	r := httptest.NewRequest(method, target, body)
36	if r.Host == "" || r.Host == "example.com" {
37		r.Host = DefaultHost
38	}
39	if r.Header.Get("Content-Type") == "" {
40		r.Header.Set("Content-Type", DefaultContentType)
41	}
42	return r
43}
44
45// Body returns the JSON representation as an io.Reader. This is useful for
46// creating a request body. For example:
47//
48//   NewRequest("POST", "/", ztest.Body(someStruct{
49//       Foo: "bar",
50//   }))
51func Body(a interface{}) *bytes.Reader {
52	j, err := json.Marshal(a)
53	if err != nil {
54		panic(err)
55	}
56	return bytes.NewReader(j)
57}
58
59// HTTP sets up a HTTP test. A GET request will be made if r is nil.
60//
61// For example:
62//
63//   rr := ztest.HTTP(t, nil, MyHandler)
64//
65// Or for a POST request:
66//
67//   r, err := zhttp.NewRequest("POST", "/v1/email", nil)
68//   if err != nil {
69//       t.Fatal(err)
70//   }
71//   rr := ztest.HTTP(t, r, MyHandler)
72func HTTP(t *testing.T, r *http.Request, h http.Handler) *httptest.ResponseRecorder {
73	t.Helper()
74
75	rr := httptest.NewRecorder()
76	if r == nil {
77		var err error
78		r, err = http.NewRequest("GET", "", nil)
79		if err != nil {
80			t.Fatalf("cannot make request: %v", err)
81		}
82	}
83
84	h.ServeHTTP(rr, r)
85	return rr
86}
87
88// MultipartForm writes the keys and values from params to a multipart form.
89//
90// The first input parameter is used for "multipart/form-data" key/value
91// strings, the optional second parameter is used creating file parts.
92//
93// Don't forget to set the Content-Type from the return value:
94//
95//   r.Header.Set("Content-Type", contentType)
96func MultipartForm(params ...map[string]string) (b *bytes.Buffer, contentType string, err error) {
97	b = &bytes.Buffer{}
98	w := multipart.NewWriter(b)
99
100	for k, v := range params[0] {
101		field, err := w.CreateFormField(k)
102		if err != nil {
103			return nil, "", err
104		}
105		_, err = field.Write([]byte(v))
106		if err != nil {
107			return nil, "", err
108		}
109	}
110
111	if len(params) > 1 {
112		for k, v := range params[1] {
113			field, err := w.CreateFormFile(k, k)
114			if err != nil {
115				return nil, "", err
116			}
117			_, err = field.Write([]byte(v))
118			if err != nil {
119				return nil, "", err
120			}
121		}
122	}
123
124	if err := w.Close(); err != nil {
125		return nil, "", err
126	}
127
128	return b, w.FormDataContentType(), nil
129}
130