1 // Copyright 2012 Google, Inc. All rights reserved.
3 // Use of this source code is governed by a BSD-style license
4 // that can be found in the LICENSE file in the root of the source
12 "github.com/google/gopacket"
13 "github.com/google/gopacket/layers"
14 "github.com/google/gopacket/tcpassembly"
20 var netFlow gopacket.Flow
23 netFlow, _ = gopacket.FlowFromEndpoints(
24 layers.NewIPEndpoint(net.IP{1, 2, 3, 4}),
25 layers.NewIPEndpoint(net.IP{5, 6, 7, 8}))
28 type readReturn struct {
32 type readSequence struct {
36 type testReaderFactory struct {
43 func (t *testReaderFactory) New(a, b gopacket.Flow) tcpassembly.Stream {
44 return &t.ReaderStream
47 func testReadSequence(t *testing.T, lossErrors bool, readSize int, seq readSequence) {
48 f := &testReaderFactory{ReaderStream: NewReaderStream()}
49 f.ReaderStream.LossErrors = lossErrors
50 p := tcpassembly.NewStreamPool(f)
51 a := tcpassembly.NewAssembler(p)
52 buf := make([]byte, readSize)
54 for i, test := range seq.in {
55 fmt.Println("Assembling", i)
56 a.Assemble(netFlow, &test)
57 fmt.Println("Assembly done")
60 for i, test := range seq.want {
61 fmt.Println("Waiting for read", i)
62 n, err := f.Read(buf[:])
63 fmt.Println("Got read")
64 if n != len(test.data) {
65 t.Errorf("test %d want %d bytes, got %d bytes", i, len(test.data), n)
66 } else if err != test.err {
67 t.Errorf("test %d want err %v, got err %v", i, test.err, err)
68 } else if !bytes.Equal(buf[:n], test.data) {
69 t.Errorf("test %d\nwant: %v\n got: %v\n", i, test.data, buf[:n])
72 fmt.Println("All done reads")
75 func TestRead(t *testing.T) {
76 testReadSequence(t, false, 10, readSequence{
83 BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}},
93 {data: []byte{1, 2, 3}},
99 func TestReadSmallChunks(t *testing.T) {
100 testReadSequence(t, false, 2, readSequence{
107 BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}},
117 {data: []byte{1, 2}},
124 func ExampleDiscardBytesToEOF() {
125 b := bytes.NewBuffer([]byte{1, 2, 3, 4, 5})
126 fmt.Println(DiscardBytesToEOF(b))