1// Copyright 2011 Google Inc. All rights reserved.
2// Use of this source code is governed by the Apache 2.0
3// license that can be found in the LICENSE file.
4
5// +build !appengine
6
7package internal
8
9import (
10	"bytes"
11	"errors"
12	"fmt"
13	"io/ioutil"
14	"log"
15	"net"
16	"net/http"
17	"net/url"
18	"os"
19	"runtime"
20	"strconv"
21	"strings"
22	"sync"
23	"sync/atomic"
24	"time"
25
26	"github.com/golang/protobuf/proto"
27	netcontext "golang.org/x/net/context"
28
29	basepb "google.golang.org/appengine/internal/base"
30	logpb "google.golang.org/appengine/internal/log"
31	remotepb "google.golang.org/appengine/internal/remote_api"
32)
33
34const (
35	apiPath             = "/rpc_http"
36	defaultTicketSuffix = "/default.20150612t184001.0"
37)
38
39var (
40	// Incoming headers.
41	ticketHeader       = http.CanonicalHeaderKey("X-AppEngine-API-Ticket")
42	dapperHeader       = http.CanonicalHeaderKey("X-Google-DapperTraceInfo")
43	traceHeader        = http.CanonicalHeaderKey("X-Cloud-Trace-Context")
44	curNamespaceHeader = http.CanonicalHeaderKey("X-AppEngine-Current-Namespace")
45	userIPHeader       = http.CanonicalHeaderKey("X-AppEngine-User-IP")
46	remoteAddrHeader   = http.CanonicalHeaderKey("X-AppEngine-Remote-Addr")
47	devRequestIdHeader = http.CanonicalHeaderKey("X-Appengine-Dev-Request-Id")
48
49	// Outgoing headers.
50	apiEndpointHeader      = http.CanonicalHeaderKey("X-Google-RPC-Service-Endpoint")
51	apiEndpointHeaderValue = []string{"app-engine-apis"}
52	apiMethodHeader        = http.CanonicalHeaderKey("X-Google-RPC-Service-Method")
53	apiMethodHeaderValue   = []string{"/VMRemoteAPI.CallRemoteAPI"}
54	apiDeadlineHeader      = http.CanonicalHeaderKey("X-Google-RPC-Service-Deadline")
55	apiContentType         = http.CanonicalHeaderKey("Content-Type")
56	apiContentTypeValue    = []string{"application/octet-stream"}
57	logFlushHeader         = http.CanonicalHeaderKey("X-AppEngine-Log-Flush-Count")
58
59	apiHTTPClient = &http.Client{
60		Transport: &http.Transport{
61			Proxy:               http.ProxyFromEnvironment,
62			Dial:                limitDial,
63			MaxIdleConns:        1000,
64			MaxIdleConnsPerHost: 10000,
65			IdleConnTimeout:     90 * time.Second,
66		},
67	}
68
69	defaultTicketOnce     sync.Once
70	defaultTicket         string
71	backgroundContextOnce sync.Once
72	backgroundContext     netcontext.Context
73)
74
75func apiURL() *url.URL {
76	host, port := "appengine.googleapis.internal", "10001"
77	if h := os.Getenv("API_HOST"); h != "" {
78		host = h
79	}
80	if p := os.Getenv("API_PORT"); p != "" {
81		port = p
82	}
83	return &url.URL{
84		Scheme: "http",
85		Host:   host + ":" + port,
86		Path:   apiPath,
87	}
88}
89
90func handleHTTP(w http.ResponseWriter, r *http.Request) {
91	c := &context{
92		req:       r,
93		outHeader: w.Header(),
94		apiURL:    apiURL(),
95	}
96	r = r.WithContext(withContext(r.Context(), c))
97	c.req = r
98
99	stopFlushing := make(chan int)
100
101	// Patch up RemoteAddr so it looks reasonable.
102	if addr := r.Header.Get(userIPHeader); addr != "" {
103		r.RemoteAddr = addr
104	} else if addr = r.Header.Get(remoteAddrHeader); addr != "" {
105		r.RemoteAddr = addr
106	} else {
107		// Should not normally reach here, but pick a sensible default anyway.
108		r.RemoteAddr = "127.0.0.1"
109	}
110	// The address in the headers will most likely be of these forms:
111	//	123.123.123.123
112	//	2001:db8::1
113	// net/http.Request.RemoteAddr is specified to be in "IP:port" form.
114	if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
115		// Assume the remote address is only a host; add a default port.
116		r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
117	}
118
119	// Start goroutine responsible for flushing app logs.
120	// This is done after adding c to ctx.m (and stopped before removing it)
121	// because flushing logs requires making an API call.
122	go c.logFlusher(stopFlushing)
123
124	executeRequestSafely(c, r)
125	c.outHeader = nil // make sure header changes aren't respected any more
126
127	stopFlushing <- 1 // any logging beyond this point will be dropped
128
129	// Flush any pending logs asynchronously.
130	c.pendingLogs.Lock()
131	flushes := c.pendingLogs.flushes
132	if len(c.pendingLogs.lines) > 0 {
133		flushes++
134	}
135	c.pendingLogs.Unlock()
136	flushed := make(chan struct{})
137	go func() {
138		defer close(flushed)
139		// Force a log flush, because with very short requests we
140		// may not ever flush logs.
141		c.flushLog(true)
142	}()
143	w.Header().Set(logFlushHeader, strconv.Itoa(flushes))
144
145	// Avoid nil Write call if c.Write is never called.
146	if c.outCode != 0 {
147		w.WriteHeader(c.outCode)
148	}
149	if c.outBody != nil {
150		w.Write(c.outBody)
151	}
152	// Wait for the last flush to complete before returning,
153	// otherwise the security ticket will not be valid.
154	<-flushed
155}
156
157func executeRequestSafely(c *context, r *http.Request) {
158	defer func() {
159		if x := recover(); x != nil {
160			logf(c, 4, "%s", renderPanic(x)) // 4 == critical
161			c.outCode = 500
162		}
163	}()
164
165	http.DefaultServeMux.ServeHTTP(c, r)
166}
167
168func renderPanic(x interface{}) string {
169	buf := make([]byte, 16<<10) // 16 KB should be plenty
170	buf = buf[:runtime.Stack(buf, false)]
171
172	// Remove the first few stack frames:
173	//   this func
174	//   the recover closure in the caller
175	// That will root the stack trace at the site of the panic.
176	const (
177		skipStart  = "internal.renderPanic"
178		skipFrames = 2
179	)
180	start := bytes.Index(buf, []byte(skipStart))
181	p := start
182	for i := 0; i < skipFrames*2 && p+1 < len(buf); i++ {
183		p = bytes.IndexByte(buf[p+1:], '\n') + p + 1
184		if p < 0 {
185			break
186		}
187	}
188	if p >= 0 {
189		// buf[start:p+1] is the block to remove.
190		// Copy buf[p+1:] over buf[start:] and shrink buf.
191		copy(buf[start:], buf[p+1:])
192		buf = buf[:len(buf)-(p+1-start)]
193	}
194
195	// Add panic heading.
196	head := fmt.Sprintf("panic: %v\n\n", x)
197	if len(head) > len(buf) {
198		// Extremely unlikely to happen.
199		return head
200	}
201	copy(buf[len(head):], buf)
202	copy(buf, head)
203
204	return string(buf)
205}
206
207// context represents the context of an in-flight HTTP request.
208// It implements the appengine.Context and http.ResponseWriter interfaces.
209type context struct {
210	req *http.Request
211
212	outCode   int
213	outHeader http.Header
214	outBody   []byte
215
216	pendingLogs struct {
217		sync.Mutex
218		lines   []*logpb.UserAppLogLine
219		flushes int
220	}
221
222	apiURL *url.URL
223}
224
225var contextKey = "holds a *context"
226
227// jointContext joins two contexts in a superficial way.
228// It takes values and timeouts from a base context, and only values from another context.
229type jointContext struct {
230	base       netcontext.Context
231	valuesOnly netcontext.Context
232}
233
234func (c jointContext) Deadline() (time.Time, bool) {
235	return c.base.Deadline()
236}
237
238func (c jointContext) Done() <-chan struct{} {
239	return c.base.Done()
240}
241
242func (c jointContext) Err() error {
243	return c.base.Err()
244}
245
246func (c jointContext) Value(key interface{}) interface{} {
247	if val := c.base.Value(key); val != nil {
248		return val
249	}
250	return c.valuesOnly.Value(key)
251}
252
253// fromContext returns the App Engine context or nil if ctx is not
254// derived from an App Engine context.
255func fromContext(ctx netcontext.Context) *context {
256	c, _ := ctx.Value(&contextKey).(*context)
257	return c
258}
259
260func withContext(parent netcontext.Context, c *context) netcontext.Context {
261	ctx := netcontext.WithValue(parent, &contextKey, c)
262	if ns := c.req.Header.Get(curNamespaceHeader); ns != "" {
263		ctx = withNamespace(ctx, ns)
264	}
265	return ctx
266}
267
268func toContext(c *context) netcontext.Context {
269	return withContext(netcontext.Background(), c)
270}
271
272func IncomingHeaders(ctx netcontext.Context) http.Header {
273	if c := fromContext(ctx); c != nil {
274		return c.req.Header
275	}
276	return nil
277}
278
279func ReqContext(req *http.Request) netcontext.Context {
280	return req.Context()
281}
282
283func WithContext(parent netcontext.Context, req *http.Request) netcontext.Context {
284	return jointContext{
285		base:       parent,
286		valuesOnly: req.Context(),
287	}
288}
289
290// DefaultTicket returns a ticket used for background context or dev_appserver.
291func DefaultTicket() string {
292	defaultTicketOnce.Do(func() {
293		if IsDevAppServer() {
294			defaultTicket = "testapp" + defaultTicketSuffix
295			return
296		}
297		appID := partitionlessAppID()
298		escAppID := strings.Replace(strings.Replace(appID, ":", "_", -1), ".", "_", -1)
299		majVersion := VersionID(nil)
300		if i := strings.Index(majVersion, "."); i > 0 {
301			majVersion = majVersion[:i]
302		}
303		defaultTicket = fmt.Sprintf("%s/%s.%s.%s", escAppID, ModuleName(nil), majVersion, InstanceID())
304	})
305	return defaultTicket
306}
307
308func BackgroundContext() netcontext.Context {
309	backgroundContextOnce.Do(func() {
310		// Compute background security ticket.
311		ticket := DefaultTicket()
312
313		c := &context{
314			req: &http.Request{
315				Header: http.Header{
316					ticketHeader: []string{ticket},
317				},
318			},
319			apiURL: apiURL(),
320		}
321		backgroundContext = toContext(c)
322
323		// TODO(dsymonds): Wire up the shutdown handler to do a final flush.
324		go c.logFlusher(make(chan int))
325	})
326
327	return backgroundContext
328}
329
330// RegisterTestRequest registers the HTTP request req for testing, such that
331// any API calls are sent to the provided URL. It returns a closure to delete
332// the registration.
333// It should only be used by aetest package.
334func RegisterTestRequest(req *http.Request, apiURL *url.URL, decorate func(netcontext.Context) netcontext.Context) (*http.Request, func()) {
335	c := &context{
336		req:    req,
337		apiURL: apiURL,
338	}
339	ctx := withContext(decorate(req.Context()), c)
340	req = req.WithContext(ctx)
341	c.req = req
342	return req, func() {}
343}
344
345var errTimeout = &CallError{
346	Detail:  "Deadline exceeded",
347	Code:    int32(remotepb.RpcError_CANCELLED),
348	Timeout: true,
349}
350
351func (c *context) Header() http.Header { return c.outHeader }
352
353// Copied from $GOROOT/src/pkg/net/http/transfer.go. Some response status
354// codes do not permit a response body (nor response entity headers such as
355// Content-Length, Content-Type, etc).
356func bodyAllowedForStatus(status int) bool {
357	switch {
358	case status >= 100 && status <= 199:
359		return false
360	case status == 204:
361		return false
362	case status == 304:
363		return false
364	}
365	return true
366}
367
368func (c *context) Write(b []byte) (int, error) {
369	if c.outCode == 0 {
370		c.WriteHeader(http.StatusOK)
371	}
372	if len(b) > 0 && !bodyAllowedForStatus(c.outCode) {
373		return 0, http.ErrBodyNotAllowed
374	}
375	c.outBody = append(c.outBody, b...)
376	return len(b), nil
377}
378
379func (c *context) WriteHeader(code int) {
380	if c.outCode != 0 {
381		logf(c, 3, "WriteHeader called multiple times on request.") // error level
382		return
383	}
384	c.outCode = code
385}
386
387func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) {
388	hreq := &http.Request{
389		Method: "POST",
390		URL:    c.apiURL,
391		Header: http.Header{
392			apiEndpointHeader: apiEndpointHeaderValue,
393			apiMethodHeader:   apiMethodHeaderValue,
394			apiContentType:    apiContentTypeValue,
395			apiDeadlineHeader: []string{strconv.FormatFloat(timeout.Seconds(), 'f', -1, 64)},
396		},
397		Body:          ioutil.NopCloser(bytes.NewReader(body)),
398		ContentLength: int64(len(body)),
399		Host:          c.apiURL.Host,
400	}
401	if info := c.req.Header.Get(dapperHeader); info != "" {
402		hreq.Header.Set(dapperHeader, info)
403	}
404	if info := c.req.Header.Get(traceHeader); info != "" {
405		hreq.Header.Set(traceHeader, info)
406	}
407
408	tr := apiHTTPClient.Transport.(*http.Transport)
409
410	var timedOut int32 // atomic; set to 1 if timed out
411	t := time.AfterFunc(timeout, func() {
412		atomic.StoreInt32(&timedOut, 1)
413		tr.CancelRequest(hreq)
414	})
415	defer t.Stop()
416	defer func() {
417		// Check if timeout was exceeded.
418		if atomic.LoadInt32(&timedOut) != 0 {
419			err = errTimeout
420		}
421	}()
422
423	hresp, err := apiHTTPClient.Do(hreq)
424	if err != nil {
425		return nil, &CallError{
426			Detail: fmt.Sprintf("service bridge HTTP failed: %v", err),
427			Code:   int32(remotepb.RpcError_UNKNOWN),
428		}
429	}
430	defer hresp.Body.Close()
431	hrespBody, err := ioutil.ReadAll(hresp.Body)
432	if hresp.StatusCode != 200 {
433		return nil, &CallError{
434			Detail: fmt.Sprintf("service bridge returned HTTP %d (%q)", hresp.StatusCode, hrespBody),
435			Code:   int32(remotepb.RpcError_UNKNOWN),
436		}
437	}
438	if err != nil {
439		return nil, &CallError{
440			Detail: fmt.Sprintf("service bridge response bad: %v", err),
441			Code:   int32(remotepb.RpcError_UNKNOWN),
442		}
443	}
444	return hrespBody, nil
445}
446
447func Call(ctx netcontext.Context, service, method string, in, out proto.Message) error {
448	if ns := NamespaceFromContext(ctx); ns != "" {
449		if fn, ok := NamespaceMods[service]; ok {
450			fn(in, ns)
451		}
452	}
453
454	if f, ctx, ok := callOverrideFromContext(ctx); ok {
455		return f(ctx, service, method, in, out)
456	}
457
458	// Handle already-done contexts quickly.
459	select {
460	case <-ctx.Done():
461		return ctx.Err()
462	default:
463	}
464
465	c := fromContext(ctx)
466	if c == nil {
467		// Give a good error message rather than a panic lower down.
468		return errNotAppEngineContext
469	}
470
471	// Apply transaction modifications if we're in a transaction.
472	if t := transactionFromContext(ctx); t != nil {
473		if t.finished {
474			return errors.New("transaction context has expired")
475		}
476		applyTransaction(in, &t.transaction)
477	}
478
479	// Default RPC timeout is 60s.
480	timeout := 60 * time.Second
481	if deadline, ok := ctx.Deadline(); ok {
482		timeout = deadline.Sub(time.Now())
483	}
484
485	data, err := proto.Marshal(in)
486	if err != nil {
487		return err
488	}
489
490	ticket := c.req.Header.Get(ticketHeader)
491	// Use a test ticket under test environment.
492	if ticket == "" {
493		if appid := ctx.Value(&appIDOverrideKey); appid != nil {
494			ticket = appid.(string) + defaultTicketSuffix
495		}
496	}
497	// Fall back to use background ticket when the request ticket is not available in Flex or dev_appserver.
498	if ticket == "" {
499		ticket = DefaultTicket()
500	}
501	if dri := c.req.Header.Get(devRequestIdHeader); IsDevAppServer() && dri != "" {
502		ticket = dri
503	}
504	req := &remotepb.Request{
505		ServiceName: &service,
506		Method:      &method,
507		Request:     data,
508		RequestId:   &ticket,
509	}
510	hreqBody, err := proto.Marshal(req)
511	if err != nil {
512		return err
513	}
514
515	hrespBody, err := c.post(hreqBody, timeout)
516	if err != nil {
517		return err
518	}
519
520	res := &remotepb.Response{}
521	if err := proto.Unmarshal(hrespBody, res); err != nil {
522		return err
523	}
524	if res.RpcError != nil {
525		ce := &CallError{
526			Detail: res.RpcError.GetDetail(),
527			Code:   *res.RpcError.Code,
528		}
529		switch remotepb.RpcError_ErrorCode(ce.Code) {
530		case remotepb.RpcError_CANCELLED, remotepb.RpcError_DEADLINE_EXCEEDED:
531			ce.Timeout = true
532		}
533		return ce
534	}
535	if res.ApplicationError != nil {
536		return &APIError{
537			Service: *req.ServiceName,
538			Detail:  res.ApplicationError.GetDetail(),
539			Code:    *res.ApplicationError.Code,
540		}
541	}
542	if res.Exception != nil || res.JavaException != nil {
543		// This shouldn't happen, but let's be defensive.
544		return &CallError{
545			Detail: "service bridge returned exception",
546			Code:   int32(remotepb.RpcError_UNKNOWN),
547		}
548	}
549	return proto.Unmarshal(res.Response, out)
550}
551
552func (c *context) Request() *http.Request {
553	return c.req
554}
555
556func (c *context) addLogLine(ll *logpb.UserAppLogLine) {
557	// Truncate long log lines.
558	// TODO(dsymonds): Check if this is still necessary.
559	const lim = 8 << 10
560	if len(*ll.Message) > lim {
561		suffix := fmt.Sprintf("...(length %d)", len(*ll.Message))
562		ll.Message = proto.String((*ll.Message)[:lim-len(suffix)] + suffix)
563	}
564
565	c.pendingLogs.Lock()
566	c.pendingLogs.lines = append(c.pendingLogs.lines, ll)
567	c.pendingLogs.Unlock()
568}
569
570var logLevelName = map[int64]string{
571	0: "DEBUG",
572	1: "INFO",
573	2: "WARNING",
574	3: "ERROR",
575	4: "CRITICAL",
576}
577
578func logf(c *context, level int64, format string, args ...interface{}) {
579	if c == nil {
580		panic("not an App Engine context")
581	}
582	s := fmt.Sprintf(format, args...)
583	s = strings.TrimRight(s, "\n") // Remove any trailing newline characters.
584	c.addLogLine(&logpb.UserAppLogLine{
585		TimestampUsec: proto.Int64(time.Now().UnixNano() / 1e3),
586		Level:         &level,
587		Message:       &s,
588	})
589	// Only duplicate log to stderr if not running on App Engine second generation
590	if !IsSecondGen() {
591		log.Print(logLevelName[level] + ": " + s)
592	}
593}
594
595// flushLog attempts to flush any pending logs to the appserver.
596// It should not be called concurrently.
597func (c *context) flushLog(force bool) (flushed bool) {
598	c.pendingLogs.Lock()
599	// Grab up to 30 MB. We can get away with up to 32 MB, but let's be cautious.
600	n, rem := 0, 30<<20
601	for ; n < len(c.pendingLogs.lines); n++ {
602		ll := c.pendingLogs.lines[n]
603		// Each log line will require about 3 bytes of overhead.
604		nb := proto.Size(ll) + 3
605		if nb > rem {
606			break
607		}
608		rem -= nb
609	}
610	lines := c.pendingLogs.lines[:n]
611	c.pendingLogs.lines = c.pendingLogs.lines[n:]
612	c.pendingLogs.Unlock()
613
614	if len(lines) == 0 && !force {
615		// Nothing to flush.
616		return false
617	}
618
619	rescueLogs := false
620	defer func() {
621		if rescueLogs {
622			c.pendingLogs.Lock()
623			c.pendingLogs.lines = append(lines, c.pendingLogs.lines...)
624			c.pendingLogs.Unlock()
625		}
626	}()
627
628	buf, err := proto.Marshal(&logpb.UserAppLogGroup{
629		LogLine: lines,
630	})
631	if err != nil {
632		log.Printf("internal.flushLog: marshaling UserAppLogGroup: %v", err)
633		rescueLogs = true
634		return false
635	}
636
637	req := &logpb.FlushRequest{
638		Logs: buf,
639	}
640	res := &basepb.VoidProto{}
641	c.pendingLogs.Lock()
642	c.pendingLogs.flushes++
643	c.pendingLogs.Unlock()
644	if err := Call(toContext(c), "logservice", "Flush", req, res); err != nil {
645		log.Printf("internal.flushLog: Flush RPC: %v", err)
646		rescueLogs = true
647		return false
648	}
649	return true
650}
651
652const (
653	// Log flushing parameters.
654	flushInterval      = 1 * time.Second
655	forceFlushInterval = 60 * time.Second
656)
657
658func (c *context) logFlusher(stop <-chan int) {
659	lastFlush := time.Now()
660	tick := time.NewTicker(flushInterval)
661	for {
662		select {
663		case <-stop:
664			// Request finished.
665			tick.Stop()
666			return
667		case <-tick.C:
668			force := time.Now().Sub(lastFlush) > forceFlushInterval
669			if c.flushLog(force) {
670				lastFlush = time.Now()
671			}
672		}
673	}
674}
675
676func ContextForTesting(req *http.Request) netcontext.Context {
677	return toContext(&context{req: req})
678}
679