api: verify message size on receipt
[vpp.git] / src / vlibmemory / socket_client.c
1 /*
2  *------------------------------------------------------------------
3  * socket_client.c - API message handling over sockets, client code.
4  *
5  * Copyright (c) 2017 Cisco and/or its affiliates.
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at:
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  *------------------------------------------------------------------
18  */
19
20 #include <stdio.h>
21 #define __USE_GNU
22 #define _GNU_SOURCE
23 #include <sys/socket.h>
24
25 #include <svm/ssvm.h>
26 #include <vlibmemory/socket_client.h>
27 #include <vlibmemory/memory_client.h>
28
29 #include <vlibmemory/vl_memory_msg_enum.h>
30
31 #define vl_typedefs             /* define message structures */
32 #include <vlibmemory/vl_memory_api_h.h>
33 #undef vl_typedefs
34
35 #define vl_endianfun            /* define message structures */
36 #include <vlibmemory/vl_memory_api_h.h>
37 #undef vl_endianfun
38
39 #define vl_calcsizefun
40 #include <vlibmemory/vl_memory_api_h.h>
41 #undef vl_calcsizefun
42
43 /* instantiate all the print functions we know about */
44 #define vl_print(handle, ...) clib_warning (__VA_ARGS__)
45 #define vl_printfun
46 #include <vlibmemory/vl_memory_api_h.h>
47 #undef vl_printfun
48
49 socket_client_main_t socket_client_main;
50 __thread socket_client_main_t *socket_client_ctx = &socket_client_main;
51
52 /* Debug aid */
53 u32 vl (void *p) __attribute__ ((weak));
54
55 u32
56 vl (void *p)
57 {
58   return vec_len (p);
59 }
60
61 static socket_client_main_t *
62 vl_socket_client_ctx_push (socket_client_main_t * ctx)
63 {
64   socket_client_main_t *old = socket_client_ctx;
65   socket_client_ctx = ctx;
66   return old;
67 }
68
69 static void
70 vl_socket_client_ctx_pop (socket_client_main_t * old_ctx)
71 {
72   socket_client_ctx = old_ctx;
73 }
74
75 static int
76 vl_socket_client_read_internal (socket_client_main_t * scm, int wait)
77 {
78   u32 data_len = 0, msg_size;
79   int n, current_rx_index;
80   msgbuf_t *mbp = 0;
81   f64 timeout;
82
83   if (scm->socket_fd == 0)
84     return -1;
85
86   if (wait)
87     timeout = clib_time_now (&scm->clib_time) + wait;
88
89   while (1)
90     {
91       while (vec_len (scm->socket_rx_buffer) < sizeof (*mbp))
92         {
93           current_rx_index = vec_len (scm->socket_rx_buffer);
94           vec_validate (scm->socket_rx_buffer, current_rx_index
95                         + scm->socket_buffer_size - 1);
96           _vec_len (scm->socket_rx_buffer) = current_rx_index;
97           n = read (scm->socket_fd, scm->socket_rx_buffer + current_rx_index,
98                     scm->socket_buffer_size);
99           if (n < 0)
100             {
101               if (errno == EAGAIN)
102                 continue;
103
104               clib_unix_warning ("socket_read");
105               return -1;
106             }
107           _vec_len (scm->socket_rx_buffer) += n;
108         }
109
110 #if CLIB_DEBUG > 1
111       if (n > 0)
112         clib_warning ("read %d bytes", n);
113 #endif
114
115       mbp = (msgbuf_t *) (scm->socket_rx_buffer);
116       data_len = ntohl (mbp->data_len);
117       current_rx_index = vec_len (scm->socket_rx_buffer);
118       vec_validate (scm->socket_rx_buffer, current_rx_index + data_len);
119       _vec_len (scm->socket_rx_buffer) = current_rx_index;
120       mbp = (msgbuf_t *) (scm->socket_rx_buffer);
121       msg_size = data_len + sizeof (*mbp);
122
123       while (vec_len (scm->socket_rx_buffer) < msg_size)
124         {
125           n = read (scm->socket_fd,
126                     scm->socket_rx_buffer + vec_len (scm->socket_rx_buffer),
127                     msg_size - vec_len (scm->socket_rx_buffer));
128           if (n < 0)
129             {
130               if (errno == EAGAIN)
131                 continue;
132
133               clib_unix_warning ("socket_read");
134               return -1;
135             }
136           _vec_len (scm->socket_rx_buffer) += n;
137         }
138
139       if (vec_len (scm->socket_rx_buffer) >= data_len + sizeof (*mbp))
140         {
141           vl_msg_api_socket_handler ((void *) (mbp->data), data_len);
142
143           if (vec_len (scm->socket_rx_buffer) == data_len + sizeof (*mbp))
144             _vec_len (scm->socket_rx_buffer) = 0;
145           else
146             vec_delete (scm->socket_rx_buffer, data_len + sizeof (*mbp), 0);
147           mbp = 0;
148
149           /* Quit if we're out of data, and not expecting a ping reply */
150           if (vec_len (scm->socket_rx_buffer) == 0
151               && scm->control_pings_outstanding == 0)
152             break;
153         }
154       if (wait && clib_time_now (&scm->clib_time) >= timeout)
155         return -1;
156     }
157   return 0;
158 }
159
160 int
161 vl_socket_client_read (int wait)
162 {
163   return vl_socket_client_read_internal (socket_client_ctx, wait);
164 }
165
166 int
167 vl_socket_client_read2 (socket_client_main_t * scm, int wait)
168 {
169   socket_client_main_t *old_ctx;
170   int rv;
171
172   old_ctx = vl_socket_client_ctx_push (scm);
173   rv = vl_socket_client_read_internal (scm, wait);
174   vl_socket_client_ctx_pop (old_ctx);
175   return rv;
176 }
177
178 static int
179 vl_socket_client_write_internal (socket_client_main_t * scm)
180 {
181   int n;
182
183   msgbuf_t msgbuf = {
184     .q = 0,
185     .gc_mark_timestamp = 0,
186     .data_len = htonl (scm->socket_tx_nbytes),
187   };
188
189   n = write (scm->socket_fd, &msgbuf, sizeof (msgbuf));
190   if (n < sizeof (msgbuf))
191     {
192       clib_unix_warning ("socket write (msgbuf)");
193       return -1;
194     }
195
196   n = write (scm->socket_fd, scm->socket_tx_buffer, scm->socket_tx_nbytes);
197   if (n < scm->socket_tx_nbytes)
198     {
199       clib_unix_warning ("socket write (msg)");
200       return -1;
201     }
202
203   return n;
204 }
205
206 int
207 vl_socket_client_write (void)
208 {
209   return vl_socket_client_write_internal (socket_client_ctx);
210 }
211
212 int
213 vl_socket_client_write2 (socket_client_main_t * scm)
214 {
215   socket_client_main_t *old_ctx;
216   int rv;
217
218   old_ctx = vl_socket_client_ctx_push (scm);
219   rv = vl_socket_client_write_internal (scm);
220   vl_socket_client_ctx_pop (old_ctx);
221   return rv;
222 }
223
224 void *
225 vl_socket_client_msg_alloc2 (socket_client_main_t * scm, int nbytes)
226 {
227   scm->socket_tx_nbytes = nbytes;
228   return ((void *) scm->socket_tx_buffer);
229 }
230
231 void *
232 vl_socket_client_msg_alloc (int nbytes)
233 {
234   return vl_socket_client_msg_alloc2 (socket_client_ctx, nbytes);
235 }
236
237 void
238 vl_socket_client_disconnect2 (socket_client_main_t * scm)
239 {
240   if (vl_mem_client_is_connected ())
241     {
242       vl_client_disconnect_from_vlib_no_unmap ();
243       ssvm_delete_memfd (&scm->memfd_segment);
244     }
245   if (scm->socket_fd && (close (scm->socket_fd) < 0))
246     clib_unix_warning ("close");
247   scm->socket_fd = 0;
248 }
249
250 void
251 vl_socket_client_disconnect (void)
252 {
253   vl_socket_client_disconnect2 (socket_client_ctx);
254 }
255
256 void
257 vl_socket_client_enable_disable2 (socket_client_main_t * scm, int enable)
258 {
259   scm->socket_enable = enable;
260 }
261
262 void
263 vl_socket_client_enable_disable (int enable)
264 {
265   vl_socket_client_enable_disable2 (socket_client_ctx, enable);
266 }
267
268 static clib_error_t *
269 vl_sock_api_recv_fd_msg_internal (socket_client_main_t * scm, int fds[],
270                                   int n_fds, u32 wait)
271 {
272   char msgbuf[16];
273   char ctl[CMSG_SPACE (sizeof (int) * n_fds)
274            + CMSG_SPACE (sizeof (struct ucred))];
275   struct msghdr mh = { 0 };
276   struct iovec iov[1];
277   ssize_t size = 0;
278   struct ucred *cr = 0;
279   struct cmsghdr *cmsg;
280   pid_t pid __attribute__ ((unused));
281   uid_t uid __attribute__ ((unused));
282   gid_t gid __attribute__ ((unused));
283   int socket_fd;
284   f64 timeout;
285
286   socket_fd = scm->client_socket.fd;
287
288   iov[0].iov_base = msgbuf;
289   iov[0].iov_len = 5;
290   mh.msg_iov = iov;
291   mh.msg_iovlen = 1;
292   mh.msg_control = ctl;
293   mh.msg_controllen = sizeof (ctl);
294
295   clib_memset (ctl, 0, sizeof (ctl));
296
297   if (wait != ~0)
298     {
299       timeout = clib_time_now (&scm->clib_time) + wait;
300       while (size != 5 && clib_time_now (&scm->clib_time) < timeout)
301         size = recvmsg (socket_fd, &mh, MSG_DONTWAIT);
302     }
303   else
304     size = recvmsg (socket_fd, &mh, 0);
305
306   if (size != 5)
307     {
308       return (size == 0) ? clib_error_return (0, "disconnected") :
309         clib_error_return_unix (0, "recvmsg: malformed message (fd %d)",
310                                 socket_fd);
311     }
312
313   cmsg = CMSG_FIRSTHDR (&mh);
314   while (cmsg)
315     {
316       if (cmsg->cmsg_level == SOL_SOCKET)
317         {
318           if (cmsg->cmsg_type == SCM_CREDENTIALS)
319             {
320               cr = (struct ucred *) CMSG_DATA (cmsg);
321               uid = cr->uid;
322               gid = cr->gid;
323               pid = cr->pid;
324             }
325           else if (cmsg->cmsg_type == SCM_RIGHTS)
326             {
327               clib_memcpy_fast (fds, CMSG_DATA (cmsg), sizeof (int) * n_fds);
328             }
329         }
330       cmsg = CMSG_NXTHDR (&mh, cmsg);
331     }
332   return 0;
333 }
334
335 clib_error_t *
336 vl_sock_api_recv_fd_msg (int socket_fd, int fds[], int n_fds, u32 wait)
337 {
338   return vl_sock_api_recv_fd_msg_internal (socket_client_ctx, fds, n_fds,
339                                            wait);
340 }
341
342 clib_error_t *
343 vl_sock_api_recv_fd_msg2 (socket_client_main_t * scm, int socket_fd,
344                           int fds[], int n_fds, u32 wait)
345 {
346   socket_client_main_t *old_ctx;
347   clib_error_t *error;
348
349   old_ctx = vl_socket_client_ctx_push (scm);
350   error = vl_sock_api_recv_fd_msg_internal (scm, fds, n_fds, wait);
351   vl_socket_client_ctx_pop (old_ctx);
352   return error;
353 }
354
355 static void vl_api_sock_init_shm_reply_t_handler
356   (vl_api_sock_init_shm_reply_t * mp)
357 {
358   socket_client_main_t *scm = socket_client_ctx;
359   ssvm_private_t *memfd = &scm->memfd_segment;
360   i32 retval = ntohl (mp->retval);
361   api_main_t *am = vlibapi_get_main ();
362   clib_error_t *error;
363   int my_fd = -1;
364   u8 *new_name;
365
366   if (retval)
367     {
368       clib_warning ("failed to init shmem");
369       return;
370     }
371
372   /*
373    * Check the socket for the magic fd
374    */
375   error = vl_sock_api_recv_fd_msg (scm->socket_fd, &my_fd, 1, 5);
376   if (error)
377     {
378       clib_error_report (error);
379       retval = -99;
380       return;
381     }
382
383   clib_memset (memfd, 0, sizeof (*memfd));
384   memfd->fd = my_fd;
385
386   /* Note: this closes memfd.fd */
387   retval = ssvm_client_init_memfd (memfd);
388   if (retval)
389     clib_warning ("WARNING: segment map returned %d", retval);
390
391   /*
392    * Pivot to the memory client segment that vpp just created
393    */
394   am->vlib_rp = (void *) (memfd->requested_va + MMAP_PAGESIZE);
395   am->shmem_hdr = (void *) am->vlib_rp->user_ctx;
396
397   new_name = format (0, "%v[shm]%c", scm->name, 0);
398   vl_client_install_client_message_handlers ();
399   if (scm->want_shm_pthread)
400     {
401       vl_client_connect_to_vlib_no_map ("pvt", (char *) new_name,
402                                         32 /* input_queue_length */ );
403     }
404   else
405     {
406       vl_client_connect_to_vlib_no_rx_pthread_no_map ("pvt",
407                                                       (char *) new_name, 32
408                                                       /* input_queue_length */
409         );
410     }
411   vl_socket_client_enable_disable (0);
412   vec_free (new_name);
413 }
414
415 static void
416 vl_api_sockclnt_create_reply_t_handler (vl_api_sockclnt_create_reply_t * mp)
417 {
418   socket_client_main_t *scm = socket_client_ctx;
419   if (!mp->response)
420     {
421       scm->socket_enable = 1;
422       scm->client_index = clib_net_to_host_u32 (mp->index);
423     }
424 }
425
426 #define foreach_sock_client_api_msg                             \
427 _(SOCKCLNT_CREATE_REPLY, sockclnt_create_reply)                 \
428 _(SOCK_INIT_SHM_REPLY, sock_init_shm_reply)                     \
429
430 static void
431 noop_handler (void *notused)
432 {
433 }
434
435 void
436 vl_sock_client_install_message_handlers (void)
437 {
438
439 #define _(N, n)                                                               \
440   vl_msg_api_set_handlers (                                                   \
441     VL_API_##N, #n, vl_api_##n##_t_handler, noop_handler,                     \
442     vl_api_##n##_t_endian, vl_api_##n##_t_print, sizeof (vl_api_##n##_t), 0,  \
443     vl_api_##n##_t_print_json, vl_api_##n##_t_tojson,                         \
444     vl_api_##n##_t_fromjson, vl_api_##n##_t_calc_size);
445   foreach_sock_client_api_msg;
446 #undef _
447 }
448
449 int
450 vl_socket_client_connect_internal (socket_client_main_t * scm,
451                                    char *socket_path, char *client_name,
452                                    u32 socket_buffer_size)
453 {
454   vl_api_sockclnt_create_t *mp;
455   clib_socket_t *sock;
456   clib_error_t *error;
457
458   /* Already connected? */
459   if (scm->socket_fd)
460     return (-2);
461
462   /* bogus call? */
463   if (socket_path == 0 || client_name == 0)
464     return (-3);
465
466   sock = &scm->client_socket;
467   sock->config = socket_path;
468   sock->flags = CLIB_SOCKET_F_IS_CLIENT;
469
470   if ((error = clib_socket_init (sock)))
471     {
472       clib_error_report (error);
473       return (-1);
474     }
475
476   vl_sock_client_install_message_handlers ();
477
478   scm->socket_fd = sock->fd;
479   scm->socket_buffer_size = socket_buffer_size ? socket_buffer_size :
480     SOCKET_CLIENT_DEFAULT_BUFFER_SIZE;
481   vec_validate (scm->socket_tx_buffer, scm->socket_buffer_size - 1);
482   vec_validate (scm->socket_rx_buffer, scm->socket_buffer_size - 1);
483   _vec_len (scm->socket_rx_buffer) = 0;
484   _vec_len (scm->socket_tx_buffer) = 0;
485   scm->name = format (0, "%s", client_name);
486
487   mp = vl_socket_client_msg_alloc2 (scm, sizeof (*mp));
488   mp->_vl_msg_id = htons (VL_API_SOCKCLNT_CREATE);
489   strncpy ((char *) mp->name, client_name, sizeof (mp->name) - 1);
490   mp->name[sizeof (mp->name) - 1] = 0;
491   mp->context = 0xfeedface;
492
493   clib_time_init (&scm->clib_time);
494
495   if (vl_socket_client_write_internal (scm) <= 0)
496     return (-1);
497
498   if (vl_socket_client_read_internal (scm, 5))
499     return (-1);
500
501   return (0);
502 }
503
504 int
505 vl_socket_client_connect (char *socket_path, char *client_name,
506                           u32 socket_buffer_size)
507 {
508   return vl_socket_client_connect_internal (socket_client_ctx, socket_path,
509                                             client_name, socket_buffer_size);
510 }
511
512 int
513 vl_socket_client_connect2 (socket_client_main_t * scm, char *socket_path,
514                            char *client_name, u32 socket_buffer_size)
515 {
516   socket_client_main_t *old_ctx;
517   int rv;
518
519   old_ctx = vl_socket_client_ctx_push (scm);
520   rv = vl_socket_client_connect_internal (socket_client_ctx, socket_path,
521                                           client_name, socket_buffer_size);
522   vl_socket_client_ctx_pop (old_ctx);
523   return rv;
524 }
525
526 int
527 vl_socket_client_init_shm_internal (socket_client_main_t * scm,
528                                     vl_api_shm_elem_config_t * config,
529                                     int want_pthread)
530 {
531   vl_api_sock_init_shm_t *mp;
532   int rv, i;
533   u64 *cfg;
534
535   scm->want_shm_pthread = want_pthread;
536
537   mp = vl_socket_client_msg_alloc2 (scm, sizeof (*mp) +
538                                     vec_len (config) * sizeof (u64));
539   clib_memset (mp, 0, sizeof (*mp));
540   mp->_vl_msg_id = clib_host_to_net_u16 (VL_API_SOCK_INIT_SHM);
541   mp->client_index = clib_host_to_net_u32 (scm->client_index);
542   mp->requested_size = 64 << 20;
543
544   if (config)
545     {
546       for (i = 0; i < vec_len (config); i++)
547         {
548           cfg = (u64 *) & config[i];
549           mp->configs[i] = *cfg;
550         }
551       mp->nitems = vec_len (config);
552     }
553   rv = vl_socket_client_write_internal (scm);
554   if (rv <= 0)
555     return rv;
556
557   if (vl_socket_client_read_internal (scm, 1))
558     return -1;
559
560   return 0;
561 }
562
563 int
564 vl_socket_client_init_shm (vl_api_shm_elem_config_t * config,
565                            int want_pthread)
566 {
567   return vl_socket_client_init_shm_internal (socket_client_ctx, config,
568                                              want_pthread);
569 }
570
571 int
572 vl_socket_client_init_shm2 (socket_client_main_t * scm,
573                             vl_api_shm_elem_config_t * config,
574                             int want_pthread)
575 {
576   socket_client_main_t *old_ctx;
577   int rv;
578
579   old_ctx = vl_socket_client_ctx_push (scm);
580   rv = vl_socket_client_init_shm_internal (socket_client_ctx, config,
581                                            want_pthread);
582   vl_socket_client_ctx_pop (old_ctx);
583   return rv;
584 }
585
586 clib_error_t *
587 vl_socket_client_recv_fd_msg2 (socket_client_main_t * scm, int fds[],
588                                int n_fds, u32 wait)
589 {
590   if (!scm->socket_fd)
591     return clib_error_return (0, "no socket");
592   return vl_sock_api_recv_fd_msg_internal (scm, fds, n_fds, wait);
593 }
594
595 clib_error_t *
596 vl_socket_client_recv_fd_msg (int fds[], int n_fds, u32 wait)
597 {
598   return vl_socket_client_recv_fd_msg2 (socket_client_ctx, fds, n_fds, wait);
599 }
600
601 /*
602  * fd.io coding-style-patch-verification: ON
603  *
604  * Local Variables:
605  * eval: (c-set-style "gnu")
606  * End:
607  */