1// Copyright 2012 Google, Inc. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style license
4// that can be found in the LICENSE file in the root of the source
5// tree.
6
7package reassembly
8
9import (
10	"encoding/binary"
11	"fmt"
12
13	"github.com/google/gopacket"
14	"github.com/google/gopacket/layers"
15)
16
17/*
18 * Check TCP packet against options (window, MSS)
19 */
20
21type tcpStreamOptions struct {
22	mss           int
23	scale         int
24	receiveWindow uint
25}
26
27// TCPOptionCheck contains options for the two directions
28type TCPOptionCheck struct {
29	options [2]tcpStreamOptions
30}
31
32func (t *TCPOptionCheck) getOptions(dir TCPFlowDirection) *tcpStreamOptions {
33	if dir == TCPDirClientToServer {
34		return &t.options[0]
35	}
36	return &t.options[1]
37}
38
39// NewTCPOptionCheck creates default options
40func NewTCPOptionCheck() TCPOptionCheck {
41	return TCPOptionCheck{
42		options: [2]tcpStreamOptions{
43			tcpStreamOptions{
44				mss:           0,
45				scale:         -1,
46				receiveWindow: 0,
47			}, tcpStreamOptions{
48				mss:           0,
49				scale:         -1,
50				receiveWindow: 0,
51			},
52		},
53	}
54}
55
56// Accept checks whether the packet should be accepted by checking TCP options
57func (t *TCPOptionCheck) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir TCPFlowDirection, nextSeq Sequence, start *bool) error {
58	options := t.getOptions(dir)
59	if tcp.SYN {
60		mss := -1
61		scale := -1
62		for _, o := range tcp.Options {
63			// MSS
64			if o.OptionType == 2 {
65				if len(o.OptionData) != 2 {
66					return fmt.Errorf("MSS option data length expected 2, got %d", len(o.OptionData))
67				}
68				mss = int(binary.BigEndian.Uint16(o.OptionData[:2]))
69			}
70			// Window scaling
71			if o.OptionType == 3 {
72				if len(o.OptionData) != 1 {
73					return fmt.Errorf("Window scaling length expected: 1, got %d", len(o.OptionData))
74				}
75				scale = int(o.OptionData[0])
76			}
77		}
78		options.mss = mss
79		options.scale = scale
80	} else {
81		if nextSeq != invalidSequence {
82			revOptions := t.getOptions(dir.Reverse())
83			length := len(tcp.Payload)
84
85			// Check packet is in the correct window
86			diff := nextSeq.Difference(Sequence(tcp.Seq))
87			if diff == -1 && (length == 1 || length == 0) {
88				// This is probably a Keep-alive
89				// TODO: check byte is ok
90			} else if diff < 0 {
91				return fmt.Errorf("Re-emitted packet (diff:%d,seq:%d,rev-ack:%d)", diff,
92					tcp.Seq, nextSeq)
93			} else if revOptions.mss > 0 && length > revOptions.mss {
94				return fmt.Errorf("%d > mss (%d)", length, revOptions.mss)
95			} else if revOptions.receiveWindow != 0 && revOptions.scale < 0 && diff > int(revOptions.receiveWindow) {
96				return fmt.Errorf("%d > receiveWindow(%d)", diff, revOptions.receiveWindow)
97			}
98		}
99	}
100	// Compute receiveWindow
101	options.receiveWindow = uint(tcp.Window)
102	if options.scale > 0 {
103		options.receiveWindow = options.receiveWindow << (uint(options.scale))
104	}
105	return nil
106}
107
108// TCPSimpleFSM implements a very simple TCP state machine
109//
110// Usage:
111// When implementing a Stream interface and to avoid to consider packets that
112// would be rejected due to client/server's TCP stack, the  Accept() can call
113// TCPSimpleFSM.CheckState().
114//
115// Limitations:
116// - packet should be received in-order.
117// - no check on sequence number is performed
118// - no RST
119type TCPSimpleFSM struct {
120	dir     TCPFlowDirection
121	state   int
122	options TCPSimpleFSMOptions
123}
124
125// TCPSimpleFSMOptions holds options for TCPSimpleFSM
126type TCPSimpleFSMOptions struct {
127	SupportMissingEstablishment bool // Allow missing SYN, SYN+ACK, ACK
128}
129
130// Internal values of state machine
131const (
132	TCPStateClosed      = 0
133	TCPStateSynSent     = 1
134	TCPStateEstablished = 2
135	TCPStateCloseWait   = 3
136	TCPStateLastAck     = 4
137	TCPStateReset       = 5
138)
139
140// NewTCPSimpleFSM creates a new TCPSimpleFSM
141func NewTCPSimpleFSM(options TCPSimpleFSMOptions) *TCPSimpleFSM {
142	return &TCPSimpleFSM{
143		state:   TCPStateClosed,
144		options: options,
145	}
146}
147
148func (t *TCPSimpleFSM) String() string {
149	switch t.state {
150	case TCPStateClosed:
151		return "Closed"
152	case TCPStateSynSent:
153		return "SynSent"
154	case TCPStateEstablished:
155		return "Established"
156	case TCPStateCloseWait:
157		return "CloseWait"
158	case TCPStateLastAck:
159		return "LastAck"
160	case TCPStateReset:
161		return "Reset"
162	}
163	return "?"
164}
165
166// CheckState returns false if tcp is invalid wrt current state or update the state machine's state
167func (t *TCPSimpleFSM) CheckState(tcp *layers.TCP, dir TCPFlowDirection) bool {
168	if t.state == TCPStateClosed && t.options.SupportMissingEstablishment && !(tcp.SYN && !tcp.ACK) {
169		/* try to figure out state */
170		switch true {
171		case tcp.SYN && tcp.ACK:
172			t.state = TCPStateSynSent
173			t.dir = dir.Reverse()
174		case tcp.FIN && !tcp.ACK:
175			t.state = TCPStateEstablished
176		case tcp.FIN && tcp.ACK:
177			t.state = TCPStateCloseWait
178			t.dir = dir.Reverse()
179		default:
180			t.state = TCPStateEstablished
181		}
182	}
183
184	switch t.state {
185	/* openning connection */
186	case TCPStateClosed:
187		if tcp.SYN && !tcp.ACK {
188			t.dir = dir
189			t.state = TCPStateSynSent
190			return true
191		}
192	case TCPStateSynSent:
193		if tcp.RST {
194			t.state = TCPStateReset
195			return true
196		}
197
198		if tcp.SYN && tcp.ACK && dir == t.dir.Reverse() {
199			t.state = TCPStateEstablished
200			return true
201		}
202		if tcp.SYN && !tcp.ACK && dir == t.dir {
203			// re-transmission
204			return true
205		}
206	/* established */
207	case TCPStateEstablished:
208		if tcp.RST {
209			t.state = TCPStateReset
210			return true
211		}
212
213		if tcp.FIN {
214			t.state = TCPStateCloseWait
215			t.dir = dir
216			return true
217		}
218		// accept any packet
219		return true
220	/* closing connection */
221	case TCPStateCloseWait:
222		if tcp.RST {
223			t.state = TCPStateReset
224			return true
225		}
226
227		if tcp.FIN && tcp.ACK && dir == t.dir.Reverse() {
228			t.state = TCPStateLastAck
229			return true
230		}
231		if tcp.ACK {
232			return true
233		}
234	case TCPStateLastAck:
235		if tcp.RST {
236			t.state = TCPStateReset
237			return true
238		}
239
240		if tcp.ACK && t.dir == dir {
241			t.state = TCPStateClosed
242			return true
243		}
244	}
245	return false
246}
247