Framework: Refactor complex functions in PLRSearch
[csit.git] / resources / libraries / python / PLRsearch / PLRsearch.py
index db870c5..4205818 100644 (file)
@@ -17,17 +17,17 @@ import logging
 import math
 import multiprocessing
 import time
+from collections import namedtuple
 
 import dill
-# TODO: Inform pylint about scipy (of correct version) being available.
 from scipy.special import erfcx, erfc
 
 # TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH,
 # then switch to absolute imports within PLRsearch package.
 # Current usage of relative imports is just a short term workaround.
-import Integrator  # pylint: disable=relative-import
-from log_plus import log_plus, log_minus  # pylint: disable=relative-import
-import stat_trackers  # pylint: disable=relative-import
+from . import Integrator
+from .log_plus import log_plus, log_minus
+from . import stat_trackers
 
 
 class PLRsearch(object):
@@ -461,8 +461,6 @@ class PLRsearch(object):
             trace("log_trial_likelihood", log_trial_likelihood)
         return log_likelihood
 
-    # TODO: Refactor (somehow) so pylint stops complaining about
-    # too many local variables.
     def measure_and_compute(
             self, trial_duration, transmit_rate, trial_result_list,
             min_rate, max_rate, focus_trackers=(None, None), max_samples=None):
@@ -531,6 +529,7 @@ class PLRsearch(object):
             erf_focus_tracker = stat_trackers.VectorStatTracker(dimension)
             erf_focus_tracker.unit_reset()
         old_trackers = stretch_focus_tracker.copy(), erf_focus_tracker.copy()
+
         def start_computing(fitting_function, focus_tracker):
             """Just a block of code to be used for each fitting function.
 
@@ -546,6 +545,7 @@ class PLRsearch(object):
             :returns: Boss end of communication pipe.
             :rtype: multiprocessing.Connection
             """
+
             def value_logweight_func(trace, x_mrr, x_spread):
                 """Return log of critical rate and log of likelihood.
 
@@ -585,6 +585,7 @@ class PLRsearch(object):
                     trace, fitting_function, min_rate, max_rate,
                     self.packet_loss_ratio_target, mrr, spread))
                 return value, logweight
+
             dilled_function = dill.dumps(value_logweight_func)
             boss_pipe_end, worker_pipe_end = multiprocessing.Pipe()
             boss_pipe_end.send(
@@ -595,12 +596,15 @@ class PLRsearch(object):
             worker.daemon = True
             worker.start()
             return boss_pipe_end
+
         erf_pipe = start_computing(
             self.lfit_erf, erf_focus_tracker)
         stretch_pipe = start_computing(
             self.lfit_stretch, stretch_focus_tracker)
+
         # Measurement phase.
         measurement = self.measurer.measure(trial_duration, transmit_rate)
+
         # Processing phase.
         def stop_computing(name, pipe):
             """Just a block of code to be used for each worker.
@@ -637,30 +641,42 @@ class PLRsearch(object):
                 logging.debug(message)
             logging.debug("trackers: value %(val)r focus %(foc)r", {
                 "val": value_tracker, "foc": focus_tracker})
-            return value_tracker, focus_tracker, sampls
-        stretch_value_tracker, stretch_focus_tracker, stretch_samples = (
-            stop_computing("stretch", stretch_pipe))
-        erf_value_tracker, erf_focus_tracker, erf_samples = (
-            stop_computing("erf", erf_pipe))
-        stretch_avg = stretch_value_tracker.average
-        erf_avg = erf_value_tracker.average
-        # TODO: Take into account secondary stats.
-        stretch_stdev = math.exp(stretch_value_tracker.log_variance / 2)
-        erf_stdev = math.exp(erf_value_tracker.log_variance / 2)
-        avg = math.exp((stretch_avg + erf_avg) / 2.0)
-        var = (stretch_stdev * stretch_stdev + erf_stdev * erf_stdev) / 2.0
-        var += (stretch_avg - erf_avg) * (stretch_avg - erf_avg) / 4.0
-        stdev = avg * math.sqrt(var)
-        focus_trackers = (stretch_focus_tracker, erf_focus_tracker)
+            return _PartialResult(value_tracker, focus_tracker, sampls)
+
+        stretch_result = stop_computing("stretch", stretch_pipe)
+        erf_result = stop_computing("erf", erf_pipe)
+        result = PLRsearch._get_result(measurement, stretch_result, erf_result)
         logging.info(
             "measure_and_compute finished with trial result %(res)r "
             "avg %(avg)r stdev %(stdev)r stretch %(a1)r erf %(a2)r "
             "new trackers %(nt)r old trackers %(ot)r stretch samples %(ss)r "
             "erf samples %(es)r",
-            {"res": measurement, "avg": avg, "stdev": stdev,
-             "a1": math.exp(stretch_avg), "a2": math.exp(erf_avg),
-             "nt": focus_trackers, "ot": old_trackers, "ss": stretch_samples,
-             "es": erf_samples})
-        return (
-            measurement, avg, stdev, math.exp(stretch_avg),
-            math.exp(erf_avg), focus_trackers)
+            {"res": result.measurement,
+             "avg": result.avg, "stdev": result.stdev,
+             "a1": result.stretch_exp_avg, "a2": result.erf_exp_avg,
+             "nt": result.trackers, "ot": old_trackers,
+             "ss": stretch_result.samples, "es": erf_result.samples})
+        return result
+
+    @staticmethod
+    def _get_result(measurement, stretch_result, erf_result):
+        """Collate results from measure_and_compute"""
+        stretch_avg = stretch_result.value_tracker.average
+        erf_avg = erf_result.value_tracker.average
+        # TODO: Take into account secondary stats.
+        stretch_stdev = math.exp(stretch_result.value_tracker.log_variance / 2)
+        erf_stdev = math.exp(erf_result.value_tracker.log_variance / 2)
+        avg = math.exp((stretch_avg + erf_avg) / 2.0)
+        var = (stretch_stdev * stretch_stdev + erf_stdev * erf_stdev) / 2.0
+        var += (stretch_avg - erf_avg) * (stretch_avg - erf_avg) / 4.0
+        stdev = avg * math.sqrt(var)
+        trackers = (stretch_result.focus_tracker, erf_result.focus_tracker)
+        sea = math.exp(stretch_avg)
+        eea = math.exp(erf_avg)
+        return _ComputeResult(measurement, avg, stdev, sea, eea, trackers)
+
+
+_PartialResult = namedtuple('_PartialResult',
+                            'value_tracker focus_tracker samples')
+_ComputeResult = namedtuple('_ComputeResult', 'measurement avg stdev ' +
+                            'stretch_exp_avg erf_exp_avg trackers')