4 from scapy.packet import Packet
5 from scapy.packet import Raw
6 from scapy.layers.l2 import Ether
7 from scapy.layers.inet import IP, UDP
8 from scapy.contrib.wireguard import Wireguard, WireguardResponse, \
10 from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
11 from cryptography.hazmat.primitives.serialization import Encoding, \
12 PrivateFormat, PublicFormat, NoEncryption
14 from vpp_ipip_tun_interface import VppIpIpTunInterface
15 from vpp_interface import VppInterface
16 from vpp_object import VppObject
17 from framework import VppTestCase
18 from re import compile
21 """ TestWg is a subclass of VPPTestCase classes.
28 class VppWgInterface(VppInterface):
30 VPP WireGuard interface
33 def __init__(self, test, src, port, key=None):
34 super(VppWgInterface, self).__init__(test)
44 def add_vpp_config(self):
45 r = self.test.vapi.wireguard_interface_create(interface={
46 'user_instance': 0xffffffff,
49 'private_key': self.key_bytes()
51 self.set_sw_if_index(r.sw_if_index)
52 self.test.registry.register(self, self.test.logger)
57 return self.key.private_bytes(Encoding.Raw,
63 def remove_vpp_config(self):
64 self.test.vapi.wireguard_interface_delete(
65 sw_if_index=self._sw_if_index)
67 def query_vpp_config(self):
68 ts = self.test.vapi.wireguard_interface_dump(sw_if_index=0xffffffff)
70 if t.interface.sw_if_index == self._sw_if_index and \
71 str(t.interface.src_ip) == self.src and \
72 t.interface.port == self.port and \
73 t.interface.private_key == self.key_bytes():
78 return self.object_id()
81 return "wireguard-%d" % self._sw_if_index
84 def find_route(test, prefix, table_id=0):
85 routes = test.vapi.ip_route_dump(table_id, False)
88 if table_id == e.route.table_id \
89 and str(e.route.prefix) == str(prefix):
94 class VppWgPeer(VppObject):
102 persistent_keepalive=15):
105 self.endpoint = endpoint
107 self.allowed_ips = allowed_ips
108 self.persistent_keepalive = persistent_keepalive
109 self.private_key = X25519PrivateKey.generate()
110 self.public_key = self.private_key.public_key()
111 self.hash = bytearray(16)
113 def validate_routing(self):
114 for a in self.allowed_ips:
115 self._test.assertTrue(find_route(self._test, a))
117 def validate_no_routing(self):
118 for a in self.allowed_ips:
119 self._test.assertFalse(find_route(self._test, a))
121 def add_vpp_config(self):
122 rv = self._test.vapi.wireguard_peer_add(
124 'public_key': self.public_key_bytes(),
126 'endpoint': self.endpoint,
127 'n_allowed_ips': len(self.allowed_ips),
128 'allowed_ips': self.allowed_ips,
129 'sw_if_index': self.itf.sw_if_index,
130 'persistent_keepalive': self.persistent_keepalive})
131 self.index = rv.peer_index
132 self._test.registry.register(self, self._test.logger)
133 self.validate_routing()
136 def remove_vpp_config(self):
137 self._test.vapi.wireguard_peer_remove(peer_index=self.index)
138 self.validate_no_routing()
141 return ("wireguard-peer-%s" % self.index)
143 def public_key_bytes(self):
144 return self.public_key.public_bytes(Encoding.Raw,
147 def private_key_bytes(self):
148 return self.private_key.private_bytes(Encoding.Raw,
152 def query_vpp_config(self):
153 peers = self._test.vapi.wireguard_peers_dump()
156 if p.peer.public_key == self.public_key_bytes() and \
157 p.peer.port == self.port and \
158 str(p.peer.endpoint) == self.endpoint and \
159 p.peer.sw_if_index == self.itf.sw_if_index and \
160 len(self.allowed_ips) == p.peer.n_allowed_ips:
161 self.allowed_ips.sort()
162 p.peer.allowed_ips.sort()
164 for (a1, a2) in zip(self.allowed_ips, p.peer.allowed_ips):
165 if str(a1) != str(a2):
171 class TestWg(VppTestCase):
172 """ Wireguard Test Case """
174 error_str = compile(r"Error")
178 super(TestWg, cls).setUpClass()
180 cls.create_pg_interfaces(range(3))
181 for i in cls.pg_interfaces:
187 super(TestWg, cls).tearDownClass()
191 def tearDownClass(cls):
192 super(TestWg, cls).tearDownClass()
194 def test_wg_interface(self):
198 wg0 = VppWgInterface(self,
200 port).add_vpp_config()
202 self.logger.info(self.vapi.cli("sh int"))
205 wg0.remove_vpp_config()
207 def test_wg_peer(self):
208 wg_output_node_name = '/err/wg-output-tun/'
209 wg_input_node_name = '/err/wg-input/'
214 wg0 = VppWgInterface(self,
217 key=X25519PrivateKey.generate()).add_vpp_config()
218 wg1 = VppWgInterface(self,
220 port+1).add_vpp_config()
225 self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
227 self.pg_enable_capture(self.pg_interfaces)
230 peer_1 = VppWgPeer(self,
235 "10.11.3.0/24"]).add_vpp_config()
236 self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
238 # wait for the peer to send a handshake
239 capture = self.pg1.get_capture(1, timeout=2)
240 handshake = capture[0]
242 self.assertEqual(handshake[IP].src, wg0.src)
243 self.assertEqual(handshake[IP].dst, peer_1.endpoint)
244 self.assertEqual(handshake[UDP].sport, wg0.port)
245 self.assertEqual(handshake[UDP].dport, peer_1.port)
246 handshake = Wireguard(handshake[Raw])
247 self.assertEqual(handshake.message_type, 1) # "initiate")
248 init = handshake[WireguardInitiation]
250 # route a packet into the wg interface
251 # use the allowed-ip prefix
252 p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
253 IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
254 UDP(sport=555, dport=556) /
256 # rx = self.send_and_expect(self.pg0, [p], self.pg1)
257 rx = self.send_and_assert_no_replies(self.pg0, [p])
259 self.logger.info(self.vapi.cli("sh error"))
260 init_sent = wg_output_node_name + "Keypair error"
261 self.assertEqual(1, self.statistics.get_err_counter(init_sent))
263 # Create many peers on sencond interface
265 self.pg2.generate_remote_hosts(NUM_PEERS)
266 self.pg2.configure_ipv4_neighbors()
269 for i in range(NUM_PEERS):
270 peers.append(VppWgPeer(self,
272 self.pg2.remote_hosts[i].ip4,
274 ["10.10.%d.4/32" % i]).add_vpp_config())
275 self.assertEqual(len(self.vapi.wireguard_peers_dump()), i+2)
277 self.logger.info(self.vapi.cli("show wireguard peer"))
278 self.logger.info(self.vapi.cli("show wireguard interface"))
279 self.logger.info(self.vapi.cli("show adj 37"))
280 self.logger.info(self.vapi.cli("sh ip fib 172.16.3.17"))
281 self.logger.info(self.vapi.cli("sh ip fib 10.11.3.0"))
285 self.assertTrue(p.query_vpp_config())
286 p.remove_vpp_config()
287 self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
288 peer_1.remove_vpp_config()
289 self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
291 wg0.remove_vpp_config()
292 # wg1.remove_vpp_config()