1// Licensed to Elasticsearch B.V. under one or more contributor
2// license agreements. See the NOTICE file distributed with
3// this work for additional information regarding copyright
4// ownership. Elasticsearch B.V. licenses this file to you under
5// the Apache License, Version 2.0 (the "License"); you may
6// not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18package apmrestful_test
19
20import (
21	"net"
22	"net/http"
23	"net/http/httptest"
24	"net/url"
25	"testing"
26
27	restful "github.com/emicklei/go-restful"
28	"github.com/stretchr/testify/assert"
29	"github.com/stretchr/testify/require"
30	"github.com/stretchr/testify/suite"
31
32	"go.elastic.co/apm"
33	"go.elastic.co/apm/apmtest"
34	"go.elastic.co/apm/model"
35	"go.elastic.co/apm/module/apmrestful"
36	"go.elastic.co/apm/transport/transporttest"
37)
38
39func TestHandlerHTTPSuite(t *testing.T) {
40	tracer, recorder := transporttest.NewRecorderTracer()
41	var ws restful.WebService
42	ws.Path("/").Consumes(restful.MIME_JSON, restful.MIME_XML).Produces(restful.MIME_JSON, restful.MIME_XML)
43	ws.Route(ws.GET("/implicit_write").To(func(req *restful.Request, resp *restful.Response) {}))
44	ws.Route(ws.GET("/panic_before_write").To(func(req *restful.Request, resp *restful.Response) {
45		panic("boom")
46	}))
47	ws.Route(ws.GET("/panic_after_write").To(func(req *restful.Request, resp *restful.Response) {
48		resp.Write([]byte("hello, world"))
49		panic("boom")
50	}))
51	container := restful.NewContainer()
52	container.Add(&ws)
53	container.Filter(apmrestful.Filter(apmrestful.WithTracer(tracer)))
54
55	suite.Run(t, &apmtest.HTTPTestSuite{
56		Handler:  container,
57		Tracer:   tracer,
58		Recorder: recorder,
59	})
60}
61
62func TestContainerFilter(t *testing.T) {
63	type Thing struct {
64		ID string
65	}
66
67	var ws restful.WebService
68	ws.Path("/things").Consumes(restful.MIME_JSON, restful.MIME_XML).Produces(restful.MIME_JSON, restful.MIME_XML)
69	ws.Route(ws.GET("/{id:[0-1]+}").To(func(req *restful.Request, resp *restful.Response) {
70		if apm.TransactionFromContext(req.Request.Context()) == nil {
71			panic("no transaction in context")
72		}
73		resp.WriteHeaderAndEntity(http.StatusTeapot, Thing{
74			ID: req.PathParameter("id"),
75		})
76	}))
77
78	tracer, transport := transporttest.NewRecorderTracer()
79	defer tracer.Close()
80
81	container := restful.NewContainer()
82	container.Add(&ws)
83	container.Filter(apmrestful.Filter(apmrestful.WithTracer(tracer)))
84
85	server := httptest.NewServer(container)
86	defer server.Close()
87	serverURL, err := url.Parse(server.URL)
88	require.NoError(t, err)
89	serverHost, serverPort, err := net.SplitHostPort(serverURL.Host)
90	require.NoError(t, err)
91
92	resp, err := http.Get(server.URL + "/things/123")
93	require.NoError(t, err)
94	require.NoError(t, resp.Body.Close())
95	assert.Equal(t, http.StatusTeapot, resp.StatusCode)
96	tracer.Flush(nil)
97
98	payloads := transport.Payloads()
99	assert.Len(t, payloads.Transactions, 1)
100	transaction := payloads.Transactions[0]
101
102	assert.Equal(t, "GET /things/{id}", transaction.Name)
103	assert.Equal(t, "request", transaction.Type)
104	assert.Equal(t, "HTTP 4xx", transaction.Result)
105
106	assert.Equal(t, &model.Context{
107		Service: &model.Service{
108			Framework: &model.Framework{
109				Name:    "go-restful",
110				Version: "unspecified",
111			},
112		},
113		Request: &model.Request{
114			Socket: &model.RequestSocket{
115				RemoteAddress: "127.0.0.1",
116			},
117			URL: model.URL{
118				Full:     server.URL + "/things/123",
119				Protocol: "http",
120				Hostname: serverHost,
121				Port:     serverPort,
122				Path:     "/things/123",
123			},
124			Method:      "GET",
125			HTTPVersion: "1.1",
126			Headers: model.Headers{{
127				Key:    "Accept-Encoding",
128				Values: []string{"gzip"},
129			}, {
130				Key:    "User-Agent",
131				Values: []string{"Go-http-client/1.1"},
132			}},
133		},
134		Response: &model.Response{
135			StatusCode: 418,
136			Headers: model.Headers{{
137				Key:    "Content-Type",
138				Values: []string{"application/json"},
139			}},
140		},
141	}, transaction.Context)
142}
143
144func TestContainerFilterPanic(t *testing.T) {
145	var ws restful.WebService
146	ws.Path("/things").Consumes(restful.MIME_JSON, restful.MIME_XML).Produces(restful.MIME_JSON, restful.MIME_XML)
147	ws.Route(ws.GET("/{id}/foo").To(handlePanic))
148
149	tracer, transport := transporttest.NewRecorderTracer()
150	defer tracer.Close()
151
152	container := restful.NewContainer()
153	container.Add(&ws)
154	container.Filter(apmrestful.Filter(apmrestful.WithTracer(tracer)))
155
156	server := httptest.NewServer(container)
157	defer server.Close()
158	resp, err := http.Get(server.URL + "/things/123/foo")
159	require.NoError(t, err)
160	require.NoError(t, resp.Body.Close())
161	assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
162	tracer.Flush(nil)
163
164	payloads := transport.Payloads()
165	require.Len(t, payloads.Transactions, 1)
166	require.Len(t, payloads.Errors, 1)
167	panicError := payloads.Errors[0]
168	assert.Equal(t, payloads.Transactions[0].Context.Service, panicError.Context.Service)
169	assert.Equal(t, payloads.Transactions[0].ID, panicError.ParentID)
170	assert.Equal(t, "kablamo", panicError.Exception.Message)
171	assert.Equal(t, "handlePanic", panicError.Culprit)
172}
173
174func handlePanic(req *restful.Request, resp *restful.Response) {
175	panic("kablamo")
176}
177
178func TestContainerFilterUnknownRoute(t *testing.T) {
179	var ws restful.WebService
180	ws.Path("/things").Consumes(restful.MIME_JSON, restful.MIME_XML).Produces(restful.MIME_JSON, restful.MIME_XML)
181	ws.Route(ws.GET("/{id}/foo").To(handlePanic))
182
183	tracer, transport := transporttest.NewRecorderTracer()
184	defer tracer.Close()
185
186	container := restful.NewContainer()
187	container.Add(&ws)
188	container.Filter(apmrestful.Filter(apmrestful.WithTracer(tracer)))
189
190	server := httptest.NewServer(container)
191	defer server.Close()
192	resp, err := http.Get(server.URL + "/things/123/bar")
193	require.NoError(t, err)
194	require.NoError(t, resp.Body.Close())
195	assert.Equal(t, http.StatusNotFound, resp.StatusCode)
196	tracer.Flush(nil)
197
198	payloads := transport.Payloads()
199	require.Len(t, payloads.Transactions, 1)
200	assert.Equal(t, "GET unknown route", payloads.Transactions[0].Name)
201}
202