Python3: resources and libraries
[csit.git] / resources / libraries / python / PLRsearch / Integrator.py
index 82abe5f..331bd84 100644 (file)
@@ -23,31 +23,81 @@ import copy
 import traceback
 
 import dill
+
 from numpy import random
 
 # TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH,
 # then switch to absolute imports within PLRsearch package.
 # Current usage of relative imports is just a short term workaround.
-import stat_trackers  # pylint: disable=relative-import
+from . import stat_trackers
 
 
 def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
-    """Call estimate_nd but catch any exception and send traceback."""
+    """Call estimate_nd but catch any exception and send traceback.
+
+    This function does not return anything, computation result
+    is sent via the communication pipe instead.
+
+    TODO: Move scale_coeff to a field of data class
+    with constructor/factory hiding the default value,
+    and receive its instance via pipe, instead of argument.
+
+    :param communication_pipe: Endpoint for communication with parent process.
+    :param scale_coeff: Float number to tweak convergence speed with.
+    :param trace_enabled: Whether to emit trace level debugs.
+        Keeping trace disabled improves speed and saves memory.
+        Enable trace only when debugging the computation itself.
+    :type communication_pipe: multiprocessing.Connection
+    :type scale_coeff: float
+    :type trace_enabled: bool
+    :raises BaseException: Anything raised by interpreter or estimate_nd.
+    """
     try:
-        return estimate_nd(communication_pipe, scale_coeff, trace_enabled)
+        estimate_nd(communication_pipe, scale_coeff, trace_enabled)
     except BaseException:
         # Any subclass could have caused estimate_nd to stop before sending,
         # so we have to catch them all.
         traceback_string = traceback.format_exc()
         communication_pipe.send(traceback_string)
-        # After sendig, re-raise, so usages other than "one process per call"
+        # After sending, re-raise, so usages other than "one process per call"
         # keep behaving correctly.
         raise
 
 
-# TODO: Pylint reports multiple complexity violations.
-# Refactor the code, using less (but structured) variables
-# and function calls for (relatively) loosly coupled code blocks.
+def generate_sample(averages, covariance_matrix, dimension, scale_coeff):
+    """Generate next sample for estimate_nd.
+
+    Arguments control the multivariate normal "focus".
+    Keep generating until the sample point fits into unit area.
+
+    :param averages: Coordinates of the focus center.
+    :param covariance_matrix: Matrix controlling the spread around the average.
+    :param dimension: If N is dimension, average is N vector and matrix is NxN.
+    :param scale_coeff: Coefficient to conformally multiply the spread.
+    :type averages: Indexable of N floats
+    :type covariance_matrix: Indexable of N indexables of N floats
+    :type dimension: int
+    :type scale_coeff: float
+    :returns: The generated sample point.
+    :rtype: N-tuple of float
+    """
+    covariance_matrix = copy.deepcopy(covariance_matrix)
+    for first in range(dimension):
+        for second in range(dimension):
+            covariance_matrix[first][second] *= scale_coeff
+    while 1:
+        sample_point = random.multivariate_normal(
+            averages, covariance_matrix, 1
+        )[0].tolist()
+        # Multivariate Gauss can fall outside (-1, 1) interval
+        for first in range(dimension):
+            sample_coordinate = sample_point[first]
+            if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
+                break
+        else:
+            return sample_point
+
+
 def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
     """Use Bayesian inference from control queue, put result to result queue.
 
@@ -127,27 +177,28 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
     In they are not enabled, trace_list will be empty.
     It is recommended to edit some lines manually to debug_list if needed.
 
-    :param communication_pipe: Pipe to comunicate with boss process.
+    :param communication_pipe: Endpoint for communication with parent process.
     :param scale_coeff: Float number to tweak convergence speed with.
     :param trace_enabled: Whether trace list should be populated at all.
-        Default: False
-    :type communication_pipe: multiprocessing.Connection (or compatible)
+    :type communication_pipe: multiprocessing.Connection
     :type scale_coeff: float
-    :type trace_enabled: boolean
+    :type trace_enabled: bool
     :raises OverflowError: If one sample dominates the rest too much.
         Or if value_logweight_function does not handle
         some part of parameter space carefully enough.
     :raises numpy.linalg.LinAlgError: If the focus shape gets singular
         (due to rounding errors). Try changing scale_coeff.
     """
-
     debug_list = list()
     trace_list = list()
     # Block until input object appears.
     dimension, dilled_function, param_focus_tracker, max_samples = (
-        communication_pipe.recv())
-    debug_list.append("Called with param_focus_tracker {tracker!r}"
-                      .format(tracker=param_focus_tracker))
+        communication_pipe.recv()
+    )
+    debug_list.append(
+        f"Called with param_focus_tracker {param_focus_tracker!r}"
+    )
+
     def trace(name, value):
         """
         Add a variable (name and value) to trace list (if enabled).
@@ -162,7 +213,8 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
         :type value: object
         """
         if trace_enabled:
-            trace_list.append(name + " " + repr(value))
+            trace_list.append(f"{name} {value!r}")
+
     value_logweight_function = dill.loads(dilled_function)
     samples = 0
     # Importance sampling produces samples of higher weight (important)
@@ -180,54 +232,45 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
     else:
         # Focus tracker has probably too high weight.
         param_focus_tracker.log_sum_weight = None
-    # TODO: Teach pylint the used version of numpy.random does have this member.
     random.seed(0)
     while not communication_pipe.poll():
         if max_samples and samples >= max_samples:
             break
-        # Generate next sample.
-        averages = param_focus_tracker.averages
-        covariance_matrix = copy.deepcopy(param_focus_tracker.covariance_matrix)
-        for first in range(dimension):
-            for second in range(dimension):
-                covariance_matrix[first][second] *= scale_coeff
-        while 1:
-            # TODO: Teach pylint that numpy.random does also have this member.
-            sample_point = random.multivariate_normal(
-                averages, covariance_matrix, 1)[0].tolist()
-            # Multivariate Gauss can fall outside (-1, 1) interval
-            for first in range(dimension):
-                sample_coordinate = sample_point[first]
-                if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
-                    break
-            else:  # These two breaks implement "level two continue".
-                break
-        trace("sample_point", sample_point)
+        sample_point = generate_sample(
+            param_focus_tracker.averages, param_focus_tracker.covariance_matrix,
+            dimension, scale_coeff
+        )
+        trace(u"sample_point", sample_point)
         samples += 1
-        trace("samples", samples)
+        trace(u"samples", samples)
         value, log_weight = value_logweight_function(trace, *sample_point)
-        trace("value", value)
-        trace("log_weight", log_weight)
-        trace("focus tracker before adding", param_focus_tracker)
+        trace(u"value", value)
+        trace(u"log_weight", log_weight)
+        trace(u"focus tracker before adding", param_focus_tracker)
         # Update focus related statistics.
         param_distance = param_focus_tracker.add_without_dominance_get_distance(
-            sample_point, log_weight)
+            sample_point, log_weight
+        )
         # The code above looked at weight (not importance).
         # The code below looks at importance (not weight).
         log_rarity = param_distance / 2.0
-        trace("log_rarity", log_rarity)
+        trace(u"log_rarity", log_rarity)
         log_importance = log_weight + log_rarity
-        trace("log_importance", log_importance)
+        trace(u"log_importance", log_importance)
         value_tracker.add(value, log_importance)
         # Update sampled statistics.
         param_sampled_tracker.add_get_shift(sample_point, log_importance)
-    debug_list.append("integrator used " + str(samples) + " samples")
-    debug_list.append(" ".join([
-        "value_avg", str(value_tracker.average),
-        "param_sampled_avg", repr(param_sampled_tracker.averages),
-        "param_sampled_cov", repr(param_sampled_tracker.covariance_matrix),
-        "value_log_variance", str(value_tracker.log_variance),
-        "value_log_secondary_variance",
-        str(value_tracker.secondary.log_variance)]))
+    debug_list.append(f"integrator used {samples!s} samples")
+    debug_list.append(
+        u" ".join([
+            u"value_avg", str(value_tracker.average),
+            u"param_sampled_avg", repr(param_sampled_tracker.averages),
+            u"param_sampled_cov", repr(param_sampled_tracker.covariance_matrix),
+            u"value_log_variance", str(value_tracker.log_variance),
+            u"value_log_secondary_variance",
+            str(value_tracker.secondary.log_variance)
+        ])
+    )
     communication_pipe.send(
-        (value_tracker, param_focus_tracker, debug_list, trace_list, samples))
+        (value_tracker, param_focus_tracker, debug_list, trace_list, samples)
+    )