1// Copyright 2015 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package body
16
17import (
18	"bytes"
19	"encoding/base64"
20	"fmt"
21	"io"
22	"io/ioutil"
23	"mime/multipart"
24	"net/http"
25	"strings"
26	"testing"
27
28	"github.com/google/martian/v3/messageview"
29	"github.com/google/martian/v3/parse"
30	"github.com/google/martian/v3/proxyutil"
31)
32
33func TestBodyModifier(t *testing.T) {
34	mod := NewModifier([]byte("text"), "text/plain")
35
36	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
37	if err != nil {
38		t.Fatalf("NewRequest(): got %v, want no error", err)
39	}
40	req.Header.Set("Content-Encoding", "gzip")
41
42	if err := mod.ModifyRequest(req); err != nil {
43		t.Fatalf("ModifyRequest(): got %v, want no error", err)
44	}
45
46	if got, want := req.Header.Get("Content-Type"), "text/plain"; got != want {
47		t.Errorf("req.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
48	}
49	if got, want := req.ContentLength, int64(len([]byte("text"))); got != want {
50		t.Errorf("req.ContentLength: got %d, want %d", got, want)
51	}
52	if got, want := req.Header.Get("Content-Encoding"), ""; got != want {
53		t.Errorf("req.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
54	}
55
56	got, err := ioutil.ReadAll(req.Body)
57	if err != nil {
58		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
59	}
60	req.Body.Close()
61
62	if want := []byte("text"); !bytes.Equal(got, want) {
63		t.Errorf("res.Body: got %q, want %q", got, want)
64	}
65
66	res := proxyutil.NewResponse(200, nil, req)
67	res.Header.Set("Content-Encoding", "gzip")
68
69	if err := mod.ModifyResponse(res); err != nil {
70		t.Fatalf("ModifyResponse(): got %v, want no error", err)
71	}
72
73	if got, want := res.Header.Get("Content-Type"), "text/plain"; got != want {
74		t.Errorf("res.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
75	}
76	if got, want := res.ContentLength, int64(len([]byte("text"))); got != want {
77		t.Errorf("res.ContentLength: got %d, want %d", got, want)
78	}
79	if got, want := res.Header.Get("Content-Encoding"), ""; got != want {
80		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
81	}
82
83	got, err = ioutil.ReadAll(res.Body)
84	if err != nil {
85		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
86	}
87	res.Body.Close()
88
89	if want := []byte("text"); !bytes.Equal(got, want) {
90		t.Errorf("res.Body: got %q, want %q", got, want)
91	}
92}
93func TestRangeHeaderRequestSingleRange(t *testing.T) {
94	mod := NewModifier([]byte("0123456789"), "text/plain")
95
96	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
97	if err != nil {
98		t.Fatalf("NewRequest(): got %v, want no error", err)
99	}
100	req.Header.Set("Range", "bytes=1-4")
101
102	res := proxyutil.NewResponse(200, nil, req)
103
104	if err := mod.ModifyResponse(res); err != nil {
105		t.Fatalf("ModifyResponse(): got %v, want no error", err)
106	}
107
108	if got, want := res.StatusCode, http.StatusPartialContent; got != want {
109		t.Errorf("res.Status: got %v, want %v", got, want)
110	}
111	if got, want := res.ContentLength, int64(len([]byte("1234"))); got != want {
112		t.Errorf("res.ContentLength: got %d, want %d", got, want)
113	}
114	if got, want := res.Header.Get("Content-Range"), "bytes 1-4/10"; got != want {
115		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
116	}
117
118	got, err := ioutil.ReadAll(res.Body)
119	if err != nil {
120		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
121	}
122	res.Body.Close()
123
124	if want := []byte("1234"); !bytes.Equal(got, want) {
125		t.Errorf("res.Body: got %q, want %q", got, want)
126	}
127}
128
129func TestRangeHeaderRequestSingleRangeHasAllTheBytes(t *testing.T) {
130	mod := NewModifier([]byte("0123456789"), "text/plain")
131
132	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
133	if err != nil {
134		t.Fatalf("NewRequest(): got %v, want no error", err)
135	}
136	req.Header.Set("Range", "bytes=0-")
137
138	res := proxyutil.NewResponse(200, nil, req)
139
140	if err := mod.ModifyResponse(res); err != nil {
141		t.Fatalf("ModifyResponse(): got %v, want no error", err)
142	}
143
144	if got, want := res.StatusCode, http.StatusPartialContent; got != want {
145		t.Errorf("res.Status: got %v, want %v", got, want)
146	}
147	if got, want := res.ContentLength, int64(len([]byte("0123456789"))); got != want {
148		t.Errorf("res.ContentLength: got %d, want %d", got, want)
149	}
150	if got, want := res.Header.Get("Content-Range"), "bytes 0-9/10"; got != want {
151		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
152	}
153
154	got, err := ioutil.ReadAll(res.Body)
155	if err != nil {
156		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
157	}
158	res.Body.Close()
159
160	if want := []byte("0123456789"); !bytes.Equal(got, want) {
161		t.Errorf("res.Body: got %q, want %q", got, want)
162	}
163}
164
165func TestRangeNoEndingIndexSpecified(t *testing.T) {
166	mod := NewModifier([]byte("0123456789"), "text/plain")
167
168	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
169	if err != nil {
170		t.Fatalf("NewRequest(): got %v, want no error", err)
171	}
172	req.Header.Set("Range", "bytes=8-")
173
174	res := proxyutil.NewResponse(200, nil, req)
175
176	if err := mod.ModifyResponse(res); err != nil {
177		t.Fatalf("ModifyResponse(): got %v, want no error", err)
178	}
179
180	if got, want := res.StatusCode, http.StatusPartialContent; got != want {
181		t.Errorf("res.Status: got %v, want %v", got, want)
182	}
183	if got, want := res.ContentLength, int64(len([]byte("89"))); got != want {
184		t.Errorf("res.ContentLength: got %d, want %d", got, want)
185	}
186	if got, want := res.Header.Get("Content-Range"), "bytes 8-9/10"; got != want {
187		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
188	}
189}
190
191func TestRangeHeaderMultipartRange(t *testing.T) {
192	mod := NewModifier([]byte("0123456789"), "text/plain")
193	bndry := "3d6b6a416f9b5"
194	mod.SetBoundary(bndry)
195
196	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
197	if err != nil {
198		t.Fatalf("NewRequest(): got %v, want no error", err)
199	}
200	req.Header.Set("Range", "bytes=1-4, 7-9")
201
202	res := proxyutil.NewResponse(200, nil, req)
203	if err := mod.ModifyResponse(res); err != nil {
204		t.Fatalf("ModifyResponse(): got %v, want no error", err)
205	}
206
207	if got, want := res.StatusCode, http.StatusPartialContent; got != want {
208		t.Errorf("res.Status: got %v, want %v", got, want)
209	}
210
211	if got, want := res.Header.Get("Content-Type"), "multipart/byteranges; boundary=3d6b6a416f9b5"; got != want {
212		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Type", got, want)
213	}
214
215	mv := messageview.New()
216	if err := mv.SnapshotResponse(res); err != nil {
217		t.Fatalf("mv.SnapshotResponse(res): got %v, want no error", err)
218	}
219
220	br, err := mv.BodyReader()
221	if err != nil {
222		t.Fatalf("mv.BodyReader(): got %v, want no error", err)
223	}
224
225	mpr := multipart.NewReader(br, bndry)
226	prt1, err := mpr.NextPart()
227	if err != nil {
228		t.Fatalf("mpr.NextPart(): got %v, want no error", err)
229	}
230	defer prt1.Close()
231
232	if got, want := prt1.Header.Get("Content-Type"), "text/plain"; got != want {
233		t.Errorf("prt1.Header.Get(%q): got %q, want %q", "Content-Type", got, want)
234	}
235
236	if got, want := prt1.Header.Get("Content-Range"), "bytes 1-4/10"; got != want {
237		t.Errorf("prt1.Header.Get(%q): got %q, want %q", "Content-Range", got, want)
238	}
239
240	prt1b, err := ioutil.ReadAll(prt1)
241	if err != nil {
242		t.Errorf("ioutil.Readall(prt1): got %v, want no error", err)
243	}
244
245	if got, want := string(prt1b), "1234"; got != want {
246		t.Errorf("prt1 body: got %s, want %s", got, want)
247	}
248
249	prt2, err := mpr.NextPart()
250	if err != nil {
251		t.Fatalf("mpr.NextPart(): got %v, want no error", err)
252	}
253	defer prt2.Close()
254
255	if got, want := prt2.Header.Get("Content-Type"), "text/plain"; got != want {
256		t.Errorf("prt2.Header.Get(%q): got %q, want %q", "Content-Type", got, want)
257	}
258
259	if got, want := prt2.Header.Get("Content-Range"), "bytes 7-9/10"; got != want {
260		t.Errorf("prt2.Header.Get(%q): got %q, want %q", "Content-Range", got, want)
261	}
262
263	prt2b, err := ioutil.ReadAll(prt2)
264	if err != io.ErrUnexpectedEOF && err != nil {
265		t.Errorf("ioutil.Readall(prt2): got %v, want no error", err)
266	}
267
268	if got, want := string(prt2b), "789"; got != want {
269		t.Errorf("prt2 body: got %s, want %s", got, want)
270	}
271
272	_, err = mpr.NextPart()
273	if err == nil {
274		t.Errorf("mpr.NextPart: want io.EOF, got no error")
275	}
276	if err != io.EOF {
277		t.Errorf("mpr.NextPart: want io.EOF, got %v", err)
278	}
279}
280
281func TestModifierFromJSON(t *testing.T) {
282	data := base64.StdEncoding.EncodeToString([]byte("data"))
283	msg := fmt.Sprintf(`{
284	  "body.Modifier":{
285		  "scope": ["response"],
286  	  "contentType": "text/plain",
287	  	"body": %q
288    }
289	}`, data)
290
291	r, err := parse.FromJSON([]byte(msg))
292	if err != nil {
293		t.Fatalf("parse.FromJSON(): got %v, want no error", err)
294	}
295
296	resmod := r.ResponseModifier()
297
298	if resmod == nil {
299		t.Fatalf("resmod: got nil, want not nil")
300	}
301
302	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
303	if err != nil {
304		t.Fatalf("NewRequest(): got %v, want no error", err)
305	}
306
307	res := proxyutil.NewResponse(200, nil, req)
308	if err := resmod.ModifyResponse(res); err != nil {
309		t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err)
310	}
311
312	if got, want := res.Header.Get("Content-Type"), "text/plain"; got != want {
313		t.Errorf("res.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
314	}
315
316	got, err := ioutil.ReadAll(res.Body)
317	if err != nil {
318		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
319	}
320	res.Body.Close()
321
322	if want := []byte("data"); !bytes.Equal(got, want) {
323		t.Errorf("res.Body: got %q, want %q", got, want)
324	}
325}
326