style(PLRsearch): format according to black
[csit.git] / resources / libraries / python / PLRsearch / PLRsearch.py
index 7599a9e..e0eea23 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (c) 2023 Cisco and/or its affiliates.
+# Copyright (c) 2024 Cisco and/or its affiliates.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at:
@@ -53,8 +53,14 @@ class PLRsearch:
     log_xerfcx_10 = math.log(xerfcx_limit - math.exp(10) * erfcx(math.exp(10)))
 
     def __init__(
-            self, measurer, trial_duration_per_trial, packet_loss_ratio_target,
-            trial_number_offset=0, timeout=7200.0, trace_enabled=False):
+        self,
+        measurer,
+        trial_duration_per_trial,
+        packet_loss_ratio_target,
+        trial_number_offset=0,
+        timeout=7200.0,
+        trace_enabled=False,
+    ):
         """Store rate measurer and additional parameters.
 
         The measurer must never report negative loss count.
@@ -186,8 +192,12 @@ class PLRsearch:
             trial_number += 1
             logging.info(f"Trial {trial_number!r}")
             results = self.measure_and_compute(
-                self.trial_duration_per_trial * trial_number, transmit_rate,
-                trial_result_list, min_rate, max_rate, focus_trackers
+                self.trial_duration_per_trial * trial_number,
+                transmit_rate,
+                trial_result_list,
+                min_rate,
+                max_rate,
+                focus_trackers,
             )
             measurement, average, stdev, avg1, avg2, focus_trackers = results
             zeros += 1
@@ -205,15 +215,16 @@ class PLRsearch:
             if (trial_number - self.trial_number_offset) <= 1:
                 next_load = max_rate
             elif (trial_number - self.trial_number_offset) <= 3:
-                next_load = (measurement.relative_forwarding_rate / (
-                    1.0 - self.packet_loss_ratio_target))
+                next_load = measurement.relative_forwarding_rate / (
+                    1.0 - self.packet_loss_ratio_target
+                )
             else:
                 next_load = (avg1 + avg2) / 2.0
                 if zeros > 0:
                     if lossy_loads[0] > next_load:
                         diminisher = math.pow(2.0, 1 - zeros)
                         next_load = lossy_loads[0] + diminisher * next_load
-                        next_load /= (1.0 + diminisher)
+                        next_load /= 1.0 + diminisher
                     # On zero measurement, we need to drain obsoleted low losses
                     # even if we did not use them to increase next_load,
                     # in order to get to usable loses at higher loads.
@@ -263,22 +274,22 @@ class PLRsearch:
         # TODO: chi is from https://en.wikipedia.org/wiki/Nondimensionalization
         chi = (load - mrr) / spread
         chi0 = -mrr / spread
-        trace(u"stretch: load", load)
-        trace(u"mrr", mrr)
-        trace(u"spread", spread)
-        trace(u"chi", chi)
-        trace(u"chi0", chi0)
+        trace("stretch: load", load)
+        trace("mrr", mrr)
+        trace("spread", spread)
+        trace("chi", chi)
+        trace("chi0", chi0)
         if chi > 0:
             log_lps = math.log(
                 load - mrr + (log_plus(0, -chi) - log_plus(0, chi0)) * spread
             )
-            trace(u"big loss direct log_lps", log_lps)
+            trace("big loss direct log_lps", log_lps)
         else:
             two_positive = log_plus(chi, 2 * chi0 - log_2)
             two_negative = log_plus(chi0, 2 * chi - log_2)
             if two_positive <= two_negative:
                 log_lps = log_minus(chi, chi0) + log_spread
-                trace(u"small loss crude log_lps", log_lps)
+                trace("small loss crude log_lps", log_lps)
                 return log_lps
             two = log_minus(two_positive, two_negative)
             three_positive = log_plus(two_positive, 3 * chi - log_3)
@@ -286,11 +297,11 @@ class PLRsearch:
             three = log_minus(three_positive, three_negative)
             if two == three:
                 log_lps = two + log_spread
-                trace(u"small loss approx log_lps", log_lps)
+                trace("small loss approx log_lps", log_lps)
             else:
                 log_lps = math.log(log_plus(0, chi) - log_plus(0, chi0))
                 log_lps += log_spread
-                trace(u"small loss direct log_lps", log_lps)
+                trace("small loss direct log_lps", log_lps)
         return log_lps
 
     @staticmethod
@@ -329,26 +340,26 @@ class PLRsearch:
         # TODO: The stretch sign is just to have less minuses. Worth changing?
         chi = (mrr - load) / spread
         chi0 = mrr / spread
-        trace(u"Erf: load", load)
-        trace(u"mrr", mrr)
-        trace(u"spread", spread)
-        trace(u"chi", chi)
-        trace(u"chi0", chi0)
+        trace("Erf: load", load)
+        trace("mrr", mrr)
+        trace("spread", spread)
+        trace("chi", chi)
+        trace("chi0", chi0)
         if chi >= -1.0:
-            trace(u"positive, b roughly bigger than m", None)
+            trace("positive, b roughly bigger than m", None)
             if chi > math.exp(10):
                 first = PLRsearch.log_xerfcx_10 + 2 * (math.log(chi) - 10)
-                trace(u"approximated first", first)
+                trace("approximated first", first)
             else:
                 first = math.log(PLRsearch.xerfcx_limit - chi * erfcx(chi))
-                trace(u"exact first", first)
+                trace("exact first", first)
             first -= chi * chi
             second = math.log(PLRsearch.xerfcx_limit - chi * erfcx(chi0))
             second -= chi0 * chi0
             intermediate = log_minus(first, second)
-            trace(u"first", first)
+            trace("first", first)
         else:
-            trace(u"negative, b roughly smaller than m", None)
+            trace("negative, b roughly smaller than m", None)
             exp_first = PLRsearch.xerfcx_limit + chi * erfcx(-chi)
             exp_first *= math.exp(-chi * chi)
             exp_first -= 2 * chi
@@ -359,17 +370,17 @@ class PLRsearch:
             second = math.log(PLRsearch.xerfcx_limit - chi * erfcx(chi0))
             second -= chi0 * chi0
             intermediate = math.log(exp_first - math.exp(second))
-            trace(u"exp_first", exp_first)
-        trace(u"second", second)
-        trace(u"intermediate", intermediate)
+            trace("exp_first", exp_first)
+        trace("second", second)
+        trace("intermediate", intermediate)
         result = intermediate + math.log(spread) - math.log(erfc(-chi0))
-        trace(u"result", result)
+        trace("result", result)
         return result
 
     @staticmethod
     def find_critical_rate(
-            trace, lfit_func, min_rate, max_rate, loss_ratio_target,
-            mrr, spread):
+        trace, lfit_func, min_rate, max_rate, loss_ratio_target, mrr, spread
+    ):
         """Given ratio target and parameters, return the achieving offered load.
 
         This is basically an inverse function to lfit_func
@@ -411,12 +422,12 @@ class PLRsearch:
             loss_rate = math.exp(lfit_func(trace, rate, mrr, spread))
             loss_ratio = loss_rate / rate
             if loss_ratio > loss_ratio_target:
-                trace(u"halving down", rate)
+                trace("halving down", rate)
                 rate_hi = rate
             elif loss_ratio < loss_ratio_target:
-                trace(u"halving up", rate)
+                trace("halving up", rate)
                 rate_lo = rate
-        trace(u"found", rate)
+        trace("found", rate)
         return rate
 
     @staticmethod
@@ -457,12 +468,12 @@ class PLRsearch:
         :rtype: float
         """
         log_likelihood = 0.0
-        trace(u"log_weight for mrr", mrr)
-        trace(u"spread", spread)
+        trace("log_weight for mrr", mrr)
+        trace("spread", spread)
         for result in trial_result_list:
-            trace(u"for tr", result.intended_load)
-            trace(u"lc", result.loss_count)
-            trace(u"d", result.intended_duration)
+            trace("for tr", result.intended_load)
+            trace("lc", result.loss_count)
+            trace("d", result.intended_duration)
             # _rel_ values use units of intended_load (transactions per second).
             log_avg_rel_loss_per_second = lfit_func(
                 trace, result.intended_load, mrr, spread
@@ -477,13 +488,20 @@ class PLRsearch:
             log_trial_likelihood *= -result.loss_count
             log_trial_likelihood -= log_plus(0.0, +log_avg_abs_loss_per_trial)
             log_likelihood += log_trial_likelihood
-            trace(u"avg_loss_per_trial", math.exp(log_avg_abs_loss_per_trial))
-            trace(u"log_trial_likelihood", log_trial_likelihood)
+            trace("avg_loss_per_trial", math.exp(log_avg_abs_loss_per_trial))
+            trace("log_trial_likelihood", log_trial_likelihood)
         return log_likelihood
 
     def measure_and_compute(
-            self, trial_duration, transmit_rate, trial_result_list,
-            min_rate, max_rate, focus_trackers=(None, None), max_samples=None):
+        self,
+        trial_duration,
+        transmit_rate,
+        trial_result_list,
+        min_rate,
+        max_rate,
+        focus_trackers=(None, None),
+        max_samples=None,
+    ):
         """Perform both measurement and computation at once.
 
         High level steps: Prepare and launch computation worker processes,
@@ -572,7 +590,7 @@ class PLRsearch:
             # See https://stackoverflow.com/questions/15137292/large-objects-and-multiprocessing-pipes-and-send
             worker = multiprocessing.Process(
                 target=Integrator.try_estimate_nd,
-                args=(worker_pipe_end, 5.0, self.trace_enabled)
+                args=(worker_pipe_end, 5.0, self.trace_enabled),
             )
             worker.daemon = True
             worker.start()
@@ -616,8 +634,13 @@ class PLRsearch:
                 )
                 value = math.log(
                     self.find_critical_rate(
-                        trace, fitting_function, min_rate, max_rate,
-                        self.packet_loss_ratio_target, mrr, spread
+                        trace,
+                        fitting_function,
+                        min_rate,
+                        max_rate,
+                        self.packet_loss_ratio_target,
+                        mrr,
+                        spread,
                     )
                 )
                 return value, logweight
@@ -664,9 +687,13 @@ class PLRsearch:
                 raise RuntimeError(f"Worker {name} did not finish!")
             result_or_traceback = pipe.recv()
             try:
-                value_tracker, focus_tracker, debug_list, trace_list, sampls = (
-                    result_or_traceback
-                )
+                (
+                    value_tracker,
+                    focus_tracker,
+                    debug_list,
+                    trace_list,
+                    sampls,
+                ) = result_or_traceback
             except ValueError:
                 raise RuntimeError(
                     f"Worker {name} failed with the following traceback:\n"
@@ -682,8 +709,8 @@ class PLRsearch:
             )
             return _PartialResult(value_tracker, focus_tracker, sampls)
 
-        stretch_result = stop_computing(u"stretch", stretch_pipe)
-        erf_result = stop_computing(u"erf", erf_pipe)
+        stretch_result = stop_computing("stretch", stretch_pipe)
+        erf_result = stop_computing("erf", erf_pipe)
         result = PLRsearch._get_result(measurement, stretch_result, erf_result)
         logging.info(
             f"measure_and_compute finished with trial result "
@@ -730,7 +757,7 @@ class PLRsearch:
 
 # Named tuples, for multiple local variables to be passed as return value.
 _PartialResult = namedtuple(
-    u"_PartialResult", u"value_tracker focus_tracker samples"
+    "_PartialResult", "value_tracker focus_tracker samples"
 )
 """Two stat trackers and sample counter.
 
@@ -743,8 +770,8 @@ _PartialResult = namedtuple(
 """
 
 _ComputeResult = namedtuple(
-    u"_ComputeResult",
-    u"measurement avg stdev stretch_exp_avg erf_exp_avg trackers"
+    "_ComputeResult",
+    "measurement avg stdev stretch_exp_avg erf_exp_avg trackers",
 )
 """Measurement, 4 computation result values, pair of trackers.