1package pgproto3
2
3import (
4	"bytes"
5	"encoding/binary"
6	"encoding/json"
7	"errors"
8
9	"github.com/jackc/pgio"
10)
11
12type CopyOutResponse struct {
13	OverallFormat     byte
14	ColumnFormatCodes []uint16
15}
16
17func (*CopyOutResponse) Backend() {}
18
19// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
20// type identifier and 4 byte message length.
21func (dst *CopyOutResponse) Decode(src []byte) error {
22	buf := bytes.NewBuffer(src)
23
24	if buf.Len() < 3 {
25		return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
26	}
27
28	overallFormat := buf.Next(1)[0]
29
30	columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
31	if buf.Len() != columnCount*2 {
32		return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
33	}
34
35	columnFormatCodes := make([]uint16, columnCount)
36	for i := 0; i < columnCount; i++ {
37		columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
38	}
39
40	*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
41
42	return nil
43}
44
45// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
46func (src *CopyOutResponse) Encode(dst []byte) []byte {
47	dst = append(dst, 'H')
48	sp := len(dst)
49	dst = pgio.AppendInt32(dst, -1)
50
51	dst = append(dst, src.OverallFormat)
52
53	dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
54	for _, fc := range src.ColumnFormatCodes {
55		dst = pgio.AppendUint16(dst, fc)
56	}
57
58	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
59
60	return dst
61}
62
63// MarshalJSON implements encoding/json.Marshaler.
64func (src CopyOutResponse) MarshalJSON() ([]byte, error) {
65	return json.Marshal(struct {
66		Type              string
67		ColumnFormatCodes []uint16
68	}{
69		Type:              "CopyOutResponse",
70		ColumnFormatCodes: src.ColumnFormatCodes,
71	})
72}
73
74// UnmarshalJSON implements encoding/json.Unmarshaler.
75func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error {
76	// Ignore null, like in the main JSON package.
77	if string(data) == "null" {
78		return nil
79	}
80
81	var msg struct {
82		OverallFormat     string
83		ColumnFormatCodes []uint16
84	}
85	if err := json.Unmarshal(data, &msg); err != nil {
86		return err
87	}
88
89	if len(msg.OverallFormat) != 1 {
90		return errors.New("invalid length for CopyOutResponse.OverallFormat")
91	}
92
93	dst.OverallFormat = msg.OverallFormat[0]
94	dst.ColumnFormatCodes = msg.ColumnFormatCodes
95	return nil
96}
97