1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package drpcmux
5
6import (
7	"reflect"
8
9	"github.com/zeebo/errs"
10
11	"storj.io/drpc"
12)
13
14// Mux is an implementation of Handler to serve drpc connections to the
15// appropriate Receivers registered by Descriptions.
16type Mux struct {
17	rpcs map[string]rpcData
18}
19
20// New constructs a new Mux.
21func New() *Mux {
22	return &Mux{
23		rpcs: make(map[string]rpcData),
24	}
25}
26
27var (
28	streamType  = reflect.TypeOf((*drpc.Stream)(nil)).Elem()
29	messageType = reflect.TypeOf((*drpc.Message)(nil)).Elem()
30)
31
32type rpcData struct {
33	srv      interface{}
34	enc      drpc.Encoding
35	receiver drpc.Receiver
36	in1      reflect.Type
37	in2      reflect.Type
38	unitary  bool
39}
40
41// Register associates the RPCs described by the description in the server.
42// It returns an error if there was a problem registering it.
43func (m *Mux) Register(srv interface{}, desc drpc.Description) error {
44	n := desc.NumMethods()
45	for i := 0; i < n; i++ {
46		rpc, enc, receiver, method, ok := desc.Method(i)
47		if !ok {
48			return errs.New("Description returned invalid method for index %d", i)
49		}
50		if err := m.registerOne(srv, rpc, enc, receiver, method); err != nil {
51			return err
52		}
53	}
54	return nil
55}
56
57// registerOne does the work to register a single rpc.
58func (m *Mux) registerOne(srv interface{}, rpc string, enc drpc.Encoding, receiver drpc.Receiver, method interface{}) error {
59	data := rpcData{srv: srv, enc: enc, receiver: receiver}
60
61	switch mt := reflect.TypeOf(method); {
62	// unitary input, unitary output
63	case mt.NumOut() == 2:
64		data.unitary = true
65		data.in1 = mt.In(2)
66		if !data.in1.Implements(messageType) {
67			return errs.New("input argument not a drpc message: %v", data.in1)
68		}
69
70	// unitary input, stream output
71	case mt.NumIn() == 3:
72		data.in1 = mt.In(1)
73		if !data.in1.Implements(messageType) {
74			return errs.New("input argument not a drpc message: %v", data.in1)
75		}
76		data.in2 = streamType
77
78	// stream input
79	case mt.NumIn() == 2:
80		data.in1 = streamType
81
82	// code gen bug?
83	default:
84		return errs.New("unknown method type: %v", mt)
85	}
86
87	m.rpcs[rpc] = data
88	return nil
89}
90