MAP: Convert from DPO to input feature.
[vpp.git] / test / test_ipsec_nat.py
1 #!/usr/bin/env python
2
3 import socket
4
5 from scapy.layers.l2 import Ether
6 from scapy.layers.inet import ICMP, IP, TCP, UDP
7 from scapy.layers.ipsec import SecurityAssociation, ESP
8 from util import ppp, ppc
9 from template_ipsec import TemplateIpsec
10
11
12 class IPSecNATTestCase(TemplateIpsec):
13     """ IPSec/NAT
14     TUNNEL MODE:
15
16
17      public network  |   private network
18      ---   encrypt  ---   plain   ---
19     |pg0| <------- |VPP| <------ |pg1|
20      ---            ---           ---
21
22      ---   decrypt  ---   plain   ---
23     |pg0| -------> |VPP| ------> |pg1|
24      ---            ---           ---
25     """
26
27     tcp_port_in = 6303
28     tcp_port_out = 6303
29     udp_port_in = 6304
30     udp_port_out = 6304
31     icmp_id_in = 6305
32     icmp_id_out = 6305
33
34     @classmethod
35     def setUpClass(cls):
36         super(IPSecNATTestCase, cls).setUpClass()
37         cls.tun_if = cls.pg0
38         cls.vapi.ipsec_spd_add_del(cls.tun_spd_id)
39         cls.vapi.ipsec_interface_add_del_spd(cls.tun_spd_id,
40                                              cls.tun_if.sw_if_index)
41         p = cls.ipv4_params
42         cls.config_esp_tun(p)
43         cls.logger.info(cls.vapi.ppcli("show ipsec"))
44         src = socket.inet_pton(p.addr_type, p.remote_tun_if_host)
45         cls.vapi.ip_add_del_route(src, p.addr_len,
46                                   cls.tun_if.remote_addr_n[p.addr_type],
47                                   is_ipv6=p.is_ipv6)
48
49     def create_stream_plain(self, src_mac, dst_mac, src_ip, dst_ip):
50         return [
51             # TCP
52             Ether(src=src_mac, dst=dst_mac) /
53             IP(src=src_ip, dst=dst_ip) /
54             TCP(sport=self.tcp_port_in, dport=20),
55             # UDP
56             Ether(src=src_mac, dst=dst_mac) /
57             IP(src=src_ip, dst=dst_ip) /
58             UDP(sport=self.udp_port_in, dport=20),
59             # ICMP
60             Ether(src=src_mac, dst=dst_mac) /
61             IP(src=src_ip, dst=dst_ip) /
62             ICMP(id=self.icmp_id_in, type='echo-request')
63         ]
64
65     def create_stream_encrypted(self, src_mac, dst_mac, src_ip, dst_ip, sa):
66         return [
67             # TCP
68             Ether(src=src_mac, dst=dst_mac) /
69             sa.encrypt(IP(src=src_ip, dst=dst_ip) /
70                        TCP(dport=self.tcp_port_out, sport=20)),
71             # UDP
72             Ether(src=src_mac, dst=dst_mac) /
73             sa.encrypt(IP(src=src_ip, dst=dst_ip) /
74                        UDP(dport=self.udp_port_out, sport=20)),
75             # ICMP
76             Ether(src=src_mac, dst=dst_mac) /
77             sa.encrypt(IP(src=src_ip, dst=dst_ip) /
78                        ICMP(id=self.icmp_id_out, type='echo-request'))
79         ]
80
81     def verify_capture_plain(self, capture):
82         for packet in capture:
83             try:
84                 self.assert_packet_checksums_valid(packet)
85                 self.assert_equal(packet[IP].src, self.tun_if.remote_ip4,
86                                   "decrypted packet source address")
87                 self.assert_equal(packet[IP].dst, self.pg1.remote_ip4,
88                                   "decrypted packet destination address")
89                 if packet.haslayer(TCP):
90                     self.assertFalse(
91                         packet.haslayer(UDP),
92                         "unexpected UDP header in decrypted packet")
93                     self.assert_equal(packet[TCP].dport, self.tcp_port_in,
94                                       "decrypted packet TCP destination port")
95                 elif packet.haslayer(UDP):
96                     if packet[UDP].payload:
97                         self.assertFalse(
98                             packet[UDP][1].haslayer(UDP),
99                             "unexpected UDP header in decrypted packet")
100                     self.assert_equal(packet[UDP].dport, self.udp_port_in,
101                                       "decrypted packet UDP destination port")
102                 else:
103                     self.assertFalse(
104                         packet.haslayer(UDP),
105                         "unexpected UDP header in decrypted packet")
106                     self.assert_equal(packet[ICMP].id, self.icmp_id_in,
107                                       "decrypted packet ICMP ID")
108             except Exception:
109                 self.logger.error(
110                     ppp("Unexpected or invalid plain packet:", packet))
111                 raise
112
113     def verify_capture_encrypted(self, capture, sa):
114         for packet in capture:
115             try:
116                 copy = packet.__class__(str(packet))
117                 del copy[UDP].len
118                 copy = packet.__class__(str(copy))
119                 self.assert_equal(packet[UDP].len, copy[UDP].len,
120                                   "UDP header length")
121                 self.assert_packet_checksums_valid(packet)
122                 self.assertIn(ESP, packet[IP])
123                 decrypt_pkt = sa.decrypt(packet[IP])
124                 self.assert_packet_checksums_valid(decrypt_pkt)
125                 self.assert_equal(decrypt_pkt[IP].src, self.pg1.remote_ip4,
126                                   "encrypted packet source address")
127                 self.assert_equal(decrypt_pkt[IP].dst, self.tun_if.remote_ip4,
128                                   "encrypted packet destination address")
129             except Exception:
130                 self.logger.error(
131                     ppp("Unexpected or invalid encrypted packet:", packet))
132                 raise
133
134     @classmethod
135     def config_esp_tun(cls, params):
136         addr_type = params.addr_type
137         scapy_tun_sa_id = params.scapy_tun_sa_id
138         scapy_tun_spi = params.scapy_tun_spi
139         vpp_tun_sa_id = params.vpp_tun_sa_id
140         vpp_tun_spi = params.vpp_tun_spi
141         auth_algo_vpp_id = params.auth_algo_vpp_id
142         auth_key = params.auth_key
143         crypt_algo_vpp_id = params.crypt_algo_vpp_id
144         crypt_key = params.crypt_key
145         addr_any = params.addr_any
146         addr_bcast = params.addr_bcast
147         cls.vapi.ipsec_sad_add_del_entry(scapy_tun_sa_id, scapy_tun_spi,
148                                          auth_algo_vpp_id, auth_key,
149                                          crypt_algo_vpp_id, crypt_key,
150                                          cls.vpp_esp_protocol,
151                                          cls.pg1.remote_addr_n[addr_type],
152                                          cls.tun_if.remote_addr_n[addr_type],
153                                          udp_encap=1)
154         cls.vapi.ipsec_sad_add_del_entry(vpp_tun_sa_id, vpp_tun_spi,
155                                          auth_algo_vpp_id, auth_key,
156                                          crypt_algo_vpp_id, crypt_key,
157                                          cls.vpp_esp_protocol,
158                                          cls.tun_if.remote_addr_n[addr_type],
159                                          cls.pg1.remote_addr_n[addr_type],
160                                          udp_encap=1)
161         l_startaddr = r_startaddr = socket.inet_pton(addr_type, addr_any)
162         l_stopaddr = r_stopaddr = socket.inet_pton(addr_type, addr_bcast)
163         cls.vapi.ipsec_spd_add_del_entry(cls.tun_spd_id, scapy_tun_sa_id,
164                                          l_startaddr, l_stopaddr, r_startaddr,
165                                          r_stopaddr,
166                                          protocol=socket.IPPROTO_ESP)
167         cls.vapi.ipsec_spd_add_del_entry(cls.tun_spd_id, scapy_tun_sa_id,
168                                          l_startaddr, l_stopaddr, r_startaddr,
169                                          r_stopaddr, is_outbound=0,
170                                          protocol=socket.IPPROTO_ESP)
171         cls.vapi.ipsec_spd_add_del_entry(cls.tun_spd_id, scapy_tun_sa_id,
172                                          l_startaddr, l_stopaddr, r_startaddr,
173                                          r_stopaddr, remote_port_start=4500,
174                                          remote_port_stop=4500,
175                                          protocol=socket.IPPROTO_UDP)
176         cls.vapi.ipsec_spd_add_del_entry(cls.tun_spd_id, scapy_tun_sa_id,
177                                          l_startaddr, l_stopaddr, r_startaddr,
178                                          r_stopaddr, remote_port_start=4500,
179                                          remote_port_stop=4500,
180                                          protocol=socket.IPPROTO_UDP,
181                                          is_outbound=0)
182         l_startaddr = l_stopaddr = cls.tun_if.remote_addr_n[addr_type]
183         r_startaddr = r_stopaddr = cls.pg1.remote_addr_n[addr_type]
184         cls.vapi.ipsec_spd_add_del_entry(cls.tun_spd_id, vpp_tun_sa_id,
185                                          l_startaddr, l_stopaddr, r_startaddr,
186                                          r_stopaddr, priority=10, policy=3,
187                                          is_outbound=0)
188         cls.vapi.ipsec_spd_add_del_entry(cls.tun_spd_id, scapy_tun_sa_id,
189                                          r_startaddr, r_stopaddr, l_startaddr,
190                                          l_stopaddr, priority=10, policy=3)
191
192     def test_ipsec_nat_tun(self):
193         """ IPSec/NAT tunnel test case """
194         p = self.ipv4_params
195         scapy_tun_sa = SecurityAssociation(ESP, spi=p.scapy_tun_spi,
196                                            crypt_algo=p.crypt_algo,
197                                            crypt_key=p.crypt_key,
198                                            auth_algo=p.auth_algo,
199                                            auth_key=p.auth_key,
200                                            tunnel_header=IP(
201                                                src=self.pg1.remote_ip4,
202                                                dst=self.tun_if.remote_ip4),
203                                            nat_t_header=UDP(
204                                                sport=4500,
205                                                dport=4500))
206         # in2out - from private network to public
207         pkts = self.create_stream_plain(
208             self.pg1.remote_mac, self.pg1.local_mac,
209             self.pg1.remote_ip4, self.tun_if.remote_ip4)
210         self.pg1.add_stream(pkts)
211         self.pg_enable_capture(self.pg_interfaces)
212         self.pg_start()
213         capture = self.tun_if.get_capture(len(pkts))
214         self.verify_capture_encrypted(capture, scapy_tun_sa)
215
216         vpp_tun_sa = SecurityAssociation(ESP,
217                                          spi=p.vpp_tun_spi,
218                                          crypt_algo=p.crypt_algo,
219                                          crypt_key=p.crypt_key,
220                                          auth_algo=p.auth_algo,
221                                          auth_key=p.auth_key,
222                                          tunnel_header=IP(
223                                              src=self.tun_if.remote_ip4,
224                                              dst=self.pg1.remote_ip4),
225                                          nat_t_header=UDP(
226                                              sport=4500,
227                                              dport=4500))
228
229         # out2in - from public network to private
230         pkts = self.create_stream_encrypted(
231             self.tun_if.remote_mac, self.tun_if.local_mac,
232             self.tun_if.remote_ip4, self.pg1.remote_ip4, vpp_tun_sa)
233         self.logger.info(ppc("Sending packets:", pkts))
234         self.tun_if.add_stream(pkts)
235         self.pg_enable_capture(self.pg_interfaces)
236         self.pg_start()
237         capture = self.pg1.get_capture(len(pkts))
238         self.verify_capture_plain(capture)