1// Copyright 2012 Google, Inc. All rights reserved. 2// 3// Use of this source code is governed by a BSD-style license 4// that can be found in the LICENSE file in the root of the source 5// tree. 6 7package tcpreader 8 9import ( 10 "bytes" 11 "fmt" 12 "github.com/google/gopacket" 13 "github.com/google/gopacket/layers" 14 "github.com/google/gopacket/tcpassembly" 15 "io" 16 "net" 17 "testing" 18) 19 20var netFlow gopacket.Flow 21 22func init() { 23 netFlow, _ = gopacket.FlowFromEndpoints( 24 layers.NewIPEndpoint(net.IP{1, 2, 3, 4}), 25 layers.NewIPEndpoint(net.IP{5, 6, 7, 8})) 26} 27 28type readReturn struct { 29 data []byte 30 err error 31} 32type readSequence struct { 33 in []layers.TCP 34 want []readReturn 35} 36type testReaderFactory struct { 37 lossErrors bool 38 readSize int 39 ReaderStream 40 output chan []byte 41} 42 43func (t *testReaderFactory) New(a, b gopacket.Flow) tcpassembly.Stream { 44 return &t.ReaderStream 45} 46 47func testReadSequence(t *testing.T, lossErrors bool, readSize int, seq readSequence) { 48 f := &testReaderFactory{ReaderStream: NewReaderStream()} 49 f.ReaderStream.LossErrors = lossErrors 50 p := tcpassembly.NewStreamPool(f) 51 a := tcpassembly.NewAssembler(p) 52 buf := make([]byte, readSize) 53 go func() { 54 for i, test := range seq.in { 55 fmt.Println("Assembling", i) 56 a.Assemble(netFlow, &test) 57 fmt.Println("Assembly done") 58 } 59 }() 60 for i, test := range seq.want { 61 fmt.Println("Waiting for read", i) 62 n, err := f.Read(buf[:]) 63 fmt.Println("Got read") 64 if n != len(test.data) { 65 t.Errorf("test %d want %d bytes, got %d bytes", i, len(test.data), n) 66 } else if err != test.err { 67 t.Errorf("test %d want err %v, got err %v", i, test.err, err) 68 } else if !bytes.Equal(buf[:n], test.data) { 69 t.Errorf("test %d\nwant: %v\n got: %v\n", i, test.data, buf[:n]) 70 } 71 } 72 fmt.Println("All done reads") 73} 74 75func TestRead(t *testing.T) { 76 testReadSequence(t, false, 10, readSequence{ 77 in: []layers.TCP{ 78 { 79 SYN: true, 80 SrcPort: 1, 81 DstPort: 2, 82 Seq: 1000, 83 BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}}, 84 }, 85 { 86 FIN: true, 87 SrcPort: 1, 88 DstPort: 2, 89 Seq: 1004, 90 }, 91 }, 92 want: []readReturn{ 93 {data: []byte{1, 2, 3}}, 94 {err: io.EOF}, 95 }, 96 }) 97} 98 99func TestReadSmallChunks(t *testing.T) { 100 testReadSequence(t, false, 2, readSequence{ 101 in: []layers.TCP{ 102 { 103 SYN: true, 104 SrcPort: 1, 105 DstPort: 2, 106 Seq: 1000, 107 BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}}, 108 }, 109 { 110 FIN: true, 111 SrcPort: 1, 112 DstPort: 2, 113 Seq: 1004, 114 }, 115 }, 116 want: []readReturn{ 117 {data: []byte{1, 2}}, 118 {data: []byte{3}}, 119 {err: io.EOF}, 120 }, 121 }) 122} 123 124func ExampleDiscardBytesToEOF() { 125 b := bytes.NewBuffer([]byte{1, 2, 3, 4, 5}) 126 fmt.Println(DiscardBytesToEOF(b)) 127 // Output: 128 // 5 129} 130