Framework: Refactor complex functions in PLRSearch 58/21358/2
authorMiroslav Los <miroslav.los@pantheon.tech>
Fri, 16 Aug 2019 13:09:39 +0000 (15:09 +0200)
committerVratko Polak <vrpolak@cisco.com>
Mon, 19 Aug 2019 08:41:05 +0000 (08:41 +0000)
Signed-off-by: Miroslav Los <miroslav.los@pantheon.tech>
Change-Id: Ie2f19a2e3b37e8d85656ab31ece59b89c76bea25

resources/libraries/python/PLRsearch/Integrator.py
resources/libraries/python/PLRsearch/PLRsearch.py

index 82abe5f..035afd8 100644 (file)
@@ -45,9 +45,24 @@ def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
         raise
 
 
-# TODO: Pylint reports multiple complexity violations.
-# Refactor the code, using less (but structured) variables
-# and function calls for (relatively) loosly coupled code blocks.
+def generate_sample(averages, covariance_matrix, dimension, scale_coeff):
+    """Generate next sample for estimate_nd"""
+    covariance_matrix = copy.deepcopy(covariance_matrix)
+    for first in range(dimension):
+        for second in range(dimension):
+            covariance_matrix[first][second] *= scale_coeff
+    while 1:
+        sample_point = random.multivariate_normal(
+            averages, covariance_matrix, 1)[0].tolist()
+        # Multivariate Gauss can fall outside (-1, 1) interval
+        for first in range(dimension):
+            sample_coordinate = sample_point[first]
+            if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
+                break
+        else:
+            return sample_point
+
+
 def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
     """Use Bayesian inference from control queue, put result to result queue.
 
@@ -148,6 +163,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
         communication_pipe.recv())
     debug_list.append("Called with param_focus_tracker {tracker!r}"
                       .format(tracker=param_focus_tracker))
+
     def trace(name, value):
         """
         Add a variable (name and value) to trace list (if enabled).
@@ -163,6 +179,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
         """
         if trace_enabled:
             trace_list.append(name + " " + repr(value))
+
     value_logweight_function = dill.loads(dilled_function)
     samples = 0
     # Importance sampling produces samples of higher weight (important)
@@ -180,28 +197,14 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
     else:
         # Focus tracker has probably too high weight.
         param_focus_tracker.log_sum_weight = None
-    # TODO: Teach pylint the used version of numpy.random does have this member.
     random.seed(0)
     while not communication_pipe.poll():
         if max_samples and samples >= max_samples:
             break
-        # Generate next sample.
-        averages = param_focus_tracker.averages
-        covariance_matrix = copy.deepcopy(param_focus_tracker.covariance_matrix)
-        for first in range(dimension):
-            for second in range(dimension):
-                covariance_matrix[first][second] *= scale_coeff
-        while 1:
-            # TODO: Teach pylint that numpy.random does also have this member.
-            sample_point = random.multivariate_normal(
-                averages, covariance_matrix, 1)[0].tolist()
-            # Multivariate Gauss can fall outside (-1, 1) interval
-            for first in range(dimension):
-                sample_coordinate = sample_point[first]
-                if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
-                    break
-            else:  # These two breaks implement "level two continue".
-                break
+        sample_point = generate_sample(param_focus_tracker.averages,
+                                       param_focus_tracker.covariance_matrix,
+                                       dimension,
+                                       scale_coeff)
         trace("sample_point", sample_point)
         samples += 1
         trace("samples", samples)
index db870c5..4205818 100644 (file)
@@ -17,17 +17,17 @@ import logging
 import math
 import multiprocessing
 import time
+from collections import namedtuple
 
 import dill
-# TODO: Inform pylint about scipy (of correct version) being available.
 from scipy.special import erfcx, erfc
 
 # TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH,
 # then switch to absolute imports within PLRsearch package.
 # Current usage of relative imports is just a short term workaround.
-import Integrator  # pylint: disable=relative-import
-from log_plus import log_plus, log_minus  # pylint: disable=relative-import
-import stat_trackers  # pylint: disable=relative-import
+from . import Integrator
+from .log_plus import log_plus, log_minus
+from . import stat_trackers
 
 
 class PLRsearch(object):
@@ -461,8 +461,6 @@ class PLRsearch(object):
             trace("log_trial_likelihood", log_trial_likelihood)
         return log_likelihood
 
-    # TODO: Refactor (somehow) so pylint stops complaining about
-    # too many local variables.
     def measure_and_compute(
             self, trial_duration, transmit_rate, trial_result_list,
             min_rate, max_rate, focus_trackers=(None, None), max_samples=None):
@@ -531,6 +529,7 @@ class PLRsearch(object):
             erf_focus_tracker = stat_trackers.VectorStatTracker(dimension)
             erf_focus_tracker.unit_reset()
         old_trackers = stretch_focus_tracker.copy(), erf_focus_tracker.copy()
+
         def start_computing(fitting_function, focus_tracker):
             """Just a block of code to be used for each fitting function.
 
@@ -546,6 +545,7 @@ class PLRsearch(object):
             :returns: Boss end of communication pipe.
             :rtype: multiprocessing.Connection
             """
+
             def value_logweight_func(trace, x_mrr, x_spread):
                 """Return log of critical rate and log of likelihood.
 
@@ -585,6 +585,7 @@ class PLRsearch(object):
                     trace, fitting_function, min_rate, max_rate,
                     self.packet_loss_ratio_target, mrr, spread))
                 return value, logweight
+
             dilled_function = dill.dumps(value_logweight_func)
             boss_pipe_end, worker_pipe_end = multiprocessing.Pipe()
             boss_pipe_end.send(
@@ -595,12 +596,15 @@ class PLRsearch(object):
             worker.daemon = True
             worker.start()
             return boss_pipe_end
+
         erf_pipe = start_computing(
             self.lfit_erf, erf_focus_tracker)
         stretch_pipe = start_computing(
             self.lfit_stretch, stretch_focus_tracker)
+
         # Measurement phase.
         measurement = self.measurer.measure(trial_duration, transmit_rate)
+
         # Processing phase.
         def stop_computing(name, pipe):
             """Just a block of code to be used for each worker.
@@ -637,30 +641,42 @@ class PLRsearch(object):
                 logging.debug(message)
             logging.debug("trackers: value %(val)r focus %(foc)r", {
                 "val": value_tracker, "foc": focus_tracker})
-            return value_tracker, focus_tracker, sampls
-        stretch_value_tracker, stretch_focus_tracker, stretch_samples = (
-            stop_computing("stretch", stretch_pipe))
-        erf_value_tracker, erf_focus_tracker, erf_samples = (
-            stop_computing("erf", erf_pipe))
-        stretch_avg = stretch_value_tracker.average
-        erf_avg = erf_value_tracker.average
-        # TODO: Take into account secondary stats.
-        stretch_stdev = math.exp(stretch_value_tracker.log_variance / 2)
-        erf_stdev = math.exp(erf_value_tracker.log_variance / 2)
-        avg = math.exp((stretch_avg + erf_avg) / 2.0)
-        var = (stretch_stdev * stretch_stdev + erf_stdev * erf_stdev) / 2.0
-        var += (stretch_avg - erf_avg) * (stretch_avg - erf_avg) / 4.0
-        stdev = avg * math.sqrt(var)
-        focus_trackers = (stretch_focus_tracker, erf_focus_tracker)
+            return _PartialResult(value_tracker, focus_tracker, sampls)
+
+        stretch_result = stop_computing("stretch", stretch_pipe)
+        erf_result = stop_computing("erf", erf_pipe)
+        result = PLRsearch._get_result(measurement, stretch_result, erf_result)
         logging.info(
             "measure_and_compute finished with trial result %(res)r "
             "avg %(avg)r stdev %(stdev)r stretch %(a1)r erf %(a2)r "
             "new trackers %(nt)r old trackers %(ot)r stretch samples %(ss)r "
             "erf samples %(es)r",
-            {"res": measurement, "avg": avg, "stdev": stdev,
-             "a1": math.exp(stretch_avg), "a2": math.exp(erf_avg),
-             "nt": focus_trackers, "ot": old_trackers, "ss": stretch_samples,
-             "es": erf_samples})
-        return (
-            measurement, avg, stdev, math.exp(stretch_avg),
-            math.exp(erf_avg), focus_trackers)
+            {"res": result.measurement,
+             "avg": result.avg, "stdev": result.stdev,
+             "a1": result.stretch_exp_avg, "a2": result.erf_exp_avg,
+             "nt": result.trackers, "ot": old_trackers,
+             "ss": stretch_result.samples, "es": erf_result.samples})
+        return result
+
+    @staticmethod
+    def _get_result(measurement, stretch_result, erf_result):
+        """Collate results from measure_and_compute"""
+        stretch_avg = stretch_result.value_tracker.average
+        erf_avg = erf_result.value_tracker.average
+        # TODO: Take into account secondary stats.
+        stretch_stdev = math.exp(stretch_result.value_tracker.log_variance / 2)
+        erf_stdev = math.exp(erf_result.value_tracker.log_variance / 2)
+        avg = math.exp((stretch_avg + erf_avg) / 2.0)
+        var = (stretch_stdev * stretch_stdev + erf_stdev * erf_stdev) / 2.0
+        var += (stretch_avg - erf_avg) * (stretch_avg - erf_avg) / 4.0
+        stdev = avg * math.sqrt(var)
+        trackers = (stretch_result.focus_tracker, erf_result.focus_tracker)
+        sea = math.exp(stretch_avg)
+        eea = math.exp(erf_avg)
+        return _ComputeResult(measurement, avg, stdev, sea, eea, trackers)
+
+
+_PartialResult = namedtuple('_PartialResult',
+                            'value_tracker focus_tracker samples')
+_ComputeResult = namedtuple('_ComputeResult', 'measurement avg stdev ' +
+                            'stretch_exp_avg erf_exp_avg trackers')