-# Copyright (c) 2019 Cisco and/or its affiliates.
+# Copyright (c) 2021 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
import traceback
import dill
+
from numpy import random
# TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH,
# then switch to absolute imports within PLRsearch package.
# Current usage of relative imports is just a short term workaround.
-import stat_trackers # pylint: disable=relative-import
+from . import stat_trackers
def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
- """Call estimate_nd but catch any exception and send traceback."""
+ """Call estimate_nd but catch any exception and send traceback.
+
+ This function does not return anything, computation result
+ is sent via the communication pipe instead.
+
+ TODO: Move scale_coeff to a field of data class
+ with constructor/factory hiding the default value,
+ and receive its instance via pipe, instead of argument.
+
+ :param communication_pipe: Endpoint for communication with parent process.
+ :param scale_coeff: Float number to tweak convergence speed with.
+ :param trace_enabled: Whether to emit trace level debugs.
+ Keeping trace disabled improves speed and saves memory.
+ Enable trace only when debugging the computation itself.
+ :type communication_pipe: multiprocessing.Connection
+ :type scale_coeff: float
+ :type trace_enabled: bool
+ :raises BaseException: Anything raised by interpreter or estimate_nd.
+ """
try:
- return estimate_nd(communication_pipe, scale_coeff, trace_enabled)
+ estimate_nd(communication_pipe, scale_coeff, trace_enabled)
except BaseException:
# Any subclass could have caused estimate_nd to stop before sending,
# so we have to catch them all.
traceback_string = traceback.format_exc()
communication_pipe.send(traceback_string)
- # After sendig, re-raise, so usages other than "one process per call"
+ # After sending, re-raise, so usages other than "one process per call"
# keep behaving correctly.
raise
-# TODO: Pylint reports multiple complexity violations.
-# Refactor the code, using less (but structured) variables
-# and function calls for (relatively) loosly coupled code blocks.
+def generate_sample(averages, covariance_matrix, dimension, scale_coeff):
+ """Generate next sample for estimate_nd.
+
+ Arguments control the multivariate normal "focus".
+ Keep generating until the sample point fits into unit area.
+
+ :param averages: Coordinates of the focus center.
+ :param covariance_matrix: Matrix controlling the spread around the average.
+ :param dimension: If N is dimension, average is N vector and matrix is NxN.
+ :param scale_coeff: Coefficient to conformally multiply the spread.
+ :type averages: Indexable of N floats
+ :type covariance_matrix: Indexable of N indexables of N floats
+ :type dimension: int
+ :type scale_coeff: float
+ :returns: The generated sample point.
+ :rtype: N-tuple of float
+ """
+ covariance_matrix = copy.deepcopy(covariance_matrix)
+ for first in range(dimension):
+ for second in range(dimension):
+ covariance_matrix[first][second] *= scale_coeff
+ while 1:
+ sample_point = random.multivariate_normal(
+ averages, covariance_matrix, 1
+ )[0].tolist()
+ # Multivariate Gauss can fall outside (-1, 1) interval
+ for first in range(dimension):
+ sample_coordinate = sample_point[first]
+ if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
+ break
+ else:
+ return sample_point
+
+
def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
"""Use Bayesian inference from control queue, put result to result queue.
In they are not enabled, trace_list will be empty.
It is recommended to edit some lines manually to debug_list if needed.
- :param communication_pipe: Pipe to comunicate with boss process.
+ :param communication_pipe: Endpoint for communication with parent process.
:param scale_coeff: Float number to tweak convergence speed with.
:param trace_enabled: Whether trace list should be populated at all.
- Default: False
- :type communication_pipe: multiprocessing.Connection (or compatible)
+ :type communication_pipe: multiprocessing.Connection
:type scale_coeff: float
- :type trace_enabled: boolean
+ :type trace_enabled: bool
:raises OverflowError: If one sample dominates the rest too much.
Or if value_logweight_function does not handle
some part of parameter space carefully enough.
:raises numpy.linalg.LinAlgError: If the focus shape gets singular
(due to rounding errors). Try changing scale_coeff.
"""
-
debug_list = list()
trace_list = list()
# Block until input object appears.
dimension, dilled_function, param_focus_tracker, max_samples = (
- communication_pipe.recv())
- debug_list.append("Called with param_focus_tracker {tracker!r}"
- .format(tracker=param_focus_tracker))
+ communication_pipe.recv()
+ )
+ debug_list.append(
+ f"Called with param_focus_tracker {param_focus_tracker!r}"
+ )
+
def trace(name, value):
"""
Add a variable (name and value) to trace list (if enabled).
:type value: object
"""
if trace_enabled:
- trace_list.append(name + " " + repr(value))
+ trace_list.append(f"{name} {value!r}")
+
value_logweight_function = dill.loads(dilled_function)
samples = 0
# Importance sampling produces samples of higher weight (important)
else:
# Focus tracker has probably too high weight.
param_focus_tracker.log_sum_weight = None
- # TODO: Teach pylint the used version of numpy.random does have this member.
random.seed(0)
while not communication_pipe.poll():
if max_samples and samples >= max_samples:
break
- # Generate next sample.
- averages = param_focus_tracker.averages
- covariance_matrix = copy.deepcopy(param_focus_tracker.covariance_matrix)
- for first in range(dimension):
- for second in range(dimension):
- covariance_matrix[first][second] *= scale_coeff
- while 1:
- # TODO: Teach pylint that numpy.random does also have this member.
- sample_point = random.multivariate_normal(
- averages, covariance_matrix, 1)[0].tolist()
- # Multivariate Gauss can fall outside (-1, 1) interval
- for first in range(dimension):
- sample_coordinate = sample_point[first]
- if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
- break
- else: # These two breaks implement "level two continue".
- break
- trace("sample_point", sample_point)
+ sample_point = generate_sample(
+ param_focus_tracker.averages, param_focus_tracker.covariance_matrix,
+ dimension, scale_coeff
+ )
+ trace(u"sample_point", sample_point)
samples += 1
- trace("samples", samples)
+ trace(u"samples", samples)
value, log_weight = value_logweight_function(trace, *sample_point)
- trace("value", value)
- trace("log_weight", log_weight)
- trace("focus tracker before adding", param_focus_tracker)
+ trace(u"value", value)
+ trace(u"log_weight", log_weight)
+ trace(u"focus tracker before adding", param_focus_tracker)
# Update focus related statistics.
param_distance = param_focus_tracker.add_without_dominance_get_distance(
- sample_point, log_weight)
+ sample_point, log_weight
+ )
# The code above looked at weight (not importance).
# The code below looks at importance (not weight).
log_rarity = param_distance / 2.0
- trace("log_rarity", log_rarity)
+ trace(u"log_rarity", log_rarity)
log_importance = log_weight + log_rarity
- trace("log_importance", log_importance)
+ trace(u"log_importance", log_importance)
value_tracker.add(value, log_importance)
# Update sampled statistics.
param_sampled_tracker.add_get_shift(sample_point, log_importance)
- debug_list.append("integrator used " + str(samples) + " samples")
- debug_list.append(" ".join([
- "value_avg", str(value_tracker.average),
- "param_sampled_avg", repr(param_sampled_tracker.averages),
- "param_sampled_cov", repr(param_sampled_tracker.covariance_matrix),
- "value_log_variance", str(value_tracker.log_variance),
- "value_log_secondary_variance",
- str(value_tracker.secondary.log_variance)]))
+ debug_list.append(f"integrator used {samples!s} samples")
+ debug_list.append(
+ u" ".join([
+ u"value_avg", str(value_tracker.average),
+ u"param_sampled_avg", repr(param_sampled_tracker.averages),
+ u"param_sampled_cov", repr(param_sampled_tracker.covariance_matrix),
+ u"value_log_variance", str(value_tracker.log_variance),
+ u"value_log_secondary_variance",
+ str(value_tracker.secondary.log_variance)
+ ])
+ )
communication_pipe.send(
- (value_tracker, param_focus_tracker, debug_list, trace_list, samples))
+ (value_tracker, param_focus_tracker, debug_list, trace_list, samples)
+ )