1package sctp
2
3import (
4	"encoding/binary"
5	"fmt"
6	"hash/crc32"
7
8	"github.com/pkg/errors"
9)
10
11// Create the crc32 table we'll use for the checksum
12var castagnoliTable = crc32.MakeTable(crc32.Castagnoli) // nolint:gochecknoglobals
13
14// Allocate and zero this data once.
15// We need to use it for the checksum and don't want to allocate/clear each time.
16var fourZeroes [4]byte // nolint:gochecknoglobals
17
18/*
19Packet represents an SCTP packet, defined in https://tools.ietf.org/html/rfc4960#section-3
20An SCTP packet is composed of a common header and chunks.  A chunk
21contains either control information or user data.
22
23
24                      SCTP Packet Format
25 0                   1                   2                   3
26 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
27+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
28|                        Common Header                          |
29+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
30|                          Chunk #1                             |
31+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
32|                           ...                                 |
33+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
34|                          Chunk #n                             |
35+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
36
37
38                SCTP Common Header Format
39
40 0                   1                   2                   3
41 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
42+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
43|     Source Value Number        |     Destination Value Number   |
44+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
45|                      Verification Tag                         |
46+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
47|                           Checksum                            |
48+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
49
50
51*/
52type packet struct {
53	sourcePort      uint16
54	destinationPort uint16
55	verificationTag uint32
56	chunks          []chunk
57}
58
59const (
60	packetHeaderSize = 12
61)
62
63func (p *packet) unmarshal(raw []byte) error {
64	if len(raw) < packetHeaderSize {
65		return errors.Errorf("raw only %d bytes, %d is the minimum length for a SCTP packet", len(raw), packetHeaderSize)
66	}
67
68	p.sourcePort = binary.BigEndian.Uint16(raw[0:])
69	p.destinationPort = binary.BigEndian.Uint16(raw[2:])
70	p.verificationTag = binary.BigEndian.Uint32(raw[4:])
71
72	offset := packetHeaderSize
73	for {
74		// Exact match, no more chunks
75		if offset == len(raw) {
76			break
77		} else if offset+chunkHeaderSize > len(raw) {
78			return errors.Errorf("Unable to parse SCTP chunk, not enough data for complete header: offset %d remaining %d", offset, len(raw))
79		}
80
81		var c chunk
82		switch chunkType(raw[offset]) {
83		case ctInit:
84			c = &chunkInit{}
85		case ctInitAck:
86			c = &chunkInitAck{}
87		case ctAbort:
88			c = &chunkAbort{}
89		case ctCookieEcho:
90			c = &chunkCookieEcho{}
91		case ctCookieAck:
92			c = &chunkCookieAck{}
93		case ctHeartbeat:
94			c = &chunkHeartbeat{}
95		case ctPayloadData:
96			c = &chunkPayloadData{}
97		case ctSack:
98			c = &chunkSelectiveAck{}
99		case ctReconfig:
100			c = &chunkReconfig{}
101		case ctForwardTSN:
102			c = &chunkForwardTSN{}
103		case ctError:
104			c = &chunkError{}
105		default:
106			return errors.Errorf("Failed to unmarshal, contains unknown chunk type %s", chunkType(raw[offset]).String())
107		}
108
109		if err := c.unmarshal(raw[offset:]); err != nil {
110			return err
111		}
112
113		p.chunks = append(p.chunks, c)
114		chunkValuePadding := getPadding(c.valueLength())
115		offset += chunkHeaderSize + c.valueLength() + chunkValuePadding
116	}
117	theirChecksum := binary.LittleEndian.Uint32(raw[8:])
118	ourChecksum := generatePacketChecksum(raw)
119	if theirChecksum != ourChecksum {
120		return errors.Errorf("Checksum mismatch theirs: %d ours: %d", theirChecksum, ourChecksum)
121	}
122	return nil
123}
124
125func (p *packet) marshal() ([]byte, error) {
126	raw := make([]byte, packetHeaderSize)
127
128	// Populate static headers
129	// 8-12 is Checksum which will be populated when packet is complete
130	binary.BigEndian.PutUint16(raw[0:], p.sourcePort)
131	binary.BigEndian.PutUint16(raw[2:], p.destinationPort)
132	binary.BigEndian.PutUint32(raw[4:], p.verificationTag)
133
134	// Populate chunks
135	for _, c := range p.chunks {
136		chunkRaw, err := c.marshal()
137		if err != nil {
138			return nil, err
139		}
140		raw = append(raw, chunkRaw...)
141
142		paddingNeeded := getPadding(len(raw))
143		if paddingNeeded != 0 {
144			raw = append(raw, make([]byte, paddingNeeded)...)
145		}
146	}
147
148	// Checksum is already in BigEndian
149	// Using LittleEndian.PutUint32 stops it from being flipped
150	binary.LittleEndian.PutUint32(raw[8:], generatePacketChecksum(raw))
151	return raw, nil
152}
153
154func generatePacketChecksum(raw []byte) (sum uint32) {
155	// Fastest way to do a crc32 without allocating.
156	sum = crc32.Update(sum, castagnoliTable, raw[0:8])
157	sum = crc32.Update(sum, castagnoliTable, fourZeroes[:])
158	sum = crc32.Update(sum, castagnoliTable, raw[12:])
159	return sum
160}
161
162// String makes packet printable
163func (p *packet) String() string {
164	format := `Packet:
165	sourcePort: %d
166	destinationPort: %d
167	verificationTag: %d
168	`
169	res := fmt.Sprintf(format,
170		p.sourcePort,
171		p.destinationPort,
172		p.verificationTag,
173	)
174	for i, chunk := range p.chunks {
175		res += fmt.Sprintf("Chunk %d:\n %s", i, chunk)
176	}
177	return res
178}
179