1 # Copyright (c) 2021 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
6 # http://www.apache.org/licenses/LICENSE-2.0
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
14 """Module defining MeasurementDatabase class."""
16 from .ReceiveRateInterval import ReceiveRateInterval
17 from .PerDurationDatabase import PerDurationDatabase
20 class MeasurementDatabase:
21 """A structure holding measurement results.
23 The implementation uses a dict from duration values
24 to PerDurationDatabase instances.
26 Several utility methods are added, accomplishing tasks useful for MLRsearch.
28 This class contains the "find tightest bounds" parts of logic required
29 by MLRsearch. One exception is lack of any special handling for maximal
33 def __init__(self, measurements):
34 """Store measurement results in per-duration databases.
36 TODO: Move processing to a factory method,
37 keep constructor only to store (presumably valid) values.
39 If the measurements argument contains is a dict,
40 the constructor assumes it contains the processed databases.
42 :param measurements: The measurement results to store.
43 :type measurements: Iterable[ReceiveRateMeasurement]
45 if isinstance(measurements, dict):
46 self.data_for_duration = measurements
48 self.data_for_duration = dict()
49 # TODO: There is overlap with add() code. Worth extracting?
50 for measurement in measurements:
51 duration = measurement.duration
52 if duration in self.data_for_duration:
53 self.data_for_duration[duration].add(measurement)
55 self.data_for_duration[duration] = PerDurationDatabase(
56 duration, [measurement]
58 durations = sorted(self.data_for_duration.keys())
59 self.current_duration = durations[-1] if duration else None
60 self.previous_duration = durations[-2] if len(durations) > 1 else None
63 """Return string executable to get equivalent instance.
65 :returns: Code to construct equivalent instance.
68 return f"MeasurementDatabase(measurements={self.data_for_duration!r})"
70 def set_current_duration(self, duration):
71 """Remember what MLRsearch considers the current duration.
73 Setting the same duration is allowed, setting smaller is not allowed.
75 :param duration: Target trial duration of current phase, in seconds.
77 :raises ValueError: If the duration is smaller than previous.
79 if duration < self.current_duration:
81 f"Duration {duration} shorter than current duration"
82 f" {self.current_duration}"
84 if duration > self.current_duration:
85 self.previous_duration = self.current_duration
86 self.current_duration = duration
87 self.data_for_duration[duration] = PerDurationDatabase(
92 def add(self, measurement):
93 """Add a measurement. Duration has to match the set one.
95 :param measurement: Measurement result to add to the database.
96 :type measurement: ReceiveRateMeasurement
98 duration = measurement.duration
99 if duration != self.current_duration:
101 f"{measurement!r} duration different than"
102 f" {self.current_duration}"
104 self.data_for_duration[duration].add(measurement)
106 def get_bounds(self, ratio):
107 """Return 6 bounds: lower/upper, current/previous, tightest/second.
109 Second tightest bounds are only returned for current duration.
110 None instead of a measurement if there is no measurement of that type.
112 The result cotains bounds in this order:
113 1. Tightest lower bound for current duration.
114 2. Tightest upper bound for current duration.
115 3. Tightest lower bound for previous duration.
116 4. Tightest upper bound for previous duration.
117 5. Second tightest lower bound for current duration.
118 6. Second tightest upper bound for current duration.
120 :param ratio: Target ratio, valid has to be lower or equal.
122 :returns: Measurements acting as various bounds.
123 :rtype: 6-tuple of Optional[PerDurationDatabase]
125 cur_lo1, cur_hi1, pre_lo, pre_hi, cur_lo2, cur_hi2 = [None] * 6
126 duration = self.current_duration
127 if duration is not None:
128 data = self.data_for_duration[duration]
129 cur_lo1, cur_hi1, cur_lo2, cur_hi2 = data.get_valid_bounds(ratio)
130 duration = self.previous_duration
131 if duration is not None:
132 data = self.data_for_duration[duration]
133 pre_lo, pre_hi, _, _ = data.get_valid_bounds(ratio)
134 return cur_lo1, cur_hi1, pre_lo, pre_hi, cur_lo2, cur_hi2
136 def get_results(self, ratio_list):
137 """Return list of intervals for given ratios, from current duration.
139 Attempt to construct valid intervals. If a valid bound is missing,
140 use smallest/biggest target_tr for lower/upper bound.
141 This can result in degenerate intervals.
143 :param ratio_list: Ratios to create intervals for.
144 :type ratio_list: Iterable[float]
145 :returns: List of intervals.
146 :rtype: List[ReceiveRateInterval]
149 current_data = self.data_for_duration[self.current_duration]
150 for ratio in ratio_list:
151 lower_bound, upper_bound, _, _, _, _ = self.get_bounds(ratio)
152 if lower_bound is None:
153 lower_bound = current_data.measurements[0]
154 if upper_bound is None:
155 upper_bound = current_data.measurements[-1]
156 ret_list.append(ReceiveRateInterval(lower_bound, upper_bound))