cbbeb6ef76e937def8cb4d923bb214151d4700a2
[govpp.git] / vendor / github.com / google / gopacket / routing / routing.go
1 // Copyright 2012 Google, Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style license
4 // that can be found in the LICENSE file in the root of the source
5 // tree.
6
7 // +build linux
8
9 // Package routing provides a very basic but mostly functional implementation of
10 // a routing table for IPv4/IPv6 addresses.  It uses a routing table pulled from
11 // the kernel via netlink to find the correct interface, gateway, and preferred
12 // source IP address for packets destined to a particular location.
13 //
14 // The routing package is meant to be used with applications that are sending
15 // raw packet data, which don't have the benefit of having the kernel route
16 // packets for them.
17 package routing
18
19 import (
20         "bytes"
21         "errors"
22         "fmt"
23         "net"
24         "sort"
25         "strings"
26         "syscall"
27         "unsafe"
28 )
29
30 // Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
31 // See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
32 type routeInfoInMemory struct {
33         Family byte
34         DstLen byte
35         SrcLen byte
36         TOS    byte
37
38         Table    byte
39         Protocol byte
40         Scope    byte
41         Type     byte
42
43         Flags uint32
44 }
45
46 // rtInfo contains information on a single route.
47 type rtInfo struct {
48         Src, Dst         *net.IPNet
49         Gateway, PrefSrc net.IP
50         // We currently ignore the InputIface.
51         InputIface, OutputIface uint32
52         Priority                uint32
53 }
54
55 // routeSlice implements sort.Interface to sort routes by Priority.
56 type routeSlice []*rtInfo
57
58 func (r routeSlice) Len() int {
59         return len(r)
60 }
61 func (r routeSlice) Less(i, j int) bool {
62         return r[i].Priority < r[j].Priority
63 }
64 func (r routeSlice) Swap(i, j int) {
65         r[i], r[j] = r[j], r[i]
66 }
67
68 type router struct {
69         ifaces []net.Interface
70         addrs  []ipAddrs
71         v4, v6 routeSlice
72 }
73
74 func (r *router) String() string {
75         strs := []string{"ROUTER", "--- V4 ---"}
76         for _, route := range r.v4 {
77                 strs = append(strs, fmt.Sprintf("%+v", *route))
78         }
79         strs = append(strs, "--- V6 ---")
80         for _, route := range r.v6 {
81                 strs = append(strs, fmt.Sprintf("%+v", *route))
82         }
83         return strings.Join(strs, "\n")
84 }
85
86 type ipAddrs struct {
87         v4, v6 net.IP
88 }
89
90 func (r *router) Route(dst net.IP) (iface *net.Interface, gateway, preferredSrc net.IP, err error) {
91         return r.RouteWithSrc(nil, nil, dst)
92 }
93
94 func (r *router) RouteWithSrc(input net.HardwareAddr, src, dst net.IP) (iface *net.Interface, gateway, preferredSrc net.IP, err error) {
95         var ifaceIndex int
96         switch {
97         case dst.To4() != nil:
98                 ifaceIndex, gateway, preferredSrc, err = r.route(r.v4, input, src, dst)
99         case dst.To16() != nil:
100                 ifaceIndex, gateway, preferredSrc, err = r.route(r.v6, input, src, dst)
101         default:
102                 err = errors.New("IP is not valid as IPv4 or IPv6")
103                 return
104         }
105
106         // Interfaces are 1-indexed, but we store them in a 0-indexed array.
107         ifaceIndex--
108
109         iface = &r.ifaces[ifaceIndex]
110         if preferredSrc == nil {
111                 switch {
112                 case dst.To4() != nil:
113                         preferredSrc = r.addrs[ifaceIndex].v4
114                 case dst.To16() != nil:
115                         preferredSrc = r.addrs[ifaceIndex].v6
116                 }
117         }
118         return
119 }
120
121 func (r *router) route(routes routeSlice, input net.HardwareAddr, src, dst net.IP) (iface int, gateway, preferredSrc net.IP, err error) {
122         var inputIndex uint32
123         if input != nil {
124                 for i, iface := range r.ifaces {
125                         if bytes.Equal(input, iface.HardwareAddr) {
126                                 // Convert from zero- to one-indexed.
127                                 inputIndex = uint32(i + 1)
128                                 break
129                         }
130                 }
131         }
132         for _, rt := range routes {
133                 if rt.InputIface != 0 && rt.InputIface != inputIndex {
134                         continue
135                 }
136                 if rt.Src != nil && !rt.Src.Contains(src) {
137                         continue
138                 }
139                 if rt.Dst != nil && !rt.Dst.Contains(dst) {
140                         continue
141                 }
142                 return int(rt.OutputIface), rt.Gateway, rt.PrefSrc, nil
143         }
144         err = fmt.Errorf("no route found for %v", dst)
145         return
146 }
147
148 // New creates a new router object.  The router returned by New currently does
149 // not update its routes after construction... care should be taken for
150 // long-running programs to call New() regularly to take into account any
151 // changes to the routing table which have occurred since the last New() call.
152 func New() (Router, error) {
153         rtr := &router{}
154         tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
155         if err != nil {
156                 return nil, err
157         }
158         msgs, err := syscall.ParseNetlinkMessage(tab)
159         if err != nil {
160                 return nil, err
161         }
162 loop:
163         for _, m := range msgs {
164                 switch m.Header.Type {
165                 case syscall.NLMSG_DONE:
166                         break loop
167                 case syscall.RTM_NEWROUTE:
168                         rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
169                         routeInfo := rtInfo{}
170                         attrs, err := syscall.ParseNetlinkRouteAttr(&m)
171                         if err != nil {
172                                 return nil, err
173                         }
174                         switch rt.Family {
175                         case syscall.AF_INET:
176                                 rtr.v4 = append(rtr.v4, &routeInfo)
177                         case syscall.AF_INET6:
178                                 rtr.v6 = append(rtr.v6, &routeInfo)
179                         default:
180                                 continue loop
181                         }
182                         for _, attr := range attrs {
183                                 switch attr.Attr.Type {
184                                 case syscall.RTA_DST:
185                                         routeInfo.Dst = &net.IPNet{
186                                                 IP:   net.IP(attr.Value),
187                                                 Mask: net.CIDRMask(int(rt.DstLen), len(attr.Value)*8),
188                                         }
189                                 case syscall.RTA_SRC:
190                                         routeInfo.Src = &net.IPNet{
191                                                 IP:   net.IP(attr.Value),
192                                                 Mask: net.CIDRMask(int(rt.SrcLen), len(attr.Value)*8),
193                                         }
194                                 case syscall.RTA_GATEWAY:
195                                         routeInfo.Gateway = net.IP(attr.Value)
196                                 case syscall.RTA_PREFSRC:
197                                         routeInfo.PrefSrc = net.IP(attr.Value)
198                                 case syscall.RTA_IIF:
199                                         routeInfo.InputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
200                                 case syscall.RTA_OIF:
201                                         routeInfo.OutputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
202                                 case syscall.RTA_PRIORITY:
203                                         routeInfo.Priority = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
204                                 }
205                         }
206                 }
207         }
208         sort.Sort(rtr.v4)
209         sort.Sort(rtr.v6)
210         ifaces, err := net.Interfaces()
211         if err != nil {
212                 return nil, err
213         }
214         for i, iface := range ifaces {
215                 if i != iface.Index-1 {
216                         return nil, fmt.Errorf("out of order iface %d = %v", i, iface)
217                 }
218                 rtr.ifaces = append(rtr.ifaces, iface)
219                 var addrs ipAddrs
220                 ifaceAddrs, err := iface.Addrs()
221                 if err != nil {
222                         return nil, err
223                 }
224                 for _, addr := range ifaceAddrs {
225                         if inet, ok := addr.(*net.IPNet); ok {
226                                 // Go has a nasty habit of giving you IPv4s as ::ffff:1.2.3.4 instead of 1.2.3.4.
227                                 // We want to use mapped v4 addresses as v4 preferred addresses, never as v6
228                                 // preferred addresses.
229                                 if v4 := inet.IP.To4(); v4 != nil {
230                                         if addrs.v4 == nil {
231                                                 addrs.v4 = v4
232                                         }
233                                 } else if addrs.v6 == nil {
234                                         addrs.v6 = inet.IP
235                                 }
236                         }
237                 }
238                 rtr.addrs = append(rtr.addrs, addrs)
239         }
240         return rtr, nil
241 }