1 # Copyright (c) 2023 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
6 # http://www.apache.org/licenses/LICENSE-2.0
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
14 """Module defining TargetScaling class."""
16 from dataclasses import dataclass
17 from typing import Dict, Tuple
19 from .dataclass import secondary_field
20 from .discrete_width import DiscreteWidth
21 from .load_rounding import LoadRounding
22 from .search_goal import SearchGoal
23 from .search_goal_tuple import SearchGoalTuple
24 from .target_spec import TargetSpec
29 """Encapsulate targets derived from goals.
31 No default values for primaries, contructor call has to specify everything.
34 goals: SearchGoalTuple
35 """Set of goals to generate targets for."""
36 rounding: LoadRounding
37 """Rounding instance to use (targets have discrete width)."""
39 targets: Tuple[TargetSpec] = secondary_field()
40 """The generated targets, linked into chains."""
41 goal_to_final_target: Dict[SearchGoal, TargetSpec] = secondary_field()
42 """Mapping from a goal to its corresponding final target."""
44 def __post_init__(self) -> None:
45 """For each goal create final, and non-final targets and link them."""
47 self.goal_to_final_target = {}
48 for goal in self.goals:
49 standalone_targets = []
51 width = DiscreteWidth(
52 rounding=self.rounding,
53 float_width=goal.relative_width,
55 duration_sum = goal.duration_sum
57 loss_ratio=goal.loss_ratio,
58 exceed_ratio=goal.exceed_ratio,
60 trial_duration=goal.final_trial_duration,
61 duration_sum=duration_sum,
62 expansion_coefficient=goal.expansion_coefficient,
63 fail_fast=goal.fail_fast,
66 standalone_targets.append(target)
68 preceding_targets = goal.preceding_targets
71 goal.initial_trial_duration / duration_sum,
72 1.0 / preceding_targets,
77 for count in range(preceding_targets):
78 preceding_sum = duration_sum * pow(multiplier, count + 1)
79 if count + 1 >= preceding_targets:
80 preceding_sum = goal.initial_trial_duration
81 trial_duration = min(goal.final_trial_duration, preceding_sum)
84 loss_ratio=goal.loss_ratio,
85 exceed_ratio=goal.exceed_ratio,
87 trial_duration=trial_duration,
88 duration_sum=preceding_sum,
89 expansion_coefficient=goal.expansion_coefficient,
93 standalone_targets.append(target)
94 # Link preceding targets.
95 preceding_target = None
96 for target in reversed(standalone_targets):
97 linked_target = target.with_preceding(preceding_target)
98 linked_targets.append(linked_target)
99 preceding_target = linked_target
100 # Associate final target to the goal.
101 self.goal_to_final_target[goal] = linked_targets[-1]
102 # Store all targets as a tuple.
103 self.targets = tuple(linked_targets)