MLRsearch: Support other than just two ratios
[csit.git] / resources / libraries / python / MLRsearch / MeasurementDatabase.py
1 # Copyright (c) 2021 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Module defining MeasurementDatabase class."""
15
16 from .ReceiveRateInterval import ReceiveRateInterval
17 from .PerDurationDatabase import PerDurationDatabase
18
19
20 class MeasurementDatabase:
21     """A structure holding measurement results.
22
23     The implementation uses a dict from duration values
24     to PerDurationDatabase instances.
25
26     Several utility methods are added, accomplishing tasks useful for MLRsearch.
27
28     This class contains the "find tightest bounds" parts of logic required
29     by MLRsearch. One exception is lack of any special handling for maximal
30     or minimal rates.
31     """
32
33     def __init__(self, measurements):
34         """Store measurement results in per-duration databases.
35
36         TODO: Move processing to a factory method,
37         keep constructor only to store (presumably valid) values.
38
39         If the measurements argument contains is a dict,
40         the constructor assumes it contains the processed databases.
41
42         :param measurements: The measurement results to store.
43         :type measurements: Iterable[ReceiveRateMeasurement]
44         """
45         if isinstance(measurements, dict):
46             self.data_for_duration = measurements
47         else:
48             self.data_for_duration = dict()
49             # TODO: There is overlap with add() code. Worth extracting?
50             for measurement in measurements:
51                 duration = measurement.duration
52                 if duration in self.data_for_duration:
53                     self.data_for_duration[duration].add(measurement)
54                 else:
55                     self.data_for_duration[duration] = PerDurationDatabase(
56                         duration, [measurement]
57                     )
58         durations = sorted(self.data_for_duration.keys())
59         self.current_duration = durations[-1] if duration else None
60         self.previous_duration = durations[-2] if len(durations) > 1 else None
61
62     def __repr__(self):
63         """Return string executable to get equivalent instance.
64
65         :returns: Code to construct equivalent instance.
66         :rtype: str
67         """
68         return f"MeasurementDatabase(measurements={self.data_for_duration!r})"
69
70     def set_current_duration(self, duration):
71         """Remember what MLRsearch considers the current duration.
72
73         Setting the same duration is allowed, setting smaller is not allowed.
74
75         :param duration: Target trial duration of current phase, in seconds.
76         :type duration: float
77         :raises ValueError: If the duration is smaller than previous.
78         """
79         if duration < self.current_duration:
80             raise ValueError(
81                 f"Duration {duration} shorter than current duration"
82                 f" {self.current_duration}"
83             )
84         if duration > self.current_duration:
85             self.previous_duration = self.current_duration
86             self.current_duration = duration
87             self.data_for_duration[duration] = PerDurationDatabase(
88                 duration, list()
89             )
90         # Else no-op.
91
92     def add(self, measurement):
93         """Add a measurement. Duration has to match the set one.
94
95         :param measurement: Measurement result to add to the database.
96         :type measurement: ReceiveRateMeasurement
97         """
98         duration = measurement.duration
99         if duration != self.current_duration:
100             raise ValueError(
101                 f"{measurement!r} duration different than"
102                 f" {self.current_duration}"
103             )
104         self.data_for_duration[duration].add(measurement)
105
106     def get_bounds(self, ratio):
107         """Return 6 bounds: lower/upper, current/previous, tightest/second.
108
109         Second tightest bounds are only returned for current duration.
110         None instead of a measurement if there is no measurement of that type.
111
112         The result cotains bounds in this order:
113         1. Tightest lower bound for current duration.
114         2. Tightest upper bound for current duration.
115         3. Tightest lower bound for previous duration.
116         4. Tightest upper bound for previous duration.
117         5. Second tightest lower bound for current duration.
118         6. Second tightest upper bound for current duration.
119
120         :param ratio: Target ratio, valid has to be lower or equal.
121         :type ratio: float
122         :returns: Measurements acting as various bounds.
123         :rtype: 6-tuple of Optional[PerDurationDatabase]
124         """
125         cur_lo1, cur_hi1, pre_lo, pre_hi, cur_lo2, cur_hi2 = [None] * 6
126         duration = self.current_duration
127         if duration is not None:
128             data = self.data_for_duration[duration]
129             cur_lo1, cur_hi1, cur_lo2, cur_hi2 = data.get_valid_bounds(ratio)
130         duration = self.previous_duration
131         if duration is not None:
132             data = self.data_for_duration[duration]
133             pre_lo, pre_hi, _, _ = data.get_valid_bounds(ratio)
134         return cur_lo1, cur_hi1, pre_lo, pre_hi, cur_lo2, cur_hi2
135
136     def get_results(self, ratio_list):
137         """Return list of intervals for given ratios, from current duration.
138
139         Attempt to construct valid intervals. If a valid bound is missing,
140         use smallest/biggest target_tr for lower/upper bound.
141         This can result in degenerate intervals.
142
143         :param ratio_list: Ratios to create intervals for.
144         :type ratio_list: Iterable[float]
145         :returns: List of intervals.
146         :rtype: List[ReceiveRateInterval]
147         """
148         ret_list = list()
149         current_data = self.data_for_duration[self.current_duration]
150         for ratio in ratio_list:
151             lower_bound, upper_bound, _, _, _, _ = self.get_bounds(ratio)
152             if lower_bound is None:
153                 lower_bound = current_data.measurements[0]
154             if upper_bound is None:
155                 upper_bound = current_data.measurements[-1]
156             ret_list.append(ReceiveRateInterval(lower_bound, upper_bound))
157         return ret_list