1// Copyright (C) 2021 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package drpchttp
5
6import (
7	"encoding/base64"
8	"encoding/binary"
9	"encoding/json"
10	"errors"
11	"io"
12	"io/ioutil"
13
14	"github.com/zeebo/errs"
15
16	"storj.io/drpc"
17)
18
19const maxSize = 4 << 20
20
21type (
22	marshalFunc   = func(msg drpc.Message, enc drpc.Encoding) ([]byte, error)
23	unmarshalFunc = func(buf []byte, msg drpc.Message, enc drpc.Encoding) error
24	writeFunc     = func(w io.Writer, buf []byte) error
25	readFunc      = func(r io.Reader) ([]byte, error)
26)
27
28// JSONMarshal looks for a JSONMarshal method on the encoding and calls that if it
29// exists. Otherwise, it does a normal message marshal before doing a JSON marshal.
30func JSONMarshal(msg drpc.Message, enc drpc.Encoding) ([]byte, error) {
31	if enc, ok := enc.(interface {
32		JSONMarshal(msg drpc.Message) ([]byte, error)
33	}); ok {
34		return enc.JSONMarshal(msg)
35	}
36
37	// fallback to normal Marshal + JSON Marshal
38	buf, err := enc.Marshal(msg)
39	if err != nil {
40		return nil, err
41	}
42	return json.Marshal(buf)
43}
44
45// JSONUnmarshal looks for a JSONUnmarshal method on the encoding and calls that
46// if it exists. Otherwise, it JSON unmarshals the buf before doing a normal
47// message unmarshal.
48func JSONUnmarshal(buf []byte, msg drpc.Message, enc drpc.Encoding) error {
49	if enc, ok := enc.(interface {
50		JSONUnmarshal(buf []byte, msg drpc.Message) error
51	}); ok {
52		return enc.JSONUnmarshal(buf, msg)
53	}
54
55	// fallback to JSON Unmarshal + normal Unmarshal
56	var data []byte
57	if err := json.Unmarshal(buf, &data); err != nil {
58		return err
59	}
60	return enc.Unmarshal(data, msg)
61}
62
63func protoMarshal(msg drpc.Message, enc drpc.Encoding) ([]byte, error) {
64	return enc.Marshal(msg)
65}
66
67func protoUnmarshal(buf []byte, msg drpc.Message, enc drpc.Encoding) error {
68	return enc.Unmarshal(buf, msg)
69}
70
71func normalWrite(w io.Writer, buf []byte) error {
72	_, err := w.Write(buf)
73	return err
74}
75
76func base64Write(wf writeFunc) writeFunc {
77	return func(w io.Writer, buf []byte) error {
78		tmp := make([]byte, base64.StdEncoding.EncodedLen(len(buf)))
79		base64.StdEncoding.Encode(tmp, buf)
80		return wf(w, tmp)
81	}
82}
83
84func readExactly(r io.Reader, n uint64) ([]byte, error) {
85	buf := make([]byte, n)
86	_, err := io.ReadFull(r, buf)
87	return buf, err
88}
89
90func grpcRead(r io.Reader) ([]byte, error) {
91	if tmp, err := readExactly(r, 5); err != nil {
92		return nil, err
93	} else if size := binary.BigEndian.Uint32(tmp[1:5]); size > maxSize {
94		return nil, errs.New("message too large")
95	} else if data, err := readExactly(r, uint64(size)); errors.Is(err, io.EOF) {
96		return nil, io.ErrUnexpectedEOF
97	} else if err != nil {
98		return nil, err
99	} else {
100		return data, nil
101	}
102}
103
104func twirpRead(r io.Reader) ([]byte, error) {
105	if data, err := ioutil.ReadAll(io.LimitReader(r, maxSize)); err != nil {
106		return nil, err
107	} else if len(data) > maxSize {
108		return nil, errs.New("message too large")
109	} else {
110		return data, nil
111	}
112}
113
114func base64Read(rf readFunc) readFunc {
115	return func(r io.Reader) ([]byte, error) {
116		return rf(base64.NewDecoder(base64.StdEncoding, r))
117	}
118}
119