7 from hashlib import blake2s
8 from scapy.packet import Packet
9 from scapy.packet import Raw
10 from scapy.layers.l2 import Ether, ARP
11 from scapy.layers.inet import IP, UDP
12 from scapy.contrib.wireguard import Wireguard, WireguardResponse, \
13 WireguardInitiation, WireguardTransport
14 from cryptography.hazmat.primitives.asymmetric.x25519 import \
15 X25519PrivateKey, X25519PublicKey
16 from cryptography.hazmat.primitives.serialization import Encoding, \
17 PrivateFormat, PublicFormat, NoEncryption
18 from cryptography.hazmat.primitives.hashes import BLAKE2s, Hash
19 from cryptography.hazmat.primitives.hmac import HMAC
20 from cryptography.hazmat.backends import default_backend
21 from noise.connection import NoiseConnection, Keypair
23 from vpp_ipip_tun_interface import VppIpIpTunInterface
24 from vpp_interface import VppInterface
25 from vpp_object import VppObject
26 from framework import VppTestCase
27 from re import compile
30 """ TestWg is a subclass of VPPTestCase classes.
37 def private_key_bytes(k):
38 return k.private_bytes(Encoding.Raw,
43 def public_key_bytes(k):
44 return k.public_bytes(Encoding.Raw,
48 class VppWgInterface(VppInterface):
50 VPP WireGuard interface
53 def __init__(self, test, src, port):
54 super(VppWgInterface, self).__init__(test)
58 self.private_key = X25519PrivateKey.generate()
59 self.public_key = self.private_key.public_key()
61 def public_key_bytes(self):
62 return public_key_bytes(self.public_key)
64 def private_key_bytes(self):
65 return private_key_bytes(self.private_key)
67 def add_vpp_config(self):
68 r = self.test.vapi.wireguard_interface_create(interface={
69 'user_instance': 0xffffffff,
72 'private_key': private_key_bytes(self.private_key),
75 self.set_sw_if_index(r.sw_if_index)
76 self.test.registry.register(self, self.test.logger)
79 def remove_vpp_config(self):
80 self.test.vapi.wireguard_interface_delete(
81 sw_if_index=self._sw_if_index)
83 def query_vpp_config(self):
84 ts = self.test.vapi.wireguard_interface_dump(sw_if_index=0xffffffff)
86 if t.interface.sw_if_index == self._sw_if_index and \
87 str(t.interface.src_ip) == self.src and \
88 t.interface.port == self.port and \
89 t.interface.private_key == private_key_bytes(self.private_key):
94 return self.object_id()
97 return "wireguard-%d" % self._sw_if_index
100 def find_route(test, prefix, table_id=0):
101 routes = test.vapi.ip_route_dump(table_id, False)
104 if table_id == e.route.table_id \
105 and str(e.route.prefix) == str(prefix):
110 NOISE_HANDSHAKE_NAME = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
111 NOISE_IDENTIFIER_NAME = b"WireGuard v1 zx2c4 Jason@zx2c4.com"
114 class VppWgPeer(VppObject):
122 persistent_keepalive=15):
125 self.endpoint = endpoint
127 self.allowed_ips = allowed_ips
128 self.persistent_keepalive = persistent_keepalive
130 # remote peer's public
131 self.private_key = X25519PrivateKey.generate()
132 self.public_key = self.private_key.public_key()
134 self.noise = NoiseConnection.from_name(NOISE_HANDSHAKE_NAME)
136 def validate_routing(self):
137 for a in self.allowed_ips:
138 self._test.assertTrue(find_route(self._test, a))
140 def validate_no_routing(self):
141 for a in self.allowed_ips:
142 self._test.assertFalse(find_route(self._test, a))
144 def add_vpp_config(self):
145 rv = self._test.vapi.wireguard_peer_add(
147 'public_key': self.public_key_bytes(),
149 'endpoint': self.endpoint,
150 'n_allowed_ips': len(self.allowed_ips),
151 'allowed_ips': self.allowed_ips,
152 'sw_if_index': self.itf.sw_if_index,
153 'persistent_keepalive': self.persistent_keepalive})
154 self.index = rv.peer_index
155 self.receiver_index = self.index + 1
156 self._test.registry.register(self, self._test.logger)
157 self.validate_routing()
160 def remove_vpp_config(self):
161 self._test.vapi.wireguard_peer_remove(peer_index=self.index)
162 self.validate_no_routing()
165 return ("wireguard-peer-%s" % self.index)
167 def public_key_bytes(self):
168 return public_key_bytes(self.public_key)
170 def query_vpp_config(self):
171 peers = self._test.vapi.wireguard_peers_dump()
174 if p.peer.public_key == self.public_key_bytes() and \
175 p.peer.port == self.port and \
176 str(p.peer.endpoint) == self.endpoint and \
177 p.peer.sw_if_index == self.itf.sw_if_index and \
178 len(self.allowed_ips) == p.peer.n_allowed_ips:
179 self.allowed_ips.sort()
180 p.peer.allowed_ips.sort()
182 for (a1, a2) in zip(self.allowed_ips, p.peer.allowed_ips):
183 if str(a1) != str(a2):
188 def set_responder(self):
189 self.noise.set_as_responder()
191 def mk_tunnel_header(self, tx_itf):
192 return (Ether(dst=tx_itf.local_mac, src=tx_itf.remote_mac) /
193 IP(src=self.endpoint, dst=self.itf.src) /
194 UDP(sport=self.port, dport=self.itf.port))
196 def noise_init(self, public_key=None):
197 self.noise.set_prologue(NOISE_IDENTIFIER_NAME)
198 self.noise.set_psks(psk=bytes(bytearray(32)))
201 public_key = self.itf.public_key
204 self.noise.set_keypair_from_private_bytes(
206 private_key_bytes(self.private_key))
208 self.noise.set_keypair_from_public_bytes(
209 Keypair.REMOTE_STATIC,
210 public_key_bytes(public_key))
212 self.noise.start_handshake()
214 def mk_handshake(self, tx_itf, public_key=None):
215 self.noise.set_as_initiator()
216 self.noise_init(public_key)
218 p = (Wireguard() / WireguardInitiation())
220 p[Wireguard].message_type = 1
221 p[Wireguard].reserved_zero = 0
222 p[WireguardInitiation].sender_index = self.receiver_index
224 # some random data for the message
225 # lifted from the noise protocol's wireguard example
226 now = datetime.datetime.now()
227 tai = struct.pack('!qi', 4611686018427387914 + int(now.timestamp()),
228 int(now.microsecond * 1e3))
229 b = self.noise.write_message(payload=tai)
231 # load noise into init message
232 p[WireguardInitiation].unencrypted_ephemeral = b[0:32]
233 p[WireguardInitiation].encrypted_static = b[32:80]
234 p[WireguardInitiation].encrypted_timestamp = b[80:108]
236 # generate the mac1 hash
237 mac_key = blake2s(b'mac1----' +
238 self.itf.public_key_bytes()).digest()
239 p[WireguardInitiation].mac1 = blake2s(bytes(p)[0:116],
241 key=mac_key).digest()
242 p[WireguardInitiation].mac2 = bytearray(16)
244 p = (self.mk_tunnel_header(tx_itf) / p)
248 def verify_header(self, p):
249 self._test.assertEqual(p[IP].src, self.itf.src)
250 self._test.assertEqual(p[IP].dst, self.endpoint)
251 self._test.assertEqual(p[UDP].sport, self.itf.port)
252 self._test.assertEqual(p[UDP].dport, self.port)
253 self._test.assert_packet_checksums_valid(p)
255 def consume_init(self, p, tx_itf):
256 self.noise.set_as_responder()
257 self.noise_init(self.itf.public_key)
258 self.verify_header(p)
260 init = Wireguard(p[Raw])
262 self._test.assertEqual(init[Wireguard].message_type, 1)
263 self._test.assertEqual(init[Wireguard].reserved_zero, 0)
265 self.sender = init[WireguardInitiation].sender_index
268 mac_key = blake2s(b'mac1----' +
269 public_key_bytes(self.public_key)).digest()
270 mac1 = blake2s(bytes(init)[0:-32],
272 key=mac_key).digest()
273 self._test.assertEqual(init[WireguardInitiation].mac1, mac1)
275 # this passes only unencrypted_ephemeral, encrypted_static,
276 # encrypted_timestamp fields of the init
277 payload = self.noise.read_message(bytes(init)[8:-32])
280 b = self.noise.write_message()
281 mac_key = blake2s(b'mac1----' +
282 public_key_bytes(self.itf.public_key)).digest()
283 resp = (Wireguard(message_type=2, reserved_zero=0) /
284 WireguardResponse(sender_index=self.receiver_index,
285 receiver_index=self.sender,
286 unencrypted_ephemeral=b[0:32],
287 encrypted_nothing=b[32:]))
288 mac1 = blake2s(bytes(resp)[:-32],
290 key=mac_key).digest()
291 resp[WireguardResponse].mac1 = mac1
293 resp = (self.mk_tunnel_header(tx_itf) / resp)
294 self._test.assertTrue(self.noise.handshake_finished)
298 def consume_response(self, p):
299 self.verify_header(p)
301 resp = Wireguard(p[Raw])
303 self._test.assertEqual(resp[Wireguard].message_type, 2)
304 self._test.assertEqual(resp[Wireguard].reserved_zero, 0)
305 self._test.assertEqual(resp[WireguardResponse].receiver_index,
308 self.sender = resp[Wireguard].sender_index
310 payload = self.noise.read_message(bytes(resp)[12:60])
311 self._test.assertEqual(payload, b'')
312 self._test.assertTrue(self.noise.handshake_finished)
314 def decrypt_transport(self, p):
315 self.verify_header(p)
317 p = Wireguard(p[Raw])
318 self._test.assertEqual(p[Wireguard].message_type, 4)
319 self._test.assertEqual(p[Wireguard].reserved_zero, 0)
320 self._test.assertEqual(p[WireguardTransport].receiver_index,
323 d = self.noise.decrypt(
324 p[WireguardTransport].encrypted_encapsulated_packet)
327 def encrypt_transport(self, p):
328 return self.noise.encrypt(bytes(p))
331 class TestWg(VppTestCase):
332 """ Wireguard Test Case """
334 error_str = compile(r"Error")
338 super(TestWg, cls).setUpClass()
340 cls.create_pg_interfaces(range(3))
341 for i in cls.pg_interfaces:
347 super(TestWg, cls).tearDownClass()
351 def tearDownClass(cls):
352 super(TestWg, cls).tearDownClass()
354 def test_wg_interface(self):
355 """ Simple interface creation """
359 wg0 = VppWgInterface(self,
361 port).add_vpp_config()
363 self.logger.info(self.vapi.cli("sh int"))
366 wg0.remove_vpp_config()
368 def test_handshake_hash(self):
369 """ test hashing an init message """
370 # a init packet generated by linux given the key below
371 h = "0100000098b9032b" \
391 b = bytearray.fromhex(h)
394 pubb = base64.b64decode("aRuHFTTxICIQNefp05oKWlJv3zgKxb8+WW7JJMh0jyM=")
395 pub = X25519PublicKey.from_public_bytes(pubb)
397 self.assertEqual(pubb, public_key_bytes(pub))
399 # strip the macs and build a new packet
401 mac_key = blake2s(b'mac1----' + public_key_bytes(pub)).digest()
402 init += blake2s(init,
404 key=mac_key).digest()
407 act = Wireguard(init)
409 self.assertEqual(tgt, act)
411 def test_wg_peer_resp(self):
412 """ Send handshake response """
413 wg_output_node_name = '/err/wg-output-tun/'
414 wg_input_node_name = '/err/wg-input/'
419 wg0 = VppWgInterface(self,
421 port).add_vpp_config()
425 self.pg_enable_capture(self.pg_interfaces)
428 peer_1 = VppWgPeer(self,
433 "10.11.3.0/24"]).add_vpp_config()
434 self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
436 # wait for the peer to send a handshake
437 rx = self.pg1.get_capture(1, timeout=2)
439 # consume the handshake in the noise protocol and
440 # generate the response
441 resp = peer_1.consume_init(rx[0], self.pg1)
443 # send the response, get keepalive
444 rxs = self.send_and_expect(self.pg1, [resp], self.pg1)
447 b = peer_1.decrypt_transport(rx)
448 self.assertEqual(0, len(b))
450 # send a packets that are routed into the tunnel
451 p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
452 IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
453 UDP(sport=555, dport=556) /
456 rxs = self.send_and_expect(self.pg0, p * 255, self.pg1)
459 rx = IP(peer_1.decrypt_transport(rx))
460 # chech the oringial packet is present
461 self.assertEqual(rx[IP].dst, p[IP].dst)
462 self.assertEqual(rx[IP].ttl, p[IP].ttl-1)
464 # send packets into the tunnel, expect to receive them on
466 p = [(peer_1.mk_tunnel_header(self.pg1) /
467 Wireguard(message_type=4, reserved_zero=0) /
469 receiver_index=peer_1.sender,
471 encrypted_encapsulated_packet=peer_1.encrypt_transport(
472 (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
473 UDP(sport=222, dport=223) /
474 Raw())))) for ii in range(255)]
476 rxs = self.send_and_expect(self.pg1, p, self.pg0)
479 self.assertEqual(rx[IP].dst, self.pg0.remote_ip4)
480 self.assertEqual(rx[IP].ttl, 19)
482 def test_wg_peer_init(self):
483 """ Send handshake init """
484 wg_output_node_name = '/err/wg-output-tun/'
485 wg_input_node_name = '/err/wg-input/'
490 wg0 = VppWgInterface(self,
492 port).add_vpp_config()
496 peer_1 = VppWgPeer(self,
501 "10.11.3.0/24"]).add_vpp_config()
502 self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
504 # route a packet into the wg interface
505 # use the allowed-ip prefix
506 # this is dropped because the peer is not initiated
507 p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
508 IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
509 UDP(sport=555, dport=556) /
511 self.send_and_assert_no_replies(self.pg0, [p])
513 kp_error = wg_output_node_name + "Keypair error"
514 self.assertEqual(1, self.statistics.get_err_counter(kp_error))
516 # send a handsake from the peer with an invalid MAC
517 p = peer_1.mk_handshake(self.pg1)
518 p[WireguardInitiation].mac1 = b'foobar'
519 self.send_and_assert_no_replies(self.pg1, [p])
520 self.assertEqual(1, self.statistics.get_err_counter(
521 wg_input_node_name + "Invalid MAC handshake"))
523 # send a handsake from the peer but signed by the wrong key.
524 p = peer_1.mk_handshake(self.pg1,
525 X25519PrivateKey.generate().public_key())
526 self.send_and_assert_no_replies(self.pg1, [p])
527 self.assertEqual(1, self.statistics.get_err_counter(
528 wg_input_node_name + "Peer error"))
530 # send a valid handsake init for which we expect a response
531 p = peer_1.mk_handshake(self.pg1)
533 rx = self.send_and_expect(self.pg1, [p], self.pg1)
535 peer_1.consume_response(rx[0])
537 # route a packet into the wg interface
538 # this is dropped because the peer is still not initiated
539 p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
540 IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
541 UDP(sport=555, dport=556) /
543 self.send_and_assert_no_replies(self.pg0, [p])
544 self.assertEqual(2, self.statistics.get_err_counter(kp_error))
546 # send a data packet from the peer through the tunnel
547 # this completes the handshake
548 p = (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
549 UDP(sport=222, dport=223) /
551 d = peer_1.encrypt_transport(p)
552 p = (peer_1.mk_tunnel_header(self.pg1) /
553 (Wireguard(message_type=4, reserved_zero=0) /
554 WireguardTransport(receiver_index=peer_1.sender,
556 encrypted_encapsulated_packet=d)))
557 rxs = self.send_and_expect(self.pg1, [p], self.pg0)
560 self.assertEqual(rx[IP].dst, self.pg0.remote_ip4)
561 self.assertEqual(rx[IP].ttl, 19)
563 # send a packets that are routed into the tunnel
564 p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
565 IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
566 UDP(sport=555, dport=556) /
569 rxs = self.send_and_expect(self.pg0, p * 255, self.pg1)
572 rx = IP(peer_1.decrypt_transport(rx))
574 # chech the oringial packet is present
575 self.assertEqual(rx[IP].dst, p[IP].dst)
576 self.assertEqual(rx[IP].ttl, p[IP].ttl-1)
578 # send packets into the tunnel, expect to receive them on
580 p = [(peer_1.mk_tunnel_header(self.pg1) /
581 Wireguard(message_type=4, reserved_zero=0) /
583 receiver_index=peer_1.sender,
585 encrypted_encapsulated_packet=peer_1.encrypt_transport(
586 (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
587 UDP(sport=222, dport=223) /
588 Raw())))) for ii in range(255)]
590 rxs = self.send_and_expect(self.pg1, p, self.pg0)
593 self.assertEqual(rx[IP].dst, self.pg0.remote_ip4)
594 self.assertEqual(rx[IP].ttl, 19)
596 peer_1.remove_vpp_config()
597 wg0.remove_vpp_config()
599 def test_wg_multi_peer(self):
600 """ multiple peer setup """
604 wg0 = VppWgInterface(self,
606 port).add_vpp_config()
607 wg1 = VppWgInterface(self,
609 port+1).add_vpp_config()
614 self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
616 self.pg_enable_capture(self.pg_interfaces)
619 # Create many peers on sencond interface
621 self.pg2.generate_remote_hosts(NUM_PEERS)
622 self.pg2.configure_ipv4_neighbors()
623 self.pg1.generate_remote_hosts(NUM_PEERS)
624 self.pg1.configure_ipv4_neighbors()
628 for i in range(NUM_PEERS):
629 peers_1.append(VppWgPeer(self,
631 self.pg1.remote_hosts[i].ip4,
633 ["10.0.%d.4/32" % i]).add_vpp_config())
634 peers_2.append(VppWgPeer(self,
636 self.pg2.remote_hosts[i].ip4,
638 ["10.100.%d.4/32" % i]).add_vpp_config())
640 self.assertEqual(len(self.vapi.wireguard_peers_dump()), NUM_PEERS*2)
642 self.logger.info(self.vapi.cli("show wireguard peer"))
643 self.logger.info(self.vapi.cli("show wireguard interface"))
644 self.logger.info(self.vapi.cli("show adj 37"))
645 self.logger.info(self.vapi.cli("sh ip fib 172.16.3.17"))
646 self.logger.info(self.vapi.cli("sh ip fib 10.11.3.0"))
650 self.assertTrue(p.query_vpp_config())
651 p.remove_vpp_config()
653 self.assertTrue(p.query_vpp_config())
654 p.remove_vpp_config()
656 wg0.remove_vpp_config()
657 wg1.remove_vpp_config()