lb: remove api boilerplate
[vpp.git] / test / remote_test.py
1 #!/usr/bin/env python
2
3 import inspect
4 import os
5 import unittest
6 from framework import VppTestCase
7 from multiprocessing import Process, Pipe
8 from pickle import dumps
9 import six
10 from six import moves
11 import sys
12 from aenum import IntEnum, IntFlag
13
14
15 class SerializableClassCopy(object):
16     """
17     Empty class used as a basis for a serializable copy of another class.
18     """
19     pass
20
21     def __repr__(self):
22         return '<SerializableClassCopy dict=%s>' % self.__dict__
23
24
25 class RemoteClassAttr(object):
26     """
27     Wrapper around attribute of a remotely executed class.
28     """
29
30     def __init__(self, remote, attr):
31         self._path = [attr] if attr else []
32         self._remote = remote
33
34     def path_to_str(self):
35         return '.'.join(self._path)
36
37     def get_remote_value(self):
38         return self._remote._remote_exec(RemoteClass.GET, self.path_to_str())
39
40     def __repr__(self):
41         return self._remote._remote_exec(RemoteClass.REPR, self.path_to_str())
42
43     def __str__(self):
44         return self._remote._remote_exec(RemoteClass.STR, self.path_to_str())
45
46     def __getattr__(self, attr):
47         if attr[0] == '_':
48             if not (attr.startswith('__') and attr.endswith('__')):
49                 raise AttributeError('tried to get private attribute: %s ',
50                                      attr)
51         self._path.append(attr)
52         return self
53
54     def __setattr__(self, attr, val):
55         if attr[0] == '_':
56             if not (attr.startswith('__') and attr.endswith('__')):
57                 super(RemoteClassAttr, self).__setattr__(attr, val)
58                 return
59         self._path.append(attr)
60         self._remote._remote_exec(RemoteClass.SETATTR, self.path_to_str(),
61                                   value=val)
62
63     def __call__(self, *args, **kwargs):
64         return self._remote._remote_exec(RemoteClass.CALL, self.path_to_str(),
65                                          *args, **kwargs)
66
67
68 class RemoteClass(Process):
69     """
70     This class can wrap around and adapt the interface of another class,
71     and then delegate its execution to a newly forked child process.
72     Usage:
73         # Create a remotely executed instance of MyClass
74         object = RemoteClass(MyClass, arg1='foo', arg2='bar')
75         object.start_remote()
76         # Access the object normally as if it was an instance of your class.
77         object.my_attribute = 20
78         print object.my_attribute
79         print object.my_method(object.my_attribute)
80         object.my_attribute.nested_attribute = 'test'
81         # If you need the value of a remote attribute, use .get_remote_value
82         method. This method is automatically called when needed in the context
83         of a remotely executed class. E.g.:
84         if (object.my_attribute.get_remote_value() > 20):
85             object.my_attribute2 = object.my_attribute
86         # Destroy the instance
87         object.quit_remote()
88         object.terminate()
89     """
90
91     GET = 0       # Get attribute remotely
92     CALL = 1      # Call method remotely
93     SETATTR = 2   # Set attribute remotely
94     REPR = 3      # Get representation of a remote object
95     STR = 4       # Get string representation of a remote object
96     QUIT = 5      # Quit remote execution
97
98     PIPE_PARENT = 0  # Parent end of the pipe
99     PIPE_CHILD = 1  # Child end of the pipe
100
101     DEFAULT_TIMEOUT = 2  # default timeout for an operation to execute
102
103     def __init__(self, cls, *args, **kwargs):
104         super(RemoteClass, self).__init__()
105         self._cls = cls
106         self._args = args
107         self._kwargs = kwargs
108         self._timeout = RemoteClass.DEFAULT_TIMEOUT
109         self._pipe = Pipe()  # pipe for input/output arguments
110
111     def __repr__(self):
112         return moves.reprlib.repr(RemoteClassAttr(self, None))
113
114     def __str__(self):
115         return str(RemoteClassAttr(self, None))
116
117     def __call__(self, *args, **kwargs):
118         return self.RemoteClassAttr(self, None)()
119
120     def __getattr__(self, attr):
121         if attr[0] == '_' or not self.is_alive():
122             if not (attr.startswith('__') and attr.endswith('__')):
123                 if hasattr(super(RemoteClass, self), '__getattr__'):
124                     return super(RemoteClass, self).__getattr__(attr)
125                 raise AttributeError('missing: %s', attr)
126         return RemoteClassAttr(self, attr)
127
128     def __setattr__(self, attr, val):
129         if attr[0] == '_' or not self.is_alive():
130             if not (attr.startswith('__') and attr.endswith('__')):
131                 super(RemoteClass, self).__setattr__(attr, val)
132                 return
133         setattr(RemoteClassAttr(self, None), attr, val)
134
135     def _remote_exec(self, op, path=None, *args, **kwargs):
136         """
137         Execute given operation on a given, possibly nested, member remotely.
138         """
139         # automatically resolve remote objects in the arguments
140         mutable_args = list(args)
141         for i, val in enumerate(mutable_args):
142             if isinstance(val, RemoteClass) or \
143                     isinstance(val, RemoteClassAttr):
144                 mutable_args[i] = val.get_remote_value()
145         args = tuple(mutable_args)
146         for key, val in six.iteritems(kwargs):
147             if isinstance(val, RemoteClass) or \
148                     isinstance(val, RemoteClassAttr):
149                 kwargs[key] = val.get_remote_value()
150         # send request
151         args = self._make_serializable(args)
152         kwargs = self._make_serializable(kwargs)
153         self._pipe[RemoteClass.PIPE_PARENT].send((op, path, args, kwargs))
154         timeout = self._timeout
155         # adjust timeout specifically for the .sleep method
156         if path is not None and path.split('.')[-1] == 'sleep':
157             if args and isinstance(args[0], (long, int)):
158                 timeout += args[0]
159             elif 'timeout' in kwargs:
160                 timeout += kwargs['timeout']
161         if not self._pipe[RemoteClass.PIPE_PARENT].poll(timeout):
162             return None
163         try:
164             rv = self._pipe[RemoteClass.PIPE_PARENT].recv()
165             rv = self._deserialize(rv)
166             return rv
167         except EOFError:
168             return None
169
170     def _get_local_object(self, path):
171         """
172         Follow the path to obtain a reference on the addressed nested attribute
173         """
174         obj = self._instance
175         for attr in path:
176             obj = getattr(obj, attr)
177         return obj
178
179     def _get_local_value(self, path):
180         try:
181             return self._get_local_object(path)
182         except AttributeError:
183             return None
184
185     def _call_local_method(self, path, *args, **kwargs):
186         try:
187             method = self._get_local_object(path)
188             return method(*args, **kwargs)
189         except AttributeError:
190             return None
191
192     def _set_local_attr(self, path, value):
193         try:
194             obj = self._get_local_object(path[:-1])
195             setattr(obj, path[-1], value)
196         except AttributeError:
197             pass
198         return None
199
200     def _get_local_repr(self, path):
201         try:
202             obj = self._get_local_object(path)
203             return moves.reprlib.repr(obj)
204         except AttributeError:
205             return None
206
207     def _get_local_str(self, path):
208         try:
209             obj = self._get_local_object(path)
210             return str(obj)
211         except AttributeError:
212             return None
213
214     def _serializable(self, obj):
215         """ Test if the given object is serializable """
216         try:
217             dumps(obj)
218             return True
219         except:
220             return False
221
222     def _make_obj_serializable(self, obj):
223         """
224         Make a serializable copy of an object.
225         Members which are difficult/impossible to serialize are stripped.
226         """
227         if self._serializable(obj):
228             return obj  # already serializable
229
230         copy = SerializableClassCopy()
231
232         """
233         Dictionaries can hold complex values, so we split keys and values into
234         separate lists and serialize them individually.
235         """
236         if (type(obj) is dict):
237             copy.type = type(obj)
238             copy.k_list = list()
239             copy.v_list = list()
240             for k, v in obj.items():
241                 copy.k_list.append(self._make_serializable(k))
242                 copy.v_list.append(self._make_serializable(v))
243             return copy
244
245         # copy at least serializable attributes and properties
246         for name, member in inspect.getmembers(obj):
247             # skip private members and non-writable dunder methods.
248             if name[0] == '_':
249                 if name in ['__weakref__']:
250                     continue
251                 if name in ['__dict__']:
252                     continue
253                 if not (name.startswith('__') and name.endswith('__')):
254                     continue
255             if callable(member) and not isinstance(member, property):
256                 continue
257             if not self._serializable(member):
258                 member = self._make_serializable(member)
259             setattr(copy, name, member)
260         return copy
261
262     def _make_serializable(self, obj):
263         """
264         Make a serializable copy of an object or a list/tuple of objects.
265         Members which are difficult/impossible to serialize are stripped.
266         """
267         if (type(obj) is list) or (type(obj) is tuple):
268             rv = []
269             for item in obj:
270                 rv.append(self._make_serializable(item))
271             if type(obj) is tuple:
272                 rv = tuple(rv)
273             return rv
274         elif (isinstance(obj, IntEnum) or isinstance(obj, IntFlag)):
275             return obj.value
276         else:
277             return self._make_obj_serializable(obj)
278
279     def _deserialize_obj(self, obj):
280         if (hasattr(obj, 'type')):
281             if obj.type is dict:
282                 _obj = dict()
283                 for k, v in zip(obj.k_list, obj.v_list):
284                     _obj[self._deserialize(k)] = self._deserialize(v)
285             return _obj
286         return obj
287
288     def _deserialize(self, obj):
289         if (type(obj) is list) or (type(obj) is tuple):
290             rv = []
291             for item in obj:
292                 rv.append(self._deserialize(item))
293             if type(obj) is tuple:
294                 rv = tuple(rv)
295             return rv
296         else:
297             return self._deserialize_obj(obj)
298
299     def start_remote(self):
300         """ Start remote execution """
301         self.start()
302
303     def quit_remote(self):
304         """ Quit remote execution """
305         self._remote_exec(RemoteClass.QUIT, None)
306
307     def get_remote_value(self):
308         """ Get value of a remotely held object """
309         return RemoteClassAttr(self, None).get_remote_value()
310
311     def set_request_timeout(self, timeout):
312         """ Change request timeout """
313         self._timeout = timeout
314
315     def run(self):
316         """
317         Create instance of the wrapped class and execute operations
318         on it as requested by the parent process.
319         """
320         self._instance = self._cls(*self._args, **self._kwargs)
321         while True:
322             try:
323                 rv = None
324                 # get request from the parent process
325                 (op, path, args,
326                  kwargs) = self._pipe[RemoteClass.PIPE_CHILD].recv()
327                 args = self._deserialize(args)
328                 kwargs = self._deserialize(kwargs)
329                 path = path.split('.') if path else []
330                 if op == RemoteClass.GET:
331                     rv = self._get_local_value(path)
332                 elif op == RemoteClass.CALL:
333                     rv = self._call_local_method(path, *args, **kwargs)
334                 elif op == RemoteClass.SETATTR and 'value' in kwargs:
335                     self._set_local_attr(path, kwargs['value'])
336                 elif op == RemoteClass.REPR:
337                     rv = self._get_local_repr(path)
338                 elif op == RemoteClass.STR:
339                     rv = self._get_local_str(path)
340                 elif op == RemoteClass.QUIT:
341                     break
342                 else:
343                     continue
344                 # send return value
345                 if not self._serializable(rv):
346                     rv = self._make_serializable(rv)
347                 self._pipe[RemoteClass.PIPE_CHILD].send(rv)
348             except EOFError:
349                 break
350         self._instance = None  # destroy the instance
351
352
353 @unittest.skip("Remote Vpp Test Case Class")
354 class RemoteVppTestCase(VppTestCase):
355     """ Re-use VppTestCase to create remote VPP segment
356
357         In your test case:
358
359         @classmethod
360         def setUpClass(cls):
361             # fork new process before client connects to VPP
362             cls.remote_test = RemoteClass(RemoteVppTestCase)
363
364             # start remote process
365             cls.remote_test.start_remote()
366
367             # set up your test case
368             super(MyTestCase, cls).setUpClass()
369
370             # set up remote test
371             cls.remote_test.setUpClass(cls.tempdir)
372
373         @classmethod
374         def tearDownClass(cls):
375             # tear down remote test
376             cls.remote_test.tearDownClass()
377
378             # stop remote process
379             cls.remote_test.quit_remote()
380
381             # tear down your test case
382             super(MyTestCase, cls).tearDownClass()
383     """
384
385     def __init__(self):
386         super(RemoteVppTestCase, self).__init__("emptyTest")
387
388     # Note: __del__ is a 'Finalizer" not a 'Destructor'.
389     # https://docs.python.org/3/reference/datamodel.html#object.__del__
390     def __del__(self):
391         if hasattr(self, "vpp"):
392             self.vpp.poll()
393             if self.vpp.returncode is None:
394                 self.vpp.terminate()
395                 self.vpp.communicate()
396
397     @classmethod
398     def setUpClass(cls, tempdir):
399         # disable features unsupported in remote VPP
400         orig_env = dict(os.environ)
401         if 'STEP' in os.environ:
402             del os.environ['STEP']
403         if 'DEBUG' in os.environ:
404             del os.environ['DEBUG']
405         cls.tempdir_prefix = os.path.basename(tempdir) + "/"
406         super(RemoteVppTestCase, cls).setUpClass()
407         os.environ = orig_env
408
409     @classmethod
410     def tearDownClass(cls):
411         super(RemoteVppTestCase, cls).tearDownClass()
412
413     @unittest.skip("Empty test")
414     def emptyTest(self):
415         """ Do nothing """
416         pass
417
418     def setTestFunctionInfo(self, name, doc):
419         """
420         Store the name and documentation string of currently executed test
421         in the main VPP for logging purposes.
422         """
423         self._testMethodName = name
424         self._testMethodDoc = doc