1// Copyright 2015 Keybase, Inc. All rights reserved. Use of
2// this source code is governed by the included BSD license.
3
4package libkb
5
6import (
7	"fmt"
8	"sort"
9	"sync"
10	"time"
11
12	keybase1 "github.com/keybase/client/go/protocol/keybase1"
13	"github.com/keybase/go-framed-msgpack-rpc/rpc"
14)
15
16// ConnectionID is a sequential integer assigned to each RPC connection
17// that this process serves. No IDs are reused.
18type ConnectionID int
19
20// ApplyFn can be applied to every connection. It is called with the
21// RPC transporter, and also the connectionID. It should return a bool
22// true to keep going and false to stop.
23type ApplyFn func(i ConnectionID, xp rpc.Transporter) bool
24
25// ApplyDetailsFn can be applied to every connection. It is called with the
26// RPC transporter, and also the connectionID. It should return a bool
27// true to keep going and false to stop.
28type ApplyDetailsFn func(i ConnectionID, xp rpc.Transporter, details *keybase1.ClientDetails) bool
29
30// LabelCb is a callback to be run when a client connects and labels itself.
31type LabelCb func(typ keybase1.ClientType)
32
33type rpcConnection struct {
34	transporter rpc.Transporter
35	details     *keybase1.ClientStatus
36}
37
38// ConnectionManager manages all active connections for a given service.
39// It can be called from multiple goroutines.
40type ConnectionManager struct {
41	sync.Mutex
42	nxt      ConnectionID
43	lookup   map[ConnectionID](*rpcConnection)
44	labelCbs []LabelCb
45}
46
47// AddConnection adds a new connection to the table of Connection object, with a
48// related closeListener. We'll listen for a close on that channel, and when one occurs,
49// we'll remove the connection from the pool.
50func (c *ConnectionManager) AddConnection(xp rpc.Transporter, closeListener chan error) ConnectionID {
51	c.Lock()
52	c.nxt++ // increment first, since 0 is reserved
53	id := c.nxt
54	c.lookup[id] = &rpcConnection{transporter: xp}
55	c.Unlock()
56
57	if closeListener != nil {
58		go func() {
59			<-closeListener
60			c.removeConnection(id)
61		}()
62	}
63
64	return id
65}
66
67func (c *ConnectionManager) removeConnection(id ConnectionID) {
68	c.Lock()
69	delete(c.lookup, id)
70	c.Unlock()
71}
72
73// LookupConnection looks up a connection given a connectionID, or returns nil
74// if no such connection was found.
75func (c *ConnectionManager) LookupConnection(i ConnectionID) rpc.Transporter {
76	c.Lock()
77	defer c.Unlock()
78	if conn := c.lookup[i]; conn != nil {
79		return conn.transporter
80	}
81	return nil
82}
83
84func (c *ConnectionManager) Shutdown() {
85}
86
87func (c *ConnectionManager) LookupByClientType(clientType keybase1.ClientType) rpc.Transporter {
88	c.Lock()
89	defer c.Unlock()
90	for _, v := range c.lookup {
91		if v.details != nil && v.details.Details.ClientType == clientType {
92			return v.transporter
93		}
94	}
95	return nil
96}
97
98func (c *ConnectionManager) Label(id ConnectionID, d keybase1.ClientDetails) error {
99	c.Lock()
100	defer c.Unlock()
101
102	var err error
103	if conn := c.lookup[id]; conn != nil {
104		conn.details = &keybase1.ClientStatus{
105			Details:      d,
106			ConnectionID: int(id),
107		}
108	} else {
109		err = NotFoundError{Msg: fmt.Sprintf("connection %d not found", id)}
110	}
111
112	// Hit all the callbacks with the client type
113	for _, lloop := range c.labelCbs {
114		go func(l LabelCb) { l(d.ClientType) }(lloop)
115	}
116
117	return err
118}
119
120func (c *ConnectionManager) RegisterLabelCallback(f LabelCb) {
121	c.Lock()
122	c.labelCbs = append(c.labelCbs, f)
123	c.Unlock()
124}
125
126func (c *ConnectionManager) hasClientType(clientType keybase1.ClientType) bool {
127	for _, con := range c.ListAllLabeledConnections() {
128		if clientType == con.Details.ClientType {
129			return true
130		}
131	}
132	return false
133}
134
135// WaitForClientType returns true if client type is connected, or waits until timeout for the connection
136func (c *ConnectionManager) WaitForClientType(clientType keybase1.ClientType, timeout time.Duration) bool {
137	if c.hasClientType(clientType) {
138		return true
139	}
140	ticker := time.NewTicker(time.Second)
141	deadline := time.After(timeout)
142	defer ticker.Stop()
143	for {
144		select {
145		case <-ticker.C:
146			if c.hasClientType(clientType) {
147				return true
148			}
149		case <-deadline:
150			return false
151		}
152	}
153}
154
155func (c *ConnectionManager) ListAllLabeledConnections() (ret []keybase1.ClientStatus) {
156	c.Lock()
157	defer c.Unlock()
158	for _, v := range c.lookup {
159		if v.details != nil {
160			ret = append(ret, *v.details)
161		}
162	}
163	sort.Sort(byClientType(ret))
164	return ret
165}
166
167type byClientType []keybase1.ClientStatus
168
169func (a byClientType) Len() int           { return len(a) }
170func (a byClientType) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
171func (a byClientType) Less(i, j int) bool { return a[i].Details.ClientType < a[j].Details.ClientType }
172
173// ApplyAll applies the given function f to all connections in the table.
174// If you're going to do something blocking, please do it in a GoRoutine,
175// since we're holding the lock for all connections as we do this.
176func (c *ConnectionManager) ApplyAll(f ApplyFn) {
177	c.Lock()
178	defer c.Unlock()
179	for k, v := range c.lookup {
180		if !f(k, v.transporter) {
181			break
182		}
183	}
184}
185
186// ApplyAllDetails applies the given function f to all connections in the table.
187// If you're going to do something blocking, please do it in a GoRoutine,
188// since we're holding the lock for all connections as we do this.
189func (c *ConnectionManager) ApplyAllDetails(f ApplyDetailsFn) {
190	c.Lock()
191	defer c.Unlock()
192	for k, v := range c.lookup {
193		status := v.details
194		var details *keybase1.ClientDetails
195		if status != nil {
196			details = &status.Details
197		}
198		if !f(k, v.transporter, details) {
199			break
200		}
201	}
202}
203
204// NewConnectionManager makes a new ConnectionManager.
205func NewConnectionManager() *ConnectionManager {
206	return &ConnectionManager{
207		lookup: make(map[ConnectionID](*rpcConnection)),
208	}
209}
210