X-Git-Url: https://gerrit.fd.io/r/gitweb?a=blobdiff_plain;f=resources%2Flibraries%2Fpython%2FMLRsearch%2Ftarget_scaling.py;fp=resources%2Flibraries%2Fpython%2FMLRsearch%2Ftarget_scaling.py;h=25114c311c993abe4208d313c6e2e95f64d0f5a4;hb=e5dbe10d9599b9a53fa07e6fadfaf427ba6d69e3;hp=0000000000000000000000000000000000000000;hpb=c6dfb6c09c5dafd1d522f96b4b86c5ec5efc1c83;p=csit.git diff --git a/resources/libraries/python/MLRsearch/target_scaling.py b/resources/libraries/python/MLRsearch/target_scaling.py new file mode 100644 index 0000000000..25114c311c --- /dev/null +++ b/resources/libraries/python/MLRsearch/target_scaling.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module defining TargetScaling class.""" + +from dataclasses import dataclass +from typing import Dict, Tuple + +from .dataclass import secondary_field +from .discrete_width import DiscreteWidth +from .load_rounding import LoadRounding +from .search_goal import SearchGoal +from .search_goal_tuple import SearchGoalTuple +from .target_spec import TargetSpec + + +@dataclass +class TargetScaling: + """Encapsulate targets derived from goals. + + No default values for primaries, contructor call has to specify everything. + """ + + goals: SearchGoalTuple + """Set of goals to generate targets for.""" + rounding: LoadRounding + """Rounding instance to use (targets have discrete width).""" + # Derived quantities. + targets: Tuple[TargetSpec] = secondary_field() + """The generated targets, linked into chains.""" + goal_to_final_target: Dict[SearchGoal, TargetSpec] = secondary_field() + """Mapping from a goal to its corresponding final target.""" + + def __post_init__(self) -> None: + """For each goal create final, and non-final targets and link them.""" + linked_targets = [] + self.goal_to_final_target = {} + for goal in self.goals: + standalone_targets = [] + # Final target. + width = DiscreteWidth( + rounding=self.rounding, + float_width=goal.relative_width, + ).rounded_down() + duration_sum = goal.duration_sum + target = TargetSpec( + loss_ratio=goal.loss_ratio, + exceed_ratio=goal.exceed_ratio, + discrete_width=width, + trial_duration=goal.final_trial_duration, + duration_sum=duration_sum, + expansion_coefficient=goal.expansion_coefficient, + fail_fast=goal.fail_fast, + preceding=None, + ) + standalone_targets.append(target) + # Non-final targets. + preceding_targets = goal.preceding_targets + multiplier = ( + pow( + goal.initial_trial_duration / duration_sum, + 1.0 / preceding_targets, + ) + if preceding_targets + else 1.0 + ) + for count in range(preceding_targets): + preceding_sum = duration_sum * pow(multiplier, count + 1) + if count + 1 >= preceding_targets: + preceding_sum = goal.initial_trial_duration + trial_duration = min(goal.final_trial_duration, preceding_sum) + width *= 2 + target = TargetSpec( + loss_ratio=goal.loss_ratio, + exceed_ratio=goal.exceed_ratio, + discrete_width=width, + trial_duration=trial_duration, + duration_sum=preceding_sum, + expansion_coefficient=goal.expansion_coefficient, + fail_fast=False, + preceding=None, + ) + standalone_targets.append(target) + # Link preceding targets. + preceding_target = None + for target in reversed(standalone_targets): + linked_target = target.with_preceding(preceding_target) + linked_targets.append(linked_target) + preceding_target = linked_target + # Associate final target to the goal. + self.goal_to_final_target[goal] = linked_targets[-1] + # Store all targets as a tuple. + self.targets = tuple(linked_targets)