1package pgproto3
2
3import (
4	"bytes"
5	"encoding/binary"
6	"encoding/json"
7
8	"github.com/jackc/pgx/pgio"
9	"github.com/pkg/errors"
10)
11
12const (
13	ProtocolVersionNumber = 196608 // 3.0
14	sslRequestNumber      = 80877103
15)
16
17type StartupMessage struct {
18	ProtocolVersion uint32
19	Parameters      map[string]string
20}
21
22func (*StartupMessage) Frontend() {}
23
24func (dst *StartupMessage) Decode(src []byte) error {
25	if len(src) < 4 {
26		return errors.Errorf("startup message too short")
27	}
28
29	dst.ProtocolVersion = binary.BigEndian.Uint32(src)
30	rp := 4
31
32	if dst.ProtocolVersion == sslRequestNumber {
33		return errors.Errorf("can't handle ssl connection request")
34	}
35
36	if dst.ProtocolVersion != ProtocolVersionNumber {
37		return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
38	}
39
40	dst.Parameters = make(map[string]string)
41	for {
42		idx := bytes.IndexByte(src[rp:], 0)
43		if idx < 0 {
44			return &invalidMessageFormatErr{messageType: "StartupMesage"}
45		}
46		key := string(src[rp : rp+idx])
47		rp += idx + 1
48
49		idx = bytes.IndexByte(src[rp:], 0)
50		if idx < 0 {
51			return &invalidMessageFormatErr{messageType: "StartupMesage"}
52		}
53		value := string(src[rp : rp+idx])
54		rp += idx + 1
55
56		dst.Parameters[key] = value
57
58		if len(src[rp:]) == 1 {
59			if src[rp] != 0 {
60				return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
61			}
62			break
63		}
64	}
65
66	return nil
67}
68
69func (src *StartupMessage) Encode(dst []byte) []byte {
70	sp := len(dst)
71	dst = pgio.AppendInt32(dst, -1)
72
73	dst = pgio.AppendUint32(dst, src.ProtocolVersion)
74	for k, v := range src.Parameters {
75		dst = append(dst, k...)
76		dst = append(dst, 0)
77		dst = append(dst, v...)
78		dst = append(dst, 0)
79	}
80	dst = append(dst, 0)
81
82	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
83
84	return dst
85}
86
87func (src *StartupMessage) MarshalJSON() ([]byte, error) {
88	return json.Marshal(struct {
89		Type            string
90		ProtocolVersion uint32
91		Parameters      map[string]string
92	}{
93		Type:            "StartupMessage",
94		ProtocolVersion: src.ProtocolVersion,
95		Parameters:      src.Parameters,
96	})
97}
98