X-Git-Url: https://gerrit.fd.io/r/gitweb?p=csit.git;a=blobdiff_plain;f=resources%2Flibraries%2Fpython%2FPLRsearch%2FIntegrator.py;h=035afd848caa1f70730d3c4735b1f91e9dc3677f;hp=82abe5f8a3d9f6f77362e1d73cc1b925d8de2551;hb=51511689eb9f93134878c314ba0349f28ef2ec4f;hpb=d95a4661e5ea956f71e9f3c5aed46491cf20a9a3 diff --git a/resources/libraries/python/PLRsearch/Integrator.py b/resources/libraries/python/PLRsearch/Integrator.py index 82abe5f8a3..035afd848c 100644 --- a/resources/libraries/python/PLRsearch/Integrator.py +++ b/resources/libraries/python/PLRsearch/Integrator.py @@ -45,9 +45,24 @@ def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): raise -# TODO: Pylint reports multiple complexity violations. -# Refactor the code, using less (but structured) variables -# and function calls for (relatively) loosly coupled code blocks. +def generate_sample(averages, covariance_matrix, dimension, scale_coeff): + """Generate next sample for estimate_nd""" + covariance_matrix = copy.deepcopy(covariance_matrix) + for first in range(dimension): + for second in range(dimension): + covariance_matrix[first][second] *= scale_coeff + while 1: + sample_point = random.multivariate_normal( + averages, covariance_matrix, 1)[0].tolist() + # Multivariate Gauss can fall outside (-1, 1) interval + for first in range(dimension): + sample_coordinate = sample_point[first] + if sample_coordinate <= -1.0 or sample_coordinate >= 1.0: + break + else: + return sample_point + + def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): """Use Bayesian inference from control queue, put result to result queue. @@ -148,6 +163,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): communication_pipe.recv()) debug_list.append("Called with param_focus_tracker {tracker!r}" .format(tracker=param_focus_tracker)) + def trace(name, value): """ Add a variable (name and value) to trace list (if enabled). @@ -163,6 +179,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): """ if trace_enabled: trace_list.append(name + " " + repr(value)) + value_logweight_function = dill.loads(dilled_function) samples = 0 # Importance sampling produces samples of higher weight (important) @@ -180,28 +197,14 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): else: # Focus tracker has probably too high weight. param_focus_tracker.log_sum_weight = None - # TODO: Teach pylint the used version of numpy.random does have this member. random.seed(0) while not communication_pipe.poll(): if max_samples and samples >= max_samples: break - # Generate next sample. - averages = param_focus_tracker.averages - covariance_matrix = copy.deepcopy(param_focus_tracker.covariance_matrix) - for first in range(dimension): - for second in range(dimension): - covariance_matrix[first][second] *= scale_coeff - while 1: - # TODO: Teach pylint that numpy.random does also have this member. - sample_point = random.multivariate_normal( - averages, covariance_matrix, 1)[0].tolist() - # Multivariate Gauss can fall outside (-1, 1) interval - for first in range(dimension): - sample_coordinate = sample_point[first] - if sample_coordinate <= -1.0 or sample_coordinate >= 1.0: - break - else: # These two breaks implement "level two continue". - break + sample_point = generate_sample(param_focus_tracker.averages, + param_focus_tracker.covariance_matrix, + dimension, + scale_coeff) trace("sample_point", sample_point) samples += 1 trace("samples", samples)