MLRsearch: Support other than just two ratios
[csit.git] / resources / libraries / python / MLRsearch / MeasurementDatabase.py
diff --git a/resources/libraries/python/MLRsearch/MeasurementDatabase.py b/resources/libraries/python/MLRsearch/MeasurementDatabase.py
new file mode 100644 (file)
index 0000000..2f601d6
--- /dev/null
@@ -0,0 +1,157 @@
+# Copyright (c) 2021 Cisco and/or its affiliates.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at:
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module defining MeasurementDatabase class."""
+
+from .ReceiveRateInterval import ReceiveRateInterval
+from .PerDurationDatabase import PerDurationDatabase
+
+
+class MeasurementDatabase:
+    """A structure holding measurement results.
+
+    The implementation uses a dict from duration values
+    to PerDurationDatabase instances.
+
+    Several utility methods are added, accomplishing tasks useful for MLRsearch.
+
+    This class contains the "find tightest bounds" parts of logic required
+    by MLRsearch. One exception is lack of any special handling for maximal
+    or minimal rates.
+    """
+
+    def __init__(self, measurements):
+        """Store measurement results in per-duration databases.
+
+        TODO: Move processing to a factory method,
+        keep constructor only to store (presumably valid) values.
+
+        If the measurements argument contains is a dict,
+        the constructor assumes it contains the processed databases.
+
+        :param measurements: The measurement results to store.
+        :type measurements: Iterable[ReceiveRateMeasurement]
+        """
+        if isinstance(measurements, dict):
+            self.data_for_duration = measurements
+        else:
+            self.data_for_duration = dict()
+            # TODO: There is overlap with add() code. Worth extracting?
+            for measurement in measurements:
+                duration = measurement.duration
+                if duration in self.data_for_duration:
+                    self.data_for_duration[duration].add(measurement)
+                else:
+                    self.data_for_duration[duration] = PerDurationDatabase(
+                        duration, [measurement]
+                    )
+        durations = sorted(self.data_for_duration.keys())
+        self.current_duration = durations[-1] if duration else None
+        self.previous_duration = durations[-2] if len(durations) > 1 else None
+
+    def __repr__(self):
+        """Return string executable to get equivalent instance.
+
+        :returns: Code to construct equivalent instance.
+        :rtype: str
+        """
+        return f"MeasurementDatabase(measurements={self.data_for_duration!r})"
+
+    def set_current_duration(self, duration):
+        """Remember what MLRsearch considers the current duration.
+
+        Setting the same duration is allowed, setting smaller is not allowed.
+
+        :param duration: Target trial duration of current phase, in seconds.
+        :type duration: float
+        :raises ValueError: If the duration is smaller than previous.
+        """
+        if duration < self.current_duration:
+            raise ValueError(
+                f"Duration {duration} shorter than current duration"
+                f" {self.current_duration}"
+            )
+        if duration > self.current_duration:
+            self.previous_duration = self.current_duration
+            self.current_duration = duration
+            self.data_for_duration[duration] = PerDurationDatabase(
+                duration, list()
+            )
+        # Else no-op.
+
+    def add(self, measurement):
+        """Add a measurement. Duration has to match the set one.
+
+        :param measurement: Measurement result to add to the database.
+        :type measurement: ReceiveRateMeasurement
+        """
+        duration = measurement.duration
+        if duration != self.current_duration:
+            raise ValueError(
+                f"{measurement!r} duration different than"
+                f" {self.current_duration}"
+            )
+        self.data_for_duration[duration].add(measurement)
+
+    def get_bounds(self, ratio):
+        """Return 6 bounds: lower/upper, current/previous, tightest/second.
+
+        Second tightest bounds are only returned for current duration.
+        None instead of a measurement if there is no measurement of that type.
+
+        The result cotains bounds in this order:
+        1. Tightest lower bound for current duration.
+        2. Tightest upper bound for current duration.
+        3. Tightest lower bound for previous duration.
+        4. Tightest upper bound for previous duration.
+        5. Second tightest lower bound for current duration.
+        6. Second tightest upper bound for current duration.
+
+        :param ratio: Target ratio, valid has to be lower or equal.
+        :type ratio: float
+        :returns: Measurements acting as various bounds.
+        :rtype: 6-tuple of Optional[PerDurationDatabase]
+        """
+        cur_lo1, cur_hi1, pre_lo, pre_hi, cur_lo2, cur_hi2 = [None] * 6
+        duration = self.current_duration
+        if duration is not None:
+            data = self.data_for_duration[duration]
+            cur_lo1, cur_hi1, cur_lo2, cur_hi2 = data.get_valid_bounds(ratio)
+        duration = self.previous_duration
+        if duration is not None:
+            data = self.data_for_duration[duration]
+            pre_lo, pre_hi, _, _ = data.get_valid_bounds(ratio)
+        return cur_lo1, cur_hi1, pre_lo, pre_hi, cur_lo2, cur_hi2
+
+    def get_results(self, ratio_list):
+        """Return list of intervals for given ratios, from current duration.
+
+        Attempt to construct valid intervals. If a valid bound is missing,
+        use smallest/biggest target_tr for lower/upper bound.
+        This can result in degenerate intervals.
+
+        :param ratio_list: Ratios to create intervals for.
+        :type ratio_list: Iterable[float]
+        :returns: List of intervals.
+        :rtype: List[ReceiveRateInterval]
+        """
+        ret_list = list()
+        current_data = self.data_for_duration[self.current_duration]
+        for ratio in ratio_list:
+            lower_bound, upper_bound, _, _, _, _ = self.get_bounds(ratio)
+            if lower_bound is None:
+                lower_bound = current_data.measurements[0]
+            if upper_bound is None:
+                upper_bound = current_data.measurements[-1]
+            ret_list.append(ReceiveRateInterval(lower_bound, upper_bound))
+        return ret_list