Improve PLRsearch yet again
[csit.git] / resources / libraries / python / PLRsearch / stat_trackers.py
index 168b09a..58ad98f 100644 (file)
@@ -29,7 +29,7 @@ import numpy
 # TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH,
 # then switch to absolute imports within PLRsearch package.
 # Current usage of relative imports is just a short term workaround.
-from log_plus import log_plus  # pylint: disable=relative-import
+from .log_plus import log_plus, safe_exp
 
 
 class ScalarStatTracker(object):
@@ -59,7 +59,11 @@ class ScalarStatTracker(object):
         self.log_variance = log_variance
 
     def __repr__(self):
-        """Return string, which interpreted constructs state of self."""
+        """Return string, which interpreted constructs state of self.
+
+        :returns: Expression contructing an equivalent instance.
+        :rtype: str
+        """
         return ("ScalarStatTracker(log_sum_weight={lsw!r},average={a!r},"
                 "log_variance={lv!r})".format(
                     lsw=self.log_sum_weight, a=self.average,
@@ -168,7 +172,11 @@ class ScalarDualStatTracker(ScalarStatTracker):
         self.max_log_weight = max_log_weight
 
     def __repr__(self):
-        """Return string, which interpreted constructs state of self."""
+        """Return string, which interpreted constructs state of self.
+
+        :returns: Expression contructing an equivalent instance.
+        :rtype: str
+        """
         sec = self.secondary
         return (
             "ScalarDualStatTracker(log_sum_weight={lsw!r},average={a!r},"
@@ -202,6 +210,27 @@ class ScalarDualStatTracker(ScalarStatTracker):
         return self
 
 
+    def get_pessimistic_variance(self):
+        """Return estimate of variance reflecting weight effects.
+
+        Typical scenario is the primary tracker dominated by a single sample.
+        In worse case, secondary tracker is also dominated by
+        a single (but different) sample.
+
+        Current implementation simply returns variance of average
+        of the two trackers, as if they were independent.
+
+        :returns: Pessimistic estimate of variance (not stdev, no log).
+        :rtype: float
+        """
+        var_primary = safe_exp(self.log_variance)
+        var_secondary = safe_exp(self.secondary.log_variance)
+        var_combined = (var_primary + var_secondary) / 2
+        avg_half_diff = (self.average - self.secondary.average) / 2
+        var_combined += avg_half_diff * avg_half_diff
+        return var_combined
+
+
 class VectorStatTracker(object):
     """Class for tracking multi-dimensional samples.
 
@@ -245,7 +274,8 @@ class VectorStatTracker(object):
         """Return string, which interpreted constructs state of self.
 
         :returns: Expression contructing an equivalent instance.
-        :rtype: str"""
+        :rtype: str
+        """
         return (
             "VectorStatTracker(dimension={d!r},log_sum_weight={lsw!r},"
             "averages={a!r},covariance_matrix={cm!r})".format(
@@ -262,8 +292,8 @@ class VectorStatTracker(object):
         :rtype: VectorStatTracker
         """
         return VectorStatTracker(
-            self.dimension, self.log_sum_weight, self.averages,
-            self.covariance_matrix)
+            self.dimension, self.log_sum_weight, self.averages[:],
+            copy.deepcopy(self.covariance_matrix))
 
     def reset(self):
         """Return state set to empty data of proper dimensionality.
@@ -288,6 +318,7 @@ class VectorStatTracker(object):
         self.reset()
         for index in range(self.dimension):
             self.covariance_matrix[index][index] = 1.0
+        return self
 
     def add_get_shift(self, vector_value, log_weight=0.0):
         """Return shift and update state to addition of another sample.
@@ -300,8 +331,8 @@ class VectorStatTracker(object):
             Default: 0.0 (as log of 1.0).
         :type vector_value: iterable of float
         :type log_weight: float
-        :returns: Updated self.
-        :rtype: VectorStatTracker
+        :returns: Shift vector
+        :rtype: list of float
         """
         dimension = self.dimension
         old_log_sum_weight = self.log_sum_weight