1// Copyright 2013 Google Inc. All rights reserved.
2// Use of this source code is governed by the Apache 2.0
3// license that can be found in the LICENSE file.
4
5package remote_api
6
7// This file provides the client for connecting remotely to a user's production
8// application.
9
10import (
11	"bytes"
12	"fmt"
13	"io/ioutil"
14	"log"
15	"math/rand"
16	"net/http"
17	"net/url"
18	"regexp"
19	"strconv"
20	"strings"
21	"time"
22
23	"github.com/golang/protobuf/proto"
24	"golang.org/x/net/context"
25
26	"google.golang.org/appengine/internal"
27	pb "google.golang.org/appengine/internal/remote_api"
28)
29
30// Client is a connection to the production APIs for an application.
31type Client struct {
32	hc    *http.Client
33	url   string
34	appID string
35}
36
37// NewClient returns a client for the given host. All communication will
38// be performed over SSL unless the host is localhost.
39func NewClient(host string, client *http.Client) (*Client, error) {
40	// Add an appcfg header to outgoing requests.
41	wrapClient := new(http.Client)
42	*wrapClient = *client
43	t := client.Transport
44	if t == nil {
45		t = http.DefaultTransport
46	}
47	wrapClient.Transport = &headerAddingRoundTripper{t}
48
49	url := url.URL{
50		Scheme: "https",
51		Host:   host,
52		Path:   "/_ah/remote_api",
53	}
54	if host == "localhost" || strings.HasPrefix(host, "localhost:") {
55		url.Scheme = "http"
56	}
57	u := url.String()
58	appID, err := getAppID(wrapClient, u)
59	if err != nil {
60		return nil, fmt.Errorf("unable to contact server: %v", err)
61	}
62	return &Client{
63		hc:    wrapClient,
64		url:   u,
65		appID: appID,
66	}, nil
67}
68
69// NewContext returns a copy of parent that will cause App Engine API
70// calls to be sent to the client's remote host.
71func (c *Client) NewContext(parent context.Context) context.Context {
72	ctx := internal.WithCallOverride(parent, c.call)
73	ctx = internal.WithLogOverride(ctx, c.logf)
74	ctx = internal.WithAppIDOverride(ctx, c.appID)
75	return ctx
76}
77
78// NewRemoteContext returns a context that gives access to the production
79// APIs for the application at the given host. All communication will be
80// performed over SSL unless the host is localhost.
81func NewRemoteContext(host string, client *http.Client) (context.Context, error) {
82	c, err := NewClient(host, client)
83	if err != nil {
84		return nil, err
85	}
86	return c.NewContext(context.Background()), nil
87}
88
89var logLevels = map[int64]string{
90	0: "DEBUG",
91	1: "INFO",
92	2: "WARNING",
93	3: "ERROR",
94	4: "CRITICAL",
95}
96
97func (c *Client) logf(level int64, format string, args ...interface{}) {
98	log.Printf(logLevels[level]+": "+format, args...)
99}
100
101func (c *Client) call(ctx context.Context, service, method string, in, out proto.Message) error {
102	req, err := proto.Marshal(in)
103	if err != nil {
104		return fmt.Errorf("error marshalling request: %v", err)
105	}
106
107	remReq := &pb.Request{
108		ServiceName: proto.String(service),
109		Method:      proto.String(method),
110		Request:     req,
111		// NOTE(djd): RequestId is unused in the server.
112	}
113
114	req, err = proto.Marshal(remReq)
115	if err != nil {
116		return fmt.Errorf("proto.Marshal: %v", err)
117	}
118
119	// TODO(djd): Respect ctx.Deadline()?
120	resp, err := c.hc.Post(c.url, "application/octet-stream", bytes.NewReader(req))
121	if err != nil {
122		return fmt.Errorf("error sending request: %v", err)
123	}
124	defer resp.Body.Close()
125
126	body, err := ioutil.ReadAll(resp.Body)
127	if resp.StatusCode != http.StatusOK {
128		return fmt.Errorf("bad response %d; body: %q", resp.StatusCode, body)
129	}
130	if err != nil {
131		return fmt.Errorf("failed reading response: %v", err)
132	}
133	remResp := &pb.Response{}
134	if err := proto.Unmarshal(body, remResp); err != nil {
135		return fmt.Errorf("error unmarshalling response: %v", err)
136	}
137
138	if ae := remResp.GetApplicationError(); ae != nil {
139		return &internal.APIError{
140			Code:    ae.GetCode(),
141			Detail:  ae.GetDetail(),
142			Service: service,
143		}
144	}
145
146	if remResp.Response == nil {
147		return fmt.Errorf("unexpected response: %s", proto.MarshalTextString(remResp))
148	}
149
150	return proto.Unmarshal(remResp.Response, out)
151}
152
153// This is a forgiving regexp designed to parse the app ID from YAML.
154var appIDRE = regexp.MustCompile(`app_id["']?\s*:\s*['"]?([-a-z0-9.:~]+)`)
155
156func getAppID(client *http.Client, url string) (string, error) {
157	// Generate a pseudo-random token for handshaking.
158	token := strconv.Itoa(rand.New(rand.NewSource(time.Now().UnixNano())).Int())
159
160	resp, err := client.Get(fmt.Sprintf("%s?rtok=%s", url, token))
161	if err != nil {
162		return "", err
163	}
164	defer resp.Body.Close()
165
166	body, err := ioutil.ReadAll(resp.Body)
167	if resp.StatusCode != http.StatusOK {
168		return "", fmt.Errorf("bad response %d; body: %q", resp.StatusCode, body)
169	}
170	if err != nil {
171		return "", fmt.Errorf("failed reading response: %v", err)
172	}
173
174	// Check the token is present in response.
175	if !bytes.Contains(body, []byte(token)) {
176		return "", fmt.Errorf("token not found: want %q; body %q", token, body)
177	}
178
179	match := appIDRE.FindSubmatch(body)
180	if match == nil {
181		return "", fmt.Errorf("app ID not found: body %q", body)
182	}
183
184	return string(match[1]), nil
185}
186
187type headerAddingRoundTripper struct {
188	Wrapped http.RoundTripper
189}
190
191func (t *headerAddingRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
192	r.Header.Set("X-Appcfg-Api-Version", "1")
193	return t.Wrapped.RoundTrip(r)
194}
195