IETF: Update MLRsearch draft
[csit.git] / resources / libraries / python / MLRsearch / MultipleLossRatioSearch.py
1 # Copyright (c) 2021 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Module defining MultipleLossRatioSearch class."""
15
16 import logging
17 import math
18 import time
19
20 from .MeasurementDatabase import MeasurementDatabase
21 from .ProgressState import ProgressState
22 from .ReceiveRateInterval import ReceiveRateInterval
23 from .WidthArithmetics import (
24     multiply_relative_width,
25     step_down,
26     step_up,
27     multiple_step_down,
28     multiple_step_up,
29     half_step_up,
30 )
31
32
33 class MultipleLossRatioSearch:
34     """Optimized binary search algorithm for finding bounds for multiple ratios.
35
36     This is unofficially a subclass of AbstractSearchAlgorithm,
37     but constructor signature is different.
38
39     Traditional binary search algorithm needs initial interval
40     (lower and upper bound), and returns final interval after bisecting
41     (until some exit condition is met).
42     The exit condition is usually related to the interval width,
43     (upper bound value minus lower bound value).
44
45     The optimized algorithm contains several improvements
46     aimed to reduce overall search time.
47
48     One improvement is searching for multiple intervals at once.
49     The intervals differ by the target loss ratio. Lower bound
50     has to have equal or smaller loss ratio, upper bound has to have larger.
51
52     Next improvement is that the initial interval does not need to be valid.
53     Imagine initial interval (10, 11) where loss at 11 is smaller
54     than the searched ratio.
55     The algorithm will try (11, 13) interval next, and if 13 is still smaller,
56     (13, 17) and so on, doubling width until the upper bound is valid.
57     The part when interval expands is called external search,
58     the part when interval is bisected is called internal search.
59
60     Next improvement is that trial measurements at small trial duration
61     can be used to find a reasonable interval for full trial duration search.
62     This results in more trials performed, but smaller overall duration
63     in general.
64
65     Next improvement is bisecting in logarithmic quantities,
66     so that exit criteria can be independent of measurement units.
67
68     Next improvement is basing the initial interval on receive rates.
69
70     Final improvement is exiting early if the minimal value
71     is not a valid lower bound.
72
73     The complete search consist of several phases,
74     each phase performing several trial measurements.
75     Initial phase creates initial interval based on receive rates
76     at maximum rate and at maximum receive rate (MRR).
77     Final phase and preceding intermediate phases are performing
78     external and internal search steps,
79     each resulting interval is the starting point for the next phase.
80     The resulting intervals of final phase is the result of the whole algorithm.
81
82     Each non-initial phase uses its own trial duration.
83     Any non-initial phase stops searching (for all ratios independently)
84     when minimum is not a valid lower bound (at current duration),
85     or all of the following is true:
86     Both bounds are valid, bounds are measured at the current phase
87     trial duration, interval width is less than the width goal
88     for current phase.
89
90     TODO: Review and update this docstring according to rst docs.
91     """
92
93     def __init__(
94             self, measurer, final_relative_width=0.005,
95             final_trial_duration=30.0, initial_trial_duration=1.0,
96             number_of_intermediate_phases=2, timeout=600.0, debug=None,
97             expansion_coefficient=2.0):
98         """Store the measurer object and additional arguments.
99
100         :param measurer: Rate provider to use by this search object.
101         :param final_relative_width: Final lower bound transmit rate
102             cannot be more distant that this multiple of upper bound [1].
103         :param final_trial_duration: Trial duration for the final phase [s].
104         :param initial_trial_duration: Trial duration for the initial phase
105             and also for the first intermediate phase [s].
106         :param number_of_intermediate_phases: Number of intermediate phases
107             to perform before the final phase [1].
108         :param timeout: The search will fail itself when not finished
109             before this overall time [s].
110         :param debug: Callable to use instead of logging.debug().
111         :param expansion_coefficient: External search multiplies width by this.
112         :type measurer: AbstractMeasurer.AbstractMeasurer
113         :type final_relative_width: float
114         :type final_trial_duration: float
115         :type initial_trial_duration: float
116         :type number_of_intermediate_phases: int
117         :type timeout: float
118         :type debug: Optional[Callable[[str], None]]
119         :type expansion_coefficient: float
120         """
121         self.measurer = measurer
122         self.final_trial_duration = float(final_trial_duration)
123         self.final_relative_width = float(final_relative_width)
124         self.number_of_intermediate_phases = int(number_of_intermediate_phases)
125         self.initial_trial_duration = float(initial_trial_duration)
126         self.timeout = float(timeout)
127         self.state = None
128         self.debug = logging.debug if debug is None else debug
129         self.expansion_coefficient = float(expansion_coefficient)
130
131     def narrow_down_intervals(self, min_rate, max_rate, packet_loss_ratios):
132         """Perform initial phase, create state object, proceed with next phases.
133
134         The current implementation requires the ratios so be unique and sorted.
135         Also non-empty.
136
137         :param min_rate: Minimal target transmit rate [tps].
138         :param max_rate: Maximal target transmit rate [tps].
139         :param packet_loss_ratios: Target ratios of packets loss to locate.
140         :type min_rate: float
141         :type max_rate: float
142         :type packet_loss_ratios: Iterable[float]
143         :returns: Structure containing narrowed down intervals
144             and their measurements.
145         :rtype: List[ReceiveRateInterval]
146         :raises RuntimeError: If total duration is larger than timeout.
147             Or if ratios list is (empty or) not sorted or unique.
148         """
149         min_rate = float(min_rate)
150         max_rate = float(max_rate)
151         packet_loss_ratios = [float(ratio) for ratio in packet_loss_ratios]
152         if len(packet_loss_ratios) < 1:
153             raise RuntimeError(u"At least one ratio is required!")
154         if packet_loss_ratios != sorted(set(packet_loss_ratios)):
155             raise RuntimeError(u"Input ratios have to be sorted and unique!")
156         measurements = list()
157         self.debug(f"First measurement at max rate: {max_rate}")
158         measured = self.measurer.measure(
159             duration=self.initial_trial_duration,
160             transmit_rate=max_rate,
161         )
162         measurements.append(measured)
163         initial_width_goal = self.final_relative_width
164         for _ in range(self.number_of_intermediate_phases):
165             initial_width_goal = multiply_relative_width(
166                 initial_width_goal, 2.0
167             )
168         max_lo = step_down(max_rate, initial_width_goal)
169         mrr = max(min_rate, min(max_lo, measured.relative_receive_rate))
170         self.debug(f"Second measurement at mrr: {mrr}")
171         measured = self.measurer.measure(
172             duration=self.initial_trial_duration,
173             transmit_rate=mrr,
174         )
175         measurements.append(measured)
176         # Attempt to get narrower width.
177         if measured.loss_ratio > packet_loss_ratios[0]:
178             max_lo = step_down(mrr, initial_width_goal)
179             mrr2 = min(max_lo, measured.relative_receive_rate)
180         else:
181             mrr2 = step_up(mrr, initial_width_goal)
182         if min_rate < mrr2 < max_rate:
183             self.debug(f"Third measurement at mrr2: {mrr2}")
184             measured = self.measurer.measure(
185                 duration=self.initial_trial_duration,
186                 transmit_rate=mrr2,
187             )
188             measurements.append(measured)
189             # If mrr2 > mrr and mrr2 got zero loss,
190             # it is better to do external search from mrr2 up.
191             # To prevent bisection between mrr2 and max_rate,
192             # we simply remove the max_rate measurement.
193             # Similar logic applies to higher loss ratio goals.
194             # Overall, with mrr2 measurement done, we never need
195             # the first measurement done at max rate.
196             measurements = measurements[1:]
197         database = MeasurementDatabase(measurements)
198         stop_time = time.monotonic() + self.timeout
199         self.state = ProgressState(
200             database, self.number_of_intermediate_phases,
201             self.final_trial_duration, self.final_relative_width,
202             packet_loss_ratios, min_rate, max_rate, stop_time
203         )
204         self.ndrpdr()
205         return self.state.database.get_results(ratio_list=packet_loss_ratios)
206
207     def ndrpdr(self):
208         """Perform trials for this phase. State is updated in-place.
209
210         Recursion to smaller durations is performed (if not performed yet).
211
212         :raises RuntimeError: If total duration is larger than timeout.
213         """
214         state = self.state
215         if state.phases > 0:
216             # We need to finish preceding intermediate phases first.
217             saved_phases = state.phases
218             state.phases -= 1
219             # Preceding phases have shorter duration.
220             saved_duration = state.duration
221             duration_multiplier = state.duration / self.initial_trial_duration
222             phase_exponent = float(state.phases) / saved_phases
223             state.duration = self.initial_trial_duration * math.pow(
224                 duration_multiplier, phase_exponent
225             )
226             # Shorter durations do not need that narrow widths.
227             saved_width = state.width_goal
228             state.width_goal = multiply_relative_width(saved_width, 2.0)
229             # Recurse.
230             self.ndrpdr()
231             # Restore the state for current phase.
232             state.width_goal = saved_width
233             state.duration = saved_duration
234             state.phases = saved_phases  # Not needed, but just in case.
235         self.debug(
236             f"Starting phase with {state.duration} duration"
237             f" and {state.width_goal} relative width goal."
238         )
239         failing_fast = False
240         database = state.database
241         database.set_current_duration(state.duration)
242         while time.monotonic() < state.stop_time:
243             for index, ratio in enumerate(state.packet_loss_ratios):
244                 new_tr = self._select_for_ratio(ratio)
245                 if new_tr is None:
246                     # Either this ratio is fine, or min rate got invalid result.
247                     # If fine, we will continue to handle next ratio.
248                     if index > 0:
249                         # First ratio passed, all next have a valid lower bound.
250                         continue
251                     lower_bound, _, _, _, _, _ = database.get_bounds(ratio)
252                     if lower_bound is None:
253                         failing_fast = True
254                         self.debug(u"No valid lower bound for this iteration.")
255                         break
256                     # First ratio is fine.
257                     continue
258                 # We have transmit rate to measure at.
259                 # We do not check duration versus stop_time here,
260                 # as some measurers can be unpredictably faster
261                 # than what duration suggests.
262                 measurement = self.measurer.measure(
263                     duration=state.duration,
264                     transmit_rate=new_tr,
265                 )
266                 database.add(measurement)
267                 # Restart ratio handling on updated database.
268                 break
269             else:
270                 # No ratio needs measuring, we are done with this phase.
271                 self.debug(u"Phase done.")
272                 break
273             # We have broken out of the for loop.
274             if failing_fast:
275                 # Abort the while loop early.
276                 break
277             # Not failing fast but database got updated, restart the while loop.
278         else:
279             # Time is up.
280             raise RuntimeError(u"Optimized search takes too long.")
281         # Min rate is not valid, but returning what we have
282         # so next duration can recover.
283
284     @staticmethod
285     def improves(new_bound, lower_bound, upper_bound):
286         """Return whether new bound improves upon old bounds.
287
288         To improve, new_bound has to be not None,
289         and between the old bounds (where the bound is not None).
290
291         This piece of logic is commonly used, when we know old bounds
292         from a primary source (e.g. current duration database)
293         and new bound from a secondary source (e.g. previous duration database).
294         Having a function allows "if improves(..):" construction to save space.
295
296         :param new_bound: The bound we consider applying.
297         :param lower_bound: Known bound, new_bound has to be higher to apply.
298         :param upper_bound: Known bound, new_bound has to be lower to apply.
299         :type new_bound: Optional[ReceiveRateMeasurement]
300         :type lower_bound: Optional[ReceiveRateMeasurement]
301         :type upper_bound: Optional[ReceiveRateMeasurement]
302         :returns: Whether we can apply the new bound.
303         :rtype: bool
304         """
305         if new_bound is None:
306             return False
307         if lower_bound is not None:
308             if new_bound.target_tr <= lower_bound.target_tr:
309                 return False
310         if upper_bound is not None:
311             if new_bound.target_tr >= upper_bound.target_tr:
312                 return False
313         return True
314
315     def _select_for_ratio(self, ratio):
316         """Return None or new target_tr to measure at.
317
318         Returning None means either we have narrow enough valid interval
319         for this ratio, or we are hitting min rate and should fail early.
320
321         :param ratio: Loss ratio to ensure narrow valid bounds for.
322         :type ratio: float
323         :returns: The next target transmit rate to measure at.
324         :rtype: Optional[float]
325         :raises RuntimeError: If database inconsistency is detected.
326         """
327         state = self.state
328         data = state.database
329         bounds = data.get_bounds(ratio)
330         cur_lo1, cur_hi1, pre_lo, pre_hi, cur_lo2, cur_hi2 = bounds
331         pre_lo_improves = self.improves(pre_lo, cur_lo1, cur_hi1)
332         pre_hi_improves = self.improves(pre_hi, cur_lo1, cur_hi1)
333         # TODO: Detect also the other case for initial bisect, see below.
334         if pre_lo_improves and pre_hi_improves:
335             # We allowed larger width for previous phase
336             # as single bisect here guarantees only one re-measurement.
337             new_tr = self._bisect(pre_lo, pre_hi)
338             if new_tr is not None:
339                 self.debug(f"Initial bisect for {ratio}, tr: {new_tr}")
340                 return new_tr
341         if pre_lo_improves:
342             new_tr = pre_lo.target_tr
343             self.debug(f"Re-measuring lower bound for {ratio}, tr: {new_tr}")
344             return new_tr
345         if pre_hi_improves:
346             # This can also happen when we did not do initial bisect
347             # for this ratio yet, but the previous duration lower bound
348             # for this ratio got already re-measured as previous duration
349             # upper bound for previous ratio.
350             new_tr = pre_hi.target_tr
351             self.debug(f"Re-measuring upper bound for {ratio}, tr: {new_tr}")
352             return new_tr
353         if cur_lo1 is None and cur_hi1 is None:
354             raise RuntimeError(u"No results found in databases!")
355         if cur_lo1 is None:
356             # Upper bound exists (cur_hi1).
357             # We already tried previous lower bound.
358             # So, we want to extend down.
359             new_tr = self._extend_down(
360                 cur_hi1, cur_hi2, pre_hi, second_needed=False
361             )
362             self.debug(
363                 f"Extending down for {ratio}:"
364                 f" old {cur_hi1.target_tr} new {new_tr}"
365             )
366             return new_tr
367         if cur_hi1 is None:
368             # Lower bound exists (cur_lo1).
369             # We already tried previous upper bound.
370             # So, we want to extend up.
371             new_tr = self._extend_up(cur_lo1, cur_lo2, pre_lo)
372             self.debug(
373                 f"Extending up for {ratio}:"
374                 f" old {cur_lo1.target_tr} new {new_tr}"
375             )
376             return new_tr
377         # Both bounds exist (cur_lo1 and cur_hi1).
378         # cur_lo1 might have been selected for this ratio (we are bisecting)
379         # or for previous ratio (we are extending down for this ratio).
380         # Compute both estimates and choose the higher value.
381         bisected_tr = self._bisect(cur_lo1, cur_hi1)
382         extended_tr = self._extend_down(
383             cur_hi1, cur_hi2, pre_hi, second_needed=True
384         )
385         # Only if both are not None we need to decide.
386         if bisected_tr and extended_tr and extended_tr > bisected_tr:
387             self.debug(
388                 f"Extending down for {ratio}:"
389                 f" old {cur_hi1.target_tr} new {extended_tr}"
390             )
391             new_tr = extended_tr
392         else:
393             self.debug(
394                 f"Bisecting for {ratio}: lower {cur_lo1.target_tr},"
395                 f" upper {cur_hi1.target_tr}, new {bisected_tr}"
396             )
397             new_tr = bisected_tr
398         return new_tr
399
400     def _extend_down(self, cur_hi1, cur_hi2, pre_hi, second_needed=False):
401         """Return extended width below, or None if hitting min rate.
402
403         If no second tightest (nor previous) upper bound is available,
404         the behavior is governed by second_needed argument.
405         If true, return None. If false, start from width goal.
406         This is useful, as if a bisect is possible,
407         we want to give it a chance.
408
409         :param cur_hi1: Tightest upper bound for current duration. Has to exist.
410         :param cur_hi2: Second tightest current upper bound, may not exist.
411         :param pre_hi: Tightest upper bound, previous duration, may not exist.
412         :param second_needed: Whether second tightest bound is required.
413         :type cur_hi1: ReceiveRateMeasurement
414         :type cur_hi2: Optional[ReceiveRateMeasurement]
415         :type pre_hi: Optional[ReceiveRateMeasurement]
416         :type second_needed: bool
417         :returns: The next target transmit rate to measure at.
418         :rtype: Optional[float]
419         """
420         state = self.state
421         old_tr = cur_hi1.target_tr
422         if state.min_rate >= old_tr:
423             self.debug(u"Extend down hits min rate.")
424             return None
425         next_bound = cur_hi2
426         if self.improves(pre_hi, cur_hi1, cur_hi2):
427             next_bound = pre_hi
428         if next_bound is None and second_needed:
429             return None
430         old_width = state.width_goal
431         if next_bound is not None:
432             old_width = ReceiveRateInterval(cur_hi1, next_bound).rel_tr_width
433             old_width = max(old_width, state.width_goal)
434         new_tr = multiple_step_down(
435             old_tr, old_width, self.expansion_coefficient
436         )
437         new_tr = max(new_tr, state.min_rate)
438         return new_tr
439
440     def _extend_up(self, cur_lo1, cur_lo2, pre_lo):
441         """Return extended width above, or None if hitting max rate.
442
443         :param cur_lo1: Tightest lower bound for current duration. Has to exist.
444         :param cur_lo2: Second tightest current lower bound, may not exist.
445         :param pre_lo: Tightest lower bound, previous duration, may not exist.
446         :type cur_lo1: ReceiveRateMeasurement
447         :type cur_lo2: Optional[ReceiveRateMeasurement]
448         :type pre_lo: Optional[ReceiveRateMeasurement]
449         :returns: The next target transmit rate to measure at.
450         :rtype: Optional[float]
451         """
452         state = self.state
453         old_tr = cur_lo1.target_tr
454         if state.max_rate <= old_tr:
455             self.debug(u"Extend up hits max rate.")
456             return None
457         next_bound = cur_lo2
458         if self.improves(pre_lo, cur_lo2, cur_lo1):
459             next_bound = pre_lo
460         old_width = state.width_goal
461         if next_bound is not None:
462             old_width = ReceiveRateInterval(cur_lo1, next_bound).rel_tr_width
463             old_width = max(old_width, state.width_goal)
464         new_tr = multiple_step_up(old_tr, old_width, self.expansion_coefficient)
465         new_tr = min(new_tr, state.max_rate)
466         return new_tr
467
468     def _bisect(self, lower_bound, upper_bound):
469         """Return middle rate or None if width is narrow enough.
470
471         :param lower_bound: Measurement to use as a lower bound. Has to exist.
472         :param upper_bound: Measurement to use as an upper bound. Has to exist.
473         :type lower_bound: ReceiveRateMeasurement
474         :type upper_bound: ReceiveRateMeasurement
475         :returns: The next target transmit rate to measure at.
476         :rtype: Optional[float]
477         :raises RuntimeError: If database inconsistency is detected.
478         """
479         state = self.state
480         width = ReceiveRateInterval(lower_bound, upper_bound).rel_tr_width
481         if width <= state.width_goal:
482             self.debug(u"No more bisects needed.")
483             return None
484         new_tr = half_step_up(lower_bound.target_tr, width, state.width_goal)
485         return new_tr