added support for string type
[govpp.git] / vendor / github.com / google / gopacket / reassembly / tcpassembly.go
1 // Copyright 2012 Google, Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style license
4 // that can be found in the LICENSE file in the root of the source
5 // tree.
6
7 // Package reassembly provides TCP stream re-assembly.
8 //
9 // The reassembly package implements uni-directional TCP reassembly, for use in
10 // packet-sniffing applications.  The caller reads packets off the wire, then
11 // presents them to an Assembler in the form of gopacket layers.TCP packets
12 // (github.com/google/gopacket, github.com/google/gopacket/layers).
13 //
14 // The Assembler uses a user-supplied
15 // StreamFactory to create a user-defined Stream interface, then passes packet
16 // data in stream order to that object.  A concurrency-safe StreamPool keeps
17 // track of all current Streams being reassembled, so multiple Assemblers may
18 // run at once to assemble packets while taking advantage of multiple cores.
19 //
20 // TODO: Add simplest example
21 package reassembly
22
23 import (
24         "encoding/hex"
25         "flag"
26         "fmt"
27         "log"
28         "sync"
29         "time"
30
31         "github.com/google/gopacket"
32         "github.com/google/gopacket/layers"
33 )
34
35 // TODO:
36 // - push to Stream on Ack
37 // - implement chunked (cheap) reads and Reader() interface
38 // - better organize file: split files: 'mem', 'misc' (seq + flow)
39
40 var defaultDebug = false
41
42 var debugLog = flag.Bool("assembly_debug_log", defaultDebug, "If true, the github.com/google/gopacket/reassembly library will log verbose debugging information (at least one line per packet)")
43
44 const invalidSequence = -1
45 const uint32Max = 0xFFFFFFFF
46
47 // Sequence is a TCP sequence number.  It provides a few convenience functions
48 // for handling TCP wrap-around.  The sequence should always be in the range
49 // [0,0xFFFFFFFF]... its other bits are simply used in wrap-around calculations
50 // and should never be set.
51 type Sequence int64
52
53 // Difference defines an ordering for comparing TCP sequences that's safe for
54 // roll-overs.  It returns:
55 //    > 0 : if t comes after s
56 //    < 0 : if t comes before s
57 //      0 : if t == s
58 // The number returned is the sequence difference, so 4.Difference(8) will
59 // return 4.
60 //
61 // It handles rollovers by considering any sequence in the first quarter of the
62 // uint32 space to be after any sequence in the last quarter of that space, thus
63 // wrapping the uint32 space.
64 func (s Sequence) Difference(t Sequence) int {
65         if s > uint32Max-uint32Max/4 && t < uint32Max/4 {
66                 t += uint32Max
67         } else if t > uint32Max-uint32Max/4 && s < uint32Max/4 {
68                 s += uint32Max
69         }
70         return int(t - s)
71 }
72
73 // Add adds an integer to a sequence and returns the resulting sequence.
74 func (s Sequence) Add(t int) Sequence {
75         return (s + Sequence(t)) & uint32Max
76 }
77
78 // TCPAssemblyStats provides some figures for a ScatterGather
79 type TCPAssemblyStats struct {
80         // For this ScatterGather
81         Chunks  int
82         Packets int
83         // For the half connection, since last call to ReassembledSG()
84         QueuedBytes    int
85         QueuedPackets  int
86         OverlapBytes   int
87         OverlapPackets int
88 }
89
90 // ScatterGather is used to pass reassembled data and metadata of reassembled
91 // packets to a Stream via ReassembledSG
92 type ScatterGather interface {
93         // Returns the length of available bytes and saved bytes
94         Lengths() (int, int)
95         // Returns the bytes up to length (shall be <= available bytes)
96         Fetch(length int) []byte
97         // Tell to keep from offset
98         KeepFrom(offset int)
99         // Return CaptureInfo of packet corresponding to given offset
100         CaptureInfo(offset int) gopacket.CaptureInfo
101         // Return some info about the reassembled chunks
102         Info() (direction TCPFlowDirection, start bool, end bool, skip int)
103         // Return some stats regarding the state of the stream
104         Stats() TCPAssemblyStats
105 }
106
107 // byteContainer is either a page or a livePacket
108 type byteContainer interface {
109         getBytes() []byte
110         length() int
111         convertToPages(*pageCache, int, AssemblerContext) (*page, *page, int)
112         captureInfo() gopacket.CaptureInfo
113         assemblerContext() AssemblerContext
114         release(*pageCache) int
115         isStart() bool
116         isEnd() bool
117         getSeq() Sequence
118         isPacket() bool
119 }
120
121 // Implements a ScatterGather
122 type reassemblyObject struct {
123         all       []byteContainer
124         Skip      int
125         Direction TCPFlowDirection
126         saved     int
127         toKeep    int
128         // stats
129         queuedBytes    int
130         queuedPackets  int
131         overlapBytes   int
132         overlapPackets int
133 }
134
135 func (rl *reassemblyObject) Lengths() (int, int) {
136         l := 0
137         for _, r := range rl.all {
138                 l += r.length()
139         }
140         return l, rl.saved
141 }
142
143 func (rl *reassemblyObject) Fetch(l int) []byte {
144         if l <= rl.all[0].length() {
145                 return rl.all[0].getBytes()[:l]
146         }
147         bytes := make([]byte, 0, l)
148         for _, bc := range rl.all {
149                 bytes = append(bytes, bc.getBytes()...)
150         }
151         return bytes[:l]
152 }
153
154 func (rl *reassemblyObject) KeepFrom(offset int) {
155         rl.toKeep = offset
156 }
157
158 func (rl *reassemblyObject) CaptureInfo(offset int) gopacket.CaptureInfo {
159         current := 0
160         for _, r := range rl.all {
161                 if current >= offset {
162                         return r.captureInfo()
163                 }
164                 current += r.length()
165         }
166         // Invalid offset
167         return gopacket.CaptureInfo{}
168 }
169
170 func (rl *reassemblyObject) Info() (TCPFlowDirection, bool, bool, int) {
171         return rl.Direction, rl.all[0].isStart(), rl.all[len(rl.all)-1].isEnd(), rl.Skip
172 }
173
174 func (rl *reassemblyObject) Stats() TCPAssemblyStats {
175         packets := int(0)
176         for _, r := range rl.all {
177                 if r.isPacket() {
178                         packets++
179                 }
180         }
181         return TCPAssemblyStats{
182                 Chunks:         len(rl.all),
183                 Packets:        packets,
184                 QueuedBytes:    rl.queuedBytes,
185                 QueuedPackets:  rl.queuedPackets,
186                 OverlapBytes:   rl.overlapBytes,
187                 OverlapPackets: rl.overlapPackets,
188         }
189 }
190
191 const pageBytes = 1900
192
193 // TCPFlowDirection distinguish the two half-connections directions.
194 //
195 // TCPDirClientToServer is assigned to half-connection for the first received
196 // packet, hence might be wrong if packets are not received in order.
197 // It's up to the caller (e.g. in Accept()) to decide if the direction should
198 // be interpretted differently.
199 type TCPFlowDirection bool
200
201 // Value are not really useful
202 const (
203         TCPDirClientToServer TCPFlowDirection = false
204         TCPDirServerToClient TCPFlowDirection = true
205 )
206
207 func (dir TCPFlowDirection) String() string {
208         switch dir {
209         case TCPDirClientToServer:
210                 return "client->server"
211         case TCPDirServerToClient:
212                 return "server->client"
213         }
214         return ""
215 }
216
217 // Reverse returns the reversed direction
218 func (dir TCPFlowDirection) Reverse() TCPFlowDirection {
219         return !dir
220 }
221
222 /* page: implements a byteContainer */
223
224 // page is used to store TCP data we're not ready for yet (out-of-order
225 // packets).  Unused pages are stored in and returned from a pageCache, which
226 // avoids memory allocation.  Used pages are stored in a doubly-linked list in
227 // a connection.
228 type page struct {
229         bytes      []byte
230         seq        Sequence
231         prev, next *page
232         buf        [pageBytes]byte
233         ac         AssemblerContext // only set for the first page of a packet
234         seen       time.Time
235         start, end bool
236 }
237
238 func (p *page) getBytes() []byte {
239         return p.bytes
240 }
241 func (p *page) captureInfo() gopacket.CaptureInfo {
242         return p.ac.GetCaptureInfo()
243 }
244 func (p *page) assemblerContext() AssemblerContext {
245         return p.ac
246 }
247 func (p *page) convertToPages(pc *pageCache, skip int, ac AssemblerContext) (*page, *page, int) {
248         if skip != 0 {
249                 p.bytes = p.bytes[skip:]
250                 p.seq = p.seq.Add(skip)
251         }
252         p.prev, p.next = nil, nil
253         return p, p, 1
254 }
255 func (p *page) length() int {
256         return len(p.bytes)
257 }
258 func (p *page) release(pc *pageCache) int {
259         pc.replace(p)
260         return 1
261 }
262 func (p *page) isStart() bool {
263         return p.start
264 }
265 func (p *page) isEnd() bool {
266         return p.end
267 }
268 func (p *page) getSeq() Sequence {
269         return p.seq
270 }
271 func (p *page) isPacket() bool {
272         return p.ac != nil
273 }
274 func (p *page) String() string {
275         return fmt.Sprintf("page@%p{seq: %v, bytes:%d, -> nextSeq:%v} (prev:%p, next:%p)", p, p.seq, len(p.bytes), p.seq+Sequence(len(p.bytes)), p.prev, p.next)
276 }
277
278 /* livePacket: implements a byteContainer */
279 type livePacket struct {
280         bytes []byte
281         start bool
282         end   bool
283         ci    gopacket.CaptureInfo
284         ac    AssemblerContext
285         seq   Sequence
286 }
287
288 func (lp *livePacket) getBytes() []byte {
289         return lp.bytes
290 }
291 func (lp *livePacket) captureInfo() gopacket.CaptureInfo {
292         return lp.ci
293 }
294 func (lp *livePacket) assemblerContext() AssemblerContext {
295         return lp.ac
296 }
297 func (lp *livePacket) length() int {
298         return len(lp.bytes)
299 }
300 func (lp *livePacket) isStart() bool {
301         return lp.start
302 }
303 func (lp *livePacket) isEnd() bool {
304         return lp.end
305 }
306 func (lp *livePacket) getSeq() Sequence {
307         return lp.seq
308 }
309 func (lp *livePacket) isPacket() bool {
310         return true
311 }
312
313 // Creates a page (or set of pages) from a TCP packet: returns the first and last
314 // page in its doubly-linked list of new pages.
315 func (lp *livePacket) convertToPages(pc *pageCache, skip int, ac AssemblerContext) (*page, *page, int) {
316         ts := lp.ci.Timestamp
317         first := pc.next(ts)
318         current := first
319         current.prev = nil
320         first.ac = ac
321         numPages := 1
322         seq, bytes := lp.seq.Add(skip), lp.bytes[skip:]
323         for {
324                 length := min(len(bytes), pageBytes)
325                 current.bytes = current.buf[:length]
326                 copy(current.bytes, bytes)
327                 current.seq = seq
328                 bytes = bytes[length:]
329                 if len(bytes) == 0 {
330                         current.end = lp.isEnd()
331                         current.next = nil
332                         break
333                 }
334                 seq = seq.Add(length)
335                 current.next = pc.next(ts)
336                 current.next.prev = current
337                 current = current.next
338                 current.ac = nil
339                 numPages++
340         }
341         return first, current, numPages
342 }
343 func (lp *livePacket) estimateNumberOfPages() int {
344         return (len(lp.bytes) + pageBytes + 1) / pageBytes
345 }
346
347 func (lp *livePacket) release(*pageCache) int {
348         return 0
349 }
350
351 // Stream is implemented by the caller to handle incoming reassembled
352 // TCP data.  Callers create a StreamFactory, then StreamPool uses
353 // it to create a new Stream for every TCP stream.
354 //
355 // assembly will, in order:
356 //    1) Create the stream via StreamFactory.New
357 //    2) Call ReassembledSG 0 or more times, passing in reassembled TCP data in order
358 //    3) Call ReassemblyComplete one time, after which the stream is dereferenced by assembly.
359 type Stream interface {
360         // Tell whether the TCP packet should be accepted, start could be modified to force a start even if no SYN have been seen
361         Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir TCPFlowDirection, ackSeq Sequence, start *bool, ac AssemblerContext) bool
362
363         // ReassembledSG is called zero or more times.
364         // ScatterGather is reused after each Reassembled call,
365         // so it's important to copy anything you need out of it,
366         // especially bytes (or use KeepFrom())
367         ReassembledSG(sg ScatterGather, ac AssemblerContext)
368
369         // ReassemblyComplete is called when assembly decides there is
370         // no more data for this Stream, either because a FIN or RST packet
371         // was seen, or because the stream has timed out without any new
372         // packet data (due to a call to FlushCloseOlderThan).
373         // It should return true if the connection should be removed from the pool
374         // It can return false if it want to see subsequent packets with Accept(), e.g. to
375         // see FIN-ACK, for deeper state-machine analysis.
376         ReassemblyComplete(ac AssemblerContext) bool
377 }
378
379 // StreamFactory is used by assembly to create a new stream for each
380 // new TCP session.
381 type StreamFactory interface {
382         // New should return a new stream for the given TCP key.
383         New(netFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac AssemblerContext) Stream
384 }
385
386 type key [2]gopacket.Flow
387
388 func (k *key) String() string {
389         return fmt.Sprintf("%s:%s", k[0], k[1])
390 }
391
392 func (k *key) Reverse() key {
393         return key{
394                 k[0].Reverse(),
395                 k[1].Reverse(),
396         }
397 }
398
399 const assemblerReturnValueInitialSize = 16
400
401 /* one-way connection, i.e. halfconnection */
402 type halfconnection struct {
403         dir               TCPFlowDirection
404         pages             int      // Number of pages used (both in first/last and saved)
405         saved             *page    // Doubly-linked list of in-order pages (seq < nextSeq) already given to Stream who told us to keep
406         first, last       *page    // Doubly-linked list of out-of-order pages (seq > nextSeq)
407         nextSeq           Sequence // sequence number of in-order received bytes
408         ackSeq            Sequence
409         created, lastSeen time.Time
410         stream            Stream
411         closed            bool
412         // for stats
413         queuedBytes    int
414         queuedPackets  int
415         overlapBytes   int
416         overlapPackets int
417 }
418
419 func (half *halfconnection) String() string {
420         closed := ""
421         if half.closed {
422                 closed = "closed "
423         }
424         return fmt.Sprintf("%screated:%v, last:%v", closed, half.created, half.lastSeen)
425 }
426
427 // Dump returns a string (crypticly) describing the halfconnction
428 func (half *halfconnection) Dump() string {
429         s := fmt.Sprintf("pages: %d\n"+
430                 "nextSeq: %d\n"+
431                 "ackSeq: %d\n"+
432                 "Seen :  %s\n"+
433                 "dir:    %s\n", half.pages, half.nextSeq, half.ackSeq, half.lastSeen, half.dir)
434         nb := 0
435         for p := half.first; p != nil; p = p.next {
436                 s += fmt.Sprintf("      Page[%d] %s len: %d\n", nb, p, len(p.bytes))
437                 nb++
438         }
439         return s
440 }
441
442 /* Bi-directionnal connection */
443
444 type connection struct {
445         key      key // client->server
446         c2s, s2c halfconnection
447         mu       sync.Mutex
448 }
449
450 func (c *connection) reset(k key, s Stream, ts time.Time) {
451         c.key = k
452         base := halfconnection{
453                 nextSeq:  invalidSequence,
454                 ackSeq:   invalidSequence,
455                 created:  ts,
456                 lastSeen: ts,
457                 stream:   s,
458         }
459         c.c2s, c.s2c = base, base
460         c.c2s.dir, c.s2c.dir = TCPDirClientToServer, TCPDirServerToClient
461 }
462
463 func (c *connection) String() string {
464         return fmt.Sprintf("c2s: %s, s2c: %s", &c.c2s, &c.s2c)
465 }
466
467 /*
468  * Assembler
469  */
470
471 // DefaultAssemblerOptions provides default options for an assembler.
472 // These options are used by default when calling NewAssembler, so if
473 // modified before a NewAssembler call they'll affect the resulting Assembler.
474 //
475 // Note that the default options can result in ever-increasing memory usage
476 // unless one of the Flush* methods is called on a regular basis.
477 var DefaultAssemblerOptions = AssemblerOptions{
478         MaxBufferedPagesPerConnection: 0, // unlimited
479         MaxBufferedPagesTotal:         0, // unlimited
480 }
481
482 // AssemblerOptions controls the behavior of each assembler.  Modify the
483 // options of each assembler you create to change their behavior.
484 type AssemblerOptions struct {
485         // MaxBufferedPagesTotal is an upper limit on the total number of pages to
486         // buffer while waiting for out-of-order packets.  Once this limit is
487         // reached, the assembler will degrade to flushing every connection it
488         // gets a packet for.  If <= 0, this is ignored.
489         MaxBufferedPagesTotal int
490         // MaxBufferedPagesPerConnection is an upper limit on the number of pages
491         // buffered for a single connection.  Should this limit be reached for a
492         // particular connection, the smallest sequence number will be flushed, along
493         // with any contiguous data.  If <= 0, this is ignored.
494         MaxBufferedPagesPerConnection int
495 }
496
497 // Assembler handles reassembling TCP streams.  It is not safe for
498 // concurrency... after passing a packet in via the Assemble call, the caller
499 // must wait for that call to return before calling Assemble again.  Callers can
500 // get around this by creating multiple assemblers that share a StreamPool.  In
501 // that case, each individual stream will still be handled serially (each stream
502 // has an individual mutex associated with it), however multiple assemblers can
503 // assemble different connections concurrently.
504 //
505 // The Assembler provides (hopefully) fast TCP stream re-assembly for sniffing
506 // applications written in Go.  The Assembler uses the following methods to be
507 // as fast as possible, to keep packet processing speedy:
508 //
509 // Avoids Lock Contention
510 //
511 // Assemblers locks connections, but each connection has an individual lock, and
512 // rarely will two Assemblers be looking at the same connection.  Assemblers
513 // lock the StreamPool when looking up connections, but they use Reader
514 // locks initially, and only force a write lock if they need to create a new
515 // connection or close one down.  These happen much less frequently than
516 // individual packet handling.
517 //
518 // Each assembler runs in its own goroutine, and the only state shared between
519 // goroutines is through the StreamPool.  Thus all internal Assembler state
520 // can be handled without any locking.
521 //
522 // NOTE:  If you can guarantee that packets going to a set of Assemblers will
523 // contain information on different connections per Assembler (for example,
524 // they're already hashed by PF_RING hashing or some other hashing mechanism),
525 // then we recommend you use a seperate StreamPool per Assembler, thus
526 // avoiding all lock contention.  Only when different Assemblers could receive
527 // packets for the same Stream should a StreamPool be shared between them.
528 //
529 // Avoids Memory Copying
530 //
531 // In the common case, handling of a single TCP packet should result in zero
532 // memory allocations.  The Assembler will look up the connection, figure out
533 // that the packet has arrived in order, and immediately pass that packet on to
534 // the appropriate connection's handling code.  Only if a packet arrives out of
535 // order is its contents copied and stored in memory for later.
536 //
537 // Avoids Memory Allocation
538 //
539 // Assemblers try very hard to not use memory allocation unless absolutely
540 // necessary.  Packet data for sequential packets is passed directly to streams
541 // with no copying or allocation.  Packet data for out-of-order packets is
542 // copied into reusable pages, and new pages are only allocated rarely when the
543 // page cache runs out.  Page caches are Assembler-specific, thus not used
544 // concurrently and requiring no locking.
545 //
546 // Internal representations for connection objects are also reused over time.
547 // Because of this, the most common memory allocation done by the Assembler is
548 // generally what's done by the caller in StreamFactory.New.  If no allocation
549 // is done there, then very little allocation is done ever, mostly to handle
550 // large increases in bandwidth or numbers of connections.
551 //
552 // TODO:  The page caches used by an Assembler will grow to the size necessary
553 // to handle a workload, and currently will never shrink.  This means that
554 // traffic spikes can result in large memory usage which isn't garbage
555 // collected when typical traffic levels return.
556 type Assembler struct {
557         AssemblerOptions
558         ret      []byteContainer
559         pc       *pageCache
560         connPool *StreamPool
561         cacheLP  livePacket
562         cacheSG  reassemblyObject
563         start    bool
564 }
565
566 // NewAssembler creates a new assembler.  Pass in the StreamPool
567 // to use, may be shared across assemblers.
568 //
569 // This sets some sane defaults for the assembler options,
570 // see DefaultAssemblerOptions for details.
571 func NewAssembler(pool *StreamPool) *Assembler {
572         pool.mu.Lock()
573         pool.users++
574         pool.mu.Unlock()
575         return &Assembler{
576                 ret:              make([]byteContainer, assemblerReturnValueInitialSize),
577                 pc:               newPageCache(),
578                 connPool:         pool,
579                 AssemblerOptions: DefaultAssemblerOptions,
580         }
581 }
582
583 // Dump returns a short string describing the page usage of the Assembler
584 func (a *Assembler) Dump() string {
585         s := ""
586         s += fmt.Sprintf("pageCache: used: %d, size: %d, free: %d", a.pc.used, a.pc.size, len(a.pc.free))
587         return s
588 }
589
590 // AssemblerContext provides method to get metadata
591 type AssemblerContext interface {
592         GetCaptureInfo() gopacket.CaptureInfo
593 }
594
595 // Implements AssemblerContext for Assemble()
596 type assemblerSimpleContext gopacket.CaptureInfo
597
598 func (asc *assemblerSimpleContext) GetCaptureInfo() gopacket.CaptureInfo {
599         return gopacket.CaptureInfo(*asc)
600 }
601
602 // Assemble calls AssembleWithContext with the current timestamp, useful for
603 // packets being read directly off the wire.
604 func (a *Assembler) Assemble(netFlow gopacket.Flow, t *layers.TCP) {
605         ctx := assemblerSimpleContext(gopacket.CaptureInfo{Timestamp: time.Now()})
606         a.AssembleWithContext(netFlow, t, &ctx)
607 }
608
609 type assemblerAction struct {
610         nextSeq Sequence
611         queue   bool
612 }
613
614 // AssembleWithContext reassembles the given TCP packet into its appropriate
615 // stream.
616 //
617 // The timestamp passed in must be the timestamp the packet was seen.
618 // For packets read off the wire, time.Now() should be fine.  For packets read
619 // from PCAP files, CaptureInfo.Timestamp should be passed in.  This timestamp
620 // will affect which streams are flushed by a call to FlushCloseOlderThan.
621 //
622 // Each AssembleWithContext call results in, in order:
623 //
624 //    zero or one call to StreamFactory.New, creating a stream
625 //    zero or one call to ReassembledSG on a single stream
626 //    zero or one call to ReassemblyComplete on the same stream
627 func (a *Assembler) AssembleWithContext(netFlow gopacket.Flow, t *layers.TCP, ac AssemblerContext) {
628         var conn *connection
629         var half *halfconnection
630         var rev *halfconnection
631
632         a.ret = a.ret[:0]
633         key := key{netFlow, t.TransportFlow()}
634         ci := ac.GetCaptureInfo()
635         timestamp := ci.Timestamp
636
637         conn, half, rev = a.connPool.getConnection(key, false, timestamp, t, ac)
638         if conn == nil {
639                 if *debugLog {
640                         log.Printf("%v got empty packet on otherwise empty connection", key)
641                 }
642                 return
643         }
644         conn.mu.Lock()
645         defer conn.mu.Unlock()
646         if half.lastSeen.Before(timestamp) {
647                 half.lastSeen = timestamp
648         }
649         a.start = half.nextSeq == invalidSequence && t.SYN
650         if !half.stream.Accept(t, ci, half.dir, rev.ackSeq, &a.start, ac) {
651                 if *debugLog {
652                         log.Printf("Ignoring packet")
653                 }
654                 return
655         }
656         if half.closed {
657                 // this way is closed
658                 return
659         }
660
661         seq, ack, bytes := Sequence(t.Seq), Sequence(t.Ack), t.Payload
662         if t.ACK {
663                 half.ackSeq = ack
664         }
665         // TODO: push when Ack is seen ??
666         action := assemblerAction{
667                 nextSeq: Sequence(invalidSequence),
668                 queue:   true,
669         }
670         a.dump("AssembleWithContext()", half)
671         if half.nextSeq == invalidSequence {
672                 if t.SYN {
673                         if *debugLog {
674                                 log.Printf("%v saw first SYN packet, returning immediately, seq=%v", key, seq)
675                         }
676                         seq = seq.Add(1)
677                         half.nextSeq = seq
678                         action.queue = false
679                 } else if a.start {
680                         if *debugLog {
681                                 log.Printf("%v start forced", key)
682                         }
683                         half.nextSeq = seq
684                         action.queue = false
685                 } else {
686                         if *debugLog {
687                                 log.Printf("%v waiting for start, storing into connection", key)
688                         }
689                 }
690         } else {
691                 diff := half.nextSeq.Difference(seq)
692                 if diff > 0 {
693                         if *debugLog {
694                                 log.Printf("%v gap in sequence numbers (%v, %v) diff %v, storing into connection", key, half.nextSeq, seq, diff)
695                         }
696                 } else {
697                         if *debugLog {
698                                 log.Printf("%v found contiguous data (%v, %v), returning immediately: len:%d", key, seq, half.nextSeq, len(bytes))
699                         }
700                         action.queue = false
701                 }
702         }
703
704         action = a.handleBytes(bytes, seq, half, ci, t.SYN, t.RST || t.FIN, action, ac)
705         if len(a.ret) > 0 {
706                 action.nextSeq = a.sendToConnection(conn, half, ac)
707         }
708         if action.nextSeq != invalidSequence {
709                 half.nextSeq = action.nextSeq
710                 if t.FIN {
711                         half.nextSeq = half.nextSeq.Add(1)
712                 }
713         }
714         if *debugLog {
715                 log.Printf("%v nextSeq:%d", key, half.nextSeq)
716         }
717 }
718
719 // Overlap strategies:
720 //  - new packet overlaps with sent packets:
721 //      1) discard new overlapping part
722 //      2) overwrite old overlapped (TODO)
723 //  - new packet overlaps existing queued packets:
724 //      a) consider "age" by timestamp (TODO)
725 //      b) consider "age" by being present
726 //      Then
727 //      1) discard new overlapping part
728 //      2) overwrite queued part
729
730 func (a *Assembler) checkOverlap(half *halfconnection, queue bool, ac AssemblerContext) {
731         var next *page
732         cur := half.last
733         bytes := a.cacheLP.bytes
734         start := a.cacheLP.seq
735         end := start.Add(len(bytes))
736
737         a.dump("before checkOverlap", half)
738
739         //          [s6           :           e6]
740         //   [s1:e1][s2:e2] -- [s3:e3] -- [s4:e4][s5:e5]
741         //             [s <--ds-- : --de--> e]
742         for cur != nil {
743
744                 if *debugLog {
745                         log.Printf("cur = %p (%s)\n", cur, cur)
746                 }
747
748                 // end < cur.start: continue (5)
749                 if end.Difference(cur.seq) > 0 {
750                         if *debugLog {
751                                 log.Printf("case 5\n")
752                         }
753                         next = cur
754                         cur = cur.prev
755                         continue
756                 }
757
758                 curEnd := cur.seq.Add(len(cur.bytes))
759                 // start > cur.end: stop (1)
760                 if start.Difference(curEnd) <= 0 {
761                         if *debugLog {
762                                 log.Printf("case 1\n")
763                         }
764                         break
765                 }
766
767                 diffStart := start.Difference(cur.seq)
768                 diffEnd := end.Difference(curEnd)
769
770                 // end > cur.end && start < cur.start: drop (3)
771                 if diffEnd <= 0 && diffStart >= 0 {
772                         if *debugLog {
773                                 log.Printf("case 3\n")
774                         }
775                         if cur.isPacket() {
776                                 half.overlapPackets++
777                         }
778                         half.overlapBytes += len(cur.bytes)
779                         // update links
780                         if cur.prev != nil {
781                                 cur.prev.next = cur.next
782                         } else {
783                                 half.first = cur.next
784                         }
785                         if cur.next != nil {
786                                 cur.next.prev = cur.prev
787                         } else {
788                                 half.last = cur.prev
789                         }
790                         tmp := cur.prev
791                         half.pages -= cur.release(a.pc)
792                         cur = tmp
793                         continue
794                 }
795
796                 // end > cur.end && start < cur.end: drop cur's end (2)
797                 if diffEnd < 0 && start.Difference(curEnd) > 0 {
798                         if *debugLog {
799                                 log.Printf("case 2\n")
800                         }
801                         cur.bytes = cur.bytes[:-start.Difference(cur.seq)]
802                         break
803                 } else
804
805                 // start < cur.start && end > cur.start: drop cur's start (4)
806                 if diffStart > 0 && end.Difference(cur.seq) < 0 {
807                         if *debugLog {
808                                 log.Printf("case 4\n")
809                         }
810                         cur.bytes = cur.bytes[-end.Difference(cur.seq):]
811                         cur.seq = cur.seq.Add(-end.Difference(cur.seq))
812                         next = cur
813                 } else
814
815                 // end < cur.end && start > cur.start: replace bytes inside cur (6)
816                 if diffEnd > 0 && diffStart < 0 {
817                         if *debugLog {
818                                 log.Printf("case 6\n")
819                         }
820                         copy(cur.bytes[-diffStart:-diffStart+len(bytes)], bytes)
821                         bytes = bytes[:0]
822                 } else {
823                         if *debugLog {
824                                 log.Printf("no overlap\n")
825                         }
826                         next = cur
827                 }
828                 cur = cur.prev
829         }
830
831         // Split bytes into pages, and insert in queue
832         a.cacheLP.bytes = bytes
833         a.cacheLP.seq = start
834         if len(bytes) > 0 && queue {
835                 p, p2, numPages := a.cacheLP.convertToPages(a.pc, 0, ac)
836                 half.queuedPackets++
837                 half.queuedBytes += len(bytes)
838                 half.pages += numPages
839                 if cur != nil {
840                         if *debugLog {
841                                 log.Printf("adding %s after %s", p, cur)
842                         }
843                         cur.next = p
844                         p.prev = cur
845                 } else {
846                         if *debugLog {
847                                 log.Printf("adding %s as first", p)
848                         }
849                         half.first = p
850                 }
851                 if next != nil {
852                         if *debugLog {
853                                 log.Printf("setting %s as next of new %s", next, p2)
854                         }
855                         p2.next = next
856                         next.prev = p2
857                 } else {
858                         if *debugLog {
859                                 log.Printf("setting %s as last", p2)
860                         }
861                         half.last = p2
862                 }
863         }
864         a.dump("After checkOverlap", half)
865 }
866
867 // Warning: this is a low-level dumper, i.e. a.ret or a.cacheSG might
868 // be strange, but it could be ok.
869 func (a *Assembler) dump(text string, half *halfconnection) {
870         if !*debugLog {
871                 return
872         }
873         log.Printf("%s: dump\n", text)
874         if half != nil {
875                 p := half.first
876                 if p == nil {
877                         log.Printf(" * half.first = %p, no chunks queued\n", p)
878                 } else {
879                         s := 0
880                         nb := 0
881                         log.Printf(" * half.first = %p, queued chunks:", p)
882                         for p != nil {
883                                 log.Printf("\t%s bytes:%s\n", p, hex.EncodeToString(p.bytes))
884                                 s += len(p.bytes)
885                                 nb++
886                                 p = p.next
887                         }
888                         log.Printf("\t%d chunks for %d bytes", nb, s)
889                 }
890                 log.Printf(" * half.last = %p\n", half.last)
891                 log.Printf(" * half.saved = %p\n", half.saved)
892                 p = half.saved
893                 for p != nil {
894                         log.Printf("\tseq:%d %s bytes:%s\n", p.getSeq(), p, hex.EncodeToString(p.bytes))
895                         p = p.next
896                 }
897         }
898         log.Printf(" * a.ret\n")
899         for i, r := range a.ret {
900                 log.Printf("\t%d: %s b:%s\n", i, r.captureInfo(), hex.EncodeToString(r.getBytes()))
901         }
902         log.Printf(" * a.cacheSG.all\n")
903         for i, r := range a.cacheSG.all {
904                 log.Printf("\t%d: %s b:%s\n", i, r.captureInfo(), hex.EncodeToString(r.getBytes()))
905         }
906 }
907
908 func (a *Assembler) overlapExisting(half *halfconnection, start, end Sequence, bytes []byte) ([]byte, Sequence) {
909         if half.nextSeq == invalidSequence {
910                 // no start yet
911                 return bytes, start
912         }
913         diff := start.Difference(half.nextSeq)
914         if diff == 0 {
915                 return bytes, start
916         }
917         s := 0
918         e := len(bytes)
919         // TODO: depending on strategy, we might want to shrink half.saved if possible
920         if e != 0 {
921                 if *debugLog {
922                         log.Printf("Overlap detected: ignoring current packet's first %d bytes", diff)
923                 }
924                 half.overlapPackets++
925                 half.overlapBytes += diff
926         }
927         start = start.Add(diff)
928         s += diff
929         if s >= e {
930                 // Completely included in sent
931                 s = e
932         }
933         bytes = bytes[s:]
934         e -= diff
935         return bytes, start
936 }
937
938 // Prepare send or queue
939 func (a *Assembler) handleBytes(bytes []byte, seq Sequence, half *halfconnection, ci gopacket.CaptureInfo, start bool, end bool, action assemblerAction, ac AssemblerContext) assemblerAction {
940         a.cacheLP.bytes = bytes
941         a.cacheLP.start = start
942         a.cacheLP.end = end
943         a.cacheLP.seq = seq
944         a.cacheLP.ci = ci
945         a.cacheLP.ac = ac
946
947         if action.queue {
948                 a.checkOverlap(half, true, ac)
949                 if (a.MaxBufferedPagesPerConnection > 0 && half.pages >= a.MaxBufferedPagesPerConnection) ||
950                         (a.MaxBufferedPagesTotal > 0 && a.pc.used >= a.MaxBufferedPagesTotal) {
951                         if *debugLog {
952                                 log.Printf("hit max buffer size: %+v, %v, %v", a.AssemblerOptions, half.pages, a.pc.used)
953                         }
954                         action.queue = false
955                         a.addNextFromConn(half)
956                 }
957                 a.dump("handleBytes after queue", half)
958         } else {
959                 a.cacheLP.bytes, a.cacheLP.seq = a.overlapExisting(half, seq, seq.Add(len(bytes)), a.cacheLP.bytes)
960                 a.checkOverlap(half, false, ac)
961                 if len(a.cacheLP.bytes) != 0 || end || start {
962                         a.ret = append(a.ret, &a.cacheLP)
963                 }
964                 a.dump("handleBytes after no queue", half)
965         }
966         return action
967 }
968
969 func (a *Assembler) setStatsToSG(half *halfconnection) {
970         a.cacheSG.queuedBytes = half.queuedBytes
971         half.queuedBytes = 0
972         a.cacheSG.queuedPackets = half.queuedPackets
973         half.queuedPackets = 0
974         a.cacheSG.overlapBytes = half.overlapBytes
975         half.overlapBytes = 0
976         a.cacheSG.overlapPackets = half.overlapPackets
977         half.overlapPackets = 0
978 }
979
980 // Build the ScatterGather object, i.e. prepend saved bytes and
981 // append continuous bytes.
982 func (a *Assembler) buildSG(half *halfconnection) (bool, Sequence) {
983         // find if there are skipped bytes
984         skip := -1
985         if half.nextSeq != invalidSequence {
986                 skip = half.nextSeq.Difference(a.ret[0].getSeq())
987         }
988         last := a.ret[0].getSeq().Add(a.ret[0].length())
989         // Prepend saved bytes
990         saved := a.addPending(half, a.ret[0].getSeq())
991         // Append continuous bytes
992         nextSeq := a.addContiguous(half, last)
993         a.cacheSG.all = a.ret
994         a.cacheSG.Direction = half.dir
995         a.cacheSG.Skip = skip
996         a.cacheSG.saved = saved
997         a.cacheSG.toKeep = -1
998         a.setStatsToSG(half)
999         a.dump("after buildSG", half)
1000         return a.ret[len(a.ret)-1].isEnd(), nextSeq
1001 }
1002
1003 func (a *Assembler) cleanSG(half *halfconnection, ac AssemblerContext) {
1004         cur := 0
1005         ndx := 0
1006         skip := 0
1007
1008         a.dump("cleanSG(start)", half)
1009
1010         var r byteContainer
1011         // Find first page to keep
1012         if a.cacheSG.toKeep < 0 {
1013                 ndx = len(a.cacheSG.all)
1014         } else {
1015                 skip = a.cacheSG.toKeep
1016                 found := false
1017                 for ndx, r = range a.cacheSG.all {
1018                         if a.cacheSG.toKeep < cur+r.length() {
1019                                 found = true
1020                                 break
1021                         }
1022                         cur += r.length()
1023                         if skip >= r.length() {
1024                                 skip -= r.length()
1025                         }
1026                 }
1027                 if !found {
1028                         ndx++
1029                 }
1030         }
1031         // Release consumed pages
1032         for _, r := range a.cacheSG.all[:ndx] {
1033                 if r == half.saved {
1034                         if half.saved.next != nil {
1035                                 half.saved.next.prev = nil
1036                         }
1037                         half.saved = half.saved.next
1038                 } else if r == half.first {
1039                         if half.first.next != nil {
1040                                 half.first.next.prev = nil
1041                         }
1042                         if half.first == half.last {
1043                                 half.first, half.last = nil, nil
1044                         } else {
1045                                 half.first = half.first.next
1046                         }
1047                 }
1048                 half.pages -= r.release(a.pc)
1049         }
1050         a.dump("after consumed release", half)
1051         // Keep un-consumed pages
1052         nbKept := 0
1053         half.saved = nil
1054         var saved *page
1055         for _, r := range a.cacheSG.all[ndx:] {
1056                 first, last, nb := r.convertToPages(a.pc, skip, ac)
1057                 if half.saved == nil {
1058                         half.saved = first
1059                 } else {
1060                         saved.next = first
1061                         first.prev = saved
1062                 }
1063                 saved = last
1064                 nbKept += nb
1065         }
1066         if *debugLog {
1067                 log.Printf("Remaining %d chunks in SG\n", nbKept)
1068                 log.Printf("%s\n", a.Dump())
1069                 a.dump("after cleanSG()", half)
1070         }
1071 }
1072
1073 // sendToConnection sends the current values in a.ret to the connection, closing
1074 // the connection if the last thing sent had End set.
1075 func (a *Assembler) sendToConnection(conn *connection, half *halfconnection, ac AssemblerContext) Sequence {
1076         if *debugLog {
1077                 log.Printf("sendToConnection\n")
1078         }
1079         end, nextSeq := a.buildSG(half)
1080         half.stream.ReassembledSG(&a.cacheSG, ac)
1081         a.cleanSG(half, ac)
1082         if end {
1083                 a.closeHalfConnection(conn, half)
1084         }
1085         if *debugLog {
1086                 log.Printf("after sendToConnection: nextSeq: %d\n", nextSeq)
1087         }
1088         return nextSeq
1089 }
1090
1091 //
1092 func (a *Assembler) addPending(half *halfconnection, firstSeq Sequence) int {
1093         if half.saved == nil {
1094                 return 0
1095         }
1096         s := 0
1097         ret := []byteContainer{}
1098         for p := half.saved; p != nil; p = p.next {
1099                 if *debugLog {
1100                         log.Printf("adding pending @%p %s (%s)\n", p, p, hex.EncodeToString(p.bytes))
1101                 }
1102                 ret = append(ret, p)
1103                 s += len(p.bytes)
1104         }
1105         if half.saved.seq.Add(s) != firstSeq {
1106                 // non-continuous saved: drop them
1107                 var next *page
1108                 for p := half.saved; p != nil; p = next {
1109                         next = p.next
1110                         p.release(a.pc)
1111                 }
1112                 half.saved = nil
1113                 ret = []byteContainer{}
1114                 s = 0
1115         }
1116
1117         a.ret = append(ret, a.ret...)
1118         return s
1119 }
1120
1121 // addContiguous adds contiguous byte-sets to a connection.
1122 func (a *Assembler) addContiguous(half *halfconnection, lastSeq Sequence) Sequence {
1123         page := half.first
1124         if page == nil {
1125                 if *debugLog {
1126                         log.Printf("addContiguous(%d): no pages\n", lastSeq)
1127                 }
1128                 return lastSeq
1129         }
1130         if lastSeq == invalidSequence {
1131                 lastSeq = page.seq
1132         }
1133         for page != nil && lastSeq.Difference(page.seq) == 0 {
1134                 if *debugLog {
1135                         log.Printf("addContiguous: lastSeq: %d, first.seq=%d, page.seq=%d\n", half.nextSeq, half.first.seq, page.seq)
1136                 }
1137                 lastSeq = lastSeq.Add(len(page.bytes))
1138                 a.ret = append(a.ret, page)
1139                 half.first = page.next
1140                 if half.first == nil {
1141                         half.last = nil
1142                 }
1143                 if page.next != nil {
1144                         page.next.prev = nil
1145                 }
1146                 page = page.next
1147         }
1148         return lastSeq
1149 }
1150
1151 // skipFlush skips the first set of bytes we're waiting for and returns the
1152 // first set of bytes we have.  If we have no bytes saved, it closes the
1153 // connection.
1154 func (a *Assembler) skipFlush(conn *connection, half *halfconnection) {
1155         if *debugLog {
1156                 log.Printf("skipFlush %v\n", half.nextSeq)
1157         }
1158         // Well, it's embarassing it there is still something in half.saved
1159         // FIXME: change API to give back saved + new/no packets
1160         if half.first == nil {
1161                 a.closeHalfConnection(conn, half)
1162                 return
1163         }
1164         a.ret = a.ret[:0]
1165         a.addNextFromConn(half)
1166         nextSeq := a.sendToConnection(conn, half, a.ret[0].assemblerContext())
1167         if nextSeq != invalidSequence {
1168                 half.nextSeq = nextSeq
1169         }
1170 }
1171
1172 func (a *Assembler) closeHalfConnection(conn *connection, half *halfconnection) {
1173         if *debugLog {
1174                 log.Printf("%v closing", conn)
1175         }
1176         half.closed = true
1177         for p := half.first; p != nil; p = p.next {
1178                 // FIXME: it should be already empty
1179                 a.pc.replace(p)
1180                 half.pages--
1181         }
1182         if conn.s2c.closed && conn.c2s.closed {
1183                 if half.stream.ReassemblyComplete(nil) { //FIXME: which context to pass ?
1184                         a.connPool.remove(conn)
1185                 }
1186         }
1187 }
1188
1189 // addNextFromConn pops the first page from a connection off and adds it to the
1190 // return array.
1191 func (a *Assembler) addNextFromConn(conn *halfconnection) {
1192         if conn.first == nil {
1193                 return
1194         }
1195         if *debugLog {
1196                 log.Printf("   adding from conn (%v, %v) %v (%d)\n", conn.first.seq, conn.nextSeq, conn.nextSeq-conn.first.seq, len(conn.first.bytes))
1197         }
1198         a.ret = append(a.ret, conn.first)
1199         conn.first = conn.first.next
1200         if conn.first != nil {
1201                 conn.first.prev = nil
1202         } else {
1203                 conn.last = nil
1204         }
1205 }
1206
1207 // FlushOptions provide options for flushing connections.
1208 type FlushOptions struct {
1209         T  time.Time // If nonzero, only connections with data older than T are flushed
1210         TC time.Time // If nonzero, only connections with data older than TC are closed (if no FIN/RST received)
1211 }
1212
1213 // FlushWithOptions finds any streams waiting for packets older than
1214 // the given time T, and pushes through the data they have (IE: tells
1215 // them to stop waiting and skip the data they're waiting for).
1216 //
1217 // It also closes streams older than TC (that can be set to zero, to keep
1218 // long-lived stream alive, but to flush data anyway).
1219 //
1220 // Each Stream maintains a list of zero or more sets of bytes it has received
1221 // out-of-order.  For example, if it has processed up through sequence number
1222 // 10, it might have bytes [15-20), [20-25), [30,50) in its list.  Each set of
1223 // bytes also has the timestamp it was originally viewed.  A flush call will
1224 // look at the smallest subsequent set of bytes, in this case [15-20), and if
1225 // its timestamp is older than the passed-in time, it will push it and all
1226 // contiguous byte-sets out to the Stream's Reassembled function.  In this case,
1227 // it will push [15-20), but also [20-25), since that's contiguous.  It will
1228 // only push [30-50) if its timestamp is also older than the passed-in time,
1229 // otherwise it will wait until the next FlushCloseOlderThan to see if bytes
1230 // [25-30) come in.
1231 //
1232 // Returns the number of connections flushed, and of those, the number closed
1233 // because of the flush.
1234 func (a *Assembler) FlushWithOptions(opt FlushOptions) (flushed, closed int) {
1235         conns := a.connPool.connections()
1236         closes := 0
1237         flushes := 0
1238         for _, conn := range conns {
1239                 remove := false
1240                 conn.mu.Lock()
1241                 for _, half := range []*halfconnection{&conn.s2c, &conn.c2s} {
1242                         flushed, closed := a.flushClose(conn, half, opt.T, opt.TC)
1243                         if flushed {
1244                                 flushes++
1245                         }
1246                         if closed {
1247                                 closes++
1248                         }
1249                 }
1250                 if conn.s2c.closed && conn.c2s.closed && conn.s2c.lastSeen.Before(opt.TC) && conn.c2s.lastSeen.Before(opt.TC) {
1251                         remove = true
1252                 }
1253                 conn.mu.Unlock()
1254                 if remove {
1255                         a.connPool.remove(conn)
1256                 }
1257         }
1258         return flushes, closes
1259 }
1260
1261 // FlushCloseOlderThan flushes and closes streams older than given time
1262 func (a *Assembler) FlushCloseOlderThan(t time.Time) (flushed, closed int) {
1263         return a.FlushWithOptions(FlushOptions{T: t, TC: t})
1264 }
1265
1266 func (a *Assembler) flushClose(conn *connection, half *halfconnection, t time.Time, tc time.Time) (bool, bool) {
1267         flushed, closed := false, false
1268         if half.closed {
1269                 return flushed, closed
1270         }
1271         for half.first != nil && half.first.seen.Before(t) {
1272                 flushed = true
1273                 a.skipFlush(conn, half)
1274                 if half.closed {
1275                         closed = true
1276                 }
1277         }
1278         if !half.closed && half.first == nil && half.lastSeen.Before(tc) {
1279                 a.closeHalfConnection(conn, half)
1280                 closed = true
1281         }
1282         return flushed, closed
1283 }
1284
1285 // FlushAll flushes all remaining data into all remaining connections and closes
1286 // those connections. It returns the total number of connections flushed/closed
1287 // by the call.
1288 func (a *Assembler) FlushAll() (closed int) {
1289         conns := a.connPool.connections()
1290         closed = len(conns)
1291         for _, conn := range conns {
1292                 conn.mu.Lock()
1293                 for _, half := range []*halfconnection{&conn.s2c, &conn.c2s} {
1294                         for !half.closed {
1295                                 a.skipFlush(conn, half)
1296                         }
1297                         if !half.closed {
1298                                 a.closeHalfConnection(conn, half)
1299                         }
1300                 }
1301                 conn.mu.Unlock()
1302         }
1303         return
1304 }
1305
1306 func min(a, b int) int {
1307         if a < b {
1308                 return a
1309         }
1310         return b
1311 }