PLRsearch: Initial implementation and suites
[csit.git] / resources / libraries / python / PLRsearch / Integrator.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Module for numerical integration, tightly coupled to PLRsearch algorithm.
15
16 See log_plus for an explanation why None acts as a special case "float" number.
17
18 TODO: Separate optimizations specific to PLRsearch
19       and distribute as a standalone package so other projects may reuse.
20 """
21
22 import math
23 import traceback
24
25 import dill
26 import numpy
27
28 # TODO: The preferred way to consume this code is via a pip package.
29 # If your project copies code instead, make sure your pylint check does not
30 # require these imports to be absolute and descending from your superpackage.
31 from log_plus import log_plus
32
33
34 def try_estimate_nd(communication_pipe, scale_coeff=10.0, trace_enabled=False):
35     """Call estimate_nd but catch any exception and send traceback."""
36     try:
37         return estimate_nd(communication_pipe, scale_coeff, trace_enabled)
38     except BaseException:
39         # Any subclass could have caused estimate_nd to stop before sending,
40         # so we have to catch them all.
41         traceback_string = traceback.format_exc()
42         communication_pipe.send(traceback_string)
43         # After sendig, re-raise, so usages other than "one process per call"
44         # keep behaving correctly.
45         raise
46
47
48 # FIXME: Pylint reports multiple complexity violations.
49 # Refactor the code, using less (but structured) variables
50 # and function calls for (relatively) loosly coupled code blocks.
51 def estimate_nd(communication_pipe, scale_coeff=10.0, trace_enabled=False):
52     """Use Bayesian inference from control queue, put result to result queue.
53
54     TODO: Use a logging framework that works in a user friendly way.
55     (Note that multiprocessing_logging does not work well with robot
56     and robotbackgroundlogger only works for threads, not processes.
57     Or, wait for https://github.com/robotframework/robotframework/pull/2182
58     Anyway, the current implementation with trace_enabled looks ugly.)
59
60     The result is average and standard deviation for posterior distribution
61     of a single dependent (positive scalar) value.
62     The prior is assumed to be uniform on (-1, 1) for every parameter.
63     Number of parameters and the function for computing
64     the dependent value and likelihood both come from input.
65
66     The likelihood is assumed to be extremely uneven (but never zero),
67     so the function should return the logarithm of the likelihood.
68     The integration method is basically a Monte Carlo
69     (TODO: Add links to notions used here.),
70     but importance sampling is used in order to focus
71     on the part of parameter space with (relatively) non-negligible likelihood.
72
73     Multivariate Gauss distribution is used for focusing,
74     so only unimodal posterior distributions are handled correctly.
75     Initial samples are mostly used for shaping (and shifting)
76     the Gaussian distribution, later samples will probably dominate.
77     Thus, initially the algorithm behavior resembles more "find the maximum",
78     as opposed to "reliably integrate". As for later iterations of PLRsearch,
79     it is assumed that the distribution position does not change rapidly;
80     thus integration algorithm returns also the distribution data,
81     to be used as initial focus in next iteration.
82
83     After some number of samples (depends on dimension),
84     the algorithm starts tracking few most likely samples to base the Gaussian
85     distribution around, mixing with estimates from observed samples.
86     The idea is that even when (usually) one of the samples dominates,
87     first few are treated as if equally likely, to get reasonable focus.
88     During the "find the maximum" phase, this set of best samples
89     frequently takes a wrong shape (compared to observed samples
90     in equilibrium). Therefore scale_coeff argument is left for humans to tweak,
91     so the convergence is reliable and quick.
92     Any data (other than correctly weighted samples) used to keep
93     distribution shape reasonable is called "bias", regardles of
94     whether it comes from input hint, or from tracking top samples.
95
96     Until the distribution locates itself roughly around
97     the maximum likeligood point, the integration results are probably wrong.
98     That means some minimal time is needed for the result to become reliable.
99     The reported standard distribution attempts to signal inconsistence
100     (when one sample has dominating weight compared to the rest of samples),
101     but some human supervision is strongly encouraged.
102
103     To facilitate running in worker processes, arguments and results
104     are communicated via pipe. The computation does not start
105     until arguments appear in the pipe, the computation stops
106     when another item (stop object) is detected in the pipe
107     (and result is put to pipe).
108
109     TODO: Create classes for arguments and results,
110           so their fields are documented (and code perhaps more readable).
111
112     Input/argument object (received from pipe)
113     is a 4-tuple of the following fields:
114     - dimension: Integer, number of parameters to consider.
115     - dilled_function: Function (serialized using dill), which:
116     - - Takes the dimension number of float parameters from (-1, 1).
117     - - Returns float 2-tuple of dependent value and parameter log-likelihood.
118     - param_hint_avg: Dimension-tuple of floats to start searching around.
119     - param_hint_cov: Covariance matrix defining initial focus shape.
120
121     Output/result object (sent to pipe queue)
122     is a 6-tuple of the following fields:
123     - value_avg: Float estimate of posterior average dependent value.
124     - value_stdev: Float estimate of posterior standard deviation of the value.
125     - param_importance_avg: Float tuple, center of Gaussian to use next.
126     - param_importance_cov: Float covariance matrix of the Gaussian to use next.
127     - debug_list: List of debug strings to log at main process.
128     - trace_list: List of trace strings to pass to main process if enabled.
129     Trace strings are very verbose, it is not recommended to enable them.
130     In they are not enabled, trace_list will be empty.
131     It is recommended to edit some lines manually to debug_list if needed.
132
133     :param communication_pipe: Pipe to comunicate with boss process.
134     :param scale_coeff: Float number to tweak convergence speed with.
135     :param trace_enabled: Whether trace list should be populated at all.
136         Default: False
137     :type communication_pipe: multiprocessing.Connection (or compatible)
138     :type scale_coeff: float
139     :type trace_enabled: boolean
140     :raises OverflowError: If one sample dominates the rest too much.
141         Or if value_logweight_function does not handle
142         some part of parameter space carefully enough.
143     :raises numpy.linalg.LinAlgError: If the focus shape gets singular
144         (due to rounding errors). Try changing scale_coeff.
145     """
146
147     # Block until input object appears.
148     dimension, dilled_function, param_hint_avg, param_hint_cov = (
149         communication_pipe.recv())
150     debug_list = list()
151     trace_list = list()
152     def trace(name, value):
153         """
154         Add a variable (name and value) to trace list (if enabled).
155
156         This is a closure (not a pure function),
157         as it accesses trace_list and trace_enabled
158         (without any of them being an explicit argument).
159
160         :param name: Any string identifying the value.
161         :param value: Any object to log repr of.
162         :type name: str
163         :type value: object
164         """
165         if trace_enabled:
166             trace_list.append(name + " " + repr(value))
167     value_logweight_function = dill.loads(dilled_function)
168     len_top = (dimension + 2) * (dimension + 1) / 2
169     top_weight_param = list()
170     samples = 0
171     log_sum_weight = None
172     # Importance sampling produces samples of higher weight (important)
173     # more frequently, and corrects that by adding weight bonus
174     # for the less frequently (unimportant) samples.
175     # But "corrected_weight" is too close to "weight" to be readable,
176     # so "importance" is used instead, even if it runs contrary to what
177     # important region is.
178     log_sum_importance = None
179     log_importance_best = None
180     value_avg = 0.0
181     # 1x1 dimensional covariance matrix is just variance.
182     # As variance is never negative, we can track logarithm.
183     value_log_variance = None
184     # Here "secondary" means "excluding the weightest sample".
185     log_secondary_sum_importance = None
186     value_secondary_avg = 0.0
187     value_log_secondary_variance = None
188     param_sampled_avg = [0.0 for first in range(dimension)]
189     # TODO: Examine whether we can gain speed by tracking triangle only.
190     # Covariance matrix can contain negative element (off-diagonal),
191     # so no logarithm here. This can lead to zeroes on diagonal,
192     # but we have biasing to make sure it does not hurt much.
193     param_sampled_cov = [[0.0 for first in range(dimension)]
194                          for second in range(dimension)]
195     # The next two variables do NOT need to be initialized here,
196     # but pylint is not a mathematician enough to understand that.
197     param_top_avg = [0.0 for first in range(dimension)]
198     param_top_cov = [[0.0 for first in range(dimension)]
199                      for second in range(dimension)]
200     if not (param_hint_avg and param_hint_cov):
201         # First call has Nones instead of useful hints.
202         param_hint_avg = [0.0 for first in range(dimension)]
203         param_hint_cov = [
204             [1.0 if first == second else 0.0 for first in range(dimension)]
205             for second in range(dimension)]
206     while not communication_pipe.poll():
207         # Compute focus data.
208         if len(top_weight_param) < len_top:
209             # Not enough samples for reasonable top, use hint bias.
210             param_focus_avg = param_hint_avg
211             param_focus_cov = param_hint_cov
212         else:
213             # We have both top samples and overall samples.
214             # Mix them according to how much the weightest sample dominates.
215             log_top_weight = top_weight_param[0][0]
216             log_weight_norm = log_plus(log_sum_weight, log_top_weight)
217             top_ratio = math.exp(log_top_weight - log_weight_norm)
218             sampled_ratio = math.exp(log_sum_weight - log_weight_norm)
219             trace("log_top_weight", log_top_weight)
220             trace("log_sum_weight", log_sum_weight)
221             trace("top_ratio", top_ratio)
222             trace("sampled_ratio", sampled_ratio)
223             param_focus_avg = [
224                 sampled_ratio * param_sampled_avg[first]
225                 + top_ratio * param_top_avg[first]
226                 for first in range(dimension)]
227             param_focus_cov = [[
228                 scale_coeff * (
229                     sampled_ratio * param_sampled_cov[first][second]
230                     + top_ratio * param_top_cov[first][second])
231                 for first in range(dimension)] for second in range(dimension)]
232         trace("param_focus_avg", param_focus_avg)
233         trace("param_focus_cov", param_focus_cov)
234         # Generate next sample.
235         while 1:
236             # TODO: Inform pylint that correct version of numpy is available.
237             sample_point = numpy.random.multivariate_normal(
238                 param_focus_avg, param_focus_cov, 1)[0]
239             # Multivariate Gauss can fall outside (-1, 1) interval
240             for first in range(dimension):
241                 sample_coordinate = sample_point[first]
242                 if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
243                     break
244             else:  # These two breaks implement "level two continue".
245                 break
246         trace("sample_point", sample_point)
247         samples += 1
248         value, log_weight = value_logweight_function(*sample_point)
249         trace("value", value)
250         trace("log_weight", log_weight)
251         # Update bias related statistics.
252         log_sum_weight = log_plus(log_sum_weight, log_weight)
253         if len(top_weight_param) < len_top:
254             top_weight_param.append((log_weight, sample_point))
255         # Hack: top_weight_param[-1] is either the smallest,
256         # or the just appended to len_top-1 item list.
257         if (len(top_weight_param) >= len_top
258                 and log_weight >= top_weight_param[-1][0]):
259             top_weight_param = top_weight_param[:-1]
260             top_weight_param.append((log_weight, sample_point))
261             top_weight_param.sort(key=lambda item: -item[0])
262             trace("top_weight_param", top_weight_param)
263             # top_weight_param has changed, recompute biases.
264             param_top_avg = top_weight_param[0][1]
265             param_top_cov = [[0.0 for first in range(dimension)]
266                              for second in range(dimension)]
267             top_item_count = 1
268             for _, near_top_param in top_weight_param[1:]:
269                 top_item_count += 1
270                 next_item_ratio = 1.0 / top_item_count
271                 previous_items_ratio = 1.0 - next_item_ratio
272                 param_shift = [
273                     near_top_param[first] - param_top_avg[first]
274                     for first in range(dimension)]
275                 # Do not move center from the weightest sample.
276                 for second in range(dimension):
277                     for first in range(dimension):
278                         param_top_cov[first][second] += (
279                             param_shift[first] * param_shift[second]
280                             * next_item_ratio)
281                         param_top_cov[first][second] *= previous_items_ratio
282             trace("param_top_avg", param_top_avg)
283             trace("param_top_cov", param_top_cov)
284         # The code above looked at weight (not importance).
285         # The code below looks at importance (not weight).
286         param_shift = [sample_point[first] - param_focus_avg[first]
287                        for first in range(dimension)]
288         rarity_gradient = numpy.linalg.solve(param_focus_cov, param_shift)
289         rarity_step = numpy.vdot(param_shift, rarity_gradient)
290         log_rarity = rarity_step / 2.0
291         trace("log_rarity", log_rarity)
292         trace("samples", samples)
293         log_importance = log_weight + log_rarity
294         trace("log_importance", log_importance)
295         # Update sampled statistics.
296         old_log_sum_importance = log_sum_importance
297         log_sum_importance = log_plus(old_log_sum_importance, log_importance)
298         trace("new log_sum_weight", log_sum_weight)
299         trace("log_sum_importance", log_sum_importance)
300         if old_log_sum_importance is None:
301             param_sampled_avg = list(sample_point)
302             value_avg = value
303             # Other value related quantities stay None.
304             continue
305         previous_samples_ratio = math.exp(
306             old_log_sum_importance - log_sum_importance)
307         new_sample_ratio = math.exp(log_importance - log_sum_importance)
308         param_shift = [sample_point[first] - param_sampled_avg[first]
309                        for first in range(dimension)]
310         value_shift = value - value_avg
311         for first in range(dimension):
312             param_sampled_avg[first] += param_shift[first] * new_sample_ratio
313         old_value_avg = value_avg
314         value_avg += value_shift * new_sample_ratio
315         value_absolute_shift = abs(value_shift)
316         for second in range(dimension):
317             for first in range(dimension):
318                 param_sampled_cov[first][second] += (
319                     param_shift[first] * param_shift[second] * new_sample_ratio)
320                 param_sampled_cov[first][second] *= previous_samples_ratio
321         trace("param_sampled_avg", param_sampled_avg)
322         trace("param_sampled_cov", param_sampled_cov)
323         update_secondary_stats = True
324         if log_importance_best is None or log_importance > log_importance_best:
325             log_importance_best = log_importance
326             log_secondary_sum_importance = old_log_sum_importance
327             value_secondary_avg = old_value_avg
328             value_log_secondary_variance = value_log_variance
329             update_secondary_stats = False
330             # TODO: Update all primary quantities before secondary ones.
331             # (As opposed to current hybrid code.)
332         if value_absolute_shift > 0.0:
333             value_log_variance = log_plus(
334                 value_log_variance, 2 * math.log(value_absolute_shift)
335                 + log_importance - log_sum_importance)
336         if value_log_variance is not None:
337             value_log_variance -= log_sum_importance - old_log_sum_importance
338         if not update_secondary_stats:
339             continue
340         # TODO: Pylint says the following variable name is bad.
341         # Make sure the refactor uses shorter names.
342         old_log_secondary_sum_importance = log_secondary_sum_importance
343         log_secondary_sum_importance = log_plus(
344             old_log_secondary_sum_importance, log_importance)
345         if old_log_secondary_sum_importance is None:
346             value_secondary_avg = value
347             continue
348         new_sample_secondary_ratio = math.exp(
349             log_importance - log_secondary_sum_importance)
350         # TODO: No would-be variable named old_value_secondary_avg
351         # appears in subsequent computations. Probably means there is a bug.
352         value_secondary_shift = value - value_secondary_avg
353         value_secondary_absolute_shift = abs(value_secondary_shift)
354         value_secondary_avg += (
355             value_secondary_shift * new_sample_secondary_ratio)
356         if value_secondary_absolute_shift > 0.0:
357             value_log_secondary_variance = log_plus(
358                 value_log_secondary_variance, (
359                     2 * math.log(value_secondary_absolute_shift)
360                     + log_importance - log_secondary_sum_importance))
361         if value_log_secondary_variance is not None:
362             value_log_secondary_variance -= (
363                 log_secondary_sum_importance - old_log_secondary_sum_importance)
364     debug_list.append("integrator used " + str(samples) + " samples")
365     debug_list.append(
366         "value_avg " + str(value_avg)
367         + " param_sampled_avg " + repr(param_sampled_avg)
368         + " param_sampled_cov " + repr(param_sampled_cov)
369         + " value_log_variance " + str(value_log_variance)
370         + " value_log_secondary_variance " + str(value_log_secondary_variance))
371     value_stdev = math.exp(
372         (2 * value_log_variance - value_log_secondary_variance) / 2.0)
373     debug_list.append("top_weight_param[0] " + repr(top_weight_param[0]))
374     # Intentionally returning param_focus_avg and param_focus_cov,
375     # instead of possibly hyper-focused bias or sampled.
376     communication_pipe.send(
377         (value_avg, value_stdev, param_focus_avg, param_focus_cov, debug_list,
378          trace_list))