1// Copyright 2015 Keybase, Inc. All rights reserved. Use of
2// this source code is governed by the included BSD license.
3
4package libkb
5
6import (
7	"bytes"
8	"compress/gzip"
9	"crypto/tls"
10	"crypto/x509"
11	"fmt"
12	"io"
13	"io/ioutil"
14	"net"
15	"net/http"
16	"net/http/cookiejar"
17	"net/url"
18	"regexp"
19	"strconv"
20	"strings"
21	"sync"
22	"time"
23
24	"github.com/keybase/go-framed-msgpack-rpc/rpc"
25	"github.com/keybase/go-framed-msgpack-rpc/rpc/resinit"
26	"golang.org/x/net/context"
27)
28
29type ClientConfig struct {
30	Host       string
31	Port       int
32	UseTLS     bool // XXX unused?
33	URL        *url.URL
34	RootCAs    *x509.CertPool
35	Prefix     string
36	UseCookies bool
37	Timeout    time.Duration
38}
39
40type Client struct {
41	cli    *http.Client
42	config *ClientConfig
43}
44
45var hostRE = regexp.MustCompile("^([^:]+)(:([0-9]+))?$")
46
47func SplitHost(joined string) (host string, port int, err error) {
48	match := hostRE.FindStringSubmatch(joined)
49	if match == nil {
50		err = fmt.Errorf("Invalid host/port found: %s", joined)
51	} else {
52		host = match[1]
53		port = 0
54		if len(match[3]) > 0 {
55			port, err = strconv.Atoi(match[3])
56			if err != nil {
57				err = fmt.Errorf("Could not convert port in host %s", joined)
58			}
59		}
60	}
61	return
62}
63
64func ParseCA(raw string) (*x509.CertPool, error) {
65	ret := x509.NewCertPool()
66	ok := ret.AppendCertsFromPEM([]byte(raw))
67	var err error
68	if !ok {
69		err = fmt.Errorf("Could not read CA for keybase.io")
70		ret = nil
71	}
72	return ret, err
73}
74
75func ShortCA(raw string) string {
76	parts := strings.Split(raw, "\n")
77	if len(parts) >= 3 {
78		parts = parts[0:3]
79	}
80	return strings.Join(parts, " ") + "..."
81}
82
83// GenClientConfigForInternalAPI pulls the information out of the environment configuration,
84// and build a Client config that will be used in all API server
85// requests
86func genClientConfigForInternalAPI(g *GlobalContext) (*ClientConfig, error) {
87	e := g.Env
88	serverURI, err := e.GetServerURI()
89
90	if err != nil {
91		return nil, err
92	}
93
94	if e.GetTorMode().Enabled() {
95		serverURI = e.GetTorHiddenAddress()
96	}
97
98	if serverURI == "" {
99		err := fmt.Errorf("Cannot find a server URL")
100		return nil, err
101	}
102	url, err := url.Parse(serverURI)
103	if err != nil {
104		return nil, err
105	}
106
107	if url.Scheme == "" {
108		return nil, fmt.Errorf("Server URL missing Scheme")
109	}
110
111	if url.Host == "" {
112		return nil, fmt.Errorf("Server URL missing Host")
113	}
114
115	useTLS := (url.Scheme == "https")
116	host, port, e2 := SplitHost(url.Host)
117	if e2 != nil {
118		return nil, e2
119	}
120	var rootCAs *x509.CertPool
121	if rawCA := e.GetBundledCA(host); len(rawCA) > 0 {
122		rootCAs, err = ParseCA(rawCA)
123		if err != nil {
124			err = fmt.Errorf("In parsing CAs for %s: %s", host, err)
125			return nil, err
126		}
127		g.Log.Debug(fmt.Sprintf("Using special root CA for %s: %s",
128			host, ShortCA(rawCA)))
129	}
130
131	// If we're using proxies, they might have their own CAs.
132	if rootCAs, err = GetProxyCAs(rootCAs, e.config); err != nil {
133		return nil, err
134	}
135
136	ret := &ClientConfig{host, port, useTLS, url, rootCAs, url.Path, true, e.GetAPITimeout()}
137	return ret, nil
138}
139
140func genClientConfigForScrapers(e *Env) (*ClientConfig, error) {
141	return &ClientConfig{
142		UseCookies: true,
143		Timeout:    e.GetScraperTimeout(),
144	}, nil
145}
146
147func NewClient(g *GlobalContext, config *ClientConfig, needCookie bool) (*Client, error) {
148	extraLog := func(ctx context.Context, msg string, args ...interface{}) {}
149	if g.Env.GetExtraNetLogging() {
150		extraLog = func(ctx context.Context, msg string, args ...interface{}) {
151			if ctx == nil {
152				g.Log.Debug(msg, args...)
153			} else {
154				g.Log.CDebugf(ctx, msg, args...)
155			}
156		}
157	}
158	extraLog(context.TODO(), "api.Client:%v New", needCookie)
159	env := g.Env
160	var jar *cookiejar.Jar
161	if needCookie && (config == nil || config.UseCookies) && env.GetTorMode().UseCookies() {
162		jar, _ = cookiejar.New(nil)
163	}
164
165	// Originally copied from http.DefaultTransport
166	dialer := net.Dialer{
167		Timeout:   30 * time.Second,
168		KeepAlive: 30 * time.Second,
169		DualStack: true,
170	}
171	xprt := http.Transport{
172		// Don't change this without re-testing proxy support. Currently the client supports proxies through
173		// environment variables that ProxyFromEnvironment picks up
174		Proxy:                 http.ProxyFromEnvironment,
175		DialContext:           (&dialer).DialContext,
176		MaxIdleConns:          200,
177		MaxIdleConnsPerHost:   100,
178		IdleConnTimeout:       90 * time.Second,
179		TLSHandshakeTimeout:   10 * time.Second,
180		ExpectContinueTimeout: 1 * time.Second,
181	}
182
183	xprt.DialContext = func(ctx context.Context, network, addr string) (c net.Conn, err error) {
184		c, err = dialer.DialContext(ctx, network, addr)
185		if err != nil {
186			extraLog(ctx, "api.Client:%v transport.Dial err=%v", needCookie, err)
187			// If we get a DNS error, it could be because glibc has cached an
188			// old version of /etc/resolv.conf. The res_init() libc function
189			// busts that cache and keeps us from getting stuck in a state
190			// where DNS requests keep failing even though the network is up.
191			// This is similar to what the Rust standard library does:
192			// https://github.com/rust-lang/rust/blob/028569ab1b/src/libstd/sys_common/net.rs#L186-L190
193			resinit.ResInitIfDNSError(err)
194			return c, err
195		}
196		if err = rpc.DisableSigPipe(c); err != nil {
197			extraLog(ctx, "api.Client:%v transport.Dial DisableSigPipe err=%v", needCookie, err)
198			return c, err
199		}
200		return c, nil
201	}
202
203	if config != nil && config.RootCAs != nil {
204		xprt.TLSClientConfig = &tls.Config{RootCAs: config.RootCAs}
205	}
206
207	xprt.Proxy = MakeProxy(env)
208
209	if !env.GetTorMode().Enabled() && env.GetRunMode() == DevelRunMode {
210		xprt.Proxy = func(req *http.Request) (*url.URL, error) {
211			host, port, err := net.SplitHostPort(req.URL.Host)
212			if err == nil && host == "localhost" {
213				// ProxyFromEnvironment refuses to proxy when the hostname is set to "localhost".
214				// So make a fake copy of the request with the url set to "127.0.0.1".
215				// This makes localhost requests use proxy settings.
216				// The Host could be anything and is only used to != "localhost".
217				url2 := *req.URL
218				url2.Host = "keybase.io:" + port
219				req2 := req
220				req2.URL = &url2
221				return http.ProxyFromEnvironment(req2)
222			}
223			return http.ProxyFromEnvironment(req)
224		}
225	}
226
227	var timeout time.Duration
228	if config == nil || config.Timeout == 0 {
229		timeout = HTTPDefaultTimeout
230	} else {
231		timeout = config.Timeout
232	}
233
234	ret := &Client{
235		cli:    &http.Client{Timeout: timeout},
236		config: config,
237	}
238	if jar != nil {
239		ret.cli.Jar = jar
240	}
241	ret.cli.Transport = NewInstrumentedRoundTripper(g, InstrumentationTagFromRequest, &xprt)
242	return ret, nil
243}
244
245func ServerLookup(env *Env, mode RunMode) (string, error) {
246	if mode == DevelRunMode {
247		return DevelServerURI, nil
248	}
249	if mode == StagingRunMode {
250		return StagingServerURI, nil
251	}
252	if mode == ProductionRunMode {
253		if env.IsCertPinningEnabled() {
254			// In order to disable SSL pinning we switch to doing requests against keybase.io which has a TLS
255			// cert signed by a publicly trusted CA (compared to api-0.keybaseapi.com which has a non-trusted but
256			// pinned certificate
257			return ProductionServerURI, nil
258		}
259		return ProductionSiteURI, nil
260	}
261	return "", fmt.Errorf("Did not find a server to use with the current RunMode!")
262}
263
264type InstrumentedBody struct {
265	MetaContextified
266	record *rpc.NetworkInstrumenter
267	body   io.ReadCloser
268	// track how large the body is
269	n int
270	// uncompressed indicates if the body was compressed on the wire but
271	// uncompressed by the http library. In this case we recompress to
272	// instrument the gzipped size.
273	uncompressed bool
274	gzipBuf      bytes.Buffer
275	gzipGetter   func(io.Writer) (*gzip.Writer, func())
276}
277
278var _ io.ReadCloser = (*InstrumentedBody)(nil)
279
280func NewInstrumentedBody(mctx MetaContext, record *rpc.NetworkInstrumenter, body io.ReadCloser, uncompressed bool,
281	gzipGetter func(io.Writer) (*gzip.Writer, func())) *InstrumentedBody {
282	return &InstrumentedBody{
283		MetaContextified: NewMetaContextified(mctx),
284		record:           record,
285		body:             body,
286		gzipGetter:       gzipGetter,
287		uncompressed:     uncompressed,
288	}
289}
290
291func (b *InstrumentedBody) Read(p []byte) (n int, err error) {
292	n, err = b.body.Read(p)
293	b.n += n
294	if b.uncompressed && n > 0 {
295		if n, err := b.gzipBuf.Write(p[:n]); err != nil {
296			return n, err
297		}
298	}
299	return n, err
300}
301
302func (b *InstrumentedBody) Close() (err error) {
303	// instrument the full body size even if the caller hasn't consumed it.
304	_, _ = io.Copy(ioutil.Discard, b.body)
305	// Do actual instrumentation in the background
306	go func() {
307		if b.uncompressed {
308			// gzip the body we stored and instrument the compressed size
309			var buf bytes.Buffer
310			writer, reclaim := b.gzipGetter(&buf)
311			defer reclaim()
312			if _, err = writer.Write(b.gzipBuf.Bytes()); err != nil {
313				b.M().Debug("InstrumentedBody:unable to write gzip %v", err)
314				return
315			}
316			if err = writer.Close(); err != nil {
317				b.M().Debug("InstrumentedBody:unable to close gzip %v", err)
318				return
319			}
320			b.record.IncrementSize(int64(buf.Len()))
321		} else {
322			b.record.IncrementSize(int64(b.n))
323		}
324		if err := b.record.Finish(b.M().Ctx()); err != nil {
325			b.M().Debug("InstrumentedBody: unable to instrument network request: %s, %s", b.record, err)
326		}
327	}()
328	return b.body.Close()
329}
330
331type InstrumentedRoundTripper struct {
332	Contextified
333	RoundTripper http.RoundTripper
334	tagger       func(*http.Request) string
335	gzipPool     sync.Pool
336}
337
338var _ http.RoundTripper = (*InstrumentedRoundTripper)(nil)
339
340func NewInstrumentedRoundTripper(g *GlobalContext, tagger func(*http.Request) string, xprt http.RoundTripper) *InstrumentedRoundTripper {
341	return &InstrumentedRoundTripper{
342		Contextified: NewContextified(g),
343		RoundTripper: xprt,
344		tagger:       tagger,
345		gzipPool: sync.Pool{
346			New: func() interface{} {
347				return gzip.NewWriter(ioutil.Discard)
348			},
349		},
350	}
351}
352
353func (i *InstrumentedRoundTripper) getGzipWriter(writer io.Writer) (*gzip.Writer, func()) {
354	gzipWriter := i.gzipPool.Get().(*gzip.Writer)
355	gzipWriter.Reset(writer)
356	return gzipWriter, func() {
357		i.gzipPool.Put(gzipWriter)
358	}
359}
360
361func (i *InstrumentedRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) {
362	tags := LogTagsFromString(req.Header.Get("X-Keybase-Log-Tags"))
363	mctx := NewMetaContextTODO(i.G()).WithLogTags(tags)
364	record := rpc.NewNetworkInstrumenter(i.G().RemoteNetworkInstrumenterStorage, i.tagger(req))
365	resp, err = i.RoundTripper.RoundTrip(req)
366	record.EndCall()
367	if err != nil {
368		if rerr := record.Finish(mctx.Ctx()); rerr != nil {
369			mctx.Debug("InstrumentedTransport: unable to instrument network request %s, %s", record, rerr)
370		}
371		return resp, err
372	}
373	resp.Body = NewInstrumentedBody(mctx, record, resp.Body, resp.Uncompressed, i.getGzipWriter)
374	return resp, err
375}
376