Improve NetworkIncrement
[csit.git] / resources / libraries / python / IPUtil.py
index fdd7c66..8a8027f 100644 (file)
@@ -1,4 +1,5 @@
 # Copyright (c) 2021 Cisco and/or its affiliates.
+# Copyright (c) 2021 PANTHEON.tech s.r.o.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at:
@@ -16,9 +17,10 @@ import re
 
 from enum import IntEnum
 
-from ipaddress import ip_address
+from ipaddress import ip_address, ip_network
 
 from resources.libraries.python.Constants import Constants
+from resources.libraries.python.IncrementUtil import ObjIncrement
 from resources.libraries.python.InterfaceUtil import InterfaceUtil
 from resources.libraries.python.IPAddress import IPAddress
 from resources.libraries.python.PapiExecutor import PapiSocketExecutor
@@ -89,6 +91,64 @@ class IpDscp(IntEnum):
     IP_API_DSCP_CS7 = 50
 
 
+class NetworkIncrement(ObjIncrement):
+    """
+    An iterator object which accepts an IPv4Network or IPv6Network and
+    returns a new network, its address part incremented by the increment
+    number of network sizes, each time it is iterated or when inc_fmt is called.
+    The increment may be positive, negative or 0
+    (in which case the network is always the same).
+
+    Both initial and subsequent IP address can have host bits set,
+    check the initial value before creating instance if needed.
+    String formatting is configurable via constructor argument.
+    """
+    def __init__(self, initial_value, increment=1, format=u"dash"):
+        """
+        :param initial_value: The initial network. Can have host bits set.
+        :param increment: The current network will be incremented by this
+            amount of network sizes in each iteration/var_str call.
+        :param format: Type of formatting to use, currently only "dash".
+        :type initial_value: Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
+        :type increment: int
+        :type format: str
+        """
+        super().__init__(initial_value, increment)
+        self._prefix_len = self._value.prefixlen
+        host_len = self._value.max_prefixlen - self._prefix_len
+        self._net_increment = self._increment * (1 << host_len)
+        self._format = str(format).lower()
+
+    def _incr(self):
+        """
+        Increment the network, e.g.:
+        '30.0.0.0/24' incremented by 1 (the next network) is '30.0.1.0/24'.
+        '30.0.0.0/24' incremented by 2 is '30.0.2.0/24'.
+        """
+        self._value = ip_network(
+            f"{self._value.network_address + self._net_increment}"
+            f"/{self._prefix_len}", strict=False
+        )
+
+    def _str_fmt(self):
+        """
+        The string representation of the network depend on format.
+        Dash format is '<ip_address_start> - <ip_address_stop>',
+        useful for 'ipsec policy add spd' cli.
+        Slash format is '<ip_address_start>/<prefix_length>'.
+
+        :returns: Current value converted to string according to format.
+        :rtype: str
+        :raises RuntimeError: If the format is not supported.
+        """
+        if self._format == u"dash":
+            return f"{self._value.network_address} - " \
+                   f"{self._value.broadcast_address}"
+        # More formats will be added in subsequent changes.
+        else:
+            raise RuntimeError(f"Unsupported format {self._format}")
+
+
 class IPUtil:
     """Common IP utilities"""