+ def set_responder(self):
+ self.noise.set_as_responder()
+
+ def mk_tunnel_header(self, tx_itf):
+ return (Ether(dst=tx_itf.local_mac, src=tx_itf.remote_mac) /
+ IP(src=self.endpoint, dst=self.itf.src) /
+ UDP(sport=self.port, dport=self.itf.port))
+
+ def noise_init(self, public_key=None):
+ self.noise.set_prologue(NOISE_IDENTIFIER_NAME)
+ self.noise.set_psks(psk=bytes(bytearray(32)))
+
+ if not public_key:
+ public_key = self.itf.public_key
+
+ # local/this private
+ self.noise.set_keypair_from_private_bytes(
+ Keypair.STATIC,
+ private_key_bytes(self.private_key))
+ # remote's public
+ self.noise.set_keypair_from_public_bytes(
+ Keypair.REMOTE_STATIC,
+ public_key_bytes(public_key))
+
+ self.noise.start_handshake()
+
+ def mk_handshake(self, tx_itf, public_key=None):
+ self.noise.set_as_initiator()
+ self.noise_init(public_key)
+
+ p = (Wireguard() / WireguardInitiation())
+
+ p[Wireguard].message_type = 1
+ p[Wireguard].reserved_zero = 0
+ p[WireguardInitiation].sender_index = self.receiver_index
+
+ # some random data for the message
+ # lifted from the noise protocol's wireguard example
+ now = datetime.datetime.now()
+ tai = struct.pack('!qi', 4611686018427387914 + int(now.timestamp()),
+ int(now.microsecond * 1e3))
+ b = self.noise.write_message(payload=tai)
+
+ # load noise into init message
+ p[WireguardInitiation].unencrypted_ephemeral = b[0:32]
+ p[WireguardInitiation].encrypted_static = b[32:80]
+ p[WireguardInitiation].encrypted_timestamp = b[80:108]
+
+ # generate the mac1 hash
+ mac_key = blake2s(b'mac1----' +
+ self.itf.public_key_bytes()).digest()
+ p[WireguardInitiation].mac1 = blake2s(bytes(p)[0:116],
+ digest_size=16,
+ key=mac_key).digest()
+ p[WireguardInitiation].mac2 = bytearray(16)
+
+ p = (self.mk_tunnel_header(tx_itf) / p)
+
+ return p
+
+ def verify_header(self, p):
+ self._test.assertEqual(p[IP].src, self.itf.src)
+ self._test.assertEqual(p[IP].dst, self.endpoint)
+ self._test.assertEqual(p[UDP].sport, self.itf.port)
+ self._test.assertEqual(p[UDP].dport, self.port)
+ self._test.assert_packet_checksums_valid(p)
+
+ def consume_init(self, p, tx_itf):
+ self.noise.set_as_responder()
+ self.noise_init(self.itf.public_key)
+ self.verify_header(p)
+
+ init = Wireguard(p[Raw])
+
+ self._test.assertEqual(init[Wireguard].message_type, 1)
+ self._test.assertEqual(init[Wireguard].reserved_zero, 0)
+
+ self.sender = init[WireguardInitiation].sender_index
+
+ # validate the hash
+ mac_key = blake2s(b'mac1----' +
+ public_key_bytes(self.public_key)).digest()
+ mac1 = blake2s(bytes(init)[0:-32],
+ digest_size=16,
+ key=mac_key).digest()
+ self._test.assertEqual(init[WireguardInitiation].mac1, mac1)
+
+ # this passes only unencrypted_ephemeral, encrypted_static,
+ # encrypted_timestamp fields of the init
+ payload = self.noise.read_message(bytes(init)[8:-32])
+
+ # build the response
+ b = self.noise.write_message()
+ mac_key = blake2s(b'mac1----' +
+ public_key_bytes(self.itf.public_key)).digest()
+ resp = (Wireguard(message_type=2, reserved_zero=0) /
+ WireguardResponse(sender_index=self.receiver_index,
+ receiver_index=self.sender,
+ unencrypted_ephemeral=b[0:32],
+ encrypted_nothing=b[32:]))
+ mac1 = blake2s(bytes(resp)[:-32],
+ digest_size=16,
+ key=mac_key).digest()
+ resp[WireguardResponse].mac1 = mac1
+
+ resp = (self.mk_tunnel_header(tx_itf) / resp)
+ self._test.assertTrue(self.noise.handshake_finished)
+
+ return resp
+
+ def consume_response(self, p):
+ self.verify_header(p)
+
+ resp = Wireguard(p[Raw])
+
+ self._test.assertEqual(resp[Wireguard].message_type, 2)
+ self._test.assertEqual(resp[Wireguard].reserved_zero, 0)
+ self._test.assertEqual(resp[WireguardResponse].receiver_index,
+ self.receiver_index)
+
+ self.sender = resp[Wireguard].sender_index
+
+ payload = self.noise.read_message(bytes(resp)[12:60])
+ self._test.assertEqual(payload, b'')
+ self._test.assertTrue(self.noise.handshake_finished)
+
+ def decrypt_transport(self, p):
+ self.verify_header(p)
+
+ p = Wireguard(p[Raw])
+ self._test.assertEqual(p[Wireguard].message_type, 4)
+ self._test.assertEqual(p[Wireguard].reserved_zero, 0)
+ self._test.assertEqual(p[WireguardTransport].receiver_index,
+ self.receiver_index)
+
+ d = self.noise.decrypt(
+ p[WireguardTransport].encrypted_encapsulated_packet)
+ return d
+
+ def encrypt_transport(self, p):
+ return self.noise.encrypt(bytes(p))
+