1// Copyright (C) 2021 Storj Labs, Inc. 2// See LICENSE for copying information. 3 4package drpchttp 5 6import ( 7 "encoding/base64" 8 "encoding/binary" 9 "encoding/json" 10 "errors" 11 "io" 12 "io/ioutil" 13 14 "github.com/zeebo/errs" 15 16 "storj.io/drpc" 17) 18 19const maxSize = 4 << 20 20 21type ( 22 marshalFunc = func(msg drpc.Message, enc drpc.Encoding) ([]byte, error) 23 unmarshalFunc = func(buf []byte, msg drpc.Message, enc drpc.Encoding) error 24 writeFunc = func(w io.Writer, buf []byte) error 25 readFunc = func(r io.Reader) ([]byte, error) 26) 27 28// JSONMarshal looks for a JSONMarshal method on the encoding and calls that if it 29// exists. Otherwise, it does a normal message marshal before doing a JSON marshal. 30func JSONMarshal(msg drpc.Message, enc drpc.Encoding) ([]byte, error) { 31 if enc, ok := enc.(interface { 32 JSONMarshal(msg drpc.Message) ([]byte, error) 33 }); ok { 34 return enc.JSONMarshal(msg) 35 } 36 37 // fallback to normal Marshal + JSON Marshal 38 buf, err := enc.Marshal(msg) 39 if err != nil { 40 return nil, err 41 } 42 return json.Marshal(buf) 43} 44 45// JSONUnmarshal looks for a JSONUnmarshal method on the encoding and calls that 46// if it exists. Otherwise, it JSON unmarshals the buf before doing a normal 47// message unmarshal. 48func JSONUnmarshal(buf []byte, msg drpc.Message, enc drpc.Encoding) error { 49 if enc, ok := enc.(interface { 50 JSONUnmarshal(buf []byte, msg drpc.Message) error 51 }); ok { 52 return enc.JSONUnmarshal(buf, msg) 53 } 54 55 // fallback to JSON Unmarshal + normal Unmarshal 56 var data []byte 57 if err := json.Unmarshal(buf, &data); err != nil { 58 return err 59 } 60 return enc.Unmarshal(data, msg) 61} 62 63func protoMarshal(msg drpc.Message, enc drpc.Encoding) ([]byte, error) { 64 return enc.Marshal(msg) 65} 66 67func protoUnmarshal(buf []byte, msg drpc.Message, enc drpc.Encoding) error { 68 return enc.Unmarshal(buf, msg) 69} 70 71func normalWrite(w io.Writer, buf []byte) error { 72 _, err := w.Write(buf) 73 return err 74} 75 76func base64Write(wf writeFunc) writeFunc { 77 return func(w io.Writer, buf []byte) error { 78 tmp := make([]byte, base64.StdEncoding.EncodedLen(len(buf))) 79 base64.StdEncoding.Encode(tmp, buf) 80 return wf(w, tmp) 81 } 82} 83 84func readExactly(r io.Reader, n uint64) ([]byte, error) { 85 buf := make([]byte, n) 86 _, err := io.ReadFull(r, buf) 87 return buf, err 88} 89 90func grpcRead(r io.Reader) ([]byte, error) { 91 if tmp, err := readExactly(r, 5); err != nil { 92 return nil, err 93 } else if size := binary.BigEndian.Uint32(tmp[1:5]); size > maxSize { 94 return nil, errs.New("message too large") 95 } else if data, err := readExactly(r, uint64(size)); errors.Is(err, io.EOF) { 96 return nil, io.ErrUnexpectedEOF 97 } else if err != nil { 98 return nil, err 99 } else { 100 return data, nil 101 } 102} 103 104func twirpRead(r io.Reader) ([]byte, error) { 105 if data, err := ioutil.ReadAll(io.LimitReader(r, maxSize)); err != nil { 106 return nil, err 107 } else if len(data) > maxSize { 108 return nil, errs.New("message too large") 109 } else { 110 return data, nil 111 } 112} 113 114func base64Read(rf readFunc) readFunc { 115 return func(r io.Reader) ([]byte, error) { 116 return rf(base64.NewDecoder(base64.StdEncoding, r)) 117 } 118} 119