feat(jumpavg): support small values via unit param
[csit.git] / resources / libraries / python / jumpavg / classify.py
index 87d2502..cc3cdcc 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (c) 2022 Cisco and/or its affiliates.
+# Copyright (c) 2023 Cisco and/or its affiliates.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at:
 
 """Module holding the classify function
 
-Classification os one of primary purposes of this package.
+Classification is one of primary purposes of this package.
 
 Minimal message length principle is used
 for grouping results into the list of groups,
 assuming each group is a population of different Gaussian distribution.
 """
 
-import typing
+from typing import Iterable, Optional, Union
 
-from .AvgStdevStats import AvgStdevStats
-from .BitCountingGroupList import BitCountingGroupList
+from .avg_stdev_stats import AvgStdevStats
+from .bit_counting_group_list import BitCountingGroupList
 
 
 def classify(
-    values: typing.Iterable[typing.Union[float, typing.Iterable[float]]]
+    values: Iterable[Union[float, Iterable[float]]],
+    unit: Optional[float] = None,
+    sbps: Optional[float] = None,
 ) -> BitCountingGroupList:
     """Return the values in groups of optimal bit count.
 
@@ -38,12 +40,27 @@ def classify(
     Internally, such sequence is replaced by AvgStdevStats
     after maximal value is found.
 
+    If the values are smaller than expected (below one unit),
+    the underlying assumption break down and the classification is wrong.
+    Use the "unit" parameter to hint at what the input resolution is.
+
+    If the correct value of unit is not known beforehand,
+    the argument "sbps" (Significant Bits Per Sample) can be used
+    to set unit such that maximal sample value is this many ones in binary.
+    If neither "unit" nor "sbps" are given, "sbps" of 12 is used by default.
+
     :param values: Sequence of runs to classify.
+    :param unit: Typical resolution of the values.
+        Zero and None means no unit given.
+    :param sbps: Significant Bits Per Sample. None on zero means 12.
+        If units is not set, this is used to compute unit from max sample value.
     :type values: Iterable[Union[float, Iterable[float]]]
+    :type unit: Optional[float]
+    :type sbps: Optional[float]
     :returns: Classified group list.
     :rtype: BitCountingGroupList
     """
-    processed_values = list()
+    processed_values = []
     max_value = 0.0
     for value in values:
         if isinstance(value, (float, int)):
@@ -55,9 +72,14 @@ def classify(
                 if subvalue > max_value:
                     max_value = subvalue
             processed_values.append(AvgStdevStats.for_runs(value))
+    if not unit:
+        if not sbps:
+            sbps = 12.0
+        max_in_units = pow(2.0, sbps + 1.0) - 1.0
+        unit = max_value / max_in_units
     # Glist means group list (BitCountingGroupList).
-    open_glists = list()
-    record_glist = BitCountingGroupList(max_value=max_value)
+    open_glists = []
+    record_glist = BitCountingGroupList(max_value=max_value, unit=unit)
     for value in processed_values:
         new_open_glist = record_glist.copy_fast().append_group_of_runs([value])
         record_glist = new_open_glist
@@ -68,9 +90,7 @@ def classify(
         open_glists.append(new_open_glist)
     previous_average = record_glist[0].stats.avg
     for group in record_glist:
-        if group.stats.avg == previous_average:
-            group.comment = "normal"
-        elif group.stats.avg < previous_average:
+        if group.stats.avg < previous_average:
             group.comment = "regression"
         elif group.stats.avg > previous_average:
             group.comment = "progression"