Revert "fix(IPsecUtil): Delete keywords no longer used"
[csit.git] / resources / libraries / python / PLRsearch / Integrator.py
1 # Copyright (c) 2024 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Module for numerical integration, tightly coupled to PLRsearch algorithm.
15
16 See log_plus for an explanation why None acts as a special case "float" number.
17
18 TODO: Separate optimizations specific to PLRsearch and distribute the rest
19       as a standalone package so other projects may reuse.
20 """
21
22 import copy
23 import traceback
24
25 import dill
26
27 from numpy import random
28
29 # TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH,
30 # then switch to absolute imports within PLRsearch package.
31 # Current usage of relative imports is just a short term workaround.
32 from . import stat_trackers
33
34
35 def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
36     """Call estimate_nd but catch any exception and send traceback.
37
38     This function does not return anything, computation result
39     is sent via the communication pipe instead.
40
41     TODO: Move scale_coeff to a field of data class
42     with constructor/factory hiding the default value,
43     and receive its instance via pipe, instead of argument.
44
45     :param communication_pipe: Endpoint for communication with parent process.
46     :param scale_coeff: Float number to tweak convergence speed with.
47     :param trace_enabled: Whether to emit trace level debugs.
48         Keeping trace disabled improves speed and saves memory.
49         Enable trace only when debugging the computation itself.
50     :type communication_pipe: multiprocessing.Connection
51     :type scale_coeff: float
52     :type trace_enabled: bool
53     :raises BaseException: Anything raised by interpreter or estimate_nd.
54     """
55     try:
56         estimate_nd(communication_pipe, scale_coeff, trace_enabled)
57     except BaseException:
58         # Any subclass could have caused estimate_nd to stop before sending,
59         # so we have to catch them all.
60         traceback_string = traceback.format_exc()
61         communication_pipe.send(traceback_string)
62         # After sending, re-raise, so usages other than "one process per call"
63         # keep behaving correctly.
64         raise
65
66
67 def generate_sample(averages, covariance_matrix, dimension, scale_coeff):
68     """Generate next sample for estimate_nd.
69
70     Arguments control the multivariate normal "focus".
71     Keep generating until the sample point fits into unit area.
72
73     :param averages: Coordinates of the focus center.
74     :param covariance_matrix: Matrix controlling the spread around the average.
75     :param dimension: If N is dimension, average is N vector and matrix is NxN.
76     :param scale_coeff: Coefficient to conformally multiply the spread.
77     :type averages: Indexable of N floats
78     :type covariance_matrix: Indexable of N indexables of N floats
79     :type dimension: int
80     :type scale_coeff: float
81     :returns: The generated sample point.
82     :rtype: N-tuple of float
83     """
84     covariance_matrix = copy.deepcopy(covariance_matrix)
85     for first in range(dimension):
86         for second in range(dimension):
87             covariance_matrix[first][second] *= scale_coeff
88     while 1:
89         sample_point = random.multivariate_normal(
90             averages, covariance_matrix, 1
91         )[0].tolist()
92         # Multivariate Gauss can fall outside (-1, 1) interval
93         for first in range(dimension):
94             sample_coordinate = sample_point[first]
95             if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
96                 break
97         else:
98             return sample_point
99
100
101 def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
102     """Use Bayesian inference from control queue, put result to result queue.
103
104     TODO: Use a logging framework that works in a user friendly way.
105     (Note that multiprocessing_logging does not work well with robot
106     and robotbackgroundlogger only works for threads, not processes.
107     Or, wait for https://github.com/robotframework/robotframework/pull/2182
108     Anyway, the current implementation with trace_enabled looks ugly.)
109
110     The result is average and standard deviation for posterior distribution
111     of a single dependent (scalar, float) value.
112     The prior is assumed to be uniform on (-1, 1) for every parameter.
113     Number of parameters and the function for computing
114     the dependent value and likelihood both come from input.
115
116     The likelihood is assumed to be extremely uneven (but never zero),
117     so the function should return the logarithm of the likelihood.
118     The integration method is basically a Monte Carlo
119     (TODO: Add links to notions used here.),
120     but importance sampling is used in order to focus
121     on the part of parameter space with (relatively) non-negligible likelihood.
122
123     Multivariate Gauss distribution is used for focusing,
124     so only unimodal posterior distributions are handled correctly.
125     Initial samples are mostly used for shaping (and shifting)
126     the Gaussian distribution, later samples will probably dominate.
127     Thus, initially the algorithm behavior resembles more "find the maximum",
128     as opposed to "reliably integrate". As for later iterations of PLRsearch,
129     it is assumed that the distribution position does not change rapidly;
130     thus integration algorithm returns also the distribution data,
131     to be used as initial focus in next iteration.
132
133     There are workarounds in place that allow old or default focus tracker
134     to be updated reasonably, even when initial samples
135     of new iteration have way smaller (or larger) weights.
136
137     During the "find the maximum" phase, the focus tracker frequently takes
138     a wrong shape (compared to observed samples in equilibrium).
139     Therefore scale_coeff argument is left for humans to tweak,
140     so the convergence is reliable and quick.
141
142     Until the distribution locates itself roughly around
143     the maximum likeligood point, the integration results are probably wrong.
144     That means some minimal time is needed for the result to become reliable.
145
146     TODO: The folowing is not currently implemented.
147     The reported standard distribution attempts to signal inconsistence
148     (when one sample has dominating weight compared to the rest of samples),
149     but some human supervision is strongly encouraged.
150
151     To facilitate running in worker processes, arguments and results
152     are communicated via a pipe. The computation does not start
153     until arguments appear in the pipe, the computation stops
154     when another item (stop object) is detected in the pipe
155     (and result is put to pipe).
156
157     TODO: Create classes for arguments and results,
158           so their fields are documented (and code perhaps more readable).
159
160     Input/argument object (received from pipe)
161     is a 4-tuple of the following fields:
162     - dimension: Integer, number of parameters to consider.
163     - dilled_function: Function (serialized using dill), which:
164     - - Takes the dimension number of float parameters from (-1, 1).
165     - - Returns float 2-tuple of dependent value and parameter log-likelihood.
166     - param_focus_tracker: VectorStatTracker to use for initial focus.
167     - max_samples: None or a limit for samples to use.
168
169     Output/result object (sent to pipe queue)
170     is a 5-tuple of the following fields:
171     - value_tracker: ScalarDualStatTracker estimate of value posterior.
172     - param_focus_tracker: VectorStatTracker to use for initial focus next.
173     - debug_list: List of debug strings to log at main process.
174     - trace_list: List of trace strings to pass to main process if enabled.
175     - samples: Number of samples used in computation (to make it reproducible).
176     Trace strings are very verbose, it is not recommended to enable them.
177     In they are not enabled, trace_list will be empty.
178     It is recommended to edit some lines manually to debug_list if needed.
179
180     :param communication_pipe: Endpoint for communication with parent process.
181     :param scale_coeff: Float number to tweak convergence speed with.
182     :param trace_enabled: Whether trace list should be populated at all.
183     :type communication_pipe: multiprocessing.Connection
184     :type scale_coeff: float
185     :type trace_enabled: bool
186     :raises OverflowError: If one sample dominates the rest too much.
187         Or if value_logweight_function does not handle
188         some part of parameter space carefully enough.
189     :raises numpy.linalg.LinAlgError: If the focus shape gets singular
190         (due to rounding errors). Try changing scale_coeff.
191     """
192     debug_list = []
193     trace_list = []
194     # Block until input object appears.
195     (
196         dimension,
197         dilled_function,
198         param_focus_tracker,
199         max_samples,
200     ) = communication_pipe.recv()
201     debug_list.append(
202         f"Called with param_focus_tracker {param_focus_tracker!r}"
203     )
204
205     def trace(name, value):
206         """
207         Add a variable (name and value) to trace list (if enabled).
208
209         This is a closure (not a pure function),
210         as it accesses trace_list and trace_enabled
211         (without any of them being an explicit argument).
212
213         :param name: Any string identifying the value.
214         :param value: Any object to log repr of.
215         :type name: str
216         :type value: object
217         """
218         if trace_enabled:
219             trace_list.append(f"{name} {value!r}")
220
221     value_logweight_function = dill.loads(dilled_function)
222     samples = 0
223     # Importance sampling produces samples of higher weight (important)
224     # more frequently, and corrects that by adding weight bonus
225     # for the less frequently (unimportant) samples.
226     # But "corrected_weight" is too close to "weight" to be readable,
227     # so "importance" is used instead, even if it runs contrary to what
228     # important region is.
229     value_tracker = stat_trackers.ScalarDualStatTracker()
230     param_sampled_tracker = stat_trackers.VectorStatTracker(dimension).reset()
231     if not param_focus_tracker:
232         # First call has None instead of a real (even empty) tracker.
233         param_focus_tracker = stat_trackers.VectorStatTracker(dimension)
234         param_focus_tracker.unit_reset()
235     else:
236         # Focus tracker has probably too high weight.
237         param_focus_tracker.log_sum_weight = None
238     random.seed(0)
239     while not communication_pipe.poll():
240         if max_samples and samples >= max_samples:
241             break
242         sample_point = generate_sample(
243             param_focus_tracker.averages,
244             param_focus_tracker.covariance_matrix,
245             dimension,
246             scale_coeff,
247         )
248         trace("sample_point", sample_point)
249         samples += 1
250         trace("samples", samples)
251         value, log_weight = value_logweight_function(trace, *sample_point)
252         trace("value", value)
253         trace("log_weight", log_weight)
254         trace("focus tracker before adding", param_focus_tracker)
255         # Update focus related statistics.
256         param_distance = param_focus_tracker.add_without_dominance_get_distance(
257             sample_point, log_weight
258         )
259         # The code above looked at weight (not importance).
260         # The code below looks at importance (not weight).
261         log_rarity = param_distance / 2.0 / scale_coeff
262         trace("log_rarity", log_rarity)
263         log_importance = log_weight + log_rarity
264         trace("log_importance", log_importance)
265         value_tracker.add(value, log_importance)
266         # Update sampled statistics.
267         param_sampled_tracker.add_get_shift(sample_point, log_importance)
268     debug_list.append(f"integrator used {samples!s} samples")
269     debug_list.append(
270         " ".join(
271             [
272                 "value_avg",
273                 str(value_tracker.average),
274                 "param_sampled_avg",
275                 repr(param_sampled_tracker.averages),
276                 "param_sampled_cov",
277                 repr(param_sampled_tracker.covariance_matrix),
278                 "value_log_variance",
279                 str(value_tracker.log_variance),
280                 "value_log_secondary_variance",
281                 str(value_tracker.secondary.log_variance),
282             ]
283         )
284     )
285     communication_pipe.send(
286         (value_tracker, param_focus_tracker, debug_list, trace_list, samples)
287     )