b48e2e7547d8699f139e3e29e10c967ce8437cf2
[csit.git] / resources / libraries / python / MLRsearch / multiple_loss_ratio_search.py
1 # Copyright (c) 2023 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Module defining MultipleLossRatioSearch class."""
15
16 import logging
17 import time
18
19 from dataclasses import dataclass
20 from typing import Callable, Optional, Tuple
21
22 from .candidate import Candidate
23 from .config import Config
24 from .dataclass import secondary_field
25 from .discrete_load import DiscreteLoad
26 from .discrete_result import DiscreteResult
27 from .expander import GlobalWidth
28 from .limit_handler import LimitHandler
29 from .load_rounding import LoadRounding
30 from .measurement_database import MeasurementDatabase
31 from .pep3140 import Pep3140Dict
32 from .search_goal import SearchGoal
33 from .selector import Selector
34 from .target_scaling import TargetScaling
35 from .trial_measurement import AbstractMeasurer
36 from .trimmed_stat import TrimmedStat
37
38
39 @dataclass
40 class MultipleLossRatioSearch:
41     """Optimized binary search algorithm for finding conditional throughputs.
42
43     Traditional binary search algorithm needs initial interval
44     (lower and upper bound), and returns final narrow bounds
45     (related to its search goal) after bisecting
46     (until some exit condition is met).
47     The exit condition is usually related to the interval width,
48     (upper bound value minus lower bound value).
49
50     The optimized algorithm in this class contains several improvements
51     aimed to reduce overall search time.
52
53     One improvement is searching for bounds for multiple search goals at once.
54     Specifically, the trial measurement results influence bounds for all goals,
55     even though the selection of trial inputs for next measurement
56     focuses only on one goal. The focus can switch between goals frequently.
57
58     Next improvement is that results of trial measurements
59     with small trial duration can be used to find a reasonable starting interval
60     for full trial duration search.
61     This results in more trials performed, but smaller overall duration
62     in general.
63     Internally, such shorter trials come from "preceding targets",
64     handled in a same way as a search goal "final target".
65     Related improvement is that the "current" interval does not need to be valid
66     (e.g. one of the bounds is missing).
67     In that case, this algorithm will move and expand the interval,
68     in a process called external search. Only when both bounds are found,
69     the interval bisection (called internal search) starts making it narrow.
70
71     Next improvement is bisecting in logarithmic quantities,
72     so that target relative width is independent of measurement units.
73
74     Next improvement is basing the initial interval on forwarding rates
75     of few initial measurements, starting at max load and using forwarding rates
76     seen so far.
77
78     Next improvement is to allow the use of multiple shorter trials
79     instead one big trial, allowing a percentage of trials
80     to exceed the loss ratio target.
81     This makes the result more stable in practice.
82     Conservative behavior (single long trial, zero exceed ratio)
83     is still available using corresponding goal definitions.
84
85     Final improvement is exiting early if the minimal load
86     is not a valid lower bound (at final duration)
87     and also exiting if the overall search duration is too long.
88
89     There are also subtle optimizations related to candidate selection
90     and uneven splitting of intervals, too numerous to list here.
91     """
92
93     config: Config
94     """Arguments required at construction time."""
95     # End of fields required at intance creation.
96     measurer: AbstractMeasurer = secondary_field()
97     """Measurer to use, set at calling search()."""
98     debug: Callable[[str], None] = secondary_field()
99     """Object to call for logging, None means logging.debug."""
100     # Fields below are computed from data above
101     rounding: LoadRounding = secondary_field()
102     """Derived from goals. Instance to use for intended load rounding."""
103     from_float: Callable[[float], DiscreteLoad] = secondary_field()
104     """Conversion method from float [tps] intended load values."""
105     limit_handler: LimitHandler = secondary_field()
106     """Load post-processing utility based on config and rounding."""
107     scaling: TargetScaling = secondary_field()
108     """Utility for creating target chains for search goals."""
109     database: MeasurementDatabase = secondary_field()
110     """Storage for (stats of) measurement results so far."""
111     stop_time: float = secondary_field()
112     """Monotonic time value at which the search should end with failure."""
113
114     def search(
115         self,
116         measurer: AbstractMeasurer,
117         debug: Optional[Callable[[str], None]] = None,
118     ) -> Pep3140Dict[SearchGoal, Optional[TrimmedStat]]:
119         """Perform initial trials, create state object, proceed with main loop.
120
121         Stateful arguments (measurer and debug) are stored.
122         Derived objects are constructed from config.
123
124         :param measurer: Measurement provider to use by this search object.
125         :param debug: Callable to optionally use instead of logging.debug().
126         :returns: Structure containing conditional throughputs and other stats,
127             one for each search goal.
128         :type measurer: AbstractMeasurer
129         :type debug: Optional[Callable[[str], None]]
130         :returns: Mapping from goal to lower bound (none if min load is hit).
131         :rtype: Pep3140Dict[SearchGoal, Optional[TrimmedStat]]
132         :raises RuntimeError: If total duration is larger than timeout,
133             or if min load becomes an upper bound for a search goal
134             that has fail fast true.
135         """
136         self.measurer = measurer
137         self.debug = logging.debug if debug is None else debug
138         self.rounding = LoadRounding(
139             min_load=self.config.min_load,
140             max_load=self.config.max_load,
141             float_goals=[goal.relative_width for goal in self.config.goals],
142         )
143         self.from_float = DiscreteLoad.float_conver(rounding=self.rounding)
144         self.limit_handler = LimitHandler(
145             rounding=self.rounding,
146             debug=self.debug,
147         )
148         self.scaling = TargetScaling(
149             goals=self.config.goals,
150             rounding=self.rounding,
151         )
152         self.database = MeasurementDatabase(self.scaling.targets)
153         self.stop_time = time.monotonic() + self.config.search_duration_max
154         result0, result1 = self.run_initial_trials()
155         self.main_loop(result0.discrete_load, result1.discrete_load)
156         ret_dict = Pep3140Dict()
157         for goal in self.config.goals:
158             target = self.scaling.goal_to_final_target[goal]
159             bounds = self.database.get_relevant_bounds(target=target)
160             ret_dict[goal] = bounds.clo
161         return ret_dict
162
163     def measure(self, duration: float, load: DiscreteLoad) -> DiscreteResult:
164         """Call measurer and put the result to appropriate form in database.
165
166         Also check the argument types and load roundness,
167         and return the result to the caller.
168
169         :param duration: Intended duration for the trial measurement.
170         :param load: Intended load for the trial measurement:
171         :type duration: float
172         :type load: DiscreteLoad
173         :returns: The trial results.
174         :rtype: DiscreteResult
175         :raises RuntimeError: If an argument doed not have the required type.
176         """
177         if not isinstance(duration, float):
178             raise RuntimeError(f"Duration has to be float: {duration!r}")
179         if not isinstance(load, DiscreteLoad):
180             raise RuntimeError(f"Load has to be discrete: {load!r}")
181         if not load.is_round:
182             raise RuntimeError(f"Told to measure unrounded: {load!r}")
183         self.debug(f"Measuring at d={duration},il={int(load)}")
184         result = self.measurer.measure(
185             intended_duration=duration,
186             intended_load=float(load),
187         )
188         self.debug(f"Measured lr={result.loss_ratio}")
189         result = DiscreteResult.with_load(result=result, load=load)
190         self.database.add(result)
191         return result
192
193     def run_initial_trials(self) -> Tuple[DiscreteResult, DiscreteResult]:
194         """Perform trials to get enough data to start the selectors.
195
196         Measurements are done with all initial targets in mind,
197         based on smallest target loss ratio, largest initial trial duration,
198         and largest initial target width.
199
200         Forwarding rate is used as a hint for next intended load.
201         The relative quantity is used, as load can use different units.
202         When the smallest target loss ratio is non-zero, a correction is needed
203         (forwarding rate is only a good hint for zero loss ratio load).
204         The correction is conservative (all increase in load turns to losses).
205
206         Also, warmup trial (if configured) is performed,
207         all other trials are added to the database.
208
209         This could return the initial width, but from implementation perspective
210         it is easier to return two measurements (or the same one twice) here
211         and compute width later. The "one value twice" happens when max load
212         has small loss, or when min load has big loss.
213
214         :returns: Two last measured values, in any order. Or one value twice.
215         :rtype: Tuple[DiscreteResult, DiscreteResult]
216         """
217         max_load = self.limit_handler.max_load
218         ratio, duration, width = None, None, None
219         for target in self.scaling.targets:
220             if target.preceding:
221                 continue
222             if ratio is None or ratio > target.loss_ratio:
223                 ratio = target.loss_ratio
224             if not duration or duration < target.trial_duration:
225                 duration = target.trial_duration
226             if not width or width < target.discrete_width:
227                 width = target.discrete_width
228         self.debug(f"Init ratio {ratio} duration {duration} width {width}")
229         if self.config.warmup_duration:
230             self.debug("Warmup trial.")
231             self.measure(self.config.warmup_duration, max_load)
232             # Warmup should not affect the real results, reset the database.
233             self.database = MeasurementDatabase(self.scaling.targets)
234         self.debug(f"First trial at max rate: {max_load}")
235         result0 = self.measure(duration, max_load)
236         rfr = result0.relative_forwarding_rate
237         corrected_rfr = (self.from_float(rfr) / (1.0 - ratio)).rounded_down()
238         if corrected_rfr >= max_load:
239             self.debug("Small loss, no other initial trials are needed.")
240             return result0, result0
241         mrr = self.limit_handler.handle(corrected_rfr, width, None, max_load)
242         self.debug(f"Second trial at (corrected) mrr: {mrr}")
243         result1 = self.measure(duration, mrr)
244         # Attempt to get narrower width.
245         result_ratio = result1.loss_ratio
246         if result_ratio > ratio:
247             rfr2 = result1.relative_forwarding_rate
248             crfr2 = (self.from_float(rfr2) / (1.0 - ratio)).rounded_down()
249             mrr2 = self.limit_handler.handle(crfr2, width, None, mrr)
250         else:
251             mrr2 = mrr + width
252             mrr2 = self.limit_handler.handle(mrr2, width, mrr, max_load)
253         if not mrr2:
254             self.debug("Close enough, measuring at mrr2 is not needed.")
255             return result1, result1
256         self.debug(f"Third trial at (corrected) mrr2: {mrr2}")
257         result2 = self.measure(duration, mrr2)
258         return result1, result2
259
260     def main_loop(self, load0: DiscreteLoad, load1: DiscreteLoad) -> None:
261         """Initialize selectors and keep measuring the winning candidate.
262
263         Selectors are created, the two input loads are useful starting points.
264
265         The search ends when no selector nominates any candidate,
266         or if the search takes too long (or if a selector raises).
267
268         Winner is selected according to ordering defined in Candidate class.
269         In case of a tie, selectors for earlier goals are preferred.
270
271         As a selector is only allowed to update current width as the winner,
272         the update is done here explicitly.
273
274         :param load0: Discrete load of one of results from run_initial_trials.
275         :param load1: Discrete load of other of results from run_initial_trials.
276         :type load0: DiscreteLoad
277         :type load1: DiscreteLoad
278         :raises RuntimeError: If the search takes too long,
279             or if min load becomes an upper bound for any search goal
280         """
281         if load1 < load0:
282             load0, load1 = load1, load0
283         global_width = GlobalWidth.from_loads(load0, load1)
284         selectors = []
285         for target in self.scaling.goal_to_final_target.values():
286             selector = Selector(
287                 final_target=target,
288                 global_width=global_width,
289                 initial_lower_load=load0,
290                 initial_upper_load=load1,
291                 database=self.database,
292                 handler=self.limit_handler,
293                 debug=self.debug,
294             )
295             selectors.append(selector)
296         while time.monotonic() < self.stop_time:
297             winner = Candidate()
298             for selector in selectors:
299                 # Order of arguments is important
300                 # when two targets nominate the same candidate.
301                 winner = min(Candidate.nomination_from(selector), winner)
302             if not winner:
303                 break
304             # We do not check duration versus stop_time here,
305             # as some measurers can be unpredictably faster
306             # than their intended duration suggests.
307             self.measure(duration=winner.duration, load=winner.load)
308             # Delayed updates.
309             if winner.width:
310                 global_width.width = winner.width
311             winner.won()
312         else:
313             raise RuntimeError("Optimized search takes too long.")
314         self.debug("Search done.")