1// Copyright 2011 Google Inc. All rights reserved.
2// Use of this source code is governed by the Apache 2.0
3// license that can be found in the LICENSE file.
4
5// +build !appengine
6
7package internal
8
9import (
10	"bytes"
11	"errors"
12	"fmt"
13	"io/ioutil"
14	"log"
15	"net"
16	"net/http"
17	"net/url"
18	"os"
19	"runtime"
20	"strconv"
21	"strings"
22	"sync"
23	"sync/atomic"
24	"time"
25
26	"github.com/golang/protobuf/proto"
27	netcontext "golang.org/x/net/context"
28
29	basepb "google.golang.org/appengine/internal/base"
30	logpb "google.golang.org/appengine/internal/log"
31	remotepb "google.golang.org/appengine/internal/remote_api"
32)
33
34const (
35	apiPath             = "/rpc_http"
36	defaultTicketSuffix = "/default.20150612t184001.0"
37)
38
39var (
40	// Incoming headers.
41	ticketHeader       = http.CanonicalHeaderKey("X-AppEngine-API-Ticket")
42	dapperHeader       = http.CanonicalHeaderKey("X-Google-DapperTraceInfo")
43	traceHeader        = http.CanonicalHeaderKey("X-Cloud-Trace-Context")
44	curNamespaceHeader = http.CanonicalHeaderKey("X-AppEngine-Current-Namespace")
45	userIPHeader       = http.CanonicalHeaderKey("X-AppEngine-User-IP")
46	remoteAddrHeader   = http.CanonicalHeaderKey("X-AppEngine-Remote-Addr")
47
48	// Outgoing headers.
49	apiEndpointHeader      = http.CanonicalHeaderKey("X-Google-RPC-Service-Endpoint")
50	apiEndpointHeaderValue = []string{"app-engine-apis"}
51	apiMethodHeader        = http.CanonicalHeaderKey("X-Google-RPC-Service-Method")
52	apiMethodHeaderValue   = []string{"/VMRemoteAPI.CallRemoteAPI"}
53	apiDeadlineHeader      = http.CanonicalHeaderKey("X-Google-RPC-Service-Deadline")
54	apiContentType         = http.CanonicalHeaderKey("Content-Type")
55	apiContentTypeValue    = []string{"application/octet-stream"}
56	logFlushHeader         = http.CanonicalHeaderKey("X-AppEngine-Log-Flush-Count")
57
58	apiHTTPClient = &http.Client{
59		Transport: &http.Transport{
60			Proxy: http.ProxyFromEnvironment,
61			Dial:  limitDial,
62		},
63	}
64
65	defaultTicketOnce     sync.Once
66	defaultTicket         string
67	backgroundContextOnce sync.Once
68	backgroundContext     netcontext.Context
69)
70
71func apiURL() *url.URL {
72	host, port := "appengine.googleapis.internal", "10001"
73	if h := os.Getenv("API_HOST"); h != "" {
74		host = h
75	}
76	if p := os.Getenv("API_PORT"); p != "" {
77		port = p
78	}
79	return &url.URL{
80		Scheme: "http",
81		Host:   host + ":" + port,
82		Path:   apiPath,
83	}
84}
85
86func handleHTTP(w http.ResponseWriter, r *http.Request) {
87	c := &context{
88		req:       r,
89		outHeader: w.Header(),
90		apiURL:    apiURL(),
91	}
92	r = r.WithContext(withContext(r.Context(), c))
93	c.req = r
94
95	stopFlushing := make(chan int)
96
97	// Patch up RemoteAddr so it looks reasonable.
98	if addr := r.Header.Get(userIPHeader); addr != "" {
99		r.RemoteAddr = addr
100	} else if addr = r.Header.Get(remoteAddrHeader); addr != "" {
101		r.RemoteAddr = addr
102	} else {
103		// Should not normally reach here, but pick a sensible default anyway.
104		r.RemoteAddr = "127.0.0.1"
105	}
106	// The address in the headers will most likely be of these forms:
107	//	123.123.123.123
108	//	2001:db8::1
109	// net/http.Request.RemoteAddr is specified to be in "IP:port" form.
110	if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
111		// Assume the remote address is only a host; add a default port.
112		r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
113	}
114
115	// Start goroutine responsible for flushing app logs.
116	// This is done after adding c to ctx.m (and stopped before removing it)
117	// because flushing logs requires making an API call.
118	go c.logFlusher(stopFlushing)
119
120	executeRequestSafely(c, r)
121	c.outHeader = nil // make sure header changes aren't respected any more
122
123	stopFlushing <- 1 // any logging beyond this point will be dropped
124
125	// Flush any pending logs asynchronously.
126	c.pendingLogs.Lock()
127	flushes := c.pendingLogs.flushes
128	if len(c.pendingLogs.lines) > 0 {
129		flushes++
130	}
131	c.pendingLogs.Unlock()
132	flushed := make(chan struct{})
133	go func() {
134		defer close(flushed)
135		// Force a log flush, because with very short requests we
136		// may not ever flush logs.
137		c.flushLog(true)
138	}()
139	w.Header().Set(logFlushHeader, strconv.Itoa(flushes))
140
141	// Avoid nil Write call if c.Write is never called.
142	if c.outCode != 0 {
143		w.WriteHeader(c.outCode)
144	}
145	if c.outBody != nil {
146		w.Write(c.outBody)
147	}
148	// Wait for the last flush to complete before returning,
149	// otherwise the security ticket will not be valid.
150	<-flushed
151}
152
153func executeRequestSafely(c *context, r *http.Request) {
154	defer func() {
155		if x := recover(); x != nil {
156			logf(c, 4, "%s", renderPanic(x)) // 4 == critical
157			c.outCode = 500
158		}
159	}()
160
161	http.DefaultServeMux.ServeHTTP(c, r)
162}
163
164func renderPanic(x interface{}) string {
165	buf := make([]byte, 16<<10) // 16 KB should be plenty
166	buf = buf[:runtime.Stack(buf, false)]
167
168	// Remove the first few stack frames:
169	//   this func
170	//   the recover closure in the caller
171	// That will root the stack trace at the site of the panic.
172	const (
173		skipStart  = "internal.renderPanic"
174		skipFrames = 2
175	)
176	start := bytes.Index(buf, []byte(skipStart))
177	p := start
178	for i := 0; i < skipFrames*2 && p+1 < len(buf); i++ {
179		p = bytes.IndexByte(buf[p+1:], '\n') + p + 1
180		if p < 0 {
181			break
182		}
183	}
184	if p >= 0 {
185		// buf[start:p+1] is the block to remove.
186		// Copy buf[p+1:] over buf[start:] and shrink buf.
187		copy(buf[start:], buf[p+1:])
188		buf = buf[:len(buf)-(p+1-start)]
189	}
190
191	// Add panic heading.
192	head := fmt.Sprintf("panic: %v\n\n", x)
193	if len(head) > len(buf) {
194		// Extremely unlikely to happen.
195		return head
196	}
197	copy(buf[len(head):], buf)
198	copy(buf, head)
199
200	return string(buf)
201}
202
203// context represents the context of an in-flight HTTP request.
204// It implements the appengine.Context and http.ResponseWriter interfaces.
205type context struct {
206	req *http.Request
207
208	outCode   int
209	outHeader http.Header
210	outBody   []byte
211
212	pendingLogs struct {
213		sync.Mutex
214		lines   []*logpb.UserAppLogLine
215		flushes int
216	}
217
218	apiURL *url.URL
219}
220
221var contextKey = "holds a *context"
222
223// jointContext joins two contexts in a superficial way.
224// It takes values and timeouts from a base context, and only values from another context.
225type jointContext struct {
226	base       netcontext.Context
227	valuesOnly netcontext.Context
228}
229
230func (c jointContext) Deadline() (time.Time, bool) {
231	return c.base.Deadline()
232}
233
234func (c jointContext) Done() <-chan struct{} {
235	return c.base.Done()
236}
237
238func (c jointContext) Err() error {
239	return c.base.Err()
240}
241
242func (c jointContext) Value(key interface{}) interface{} {
243	if val := c.base.Value(key); val != nil {
244		return val
245	}
246	return c.valuesOnly.Value(key)
247}
248
249// fromContext returns the App Engine context or nil if ctx is not
250// derived from an App Engine context.
251func fromContext(ctx netcontext.Context) *context {
252	c, _ := ctx.Value(&contextKey).(*context)
253	return c
254}
255
256func withContext(parent netcontext.Context, c *context) netcontext.Context {
257	ctx := netcontext.WithValue(parent, &contextKey, c)
258	if ns := c.req.Header.Get(curNamespaceHeader); ns != "" {
259		ctx = withNamespace(ctx, ns)
260	}
261	return ctx
262}
263
264func toContext(c *context) netcontext.Context {
265	return withContext(netcontext.Background(), c)
266}
267
268func IncomingHeaders(ctx netcontext.Context) http.Header {
269	if c := fromContext(ctx); c != nil {
270		return c.req.Header
271	}
272	return nil
273}
274
275func ReqContext(req *http.Request) netcontext.Context {
276	return req.Context()
277}
278
279func WithContext(parent netcontext.Context, req *http.Request) netcontext.Context {
280	return jointContext{
281		base:       parent,
282		valuesOnly: req.Context(),
283	}
284}
285
286// DefaultTicket returns a ticket used for background context or dev_appserver.
287func DefaultTicket() string {
288	defaultTicketOnce.Do(func() {
289		if IsDevAppServer() {
290			defaultTicket = "testapp" + defaultTicketSuffix
291			return
292		}
293		appID := partitionlessAppID()
294		escAppID := strings.Replace(strings.Replace(appID, ":", "_", -1), ".", "_", -1)
295		majVersion := VersionID(nil)
296		if i := strings.Index(majVersion, "."); i > 0 {
297			majVersion = majVersion[:i]
298		}
299		defaultTicket = fmt.Sprintf("%s/%s.%s.%s", escAppID, ModuleName(nil), majVersion, InstanceID())
300	})
301	return defaultTicket
302}
303
304func BackgroundContext() netcontext.Context {
305	backgroundContextOnce.Do(func() {
306		// Compute background security ticket.
307		ticket := DefaultTicket()
308
309		c := &context{
310			req: &http.Request{
311				Header: http.Header{
312					ticketHeader: []string{ticket},
313				},
314			},
315			apiURL: apiURL(),
316		}
317		backgroundContext = toContext(c)
318
319		// TODO(dsymonds): Wire up the shutdown handler to do a final flush.
320		go c.logFlusher(make(chan int))
321	})
322
323	return backgroundContext
324}
325
326// RegisterTestRequest registers the HTTP request req for testing, such that
327// any API calls are sent to the provided URL. It returns a closure to delete
328// the registration.
329// It should only be used by aetest package.
330func RegisterTestRequest(req *http.Request, apiURL *url.URL, decorate func(netcontext.Context) netcontext.Context) (*http.Request, func()) {
331	c := &context{
332		req:    req,
333		apiURL: apiURL,
334	}
335	ctx := withContext(decorate(req.Context()), c)
336	req = req.WithContext(ctx)
337	c.req = req
338	return req, func() {}
339}
340
341var errTimeout = &CallError{
342	Detail:  "Deadline exceeded",
343	Code:    int32(remotepb.RpcError_CANCELLED),
344	Timeout: true,
345}
346
347func (c *context) Header() http.Header { return c.outHeader }
348
349// Copied from $GOROOT/src/pkg/net/http/transfer.go. Some response status
350// codes do not permit a response body (nor response entity headers such as
351// Content-Length, Content-Type, etc).
352func bodyAllowedForStatus(status int) bool {
353	switch {
354	case status >= 100 && status <= 199:
355		return false
356	case status == 204:
357		return false
358	case status == 304:
359		return false
360	}
361	return true
362}
363
364func (c *context) Write(b []byte) (int, error) {
365	if c.outCode == 0 {
366		c.WriteHeader(http.StatusOK)
367	}
368	if len(b) > 0 && !bodyAllowedForStatus(c.outCode) {
369		return 0, http.ErrBodyNotAllowed
370	}
371	c.outBody = append(c.outBody, b...)
372	return len(b), nil
373}
374
375func (c *context) WriteHeader(code int) {
376	if c.outCode != 0 {
377		logf(c, 3, "WriteHeader called multiple times on request.") // error level
378		return
379	}
380	c.outCode = code
381}
382
383func (c *context) post(body []byte, timeout time.Duration) (b []byte, err error) {
384	hreq := &http.Request{
385		Method: "POST",
386		URL:    c.apiURL,
387		Header: http.Header{
388			apiEndpointHeader: apiEndpointHeaderValue,
389			apiMethodHeader:   apiMethodHeaderValue,
390			apiContentType:    apiContentTypeValue,
391			apiDeadlineHeader: []string{strconv.FormatFloat(timeout.Seconds(), 'f', -1, 64)},
392		},
393		Body:          ioutil.NopCloser(bytes.NewReader(body)),
394		ContentLength: int64(len(body)),
395		Host:          c.apiURL.Host,
396	}
397	if info := c.req.Header.Get(dapperHeader); info != "" {
398		hreq.Header.Set(dapperHeader, info)
399	}
400	if info := c.req.Header.Get(traceHeader); info != "" {
401		hreq.Header.Set(traceHeader, info)
402	}
403
404	tr := apiHTTPClient.Transport.(*http.Transport)
405
406	var timedOut int32 // atomic; set to 1 if timed out
407	t := time.AfterFunc(timeout, func() {
408		atomic.StoreInt32(&timedOut, 1)
409		tr.CancelRequest(hreq)
410	})
411	defer t.Stop()
412	defer func() {
413		// Check if timeout was exceeded.
414		if atomic.LoadInt32(&timedOut) != 0 {
415			err = errTimeout
416		}
417	}()
418
419	hresp, err := apiHTTPClient.Do(hreq)
420	if err != nil {
421		return nil, &CallError{
422			Detail: fmt.Sprintf("service bridge HTTP failed: %v", err),
423			Code:   int32(remotepb.RpcError_UNKNOWN),
424		}
425	}
426	defer hresp.Body.Close()
427	hrespBody, err := ioutil.ReadAll(hresp.Body)
428	if hresp.StatusCode != 200 {
429		return nil, &CallError{
430			Detail: fmt.Sprintf("service bridge returned HTTP %d (%q)", hresp.StatusCode, hrespBody),
431			Code:   int32(remotepb.RpcError_UNKNOWN),
432		}
433	}
434	if err != nil {
435		return nil, &CallError{
436			Detail: fmt.Sprintf("service bridge response bad: %v", err),
437			Code:   int32(remotepb.RpcError_UNKNOWN),
438		}
439	}
440	return hrespBody, nil
441}
442
443func Call(ctx netcontext.Context, service, method string, in, out proto.Message) error {
444	if ns := NamespaceFromContext(ctx); ns != "" {
445		if fn, ok := NamespaceMods[service]; ok {
446			fn(in, ns)
447		}
448	}
449
450	if f, ctx, ok := callOverrideFromContext(ctx); ok {
451		return f(ctx, service, method, in, out)
452	}
453
454	// Handle already-done contexts quickly.
455	select {
456	case <-ctx.Done():
457		return ctx.Err()
458	default:
459	}
460
461	c := fromContext(ctx)
462	if c == nil {
463		// Give a good error message rather than a panic lower down.
464		return errNotAppEngineContext
465	}
466
467	// Apply transaction modifications if we're in a transaction.
468	if t := transactionFromContext(ctx); t != nil {
469		if t.finished {
470			return errors.New("transaction context has expired")
471		}
472		applyTransaction(in, &t.transaction)
473	}
474
475	// Default RPC timeout is 60s.
476	timeout := 60 * time.Second
477	if deadline, ok := ctx.Deadline(); ok {
478		timeout = deadline.Sub(time.Now())
479	}
480
481	data, err := proto.Marshal(in)
482	if err != nil {
483		return err
484	}
485
486	ticket := c.req.Header.Get(ticketHeader)
487	// Use a test ticket under test environment.
488	if ticket == "" {
489		if appid := ctx.Value(&appIDOverrideKey); appid != nil {
490			ticket = appid.(string) + defaultTicketSuffix
491		}
492	}
493	// Fall back to use background ticket when the request ticket is not available in Flex or dev_appserver.
494	if ticket == "" {
495		ticket = DefaultTicket()
496	}
497	req := &remotepb.Request{
498		ServiceName: &service,
499		Method:      &method,
500		Request:     data,
501		RequestId:   &ticket,
502	}
503	hreqBody, err := proto.Marshal(req)
504	if err != nil {
505		return err
506	}
507
508	hrespBody, err := c.post(hreqBody, timeout)
509	if err != nil {
510		return err
511	}
512
513	res := &remotepb.Response{}
514	if err := proto.Unmarshal(hrespBody, res); err != nil {
515		return err
516	}
517	if res.RpcError != nil {
518		ce := &CallError{
519			Detail: res.RpcError.GetDetail(),
520			Code:   *res.RpcError.Code,
521		}
522		switch remotepb.RpcError_ErrorCode(ce.Code) {
523		case remotepb.RpcError_CANCELLED, remotepb.RpcError_DEADLINE_EXCEEDED:
524			ce.Timeout = true
525		}
526		return ce
527	}
528	if res.ApplicationError != nil {
529		return &APIError{
530			Service: *req.ServiceName,
531			Detail:  res.ApplicationError.GetDetail(),
532			Code:    *res.ApplicationError.Code,
533		}
534	}
535	if res.Exception != nil || res.JavaException != nil {
536		// This shouldn't happen, but let's be defensive.
537		return &CallError{
538			Detail: "service bridge returned exception",
539			Code:   int32(remotepb.RpcError_UNKNOWN),
540		}
541	}
542	return proto.Unmarshal(res.Response, out)
543}
544
545func (c *context) Request() *http.Request {
546	return c.req
547}
548
549func (c *context) addLogLine(ll *logpb.UserAppLogLine) {
550	// Truncate long log lines.
551	// TODO(dsymonds): Check if this is still necessary.
552	const lim = 8 << 10
553	if len(*ll.Message) > lim {
554		suffix := fmt.Sprintf("...(length %d)", len(*ll.Message))
555		ll.Message = proto.String((*ll.Message)[:lim-len(suffix)] + suffix)
556	}
557
558	c.pendingLogs.Lock()
559	c.pendingLogs.lines = append(c.pendingLogs.lines, ll)
560	c.pendingLogs.Unlock()
561}
562
563var logLevelName = map[int64]string{
564	0: "DEBUG",
565	1: "INFO",
566	2: "WARNING",
567	3: "ERROR",
568	4: "CRITICAL",
569}
570
571func logf(c *context, level int64, format string, args ...interface{}) {
572	if c == nil {
573		panic("not an App Engine context")
574	}
575	s := fmt.Sprintf(format, args...)
576	s = strings.TrimRight(s, "\n") // Remove any trailing newline characters.
577	c.addLogLine(&logpb.UserAppLogLine{
578		TimestampUsec: proto.Int64(time.Now().UnixNano() / 1e3),
579		Level:         &level,
580		Message:       &s,
581	})
582	log.Print(logLevelName[level] + ": " + s)
583}
584
585// flushLog attempts to flush any pending logs to the appserver.
586// It should not be called concurrently.
587func (c *context) flushLog(force bool) (flushed bool) {
588	c.pendingLogs.Lock()
589	// Grab up to 30 MB. We can get away with up to 32 MB, but let's be cautious.
590	n, rem := 0, 30<<20
591	for ; n < len(c.pendingLogs.lines); n++ {
592		ll := c.pendingLogs.lines[n]
593		// Each log line will require about 3 bytes of overhead.
594		nb := proto.Size(ll) + 3
595		if nb > rem {
596			break
597		}
598		rem -= nb
599	}
600	lines := c.pendingLogs.lines[:n]
601	c.pendingLogs.lines = c.pendingLogs.lines[n:]
602	c.pendingLogs.Unlock()
603
604	if len(lines) == 0 && !force {
605		// Nothing to flush.
606		return false
607	}
608
609	rescueLogs := false
610	defer func() {
611		if rescueLogs {
612			c.pendingLogs.Lock()
613			c.pendingLogs.lines = append(lines, c.pendingLogs.lines...)
614			c.pendingLogs.Unlock()
615		}
616	}()
617
618	buf, err := proto.Marshal(&logpb.UserAppLogGroup{
619		LogLine: lines,
620	})
621	if err != nil {
622		log.Printf("internal.flushLog: marshaling UserAppLogGroup: %v", err)
623		rescueLogs = true
624		return false
625	}
626
627	req := &logpb.FlushRequest{
628		Logs: buf,
629	}
630	res := &basepb.VoidProto{}
631	c.pendingLogs.Lock()
632	c.pendingLogs.flushes++
633	c.pendingLogs.Unlock()
634	if err := Call(toContext(c), "logservice", "Flush", req, res); err != nil {
635		log.Printf("internal.flushLog: Flush RPC: %v", err)
636		rescueLogs = true
637		return false
638	}
639	return true
640}
641
642const (
643	// Log flushing parameters.
644	flushInterval      = 1 * time.Second
645	forceFlushInterval = 60 * time.Second
646)
647
648func (c *context) logFlusher(stop <-chan int) {
649	lastFlush := time.Now()
650	tick := time.NewTicker(flushInterval)
651	for {
652		select {
653		case <-stop:
654			// Request finished.
655			tick.Stop()
656			return
657		case <-tick.C:
658			force := time.Now().Sub(lastFlush) > forceFlushInterval
659			if c.flushLog(force) {
660				lastFlush = time.Now()
661			}
662		}
663	}
664}
665
666func ContextForTesting(req *http.Request) netcontext.Context {
667	return toContext(&context{req: req})
668}
669