1// Licensed to Elasticsearch B.V. under one or more contributor 2// license agreements. See the NOTICE file distributed with 3// this work for additional information regarding copyright 4// ownership. Elasticsearch B.V. licenses this file to you under 5// the Apache License, Version 2.0 (the "License"); you may 6// not use this file except in compliance with the License. 7// You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, 12// software distributed under the License is distributed on an 13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14// KIND, either express or implied. See the License for the 15// specific language governing permissions and limitations 16// under the License. 17 18package apmrestful_test 19 20import ( 21 "net" 22 "net/http" 23 "net/http/httptest" 24 "net/url" 25 "testing" 26 27 restful "github.com/emicklei/go-restful" 28 "github.com/stretchr/testify/assert" 29 "github.com/stretchr/testify/require" 30 "github.com/stretchr/testify/suite" 31 32 "go.elastic.co/apm" 33 "go.elastic.co/apm/apmtest" 34 "go.elastic.co/apm/model" 35 "go.elastic.co/apm/module/apmrestful" 36 "go.elastic.co/apm/transport/transporttest" 37) 38 39func TestHandlerHTTPSuite(t *testing.T) { 40 tracer, recorder := transporttest.NewRecorderTracer() 41 var ws restful.WebService 42 ws.Path("/").Consumes(restful.MIME_JSON, restful.MIME_XML).Produces(restful.MIME_JSON, restful.MIME_XML) 43 ws.Route(ws.GET("/implicit_write").To(func(req *restful.Request, resp *restful.Response) {})) 44 ws.Route(ws.GET("/panic_before_write").To(func(req *restful.Request, resp *restful.Response) { 45 panic("boom") 46 })) 47 ws.Route(ws.GET("/panic_after_write").To(func(req *restful.Request, resp *restful.Response) { 48 resp.Write([]byte("hello, world")) 49 panic("boom") 50 })) 51 container := restful.NewContainer() 52 container.Add(&ws) 53 container.Filter(apmrestful.Filter(apmrestful.WithTracer(tracer))) 54 55 suite.Run(t, &apmtest.HTTPTestSuite{ 56 Handler: container, 57 Tracer: tracer, 58 Recorder: recorder, 59 }) 60} 61 62func TestContainerFilter(t *testing.T) { 63 type Thing struct { 64 ID string 65 } 66 67 var ws restful.WebService 68 ws.Path("/things").Consumes(restful.MIME_JSON, restful.MIME_XML).Produces(restful.MIME_JSON, restful.MIME_XML) 69 ws.Route(ws.GET("/{id:[0-1]+}").To(func(req *restful.Request, resp *restful.Response) { 70 if apm.TransactionFromContext(req.Request.Context()) == nil { 71 panic("no transaction in context") 72 } 73 resp.WriteHeaderAndEntity(http.StatusTeapot, Thing{ 74 ID: req.PathParameter("id"), 75 }) 76 })) 77 78 tracer, transport := transporttest.NewRecorderTracer() 79 defer tracer.Close() 80 81 container := restful.NewContainer() 82 container.Add(&ws) 83 container.Filter(apmrestful.Filter(apmrestful.WithTracer(tracer))) 84 85 server := httptest.NewServer(container) 86 defer server.Close() 87 serverURL, err := url.Parse(server.URL) 88 require.NoError(t, err) 89 serverHost, serverPort, err := net.SplitHostPort(serverURL.Host) 90 require.NoError(t, err) 91 92 resp, err := http.Get(server.URL + "/things/123") 93 require.NoError(t, err) 94 require.NoError(t, resp.Body.Close()) 95 assert.Equal(t, http.StatusTeapot, resp.StatusCode) 96 tracer.Flush(nil) 97 98 payloads := transport.Payloads() 99 assert.Len(t, payloads.Transactions, 1) 100 transaction := payloads.Transactions[0] 101 102 assert.Equal(t, "GET /things/{id}", transaction.Name) 103 assert.Equal(t, "request", transaction.Type) 104 assert.Equal(t, "HTTP 4xx", transaction.Result) 105 106 assert.Equal(t, &model.Context{ 107 Service: &model.Service{ 108 Framework: &model.Framework{ 109 Name: "go-restful", 110 Version: "unspecified", 111 }, 112 }, 113 Request: &model.Request{ 114 Socket: &model.RequestSocket{ 115 RemoteAddress: "127.0.0.1", 116 }, 117 URL: model.URL{ 118 Full: server.URL + "/things/123", 119 Protocol: "http", 120 Hostname: serverHost, 121 Port: serverPort, 122 Path: "/things/123", 123 }, 124 Method: "GET", 125 HTTPVersion: "1.1", 126 Headers: model.Headers{{ 127 Key: "Accept-Encoding", 128 Values: []string{"gzip"}, 129 }, { 130 Key: "User-Agent", 131 Values: []string{"Go-http-client/1.1"}, 132 }}, 133 }, 134 Response: &model.Response{ 135 StatusCode: 418, 136 Headers: model.Headers{{ 137 Key: "Content-Type", 138 Values: []string{"application/json"}, 139 }}, 140 }, 141 }, transaction.Context) 142} 143 144func TestContainerFilterPanic(t *testing.T) { 145 var ws restful.WebService 146 ws.Path("/things").Consumes(restful.MIME_JSON, restful.MIME_XML).Produces(restful.MIME_JSON, restful.MIME_XML) 147 ws.Route(ws.GET("/{id}/foo").To(handlePanic)) 148 149 tracer, transport := transporttest.NewRecorderTracer() 150 defer tracer.Close() 151 152 container := restful.NewContainer() 153 container.Add(&ws) 154 container.Filter(apmrestful.Filter(apmrestful.WithTracer(tracer))) 155 156 server := httptest.NewServer(container) 157 defer server.Close() 158 resp, err := http.Get(server.URL + "/things/123/foo") 159 require.NoError(t, err) 160 require.NoError(t, resp.Body.Close()) 161 assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 162 tracer.Flush(nil) 163 164 payloads := transport.Payloads() 165 require.Len(t, payloads.Transactions, 1) 166 require.Len(t, payloads.Errors, 1) 167 panicError := payloads.Errors[0] 168 assert.Equal(t, payloads.Transactions[0].Context.Service, panicError.Context.Service) 169 assert.Equal(t, payloads.Transactions[0].ID, panicError.ParentID) 170 assert.Equal(t, "kablamo", panicError.Exception.Message) 171 assert.Equal(t, "handlePanic", panicError.Culprit) 172} 173 174func handlePanic(req *restful.Request, resp *restful.Response) { 175 panic("kablamo") 176} 177 178func TestContainerFilterUnknownRoute(t *testing.T) { 179 var ws restful.WebService 180 ws.Path("/things").Consumes(restful.MIME_JSON, restful.MIME_XML).Produces(restful.MIME_JSON, restful.MIME_XML) 181 ws.Route(ws.GET("/{id}/foo").To(handlePanic)) 182 183 tracer, transport := transporttest.NewRecorderTracer() 184 defer tracer.Close() 185 186 container := restful.NewContainer() 187 container.Add(&ws) 188 container.Filter(apmrestful.Filter(apmrestful.WithTracer(tracer))) 189 190 server := httptest.NewServer(container) 191 defer server.Close() 192 resp, err := http.Get(server.URL + "/things/123/bar") 193 require.NoError(t, err) 194 require.NoError(t, resp.Body.Close()) 195 assert.Equal(t, http.StatusNotFound, resp.StatusCode) 196 tracer.Flush(nil) 197 198 payloads := transport.Payloads() 199 require.Len(t, payloads.Transactions, 1) 200 assert.Equal(t, "GET unknown route", payloads.Transactions[0].Name) 201} 202