PLRsearch: Use stat trackers to shorten Integrator
[csit.git] / resources / libraries / python / PLRsearch / Integrator.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Module for numerical integration, tightly coupled to PLRsearch algorithm.
15
16 See log_plus for an explanation why None acts as a special case "float" number.
17
18 TODO: Separate optimizations specific to PLRsearch and distribute the rest
19       as a standalone package so other projects may reuse.
20 """
21
22 import copy
23 import traceback
24
25 import dill
26 from numpy import random
27
28 # TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH,
29 # then switch to absolute imports within PLRsearch package.
30 # Current usage of relative imports is just a short term workaround.
31 import stat_trackers  # pylint: disable=relative-import
32
33
34 def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
35     """Call estimate_nd but catch any exception and send traceback."""
36     try:
37         return estimate_nd(communication_pipe, scale_coeff, trace_enabled)
38     except BaseException:
39         # Any subclass could have caused estimate_nd to stop before sending,
40         # so we have to catch them all.
41         traceback_string = traceback.format_exc()
42         communication_pipe.send(traceback_string)
43         # After sendig, re-raise, so usages other than "one process per call"
44         # keep behaving correctly.
45         raise
46
47
48 # TODO: Pylint reports multiple complexity violations.
49 # Refactor the code, using less (but structured) variables
50 # and function calls for (relatively) loosly coupled code blocks.
51 def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
52     """Use Bayesian inference from control queue, put result to result queue.
53
54     TODO: Use a logging framework that works in a user friendly way.
55     (Note that multiprocessing_logging does not work well with robot
56     and robotbackgroundlogger only works for threads, not processes.
57     Or, wait for https://github.com/robotframework/robotframework/pull/2182
58     Anyway, the current implementation with trace_enabled looks ugly.)
59
60     The result is average and standard deviation for posterior distribution
61     of a single dependent (scalar, float) value.
62     The prior is assumed to be uniform on (-1, 1) for every parameter.
63     Number of parameters and the function for computing
64     the dependent value and likelihood both come from input.
65
66     The likelihood is assumed to be extremely uneven (but never zero),
67     so the function should return the logarithm of the likelihood.
68     The integration method is basically a Monte Carlo
69     (TODO: Add links to notions used here.),
70     but importance sampling is used in order to focus
71     on the part of parameter space with (relatively) non-negligible likelihood.
72
73     Multivariate Gauss distribution is used for focusing,
74     so only unimodal posterior distributions are handled correctly.
75     Initial samples are mostly used for shaping (and shifting)
76     the Gaussian distribution, later samples will probably dominate.
77     Thus, initially the algorithm behavior resembles more "find the maximum",
78     as opposed to "reliably integrate". As for later iterations of PLRsearch,
79     it is assumed that the distribution position does not change rapidly;
80     thus integration algorithm returns also the distribution data,
81     to be used as initial focus in next iteration.
82
83     There are workarounds in place that allow old or default focus tracker
84     to be updated reasonably, even when initial samples
85     of new iteration have way smaller (or larger) weights.
86
87     During the "find the maximum" phase, the focus tracker frequently takes
88     a wrong shape (compared to observed samples in equilibrium).
89     Therefore scale_coeff argument is left for humans to tweak,
90     so the convergence is reliable and quick.
91
92     Until the distribution locates itself roughly around
93     the maximum likeligood point, the integration results are probably wrong.
94     That means some minimal time is needed for the result to become reliable.
95
96     TODO: The folowing is not currently implemented.
97     The reported standard distribution attempts to signal inconsistence
98     (when one sample has dominating weight compared to the rest of samples),
99     but some human supervision is strongly encouraged.
100
101     To facilitate running in worker processes, arguments and results
102     are communicated via a pipe. The computation does not start
103     until arguments appear in the pipe, the computation stops
104     when another item (stop object) is detected in the pipe
105     (and result is put to pipe).
106
107     TODO: Create classes for arguments and results,
108           so their fields are documented (and code perhaps more readable).
109
110     Input/argument object (received from pipe)
111     is a 4-tuple of the following fields:
112     - dimension: Integer, number of parameters to consider.
113     - dilled_function: Function (serialized using dill), which:
114     - - Takes the dimension number of float parameters from (-1, 1).
115     - - Returns float 2-tuple of dependent value and parameter log-likelihood.
116     - param_focus_tracker: VectorStatTracker to use for initial focus.
117     - max_samples: None or a limit for samples to use.
118
119     Output/result object (sent to pipe queue)
120     is a 5-tuple of the following fields:
121     - value_tracker: ScalarDualStatTracker estimate of value posterior.
122     - param_focus_tracker: VectorStatTracker to use for initial focus next.
123     - debug_list: List of debug strings to log at main process.
124     - trace_list: List of trace strings to pass to main process if enabled.
125     - samples: Number of samples used in computation (to make it reproducible).
126     Trace strings are very verbose, it is not recommended to enable them.
127     In they are not enabled, trace_list will be empty.
128     It is recommended to edit some lines manually to debug_list if needed.
129
130     :param communication_pipe: Pipe to comunicate with boss process.
131     :param scale_coeff: Float number to tweak convergence speed with.
132     :param trace_enabled: Whether trace list should be populated at all.
133         Default: False
134     :type communication_pipe: multiprocessing.Connection (or compatible)
135     :type scale_coeff: float
136     :type trace_enabled: boolean
137     :raises OverflowError: If one sample dominates the rest too much.
138         Or if value_logweight_function does not handle
139         some part of parameter space carefully enough.
140     :raises numpy.linalg.LinAlgError: If the focus shape gets singular
141         (due to rounding errors). Try changing scale_coeff.
142     """
143
144     debug_list = list()
145     trace_list = list()
146     # Block until input object appears.
147     dimension, dilled_function, param_focus_tracker, max_samples = (
148         communication_pipe.recv())
149     debug_list.append("Called with param_focus_tracker {tracker!r}"
150                       .format(tracker=param_focus_tracker))
151     def trace(name, value):
152         """
153         Add a variable (name and value) to trace list (if enabled).
154
155         This is a closure (not a pure function),
156         as it accesses trace_list and trace_enabled
157         (without any of them being an explicit argument).
158
159         :param name: Any string identifying the value.
160         :param value: Any object to log repr of.
161         :type name: str
162         :type value: object
163         """
164         if trace_enabled:
165             trace_list.append(name + " " + repr(value))
166     value_logweight_function = dill.loads(dilled_function)
167     samples = 0
168     # Importance sampling produces samples of higher weight (important)
169     # more frequently, and corrects that by adding weight bonus
170     # for the less frequently (unimportant) samples.
171     # But "corrected_weight" is too close to "weight" to be readable,
172     # so "importance" is used instead, even if it runs contrary to what
173     # important region is.
174     value_tracker = stat_trackers.ScalarDualStatTracker()
175     param_sampled_tracker = stat_trackers.VectorStatTracker(dimension).reset()
176     if not param_focus_tracker:
177         # First call has None instead of a real (even empty) tracker.
178         param_focus_tracker = stat_trackers.VectorStatTracker(dimension)
179         param_focus_tracker.unit_reset()
180     else:
181         # Focus tracker has probably too high weight.
182         param_focus_tracker.log_sum_weight = None
183     # TODO: Teach pylint the used version of numpy.random does have this member.
184     random.seed(0)
185     while not communication_pipe.poll():
186         if max_samples and samples >= max_samples:
187             break
188         # Generate next sample.
189         averages = param_focus_tracker.averages
190         covariance_matrix = copy.deepcopy(param_focus_tracker.covariance_matrix)
191         for first in range(dimension):
192             for second in range(dimension):
193                 covariance_matrix[first][second] *= scale_coeff
194         while 1:
195             # TODO: Teach pylint that numpy.random does also have this member.
196             sample_point = random.multivariate_normal(
197                 averages, covariance_matrix, 1)[0].tolist()
198             # Multivariate Gauss can fall outside (-1, 1) interval
199             for first in range(dimension):
200                 sample_coordinate = sample_point[first]
201                 if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
202                     break
203             else:  # These two breaks implement "level two continue".
204                 break
205         trace("sample_point", sample_point)
206         samples += 1
207         trace("samples", samples)
208         value, log_weight = value_logweight_function(trace, *sample_point)
209         trace("value", value)
210         trace("log_weight", log_weight)
211         trace("focus tracker before adding", param_focus_tracker)
212         # Update focus related statistics.
213         param_distance = param_focus_tracker.add_without_dominance_get_distance(
214             sample_point, log_weight)
215         # The code above looked at weight (not importance).
216         # The code below looks at importance (not weight).
217         log_rarity = param_distance / 2.0
218         trace("log_rarity", log_rarity)
219         log_importance = log_weight + log_rarity
220         trace("log_importance", log_importance)
221         value_tracker.add(value, log_importance)
222         # Update sampled statistics.
223         param_sampled_tracker.add_get_shift(sample_point, log_importance)
224     debug_list.append("integrator used " + str(samples) + " samples")
225     debug_list.append(" ".join([
226         "value_avg", str(value_tracker.average),
227         "param_sampled_avg", repr(param_sampled_tracker.averages),
228         "param_sampled_cov", repr(param_sampled_tracker.covariance_matrix),
229         "value_log_variance", str(value_tracker.log_variance),
230         "value_log_secondary_variance",
231         str(value_tracker.secondary.log_variance)]))
232     communication_pipe.send(
233         (value_tracker, param_focus_tracker, debug_list, trace_list, samples))