1 # Copyright (c) 2023 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
6 # http://www.apache.org/licenses/LICENSE-2.0
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
14 """Module defining DiscreteWidth class."""
16 from __future__ import annotations
18 from dataclasses import dataclass, field
20 from .load_rounding import LoadRounding
23 # TODO: Make properly frozen.
24 @dataclass(order=True)
26 """Structure to store float width together with its rounded integer form.
28 The width does not have to be positive, i.e. the computed integer width
29 does not have to be larger than zero.
31 LoadRounding instance is needed to enable conversion between two forms.
33 Conversion and arithmetic methods are added for convenience.
34 Division and non-integer multiplication are intentionally not supported,
35 as MLRsearch should not seek unround widths when round ones are available.
37 The instance is effectively immutable, but not hashable as it refers
38 to the rounding instance, which is implemented as mutable
39 (although the mutations are not visible).
42 # For most debugs, rounding in repr just takes space.
43 rounding: LoadRounding = field(repr=False, compare=False)
44 """Rounding instance to use for conversion."""
45 float_width: float = None
46 """Relative width of float intended load.
47 This is treated as a constructor argument, and does not need to match
48 the int width. Int width is computed to be no wider than this."""
49 int_width: int = field(compare=False, default=None)
50 """Integer form, difference of integer loads.
51 This is the primary quantity used by most computations."""
53 def __post_init__(self) -> None:
54 """Ensure types, compute missing information.
56 At this point, it is allowed for float width to be slightly larger
57 than the implied int width.
59 If both forms are specified, the float form is taken as primary
60 (thus the integer form is recomputed to match).
62 :raises RuntimeError: If both init arguments are None.
64 if self.float_width is None and self.int_width is None:
65 raise RuntimeError("Float or int value is needed.")
66 if self.float_width is None:
67 self.int_width = int(self.int_width)
68 min_load = self.rounding.int2float(0)
69 increased_load = self.rounding.int2float(self.int_width)
70 self.float_width = (increased_load - min_load) / increased_load
72 self.float_width = float(self.float_width)
73 min_load = self.rounding.int2float(0)
74 increased_load = min_load / (1.0 - self.float_width)
75 int_load = self.rounding.float2int(increased_load)
76 verify_load = self.rounding.int2float(int_load)
77 if verify_load > increased_load:
79 self.int_width = int_load
81 def __str__(self) -> str:
82 """Convert into a short human-readable string.
84 :returns: The short string.
87 return f"int_width={int(self)}"
89 def __int__(self) -> int:
90 """Return the integer form.
92 :returns: The int field value.
97 def __float__(self) -> float:
98 """Return the float form.
100 :returns: The float field value.
103 return self.float_width
105 def __hash__(self) -> int:
106 """Return a hash based on the float value.
108 With this, the instance can be used as if it was immutable and hashable,
109 e.g. it can be a key in a dict.
111 :returns: Hash value for this instance.
114 return hash(float(self))
116 def rounded_down(self) -> DiscreteWidth:
117 """Create and return new instance with float form matching int.
119 :returns: New instance with same int form and float form rounded down.
120 :rtype: DiscreteWidth
122 return DiscreteWidth(rounding=self.rounding, int_width=int(self))
124 def __add__(self, width: DiscreteWidth) -> DiscreteWidth:
125 """Return newly constructed instance with int widths added.
127 Rounding instance (reference) is copied from self.
129 Argument type is checked, to avoid caller adding something unsupported.
131 :param width: Value to add to int width.
132 :type width: DiscreteWidth
133 :returns: New instance.
134 :rtype: DiscreteWidth
135 :raises RuntimeError: When argument has unexpected type.
137 if not isinstance(width, DiscreteWidth):
138 raise RuntimeError(f"Not width: {width!r}")
139 return DiscreteWidth(
140 rounding=self.rounding,
141 int_width=self.int_width + int(width),
144 def __sub__(self, width: DiscreteWidth) -> DiscreteWidth:
145 """Return newly constructed instance with int widths subtracted.
147 Rounding instance (reference) is copied from self.
149 Argument type is checked, to avoid caller adding something unsupported.
150 Non-positive results are disallowed by constructor.
152 :param width: Value to subtract to int width.
153 :type width: DiscreteWidth
154 :returns: New instance.
155 :rtype: DiscreteWidth
156 :raises RuntimeError: When argument has unexpected type.
158 if not isinstance(width, DiscreteWidth):
159 raise RuntimeError(f"Not width: {type(width)}")
160 return DiscreteWidth(
161 rounding=self.rounding,
162 int_width=self.int_width - int(width),
165 def __mul__(self, coefficient: int) -> DiscreteWidth:
166 """Construct new instance with int value multiplied.
168 Rounding instance (reference) is copied from self.
170 :param coefficient: Constant to multiply int width with.
171 :type coefficient: int
172 :returns: New instance with multiplied int width.
173 :rtype: DiscreteWidth
174 :raises RuntimeError: If argument value does not meet requirements.
176 if not isinstance(coefficient, int):
177 raise RuntimeError(f"Coefficient not int: {coefficient!r}")
179 raise RuntimeError(f"Coefficient not positive: {coefficient!r}")
180 return DiscreteWidth(
181 rounding=self.rounding,
182 int_width=self.int_width * coefficient,
185 def half_rounded_down(self) -> DiscreteWidth:
186 """Contruct new instance of half the integer width.
188 If the current integer width is odd, round the half width down.
190 :returns: New instance with half int width.
191 :rtype: DiscreteWidth
192 :raises RuntimeError: If the resulting integerl width is not positive.
194 return DiscreteWidth(
195 rounding=self.rounding,
196 int_width=self.int_width // 2,