1 # Copyright (c) 2023 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
6 # http://www.apache.org/licenses/LICENSE-2.0
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
14 """Module defining Selector class."""
17 from dataclasses import dataclass, field
18 from typing import Callable, List, Optional, Tuple
20 from .dataclass import secondary_field
21 from .discrete_load import DiscreteLoad
22 from .discrete_width import DiscreteWidth
23 from .expander import TargetedExpander
24 from .global_width import GlobalWidth
25 from .limit_handler import LimitHandler
26 from .measurement_database import MeasurementDatabase
27 from .relevant_bounds import RelevantBounds
28 from .target_spec import TargetSpec
29 from .strategy import StrategyBase, STRATEGY_CLASSES
34 """A selector is an abstraction that focuses on only one of search goals.
36 While lower-level logic is hidden in strategy classes,
37 the code in this class is responsible for initializing strategies
38 and shifting targets towards the final target.
40 While the public methods have the same names and meaning as the ones
41 in strategy classes, their signature is different.
42 Selector adds the current target trial duration to the output of nominate(),
43 and adds the current bounds to the input of won().
45 The nominate method does not return a complete Candidate instance,
46 as we need to avoid circular dependencies
47 (candidate will refer to selector).
50 final_target: TargetSpec
51 """The target this selector is trying to ultimately achieve."""
52 global_width: GlobalWidth
53 """Reference to the global width tracking instance."""
54 initial_lower_load: DiscreteLoad
55 """Smaller of the two loads distinguished at instance creation.
56 During operation, this field is reused to store preceding target bound."""
57 initial_upper_load: DiscreteLoad
58 """Larger of the two loads distinguished at instance creation.
59 During operation, this field is reused to store preceding target bound."""
60 database: MeasurementDatabase = field(repr=False)
61 """Reference to the common database used by all selectors."""
62 handler: LimitHandler = field(repr=False)
63 """Reference to the class used to avoid too narrow intervals."""
64 debug: Callable[[str], None] = field(repr=False)
65 """Injectable function for debug logging."""
66 # Primary above, derived below.
67 current_target: TargetSpec = secondary_field()
68 """The target the selector is focusing on currently."""
69 target_stack: List[TargetSpec] = secondary_field()
70 """Stack of targets. When current target is achieved, next is popped."""
71 strategies: Tuple[StrategyBase] = secondary_field()
72 """Instances implementing particular selection strategies."""
73 current_strategy: Optional[StrategyBase] = secondary_field()
74 """Reference to strategy used for last nomination, needed for won()."""
76 bounds: RelevantBounds = secondary_field()
77 """New relevant bounds for this round of candidate selection."""
79 def __post_init__(self) -> None:
80 """Initialize derived values."""
81 self.target_stack = [self.final_target]
82 while preceding_target := self.target_stack[-1].preceding:
83 self.target_stack.append(preceding_target)
84 self.current_target = self.target_stack.pop()
85 self._recreate_strategies()
87 def _recreate_strategies(self) -> None:
88 """Recreate strategies after current target has changed.
90 Width expander is recreated as target width is now smaller.
91 For convenience, strategies get injectable debug
92 which prints also the current target.
94 expander = TargetedExpander(
95 target=self.current_target,
96 global_width=self.global_width,
97 initial_lower_load=self.initial_lower_load,
98 initial_upper_load=self.initial_upper_load,
103 def wrapped_debug(text: str) -> None:
104 """Call self debug with current target info prepended.
106 :param text: Message to log at debug level.
109 self.debug(f"Target {self.current_target}: {text}")
111 self.strategies = tuple(
113 target=self.current_target,
115 initial_lower_load=self.initial_lower_load,
116 initial_upper_load=self.initial_upper_load,
117 handler=self.handler,
120 for cls in STRATEGY_CLASSES
122 self.current_strategy = None
123 self.debug(f"Created strategies for: {self.current_target}")
125 def _update_bounds(self) -> None:
126 """Before each iteration, call this to update bounds cache."""
127 self.bounds = self.database.get_relevant_bounds(self.current_target)
131 ) -> Tuple[Optional[DiscreteLoad], float, Optional[DiscreteWidth]]:
132 """Find first strategy that wants to nominate, return trial inputs.
134 Returned load is None if no strategy wants to nominate.
136 Current target is shifted when (now preceding) target is reached.
137 As each strategy never becomes done before at least one
138 bound relevant to the current target becomes available,
139 it is never needed to revert to the preceding target after the shift.
141 As the initial trials had inputs relevant to all initial targets,
142 the only way for this not to nominate a load
143 is when the final target is reached (including hitting min or max load).
144 The case of hitting min load raises, so search fails early.
146 :returns: Nominated load, duration, and global width to set if winning.
147 :rtype: Tuple[Optional[DiscreteLoad], float, Optional[DiscreteWidth]]
148 :raises RuntimeError: If internal inconsistency is detected,
149 or if min load becomes an upper bound.
151 self._update_bounds()
152 self.current_strategy = None
154 for strategy in self.strategies:
155 load, width = strategy.nominate(self.bounds)
157 self.current_strategy = strategy
158 return load, self.current_target.trial_duration, width
159 if not self.bounds.clo and not self.bounds.chi:
160 raise RuntimeError("Internal error: no clo nor chi.")
161 if not self.target_stack:
162 if not self.bounds.clo and self.current_target.fail_fast:
163 raise RuntimeError(f"No lower bound: {self.bounds.chi!r}")
164 self.debug(f"Goal {self.current_target} reached: {self.bounds}")
165 return None, self.current_target.trial_duration, None
166 # Everything is ready for next target in the chain.
167 self.current_target = self.target_stack.pop()
168 # Debug logs look better if we forget bounds are TrimmedStat.
169 # Abuse rounding (if not None) to convert to pure DiscreteLoad.
170 clo, chi = self.bounds.clo, self.bounds.chi
171 self.initial_lower_load = clo.rounded_down() if clo else clo
172 self.initial_upper_load = chi.rounded_down() if chi else chi
173 self._update_bounds()
174 self._recreate_strategies()
176 def won(self, load: DiscreteLoad) -> None:
177 """Update any private info when candidate became a winner.
179 :param load: The load previously nominated by current strategy.
180 :type load: DiscreteLoad
182 self._update_bounds()
183 self.current_strategy.won(bounds=self.bounds, load=load)