wireguard: initial implementation of wireguard protocol
[vpp.git] / src / plugins / wireguard / test / test_wireguard.py
1 #!/usr/bin/env python3
2 """ Wg tests """
3
4 from scapy.packet import Packet
5 from scapy.packet import Raw
6 from scapy.layers.l2 import Ether
7 from scapy.layers.inet import IP, UDP
8 from scapy.contrib.wireguard import Wireguard, WireguardResponse, \
9     WireguardInitiation
10 from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
11 from cryptography.hazmat.primitives.serialization import Encoding, \
12     PrivateFormat, PublicFormat, NoEncryption
13
14 from vpp_ipip_tun_interface import VppIpIpTunInterface
15 from vpp_interface import VppInterface
16 from vpp_object import VppObject
17 from framework import VppTestCase
18 from re import compile
19 import unittest
20
21 """ TestWg is a subclass of  VPPTestCase classes.
22
23 Wg test.
24
25 """
26
27
28 class VppWgInterface(VppInterface):
29     """
30     VPP WireGuard interface
31     """
32
33     def __init__(self, test, src, port, key=None):
34         super(VppWgInterface, self).__init__(test)
35
36         self.key = key
37         if not self.key:
38             self.generate = True
39         else:
40             self.generate = False
41         self.port = port
42         self.src = src
43
44     def add_vpp_config(self):
45         r = self.test.vapi.wireguard_interface_create(interface={
46             'user_instance': 0xffffffff,
47             'port': self.port,
48             'src_ip': self.src,
49             'private_key': self.key_bytes()
50         })
51         self.set_sw_if_index(r.sw_if_index)
52         self.test.registry.register(self, self.test.logger)
53         return self
54
55     def key_bytes(self):
56         if self.key:
57             return self.key.private_bytes(Encoding.Raw,
58                                           PrivateFormat.Raw,
59                                           NoEncryption())
60         else:
61             return bytearray(32)
62
63     def remove_vpp_config(self):
64         self.test.vapi.wireguard_interface_delete(
65             sw_if_index=self._sw_if_index)
66
67     def query_vpp_config(self):
68         ts = self.test.vapi.wireguard_interface_dump(sw_if_index=0xffffffff)
69         for t in ts:
70             if t.interface.sw_if_index == self._sw_if_index and \
71                str(t.interface.src_ip) == self.src and \
72                t.interface.port == self.port and \
73                t.interface.private_key == self.key_bytes():
74                 return True
75         return False
76
77     def __str__(self):
78         return self.object_id()
79
80     def object_id(self):
81         return "wireguard-%d" % self._sw_if_index
82
83
84 def find_route(test, prefix, table_id=0):
85     routes = test.vapi.ip_route_dump(table_id, False)
86
87     for e in routes:
88         if table_id == e.route.table_id \
89            and str(e.route.prefix) == str(prefix):
90             return True
91     return False
92
93
94 class VppWgPeer(VppObject):
95
96     def __init__(self,
97                  test,
98                  itf,
99                  endpoint,
100                  port,
101                  allowed_ips,
102                  persistent_keepalive=15):
103         self._test = test
104         self.itf = itf
105         self.endpoint = endpoint
106         self.port = port
107         self.allowed_ips = allowed_ips
108         self.persistent_keepalive = persistent_keepalive
109         self.private_key = X25519PrivateKey.generate()
110         self.public_key = self.private_key.public_key()
111         self.hash = bytearray(16)
112
113     def validate_routing(self):
114         for a in self.allowed_ips:
115             self._test.assertTrue(find_route(self._test, a))
116
117     def validate_no_routing(self):
118         for a in self.allowed_ips:
119             self._test.assertFalse(find_route(self._test, a))
120
121     def add_vpp_config(self):
122         rv = self._test.vapi.wireguard_peer_add(
123             peer={
124                 'public_key': self.public_key_bytes(),
125                 'port': self.port,
126                 'endpoint': self.endpoint,
127                 'n_allowed_ips': len(self.allowed_ips),
128                 'allowed_ips': self.allowed_ips,
129                 'sw_if_index': self.itf.sw_if_index,
130                 'persistent_keepalive': self.persistent_keepalive})
131         self.index = rv.peer_index
132         self._test.registry.register(self, self._test.logger)
133         self.validate_routing()
134         return self
135
136     def remove_vpp_config(self):
137         self._test.vapi.wireguard_peer_remove(peer_index=self.index)
138         self.validate_no_routing()
139
140     def object_id(self):
141         return ("wireguard-peer-%s" % self.index)
142
143     def public_key_bytes(self):
144         return self.public_key.public_bytes(Encoding.Raw,
145                                             PublicFormat.Raw)
146
147     def private_key_bytes(self):
148         return self.private_key.private_bytes(Encoding.Raw,
149                                               PrivateFormat.Raw,
150                                               NoEncryption())
151
152     def query_vpp_config(self):
153         peers = self._test.vapi.wireguard_peers_dump()
154
155         for p in peers:
156             if p.peer.public_key == self.public_key_bytes() and \
157                p.peer.port == self.port and \
158                str(p.peer.endpoint) == self.endpoint and \
159                p.peer.sw_if_index == self.itf.sw_if_index and \
160                len(self.allowed_ips) == p.peer.n_allowed_ips:
161                 self.allowed_ips.sort()
162                 p.peer.allowed_ips.sort()
163
164                 for (a1, a2) in zip(self.allowed_ips, p.peer.allowed_ips):
165                     if str(a1) != str(a2):
166                         return False
167                 return True
168         return False
169
170
171 class TestWg(VppTestCase):
172     """ Wireguard Test Case """
173
174     error_str = compile(r"Error")
175
176     @classmethod
177     def setUpClass(cls):
178         super(TestWg, cls).setUpClass()
179         try:
180             cls.create_pg_interfaces(range(3))
181             for i in cls.pg_interfaces:
182                 i.admin_up()
183                 i.config_ip4()
184                 i.resolve_arp()
185
186         except Exception:
187             super(TestWg, cls).tearDownClass()
188             raise
189
190     @classmethod
191     def tearDownClass(cls):
192         super(TestWg, cls).tearDownClass()
193
194     def test_wg_interface(self):
195         port = 12312
196
197         # Create interface
198         wg0 = VppWgInterface(self,
199                              self.pg1.local_ip4,
200                              port).add_vpp_config()
201
202         self.logger.info(self.vapi.cli("sh int"))
203
204         # delete interface
205         wg0.remove_vpp_config()
206
207     def test_wg_peer(self):
208         wg_output_node_name = '/err/wg-output-tun/'
209         wg_input_node_name = '/err/wg-input/'
210
211         port = 12323
212
213         # Create interfaces
214         wg0 = VppWgInterface(self,
215                              self.pg1.local_ip4,
216                              port,
217                              key=X25519PrivateKey.generate()).add_vpp_config()
218         wg1 = VppWgInterface(self,
219                              self.pg2.local_ip4,
220                              port+1).add_vpp_config()
221         wg0.admin_up()
222         wg1.admin_up()
223
224         # Check peer counter
225         self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
226
227         self.pg_enable_capture(self.pg_interfaces)
228         self.pg_start()
229
230         peer_1 = VppWgPeer(self,
231                            wg0,
232                            self.pg1.remote_ip4,
233                            port+1,
234                            ["10.11.2.0/24",
235                             "10.11.3.0/24"]).add_vpp_config()
236         self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
237
238         # wait for the peer to send a handshake
239         capture = self.pg1.get_capture(1, timeout=2)
240         handshake = capture[0]
241
242         self.assertEqual(handshake[IP].src, wg0.src)
243         self.assertEqual(handshake[IP].dst, peer_1.endpoint)
244         self.assertEqual(handshake[UDP].sport, wg0.port)
245         self.assertEqual(handshake[UDP].dport, peer_1.port)
246         handshake = Wireguard(handshake[Raw])
247         self.assertEqual(handshake.message_type, 1)  # "initiate")
248         init = handshake[WireguardInitiation]
249
250         # route a packet into the wg interface
251         #  use the allowed-ip prefix
252         p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
253              IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
254              UDP(sport=555, dport=556) /
255              Raw())
256         # rx = self.send_and_expect(self.pg0, [p], self.pg1)
257         rx = self.send_and_assert_no_replies(self.pg0, [p])
258
259         self.logger.info(self.vapi.cli("sh error"))
260         init_sent = wg_output_node_name + "Keypair error"
261         self.assertEqual(1, self.statistics.get_err_counter(init_sent))
262
263         # Create many peers on sencond interface
264         NUM_PEERS = 16
265         self.pg2.generate_remote_hosts(NUM_PEERS)
266         self.pg2.configure_ipv4_neighbors()
267
268         peers = []
269         for i in range(NUM_PEERS):
270             peers.append(VppWgPeer(self,
271                                    wg1,
272                                    self.pg2.remote_hosts[i].ip4,
273                                    port+1+i,
274                                    ["10.10.%d.4/32" % i]).add_vpp_config())
275             self.assertEqual(len(self.vapi.wireguard_peers_dump()), i+2)
276
277         self.logger.info(self.vapi.cli("show wireguard peer"))
278         self.logger.info(self.vapi.cli("show wireguard interface"))
279         self.logger.info(self.vapi.cli("show adj 37"))
280         self.logger.info(self.vapi.cli("sh ip fib 172.16.3.17"))
281         self.logger.info(self.vapi.cli("sh ip fib 10.11.3.0"))
282
283         # remove peers
284         for p in peers:
285             self.assertTrue(p.query_vpp_config())
286             p.remove_vpp_config()
287         self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
288         peer_1.remove_vpp_config()
289         self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
290
291         wg0.remove_vpp_config()
292         # wg1.remove_vpp_config()