feat(MLRsearch): MLRsearch v7
[csit.git] / resources / libraries / python / MLRsearch / discrete_width.py
diff --git a/resources/libraries/python/MLRsearch/discrete_width.py b/resources/libraries/python/MLRsearch/discrete_width.py
new file mode 100644 (file)
index 0000000..8a4845a
--- /dev/null
@@ -0,0 +1,197 @@
+# Copyright (c) 2023 Cisco and/or its affiliates.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at:
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module defining DiscreteWidth class."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+
+from .load_rounding import LoadRounding
+
+
+# TODO: Make properly frozen.
+@dataclass(order=True)
+class DiscreteWidth:
+    """Structure to store float width together with its rounded integer form.
+
+    The width does not have to be positive, i.e. the computed integer width
+    does not have to be larger than zero.
+
+    LoadRounding instance is needed to enable conversion between two forms.
+
+    Conversion and arithmetic methods are added for convenience.
+    Division and non-integer multiplication are intentionally not supported,
+    as MLRsearch should not seek unround widths when round ones are available.
+
+    The instance is effectively immutable, but not hashable as it refers
+    to the rounding instance, which is implemented as mutable
+    (although the mutations are not visible).
+    """
+
+    # For most debugs, rounding in repr just takes space.
+    rounding: LoadRounding = field(repr=False, compare=False)
+    """Rounding instance to use for conversion."""
+    float_width: float = None
+    """Relative width of float intended load.
+    This is treated as a constructor argument, and does not need to match
+    the int width. Int width is computed to be no wider than this."""
+    int_width: int = field(compare=False, default=None)
+    """Integer form, difference of integer loads.
+    This is the primary quantity used by most computations."""
+
+    def __post_init__(self) -> None:
+        """Ensure types, compute missing information.
+
+        At this point, it is allowed for float width to be slightly larger
+        than the implied int width.
+
+        If both forms are specified, the float form is taken as primary
+        (thus the integer form is recomputed to match).
+
+        :raises RuntimeError: If both init arguments are None.
+        """
+        if self.float_width is None and self.int_width is None:
+            raise RuntimeError("Float or int value is needed.")
+        if self.float_width is None:
+            self.int_width = int(self.int_width)
+            min_load = self.rounding.int2float(0)
+            increased_load = self.rounding.int2float(self.int_width)
+            self.float_width = (increased_load - min_load) / increased_load
+            return
+        self.float_width = float(self.float_width)
+        min_load = self.rounding.int2float(0)
+        increased_load = min_load / (1.0 - self.float_width)
+        int_load = self.rounding.float2int(increased_load)
+        verify_load = self.rounding.int2float(int_load)
+        if verify_load > increased_load:
+            int_load -= 1
+        self.int_width = int_load
+
+    def __str__(self) -> str:
+        """Convert into a short human-readable string.
+
+        :returns: The short string.
+        :rtype: str
+        """
+        return f"int_width={int(self)}"
+
+    def __int__(self) -> int:
+        """Return the integer form.
+
+        :returns: The int field value.
+        :rtype: int
+        """
+        return self.int_width
+
+    def __float__(self) -> float:
+        """Return the float form.
+
+        :returns: The float field value.
+        :rtype: float
+        """
+        return self.float_width
+
+    def __hash__(self) -> int:
+        """Return a hash based on the float value.
+
+        With this, the instance can be used as if it was immutable and hashable,
+        e.g. it can be a key in a dict.
+
+        :returns: Hash value for this instance.
+        :rtype: int
+        """
+        return hash(float(self))
+
+    def rounded_down(self) -> DiscreteWidth:
+        """Create and return new instance with float form matching int.
+
+        :returns: New instance with same int form and float form rounded down.
+        :rtype: DiscreteWidth
+        """
+        return DiscreteWidth(rounding=self.rounding, int_width=int(self))
+
+    def __add__(self, width: DiscreteWidth) -> DiscreteWidth:
+        """Return newly constructed instance with int widths added.
+
+        Rounding instance (reference) is copied from self.
+
+        Argument type is checked, to avoid caller adding something unsupported.
+
+        :param width: Value to add to int width.
+        :type width: DiscreteWidth
+        :returns: New instance.
+        :rtype: DiscreteWidth
+        :raises RuntimeError: When argument has unexpected type.
+        """
+        if not isinstance(width, DiscreteWidth):
+            raise RuntimeError(f"Not width: {width!r}")
+        return DiscreteWidth(
+            rounding=self.rounding,
+            int_width=self.int_width + int(width),
+        )
+
+    def __sub__(self, width: DiscreteWidth) -> DiscreteWidth:
+        """Return newly constructed instance with int widths subtracted.
+
+        Rounding instance (reference) is copied from self.
+
+        Argument type is checked, to avoid caller adding something unsupported.
+        Non-positive results are disallowed by constructor.
+
+        :param width: Value to subtract to int width.
+        :type width: DiscreteWidth
+        :returns: New instance.
+        :rtype: DiscreteWidth
+        :raises RuntimeError: When argument has unexpected type.
+        """
+        if not isinstance(width, DiscreteWidth):
+            raise RuntimeError(f"Not width: {type(width)}")
+        return DiscreteWidth(
+            rounding=self.rounding,
+            int_width=self.int_width - int(width),
+        )
+
+    def __mul__(self, coefficient: int) -> DiscreteWidth:
+        """Construct new instance with int value multiplied.
+
+        Rounding instance (reference) is copied from self.
+
+        :param coefficient: Constant to multiply int width with.
+        :type coefficient: int
+        :returns: New instance with multiplied int width.
+        :rtype: DiscreteWidth
+        :raises RuntimeError: If argument value does not meet requirements.
+        """
+        if not isinstance(coefficient, int):
+            raise RuntimeError(f"Coefficient not int: {coefficient!r}")
+        if coefficient < 1:
+            raise RuntimeError(f"Coefficient not positive: {coefficient!r}")
+        return DiscreteWidth(
+            rounding=self.rounding,
+            int_width=self.int_width * coefficient,
+        )
+
+    def half_rounded_down(self) -> DiscreteWidth:
+        """Contruct new instance of half the integer width.
+
+        If the current integer width is odd, round the half width down.
+
+        :returns: New instance with half int width.
+        :rtype: DiscreteWidth
+        :raises RuntimeError: If the resulting integerl width is not positive.
+        """
+        return DiscreteWidth(
+            rounding=self.rounding,
+            int_width=self.int_width // 2,
+        )