ODPM 266: Go-libmemif + 2 examples.
[govpp.git] / vendor / github.com / google / gopacket / layers / tcpip.go
1 // Copyright 2012 Google, Inc. All rights reserved.
2 // Copyright 2009-2011 Andreas Krennmair. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style license
5 // that can be found in the LICENSE file in the root of the source
6 // tree.
7
8 package layers
9
10 import (
11         "errors"
12         "fmt"
13
14         "github.com/google/gopacket"
15 )
16
17 // Checksum computation for TCP/UDP.
18 type tcpipchecksum struct {
19         pseudoheader tcpipPseudoHeader
20 }
21
22 type tcpipPseudoHeader interface {
23         pseudoheaderChecksum() (uint32, error)
24 }
25
26 func (ip *IPv4) pseudoheaderChecksum() (csum uint32, err error) {
27         if err := ip.AddressTo4(); err != nil {
28                 return 0, err
29         }
30         csum += (uint32(ip.SrcIP[0]) + uint32(ip.SrcIP[2])) << 8
31         csum += uint32(ip.SrcIP[1]) + uint32(ip.SrcIP[3])
32         csum += (uint32(ip.DstIP[0]) + uint32(ip.DstIP[2])) << 8
33         csum += uint32(ip.DstIP[1]) + uint32(ip.DstIP[3])
34         return csum, nil
35 }
36
37 func (ip *IPv6) pseudoheaderChecksum() (csum uint32, err error) {
38         if err := ip.AddressTo16(); err != nil {
39                 return 0, err
40         }
41         for i := 0; i < 16; i += 2 {
42                 csum += uint32(ip.SrcIP[i]) << 8
43                 csum += uint32(ip.SrcIP[i+1])
44                 csum += uint32(ip.DstIP[i]) << 8
45                 csum += uint32(ip.DstIP[i+1])
46         }
47         return csum, nil
48 }
49
50 // Calculate the TCP/IP checksum defined in rfc1071.  The passed-in csum is any
51 // initial checksum data that's already been computed.
52 func tcpipChecksum(data []byte, csum uint32) uint16 {
53         // to handle odd lengths, we loop to length - 1, incrementing by 2, then
54         // handle the last byte specifically by checking against the original
55         // length.
56         length := len(data) - 1
57         for i := 0; i < length; i += 2 {
58                 // For our test packet, doing this manually is about 25% faster
59                 // (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16.
60                 csum += uint32(data[i]) << 8
61                 csum += uint32(data[i+1])
62         }
63         if len(data)%2 == 1 {
64                 csum += uint32(data[length]) << 8
65         }
66         for csum > 0xffff {
67                 csum = (csum >> 16) + (csum & 0xffff)
68         }
69         return ^uint16(csum)
70 }
71
72 // computeChecksum computes a TCP or UDP checksum.  headerAndPayload is the
73 // serialized TCP or UDP header plus its payload, with the checksum zero'd
74 // out. headerProtocol is the IP protocol number of the upper-layer header.
75 func (c *tcpipchecksum) computeChecksum(headerAndPayload []byte, headerProtocol IPProtocol) (uint16, error) {
76         if c.pseudoheader == nil {
77                 return 0, errors.New("TCP/IP layer 4 checksum cannot be computed without network layer... call SetNetworkLayerForChecksum to set which layer to use")
78         }
79         length := uint32(len(headerAndPayload))
80         csum, err := c.pseudoheader.pseudoheaderChecksum()
81         if err != nil {
82                 return 0, err
83         }
84         csum += uint32(headerProtocol)
85         csum += length & 0xffff
86         csum += length >> 16
87         return tcpipChecksum(headerAndPayload, csum), nil
88 }
89
90 // SetNetworkLayerForChecksum tells this layer which network layer is wrapping it.
91 // This is needed for computing the checksum when serializing, since TCP/IP transport
92 // layer checksums depends on fields in the IPv4 or IPv6 layer that contains it.
93 // The passed in layer must be an *IPv4 or *IPv6.
94 func (i *tcpipchecksum) SetNetworkLayerForChecksum(l gopacket.NetworkLayer) error {
95         switch v := l.(type) {
96         case *IPv4:
97                 i.pseudoheader = v
98         case *IPv6:
99                 i.pseudoheader = v
100         default:
101                 return fmt.Errorf("cannot use layer type %v for tcp checksum network layer", l.LayerType())
102         }
103         return nil
104 }