1// Copyright (c) 2017 Uber Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package server
16
17import (
18	"context"
19	"encoding/json"
20	"fmt"
21	"io/ioutil"
22	"net"
23	"net/http"
24	"strings"
25	"sync"
26
27	"github.com/opentracing/opentracing-go"
28	"github.com/opentracing/opentracing-go/ext"
29
30	"github.com/uber/jaeger-client-go/crossdock/common"
31	"github.com/uber/jaeger-client-go/crossdock/endtoend"
32	"github.com/uber/jaeger-client-go/crossdock/log"
33	"github.com/uber/jaeger-client-go/crossdock/thrift/tracetest"
34)
35
36// Server implements S1-S3 servers
37type Server struct {
38	HostPortHTTP      string
39	AgentHostPort     string
40	SamplingServerURL string
41	Tracer            opentracing.Tracer
42	listener          net.Listener
43	eHandler          *endtoend.Handler
44}
45
46// Start starts the test server called by the Client and other upstream servers.
47func (s *Server) Start() error {
48	if s.HostPortHTTP == "" {
49		s.HostPortHTTP = ":" + common.DefaultServerPortHTTP
50	}
51
52	s.eHandler = endtoend.NewHandler(s.AgentHostPort, s.SamplingServerURL)
53
54	mux := http.NewServeMux()
55	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { return }) // health check
56	mux.HandleFunc("/start_trace", func(w http.ResponseWriter, r *http.Request) {
57		s.handleJSON(w, r, func() interface{} {
58			return tracetest.NewStartTraceRequest()
59		}, func(ctx context.Context, req interface{}) (interface{}, error) {
60			return s.doStartTrace(req.(*tracetest.StartTraceRequest))
61		})
62	})
63	mux.HandleFunc("/join_trace", func(w http.ResponseWriter, r *http.Request) {
64		s.handleJSON(w, r, func() interface{} {
65			return tracetest.NewJoinTraceRequest()
66		}, func(ctx context.Context, req interface{}) (interface{}, error) {
67			return s.doJoinTrace(ctx, req.(*tracetest.JoinTraceRequest))
68		})
69	})
70	mux.HandleFunc("/create_traces", s.eHandler.GenerateTraces)
71
72	listener, err := net.Listen("tcp", s.HostPortHTTP)
73	if err != nil {
74		return err
75	}
76	s.listener = listener
77	s.HostPortHTTP = listener.Addr().String()
78
79	var started sync.WaitGroup
80	started.Add(1)
81	go func() {
82		started.Done()
83		http.Serve(listener, mux)
84	}()
85	started.Wait()
86	log.Printf("Started http server at %s\n", s.HostPortHTTP)
87	return nil
88}
89
90// URL returns URL of the HTTP server
91func (s *Server) URL() string {
92	return fmt.Sprintf("http://%s/", s.HostPortHTTP)
93}
94
95// Close stops the server
96func (s *Server) Close() error {
97	return s.listener.Close()
98}
99
100// GetPortHTTP returns the network port the server listens to.
101func (s *Server) GetPortHTTP() string {
102	hostPort := s.HostPortHTTP
103	hostPortSplit := strings.Split(hostPort, ":")
104	port := hostPortSplit[len(hostPortSplit)-1]
105	return port
106}
107
108func (s *Server) handleJSON(
109	w http.ResponseWriter,
110	r *http.Request,
111	newReq func() interface{},
112	handle func(ctx context.Context, req interface{}) (interface{}, error),
113) {
114	spanCtx, err := s.Tracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header))
115	if err != nil && err != opentracing.ErrSpanContextNotFound {
116		http.Error(w, fmt.Sprintf("Cannot read request body: %+v", err), http.StatusBadRequest)
117		return
118	}
119	span := s.Tracer.StartSpan("post", ext.RPCServerOption(spanCtx))
120	ctx := opentracing.ContextWithSpan(context.Background(), span)
121	defer span.Finish()
122
123	body, err := ioutil.ReadAll(r.Body)
124	if err != nil {
125		http.Error(w, fmt.Sprintf("Cannot read request body: %+v", err), http.StatusInternalServerError)
126		return
127	}
128	log.Printf("Server request: %s", string(body))
129	req := newReq()
130	if err := json.Unmarshal(body, req); err != nil {
131		http.Error(w, fmt.Sprintf("Cannot parse request JSON: %+v. body=[%s]", err, string(body)), http.StatusBadRequest)
132		return
133	}
134	resp, err := handle(ctx, req)
135	if err != nil {
136		log.Printf("Handle error: %s", err.Error())
137		http.Error(w, fmt.Sprintf("Execution error: %+v", err), http.StatusInternalServerError)
138		return
139	}
140	json, err := json.Marshal(resp)
141	if err != nil {
142		http.Error(w, fmt.Sprintf("Cannot marshall response to JSON: %+v", err), http.StatusInternalServerError)
143		return
144	}
145	log.Printf("Server response: %s", string(json))
146	w.Header().Add("Content-Type", "application/json")
147	if _, err := w.Write(json); err != nil {
148		return
149	}
150}
151