1// Copyright 2012 Google, Inc. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style license
4// that can be found in the LICENSE file in the root of the source
5// tree.
6
7package tcpreader
8
9import (
10	"bytes"
11	"fmt"
12	"github.com/google/gopacket"
13	"github.com/google/gopacket/layers"
14	"github.com/google/gopacket/tcpassembly"
15	"io"
16	"net"
17	"testing"
18)
19
20var netFlow gopacket.Flow
21
22func init() {
23	netFlow, _ = gopacket.FlowFromEndpoints(
24		layers.NewIPEndpoint(net.IP{1, 2, 3, 4}),
25		layers.NewIPEndpoint(net.IP{5, 6, 7, 8}))
26}
27
28type readReturn struct {
29	data []byte
30	err  error
31}
32type readSequence struct {
33	in   []layers.TCP
34	want []readReturn
35}
36type testReaderFactory struct {
37	lossErrors bool
38	readSize   int
39	ReaderStream
40	output chan []byte
41}
42
43func (t *testReaderFactory) New(a, b gopacket.Flow) tcpassembly.Stream {
44	return &t.ReaderStream
45}
46
47func testReadSequence(t *testing.T, lossErrors bool, readSize int, seq readSequence) {
48	f := &testReaderFactory{ReaderStream: NewReaderStream()}
49	f.ReaderStream.LossErrors = lossErrors
50	p := tcpassembly.NewStreamPool(f)
51	a := tcpassembly.NewAssembler(p)
52	buf := make([]byte, readSize)
53	go func() {
54		for i, test := range seq.in {
55			fmt.Println("Assembling", i)
56			a.Assemble(netFlow, &test)
57			fmt.Println("Assembly done")
58		}
59	}()
60	for i, test := range seq.want {
61		fmt.Println("Waiting for read", i)
62		n, err := f.Read(buf[:])
63		fmt.Println("Got read")
64		if n != len(test.data) {
65			t.Errorf("test %d want %d bytes, got %d bytes", i, len(test.data), n)
66		} else if err != test.err {
67			t.Errorf("test %d want err %v, got err %v", i, test.err, err)
68		} else if !bytes.Equal(buf[:n], test.data) {
69			t.Errorf("test %d\nwant: %v\n got: %v\n", i, test.data, buf[:n])
70		}
71	}
72	fmt.Println("All done reads")
73}
74
75func TestRead(t *testing.T) {
76	testReadSequence(t, false, 10, readSequence{
77		in: []layers.TCP{
78			{
79				SYN:       true,
80				SrcPort:   1,
81				DstPort:   2,
82				Seq:       1000,
83				BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}},
84			},
85			{
86				FIN:     true,
87				SrcPort: 1,
88				DstPort: 2,
89				Seq:     1004,
90			},
91		},
92		want: []readReturn{
93			{data: []byte{1, 2, 3}},
94			{err: io.EOF},
95		},
96	})
97}
98
99func TestReadSmallChunks(t *testing.T) {
100	testReadSequence(t, false, 2, readSequence{
101		in: []layers.TCP{
102			{
103				SYN:       true,
104				SrcPort:   1,
105				DstPort:   2,
106				Seq:       1000,
107				BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}},
108			},
109			{
110				FIN:     true,
111				SrcPort: 1,
112				DstPort: 2,
113				Seq:     1004,
114			},
115		},
116		want: []readReturn{
117			{data: []byte{1, 2}},
118			{data: []byte{3}},
119			{err: io.EOF},
120		},
121	})
122}
123
124func ExampleDiscardBytesToEOF() {
125	b := bytes.NewBuffer([]byte{1, 2, 3, 4, 5})
126	fmt.Println(DiscardBytesToEOF(b))
127	// Output:
128	// 5
129}
130