IGMP improvements
[vpp.git] / test / vpp_ip_route.py
index 17a42fe..9e5c531 100644 (file)
@@ -60,6 +60,25 @@ def find_route(test, ip_addr, len, table_id=0, inet=AF_INET):
     return False
 
 
+def find_mroute(test, grp_addr, src_addr, grp_addr_len,
+                table_id=0, inet=AF_INET):
+    if inet == AF_INET:
+        s = 4
+        routes = test.vapi.ip_mfib_dump()
+    else:
+        s = 16
+        routes = test.vapi.ip6_mfib_dump()
+    gaddr = inet_pton(inet, grp_addr)
+    saddr = inet_pton(inet, src_addr)
+    for e in routes:
+        if gaddr == e.grp_address[:s] \
+           and grp_addr_len == e.address_length \
+           and saddr == e.src_address[:s] \
+           and table_id == e.table_id:
+            return True
+    return False
+
+
 class VppIpTable(VppObject):
 
     def __init__(self,
@@ -324,6 +343,8 @@ class VppIpMRoute(VppObject):
         self.is_ip6 = is_ip6
         self.rpf_id = rpf_id
 
+        self.grp_addr_p = grp_addr
+        self.src_addr_p = src_addr
         if is_ip6:
             self.grp_addr = inet_pton(AF_INET6, grp_addr)
             self.src_addr = inet_pton(AF_INET6, src_addr)
@@ -406,17 +427,12 @@ class VppIpMRoute(VppObject):
                                           is_ipv6=self.is_ip6)
 
     def query_vpp_config(self):
-        if self.is_ip6:
-            dump = self._test.vapi.ip6_mfib_dump()
-        else:
-            dump = self._test.vapi.ip_mfib_dump()
-        for e in dump:
-            if self.grp_addr == e.grp_address \
-               and self.grp_addr_len == e.address_length \
-               and self.src_addr == e.src_address \
-               and self.table_id == e.table_id:
-                return True
-        return False
+        return find_mroute(self._test,
+                           self.grp_addr_p,
+                           self.src_addr_p,
+                           self.grp_addr_len,
+                           self.table_id,
+                           inet=AF_INET6 if self.is_ip6 == 1 else AF_INET)
 
     def __str__(self):
         return self.object_id()