1// Copyright 2015 Keybase, Inc. All rights reserved. Use of
2// this source code is governed by the included BSD license.
3
4package libkb
5
6import (
7	"fmt"
8	"net"
9
10	"github.com/keybase/client/go/logger"
11	"github.com/keybase/client/go/protocol/keybase1"
12	"github.com/keybase/go-framed-msgpack-rpc/rpc"
13)
14
15// NewSocket() (Socket, err) is defined in the various platform-specific socket_*.go files.
16type Socket interface {
17	BindToSocket() (net.Listener, error)
18	DialSocket() (net.Conn, error)
19}
20
21type SocketInfo struct {
22	log       logger.Logger
23	bindFile  string
24	dialFiles []string
25	testOwner bool //nolint
26}
27
28func (s SocketInfo) GetBindFile() string {
29	return s.bindFile
30}
31
32func (s SocketInfo) GetDialFiles() []string {
33	return s.dialFiles
34}
35
36type SocketWrapper struct {
37	Conn        net.Conn
38	Transporter rpc.Transporter
39	Err         error
40}
41
42func (g *GlobalContext) MakeLoopbackServer() (l net.Listener, err error) {
43	g.socketWrapperMu.Lock()
44	defer g.socketWrapperMu.Unlock()
45	g.LoopbackListener = NewLoopbackListener(g)
46	l = g.LoopbackListener
47	return l, err
48}
49
50func (g *GlobalContext) BindToSocket() (net.Listener, error) {
51	return g.SocketInfo.BindToSocket()
52}
53
54func NewTransportFromSocket(g *GlobalContext, s net.Conn, src keybase1.NetworkSource) rpc.Transporter {
55	return rpc.NewTransport(s, NewRPCLogFactory(g), NetworkInstrumenterStorageFromSrc(g, src), MakeWrapError(g), rpc.DefaultMaxFrameLength)
56}
57
58// ResetSocket clears and returns a new socket
59func (g *GlobalContext) ResetSocket(clearError bool) (net.Conn, rpc.Transporter, bool, error) {
60	g.socketWrapperMu.Lock()
61	defer g.socketWrapperMu.Unlock()
62
63	g.SocketWrapper = nil
64	return g.getSocketLocked(clearError)
65}
66
67func (g *GlobalContext) GetSocket(clearError bool) (conn net.Conn, xp rpc.Transporter, isNew bool, err error) {
68	g.Trace("GetSocket", &err)()
69	g.socketWrapperMu.Lock()
70	defer g.socketWrapperMu.Unlock()
71	return g.getSocketLocked(clearError)
72}
73
74func (g *GlobalContext) getSocketLocked(clearError bool) (conn net.Conn, xp rpc.Transporter, isNew bool, err error) {
75	needWrapper := false
76	if g.SocketWrapper == nil {
77		needWrapper = true
78		g.Log.Debug("| empty socket wrapper; need a new one")
79	} else if g.SocketWrapper.Transporter != nil && !g.SocketWrapper.Transporter.IsConnected() {
80		// need reconnect
81		g.Log.Debug("| rpc transport isn't connected, reconnecting...")
82		needWrapper = true
83	}
84
85	if needWrapper {
86		sw := SocketWrapper{}
87		if g.LoopbackListener != nil {
88			sw.Conn, sw.Err = g.LoopbackListener.Dial()
89		} else if g.SocketInfo == nil {
90			sw.Err = fmt.Errorf("Cannot get socket in standalone mode")
91		} else {
92			sw.Conn, sw.Err = g.SocketInfo.DialSocket()
93			g.Log.Debug("| DialSocket -> %s", ErrToOk(sw.Err))
94			isNew = true
95		}
96		if sw.Err == nil {
97			sw.Transporter = NewTransportFromSocket(g, sw.Conn, keybase1.NetworkSource_LOCAL)
98		}
99		g.SocketWrapper = &sw
100	}
101
102	sw := g.SocketWrapper
103	if sw.Err != nil && clearError {
104		g.SocketWrapper = nil
105	}
106	err = sw.Err
107
108	return sw.Conn, sw.Transporter, isNew, err
109}
110