Revert "fix(IPsecUtil): Delete keywords no longer used"
[csit.git] / resources / libraries / python / MLRsearch / selector.py
1 # Copyright (c) 2023 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Module defining Selector class."""
15
16
17 from dataclasses import dataclass, field
18 from typing import Callable, List, Optional, Tuple
19
20 from .dataclass import secondary_field
21 from .discrete_load import DiscreteLoad
22 from .discrete_width import DiscreteWidth
23 from .expander import TargetedExpander
24 from .global_width import GlobalWidth
25 from .limit_handler import LimitHandler
26 from .measurement_database import MeasurementDatabase
27 from .relevant_bounds import RelevantBounds
28 from .target_spec import TargetSpec
29 from .strategy import StrategyBase, STRATEGY_CLASSES
30
31
32 @dataclass
33 class Selector:
34     """A selector is an abstraction that focuses on only one of search goals.
35
36     While lower-level logic is hidden in strategy classes,
37     the code in this class is responsible for initializing strategies
38     and shifting targets towards the final target.
39
40     While the public methods have the same names and meaning as the ones
41     in strategy classes, their signature is different.
42     Selector adds the current target trial duration to the output of nominate(),
43     and adds the current bounds to the input of won().
44
45     The nominate method does not return a complete Candidate instance,
46     as we need to avoid circular dependencies
47     (candidate will refer to selector).
48     """
49
50     final_target: TargetSpec
51     """The target this selector is trying to ultimately achieve."""
52     global_width: GlobalWidth
53     """Reference to the global width tracking instance."""
54     initial_lower_load: DiscreteLoad
55     """Smaller of the two loads distinguished at instance creation.
56     During operation, this field is reused to store preceding target bound."""
57     initial_upper_load: DiscreteLoad
58     """Larger of the two loads distinguished at instance creation.
59     During operation, this field is reused to store preceding target bound."""
60     database: MeasurementDatabase = field(repr=False)
61     """Reference to the common database used by all selectors."""
62     handler: LimitHandler = field(repr=False)
63     """Reference to the class used to avoid too narrow intervals."""
64     debug: Callable[[str], None] = field(repr=False)
65     """Injectable function for debug logging."""
66     # Primary above, derived below.
67     current_target: TargetSpec = secondary_field()
68     """The target the selector is focusing on currently."""
69     target_stack: List[TargetSpec] = secondary_field()
70     """Stack of targets. When current target is achieved, next is popped."""
71     strategies: Tuple[StrategyBase] = secondary_field()
72     """Instances implementing particular selection strategies."""
73     current_strategy: Optional[StrategyBase] = secondary_field()
74     """Reference to strategy used for last nomination, needed for won()."""
75     # Cache.
76     bounds: RelevantBounds = secondary_field()
77     """New relevant bounds for this round of candidate selection."""
78
79     def __post_init__(self) -> None:
80         """Initialize derived values."""
81         self.target_stack = [self.final_target]
82         while preceding_target := self.target_stack[-1].preceding:
83             self.target_stack.append(preceding_target)
84         self.current_target = self.target_stack.pop()
85         self._recreate_strategies()
86
87     def _recreate_strategies(self) -> None:
88         """Recreate strategies after current target has changed.
89
90         Width expander is recreated as target width is now smaller.
91         For convenience, strategies get injectable debug
92         which prints also the current target.
93         """
94         expander = TargetedExpander(
95             target=self.current_target,
96             global_width=self.global_width,
97             initial_lower_load=self.initial_lower_load,
98             initial_upper_load=self.initial_upper_load,
99             handler=self.handler,
100             debug=self.debug,
101         )
102
103         def wrapped_debug(text: str) -> None:
104             """Call self debug with current target info prepended.
105
106             :param text: Message to log at debug level.
107             :type text: str
108             """
109             self.debug(f"Target {self.current_target}: {text}")
110
111         self.strategies = tuple(
112             cls(
113                 target=self.current_target,
114                 expander=expander,
115                 initial_lower_load=self.initial_lower_load,
116                 initial_upper_load=self.initial_upper_load,
117                 handler=self.handler,
118                 debug=wrapped_debug,
119             )
120             for cls in STRATEGY_CLASSES
121         )
122         self.current_strategy = None
123         self.debug(f"Created strategies for: {self.current_target}")
124
125     def _update_bounds(self) -> None:
126         """Before each iteration, call this to update bounds cache."""
127         self.bounds = self.database.get_relevant_bounds(self.current_target)
128
129     def nominate(
130         self,
131     ) -> Tuple[Optional[DiscreteLoad], float, Optional[DiscreteWidth]]:
132         """Find first strategy that wants to nominate, return trial inputs.
133
134         Returned load is None if no strategy wants to nominate.
135
136         Current target is shifted when (now preceding) target is reached.
137         As each strategy never becomes done before at least one
138         bound relevant to the current target becomes available,
139         it is never needed to revert to the preceding target after the shift.
140
141         As the initial trials had inputs relevant to all initial targets,
142         the only way for this not to nominate a load
143         is when the final target is reached (including hitting min or max load).
144         The case of hitting min load raises, so search fails early.
145
146         :returns: Nominated load, duration, and global width to set if winning.
147         :rtype: Tuple[Optional[DiscreteLoad], float, Optional[DiscreteWidth]]
148         :raises RuntimeError: If internal inconsistency is detected,
149             or if min load becomes an upper bound.
150         """
151         self._update_bounds()
152         self.current_strategy = None
153         while 1:
154             for strategy in self.strategies:
155                 load, width = strategy.nominate(self.bounds)
156                 if load:
157                     self.current_strategy = strategy
158                     return load, self.current_target.trial_duration, width
159             if not self.bounds.clo and not self.bounds.chi:
160                 raise RuntimeError("Internal error: no clo nor chi.")
161             if not self.target_stack:
162                 if not self.bounds.clo and self.current_target.fail_fast:
163                     raise RuntimeError(f"No lower bound: {self.bounds.chi!r}")
164                 self.debug(f"Goal {self.current_target} reached: {self.bounds}")
165                 return None, self.current_target.trial_duration, None
166             # Everything is ready for next target in the chain.
167             self.current_target = self.target_stack.pop()
168             # Debug logs look better if we forget bounds are TrimmedStat.
169             # Abuse rounding (if not None) to convert to pure DiscreteLoad.
170             clo, chi = self.bounds.clo, self.bounds.chi
171             self.initial_lower_load = clo.rounded_down() if clo else clo
172             self.initial_upper_load = chi.rounded_down() if chi else chi
173             self._update_bounds()
174             self._recreate_strategies()
175
176     def won(self, load: DiscreteLoad) -> None:
177         """Update any private info when candidate became a winner.
178
179         :param load: The load previously nominated by current strategy.
180         :type load: DiscreteLoad
181         """
182         self._update_bounds()
183         self.current_strategy.won(bounds=self.bounds, load=load)