1// Copyright 2015 Keybase, Inc. All rights reserved. Use of 2// this source code is governed by the included BSD license. 3 4package libkb 5 6import ( 7 "fmt" 8 "sort" 9 "sync" 10 "time" 11 12 keybase1 "github.com/keybase/client/go/protocol/keybase1" 13 "github.com/keybase/go-framed-msgpack-rpc/rpc" 14) 15 16// ConnectionID is a sequential integer assigned to each RPC connection 17// that this process serves. No IDs are reused. 18type ConnectionID int 19 20// ApplyFn can be applied to every connection. It is called with the 21// RPC transporter, and also the connectionID. It should return a bool 22// true to keep going and false to stop. 23type ApplyFn func(i ConnectionID, xp rpc.Transporter) bool 24 25// ApplyDetailsFn can be applied to every connection. It is called with the 26// RPC transporter, and also the connectionID. It should return a bool 27// true to keep going and false to stop. 28type ApplyDetailsFn func(i ConnectionID, xp rpc.Transporter, details *keybase1.ClientDetails) bool 29 30// LabelCb is a callback to be run when a client connects and labels itself. 31type LabelCb func(typ keybase1.ClientType) 32 33type rpcConnection struct { 34 transporter rpc.Transporter 35 details *keybase1.ClientStatus 36} 37 38// ConnectionManager manages all active connections for a given service. 39// It can be called from multiple goroutines. 40type ConnectionManager struct { 41 sync.Mutex 42 nxt ConnectionID 43 lookup map[ConnectionID](*rpcConnection) 44 labelCbs []LabelCb 45} 46 47// AddConnection adds a new connection to the table of Connection object, with a 48// related closeListener. We'll listen for a close on that channel, and when one occurs, 49// we'll remove the connection from the pool. 50func (c *ConnectionManager) AddConnection(xp rpc.Transporter, closeListener chan error) ConnectionID { 51 c.Lock() 52 c.nxt++ // increment first, since 0 is reserved 53 id := c.nxt 54 c.lookup[id] = &rpcConnection{transporter: xp} 55 c.Unlock() 56 57 if closeListener != nil { 58 go func() { 59 <-closeListener 60 c.removeConnection(id) 61 }() 62 } 63 64 return id 65} 66 67func (c *ConnectionManager) removeConnection(id ConnectionID) { 68 c.Lock() 69 delete(c.lookup, id) 70 c.Unlock() 71} 72 73// LookupConnection looks up a connection given a connectionID, or returns nil 74// if no such connection was found. 75func (c *ConnectionManager) LookupConnection(i ConnectionID) rpc.Transporter { 76 c.Lock() 77 defer c.Unlock() 78 if conn := c.lookup[i]; conn != nil { 79 return conn.transporter 80 } 81 return nil 82} 83 84func (c *ConnectionManager) Shutdown() { 85} 86 87func (c *ConnectionManager) LookupByClientType(clientType keybase1.ClientType) rpc.Transporter { 88 c.Lock() 89 defer c.Unlock() 90 for _, v := range c.lookup { 91 if v.details != nil && v.details.Details.ClientType == clientType { 92 return v.transporter 93 } 94 } 95 return nil 96} 97 98func (c *ConnectionManager) Label(id ConnectionID, d keybase1.ClientDetails) error { 99 c.Lock() 100 defer c.Unlock() 101 102 var err error 103 if conn := c.lookup[id]; conn != nil { 104 conn.details = &keybase1.ClientStatus{ 105 Details: d, 106 ConnectionID: int(id), 107 } 108 } else { 109 err = NotFoundError{Msg: fmt.Sprintf("connection %d not found", id)} 110 } 111 112 // Hit all the callbacks with the client type 113 for _, lloop := range c.labelCbs { 114 go func(l LabelCb) { l(d.ClientType) }(lloop) 115 } 116 117 return err 118} 119 120func (c *ConnectionManager) RegisterLabelCallback(f LabelCb) { 121 c.Lock() 122 c.labelCbs = append(c.labelCbs, f) 123 c.Unlock() 124} 125 126func (c *ConnectionManager) hasClientType(clientType keybase1.ClientType) bool { 127 for _, con := range c.ListAllLabeledConnections() { 128 if clientType == con.Details.ClientType { 129 return true 130 } 131 } 132 return false 133} 134 135// WaitForClientType returns true if client type is connected, or waits until timeout for the connection 136func (c *ConnectionManager) WaitForClientType(clientType keybase1.ClientType, timeout time.Duration) bool { 137 if c.hasClientType(clientType) { 138 return true 139 } 140 ticker := time.NewTicker(time.Second) 141 deadline := time.After(timeout) 142 defer ticker.Stop() 143 for { 144 select { 145 case <-ticker.C: 146 if c.hasClientType(clientType) { 147 return true 148 } 149 case <-deadline: 150 return false 151 } 152 } 153} 154 155func (c *ConnectionManager) ListAllLabeledConnections() (ret []keybase1.ClientStatus) { 156 c.Lock() 157 defer c.Unlock() 158 for _, v := range c.lookup { 159 if v.details != nil { 160 ret = append(ret, *v.details) 161 } 162 } 163 sort.Sort(byClientType(ret)) 164 return ret 165} 166 167type byClientType []keybase1.ClientStatus 168 169func (a byClientType) Len() int { return len(a) } 170func (a byClientType) Swap(i, j int) { a[i], a[j] = a[j], a[i] } 171func (a byClientType) Less(i, j int) bool { return a[i].Details.ClientType < a[j].Details.ClientType } 172 173// ApplyAll applies the given function f to all connections in the table. 174// If you're going to do something blocking, please do it in a GoRoutine, 175// since we're holding the lock for all connections as we do this. 176func (c *ConnectionManager) ApplyAll(f ApplyFn) { 177 c.Lock() 178 defer c.Unlock() 179 for k, v := range c.lookup { 180 if !f(k, v.transporter) { 181 break 182 } 183 } 184} 185 186// ApplyAllDetails applies the given function f to all connections in the table. 187// If you're going to do something blocking, please do it in a GoRoutine, 188// since we're holding the lock for all connections as we do this. 189func (c *ConnectionManager) ApplyAllDetails(f ApplyDetailsFn) { 190 c.Lock() 191 defer c.Unlock() 192 for k, v := range c.lookup { 193 status := v.details 194 var details *keybase1.ClientDetails 195 if status != nil { 196 details = &status.Details 197 } 198 if !f(k, v.transporter, details) { 199 break 200 } 201 } 202} 203 204// NewConnectionManager makes a new ConnectionManager. 205func NewConnectionManager() *ConnectionManager { 206 return &ConnectionManager{ 207 lookup: make(map[ConnectionID](*rpcConnection)), 208 } 209} 210