1// Copyright (C) 2019 Storj Labs, Inc. 2// See LICENSE for copying information. 3 4package drpcmux 5 6import ( 7 "reflect" 8 9 "github.com/zeebo/errs" 10 11 "storj.io/drpc" 12) 13 14// Mux is an implementation of Handler to serve drpc connections to the 15// appropriate Receivers registered by Descriptions. 16type Mux struct { 17 rpcs map[string]rpcData 18} 19 20// New constructs a new Mux. 21func New() *Mux { 22 return &Mux{ 23 rpcs: make(map[string]rpcData), 24 } 25} 26 27var ( 28 streamType = reflect.TypeOf((*drpc.Stream)(nil)).Elem() 29 messageType = reflect.TypeOf((*drpc.Message)(nil)).Elem() 30) 31 32type rpcData struct { 33 srv interface{} 34 enc drpc.Encoding 35 receiver drpc.Receiver 36 in1 reflect.Type 37 in2 reflect.Type 38 unitary bool 39} 40 41// Register associates the RPCs described by the description in the server. 42// It returns an error if there was a problem registering it. 43func (m *Mux) Register(srv interface{}, desc drpc.Description) error { 44 n := desc.NumMethods() 45 for i := 0; i < n; i++ { 46 rpc, enc, receiver, method, ok := desc.Method(i) 47 if !ok { 48 return errs.New("Description returned invalid method for index %d", i) 49 } 50 if err := m.registerOne(srv, rpc, enc, receiver, method); err != nil { 51 return err 52 } 53 } 54 return nil 55} 56 57// registerOne does the work to register a single rpc. 58func (m *Mux) registerOne(srv interface{}, rpc string, enc drpc.Encoding, receiver drpc.Receiver, method interface{}) error { 59 data := rpcData{srv: srv, enc: enc, receiver: receiver} 60 61 switch mt := reflect.TypeOf(method); { 62 // unitary input, unitary output 63 case mt.NumOut() == 2: 64 data.unitary = true 65 data.in1 = mt.In(2) 66 if !data.in1.Implements(messageType) { 67 return errs.New("input argument not a drpc message: %v", data.in1) 68 } 69 70 // unitary input, stream output 71 case mt.NumIn() == 3: 72 data.in1 = mt.In(1) 73 if !data.in1.Implements(messageType) { 74 return errs.New("input argument not a drpc message: %v", data.in1) 75 } 76 data.in2 = streamType 77 78 // stream input 79 case mt.NumIn() == 2: 80 data.in1 = streamType 81 82 // code gen bug? 83 default: 84 return errs.New("unknown method type: %v", mt) 85 } 86 87 m.rpcs[rpc] = data 88 return nil 89} 90