1// SPDX-License-Identifier: ISC
2// Copyright (c) 2014-2020 Bitmark Inc.
3// Use of this source code is governed by an ISC
4// license that can be found in the LICENSE file.
5
6package listeners
7
8import (
9	"crypto/tls"
10	"net"
11	"net/rpc"
12	"net/rpc/jsonrpc"
13	"strings"
14
15	"github.com/bitmark-inc/bitmarkd/announce"
16	"github.com/bitmark-inc/bitmarkd/counter"
17	"github.com/bitmark-inc/bitmarkd/fault"
18	"github.com/bitmark-inc/bitmarkd/util"
19	"github.com/bitmark-inc/logger"
20)
21
22const (
23	logName         = "client_rpc"
24	connectionLimit = 100
25	minBandwidth    = 1000000 // 1Mbps
26)
27
28type rpcListener struct {
29	log             *logger.L
30	listener        net.Listener
31	count           *counter.Counter
32	server          *rpc.Server
33	maxConnections  uint64
34	tlsConfig       *tls.Config
35	ipType          []string
36	listenIPAndPort []string
37}
38
39func (r rpcListener) Serve() error {
40	var err error
41	for i, listen := range r.listenIPAndPort {
42		r.log.Infof("starting RPC server: %s", listen)
43		r.listener, err = tls.Listen(r.ipType[i], listen, r.tlsConfig)
44		if err != nil {
45			r.log.Errorf("rpc server listen error: %s", err)
46			return err
47		}
48
49		go doServeRPC(r.listener, r.server, r.maxConnections, r.log, r.count)
50	}
51	return nil
52}
53
54func doServeRPC(listen net.Listener, server *rpc.Server, maximumConnections uint64, log *logger.L, count *counter.Counter) {
55serve_loop:
56	for {
57		conn, err := listen.Accept()
58		if err != nil {
59			log.Errorf("rpc.server terminated: accept error:", err)
60			break serve_loop
61		}
62		if count.Increment() <= maximumConnections {
63			go func() {
64				server.ServeCodec(jsonrpc.NewServerCodec(conn))
65				_ = conn.Close()
66				count.Decrement()
67			}()
68		} else {
69			count.Decrement()
70			_ = conn.Close()
71		}
72
73	}
74	_ = listen.Close()
75	log.Error("RPC accept terminated")
76}
77
78// RPCConfiguration - configuration file data for RPC setup
79type RPCConfiguration struct {
80	MaximumConnections uint64   `gluamapper:"maximum_connections" json:"maximum_connections"`
81	Bandwidth          float64  `gluamapper:"bandwidth" json:"bandwidth"`
82	Listen             []string `gluamapper:"listen" json:"listen"`
83	Certificate        string   `gluamapper:"certificate" json:"certificate"`
84	PrivateKey         string   `gluamapper:"private_key" json:"private_key"`
85	Announce           []string `gluamapper:"announce" json:"announce"`
86}
87
88func NewRPC(
89	configuration *RPCConfiguration,
90	log *logger.L,
91	count *counter.Counter,
92	server *rpc.Server,
93	ann announce.Announce,
94	tlsConfig *tls.Config,
95	certificateFingerprint [32]byte,
96) (Listener, error) {
97	if configuration.MaximumConnections < minConnectionCount {
98		log.Errorf("invalid %s maximum connection limit: %d", logName, configuration.MaximumConnections)
99		return nil, fault.MissingParameters
100	}
101	if configuration.Bandwidth <= minBandwidth { // fail if < 1Mbps
102		log.Errorf("invalid %s bandwidth: %d bps < 1Mbps", logName, configuration.Bandwidth)
103		return nil, fault.MissingParameters
104	}
105
106	r := rpcListener{
107		log:             log,
108		maxConnections:  configuration.MaximumConnections,
109		listenIPAndPort: configuration.Listen,
110		server:          server,
111		count:           count,
112		tlsConfig:       tlsConfig,
113	}
114
115	if 0 == len(configuration.Listen) {
116		log.Errorf("missing %s listen", logName)
117		return nil, fault.MissingParameters
118	}
119
120	log.Infof("%s: SHA3-256 fingerprint: %x", logName, certificateFingerprint)
121
122	// setup announce
123	l := make([]byte, 0, connectionLimit) // ***** FIX THIS: need a better default size
124
125config_loop:
126	for _, address := range configuration.Announce {
127		if "" == address {
128			continue config_loop
129		}
130		c, err := util.NewConnection(address)
131		if nil != err {
132			log.Errorf("invalid %s listen announce: %q  error: %s", logName, address, err)
133			return nil, err
134		}
135		l = append(l, c.Pack()...)
136	}
137
138	err := ann.Set(certificateFingerprint, l)
139	if nil != err {
140		log.Criticalf("announce.Set error: %s", err)
141		return nil, err
142	}
143
144	// validate all listen addresses
145	r.ipType, err = parseListenAddress(configuration.Listen, r.log)
146	if nil != err {
147		return nil, err
148	}
149
150	return &r, nil
151}
152
153func parseListenAddress(addrs []string, log *logger.L) ([]string, error) {
154	parsed := make([]string, len(addrs))
155	for i, listen := range addrs {
156		if '*' == listen[0] {
157			// change "*:PORT" to "[::]:PORT"
158			// on the assumption that this will listen on tcp4 and tcp6
159			addrs[i] = "[::]" + ":" + strings.Split(listen, ":")[1]
160			listen = "::"
161			parsed[i] = "tcp"
162		} else if '[' == listen[0] {
163			listen = strings.Split(listen[1:], "]:")[0]
164			parsed[i] = "tcp6"
165		} else {
166			listen = strings.Split(listen, ":")[0]
167			parsed[i] = "tcp4"
168		}
169
170		if ip := net.ParseIP(listen); nil == ip {
171			err := fault.InvalidIpAddress
172			log.Errorf("rpc server listen error: %s", err)
173			return nil, err
174		}
175	}
176
177	return parsed, nil
178}
179