X-Git-Url: https://gerrit.fd.io/r/gitweb?a=blobdiff_plain;f=resources%2Flibraries%2Fpython%2FPLRsearch%2FIntegrator.py;h=cc8f838fe6ea43baaa9e23d79c6ec1e74b7651aa;hb=HEAD;hp=035afd848caa1f70730d3c4735b1f91e9dc3677f;hpb=281b230ba982f9f6ad589fb6e44f121a6a46531f;p=csit.git diff --git a/resources/libraries/python/PLRsearch/Integrator.py b/resources/libraries/python/PLRsearch/Integrator.py index 035afd848c..cc8f838fe6 100644 --- a/resources/libraries/python/PLRsearch/Integrator.py +++ b/resources/libraries/python/PLRsearch/Integrator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 Cisco and/or its affiliates. +# Copyright (c) 2024 Cisco and/or its affiliates. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at: @@ -23,37 +23,72 @@ import copy import traceback import dill + from numpy import random # TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH, # then switch to absolute imports within PLRsearch package. # Current usage of relative imports is just a short term workaround. -import stat_trackers # pylint: disable=relative-import +from . import stat_trackers def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): - """Call estimate_nd but catch any exception and send traceback.""" + """Call estimate_nd but catch any exception and send traceback. + + This function does not return anything, computation result + is sent via the communication pipe instead. + + TODO: Move scale_coeff to a field of data class + with constructor/factory hiding the default value, + and receive its instance via pipe, instead of argument. + + :param communication_pipe: Endpoint for communication with parent process. + :param scale_coeff: Float number to tweak convergence speed with. + :param trace_enabled: Whether to emit trace level debugs. + Keeping trace disabled improves speed and saves memory. + Enable trace only when debugging the computation itself. + :type communication_pipe: multiprocessing.Connection + :type scale_coeff: float + :type trace_enabled: bool + :raises BaseException: Anything raised by interpreter or estimate_nd. + """ try: - return estimate_nd(communication_pipe, scale_coeff, trace_enabled) + estimate_nd(communication_pipe, scale_coeff, trace_enabled) except BaseException: # Any subclass could have caused estimate_nd to stop before sending, # so we have to catch them all. traceback_string = traceback.format_exc() communication_pipe.send(traceback_string) - # After sendig, re-raise, so usages other than "one process per call" + # After sending, re-raise, so usages other than "one process per call" # keep behaving correctly. raise def generate_sample(averages, covariance_matrix, dimension, scale_coeff): - """Generate next sample for estimate_nd""" + """Generate next sample for estimate_nd. + + Arguments control the multivariate normal "focus". + Keep generating until the sample point fits into unit area. + + :param averages: Coordinates of the focus center. + :param covariance_matrix: Matrix controlling the spread around the average. + :param dimension: If N is dimension, average is N vector and matrix is NxN. + :param scale_coeff: Coefficient to conformally multiply the spread. + :type averages: Indexable of N floats + :type covariance_matrix: Indexable of N indexables of N floats + :type dimension: int + :type scale_coeff: float + :returns: The generated sample point. + :rtype: N-tuple of float + """ covariance_matrix = copy.deepcopy(covariance_matrix) for first in range(dimension): for second in range(dimension): covariance_matrix[first][second] *= scale_coeff while 1: sample_point = random.multivariate_normal( - averages, covariance_matrix, 1)[0].tolist() + averages, covariance_matrix, 1 + )[0].tolist() # Multivariate Gauss can fall outside (-1, 1) interval for first in range(dimension): sample_coordinate = sample_point[first] @@ -142,27 +177,30 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): In they are not enabled, trace_list will be empty. It is recommended to edit some lines manually to debug_list if needed. - :param communication_pipe: Pipe to comunicate with boss process. + :param communication_pipe: Endpoint for communication with parent process. :param scale_coeff: Float number to tweak convergence speed with. :param trace_enabled: Whether trace list should be populated at all. - Default: False - :type communication_pipe: multiprocessing.Connection (or compatible) + :type communication_pipe: multiprocessing.Connection :type scale_coeff: float - :type trace_enabled: boolean + :type trace_enabled: bool :raises OverflowError: If one sample dominates the rest too much. Or if value_logweight_function does not handle some part of parameter space carefully enough. :raises numpy.linalg.LinAlgError: If the focus shape gets singular (due to rounding errors). Try changing scale_coeff. """ - - debug_list = list() - trace_list = list() + debug_list = [] + trace_list = [] # Block until input object appears. - dimension, dilled_function, param_focus_tracker, max_samples = ( - communication_pipe.recv()) - debug_list.append("Called with param_focus_tracker {tracker!r}" - .format(tracker=param_focus_tracker)) + ( + dimension, + dilled_function, + param_focus_tracker, + max_samples, + ) = communication_pipe.recv() + debug_list.append( + f"Called with param_focus_tracker {param_focus_tracker!r}" + ) def trace(name, value): """ @@ -178,7 +216,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): :type value: object """ if trace_enabled: - trace_list.append(name + " " + repr(value)) + trace_list.append(f"{name} {value!r}") value_logweight_function = dill.loads(dilled_function) samples = 0 @@ -201,10 +239,12 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): while not communication_pipe.poll(): if max_samples and samples >= max_samples: break - sample_point = generate_sample(param_focus_tracker.averages, - param_focus_tracker.covariance_matrix, - dimension, - scale_coeff) + sample_point = generate_sample( + param_focus_tracker.averages, + param_focus_tracker.covariance_matrix, + dimension, + scale_coeff, + ) trace("sample_point", sample_point) samples += 1 trace("samples", samples) @@ -214,23 +254,34 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): trace("focus tracker before adding", param_focus_tracker) # Update focus related statistics. param_distance = param_focus_tracker.add_without_dominance_get_distance( - sample_point, log_weight) + sample_point, log_weight + ) # The code above looked at weight (not importance). # The code below looks at importance (not weight). - log_rarity = param_distance / 2.0 + log_rarity = param_distance / 2.0 / scale_coeff trace("log_rarity", log_rarity) log_importance = log_weight + log_rarity trace("log_importance", log_importance) value_tracker.add(value, log_importance) # Update sampled statistics. param_sampled_tracker.add_get_shift(sample_point, log_importance) - debug_list.append("integrator used " + str(samples) + " samples") - debug_list.append(" ".join([ - "value_avg", str(value_tracker.average), - "param_sampled_avg", repr(param_sampled_tracker.averages), - "param_sampled_cov", repr(param_sampled_tracker.covariance_matrix), - "value_log_variance", str(value_tracker.log_variance), - "value_log_secondary_variance", - str(value_tracker.secondary.log_variance)])) + debug_list.append(f"integrator used {samples!s} samples") + debug_list.append( + " ".join( + [ + "value_avg", + str(value_tracker.average), + "param_sampled_avg", + repr(param_sampled_tracker.averages), + "param_sampled_cov", + repr(param_sampled_tracker.covariance_matrix), + "value_log_variance", + str(value_tracker.log_variance), + "value_log_secondary_variance", + str(value_tracker.secondary.log_variance), + ] + ) + ) communication_pipe.send( - (value_tracker, param_focus_tracker, debug_list, trace_list, samples)) + (value_tracker, param_focus_tracker, debug_list, trace_list, samples) + )