sock api: add infra for bootstrapping shm clients
[vpp.git] / src / vlibmemory / socket_client.c
1 /*
2  *------------------------------------------------------------------
3  * socket_client.c - API message handling over sockets, client code.
4  *
5  * Copyright (c) 2017 Cisco and/or its affiliates.
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at:
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  *------------------------------------------------------------------
18  */
19
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <setjmp.h>
23 #include <sys/types.h>
24 #define __USE_GNU
25 #include <sys/socket.h>
26 #include <sys/mman.h>
27 #include <sys/stat.h>
28 #include <netinet/in.h>
29 #include <signal.h>
30 #include <pthread.h>
31 #include <unistd.h>
32 #include <time.h>
33 #include <fcntl.h>
34 #include <string.h>
35 #include <vppinfra/clib.h>
36 #include <vppinfra/vec.h>
37 #include <vppinfra/hash.h>
38 #include <vppinfra/bitmap.h>
39 #include <vppinfra/fifo.h>
40 #include <vppinfra/time.h>
41 #include <vppinfra/mheap.h>
42 #include <vppinfra/heap.h>
43 #include <vppinfra/pool.h>
44 #include <vppinfra/format.h>
45
46 #include <vlib/vlib.h>
47 #include <vlib/unix/unix.h>
48 #include <svm/memfd.h>
49 #include <vlibmemory/api.h>
50
51 #include <vlibmemory/vl_memory_msg_enum.h>
52
53 #define vl_typedefs             /* define message structures */
54 #include <vlibmemory/vl_memory_api_h.h>
55 #undef vl_typedefs
56
57 #define vl_endianfun            /* define message structures */
58 #include <vlibmemory/vl_memory_api_h.h>
59 #undef vl_endianfun
60
61 /* instantiate all the print functions we know about */
62 #define vl_print(handle, ...) clib_warning (__VA_ARGS__)
63 #define vl_printfun
64 #include <vlibmemory/vl_memory_api_h.h>
65 #undef vl_printfun
66
67 socket_client_main_t socket_client_main;
68
69 /* Debug aid */
70 u32 vl (void *p) __attribute__ ((weak));
71
72 u32
73 vl (void *p)
74 {
75   return vec_len (p);
76 }
77
78 int
79 vl_socket_client_read (int wait)
80 {
81   socket_client_main_t *scm = &socket_client_main;
82   int n, current_rx_index;
83   msgbuf_t *mbp = 0;
84   f64 timeout;
85
86   if (scm->socket_fd == 0)
87     return -1;
88
89   if (wait)
90     timeout = clib_time_now (&scm->clib_time) + wait;
91
92   while (1)
93     {
94       current_rx_index = vec_len (scm->socket_rx_buffer);
95       while (vec_len (scm->socket_rx_buffer) <
96              sizeof (*mbp) + 2 /* msg id */ )
97         {
98           vec_validate (scm->socket_rx_buffer, current_rx_index
99                         + scm->socket_buffer_size - 1);
100           _vec_len (scm->socket_rx_buffer) = current_rx_index;
101           n = read (scm->socket_fd, scm->socket_rx_buffer + current_rx_index,
102                     scm->socket_buffer_size);
103           if (n < 0)
104             {
105               clib_unix_warning ("socket_read");
106               return -1;
107             }
108           _vec_len (scm->socket_rx_buffer) += n;
109         }
110
111 #if CLIB_DEBUG > 1
112       if (n > 0)
113         clib_warning ("read %d bytes", n);
114 #endif
115
116       if (mbp == 0)
117         mbp = (msgbuf_t *) (scm->socket_rx_buffer);
118
119       if (vec_len (scm->socket_rx_buffer) >= ntohl (mbp->data_len)
120           + sizeof (*mbp))
121         {
122           vl_msg_api_socket_handler ((void *) (mbp->data));
123
124           if (vec_len (scm->socket_rx_buffer) == ntohl (mbp->data_len)
125               + sizeof (*mbp))
126             _vec_len (scm->socket_rx_buffer) = 0;
127           else
128             vec_delete (scm->socket_rx_buffer, ntohl (mbp->data_len)
129                         + sizeof (*mbp), 0);
130           mbp = 0;
131
132           /* Quit if we're out of data, and not expecting a ping reply */
133           if (vec_len (scm->socket_rx_buffer) == 0
134               && scm->control_pings_outstanding == 0)
135             break;
136         }
137
138       if (wait && clib_time_now (&scm->clib_time) >= timeout)
139         return -1;
140     }
141   return 0;
142 }
143
144 int
145 vl_socket_client_write (void)
146 {
147   socket_client_main_t *scm = &socket_client_main;
148   int n;
149
150   msgbuf_t msgbuf = {
151     .q = 0,
152     .gc_mark_timestamp = 0,
153     .data_len = htonl (scm->socket_tx_nbytes),
154   };
155
156   n = write (scm->socket_fd, &msgbuf, sizeof (msgbuf));
157   if (n < sizeof (msgbuf))
158     {
159       clib_unix_warning ("socket write (msgbuf)");
160       return -1;
161     }
162
163   n = write (scm->socket_fd, scm->socket_tx_buffer, scm->socket_tx_nbytes);
164   if (n < scm->socket_tx_nbytes)
165     {
166       clib_unix_warning ("socket write (msg)");
167       return -1;
168     }
169
170   return n;
171 }
172
173 void *
174 vl_socket_client_msg_alloc (int nbytes)
175 {
176   socket_client_main.socket_tx_nbytes = nbytes;
177   return ((void *) socket_client_main.socket_tx_buffer);
178 }
179
180 void
181 vl_socket_client_disconnect (void)
182 {
183   socket_client_main_t *scm = &socket_client_main;
184   if (scm->socket_fd && (close (scm->socket_fd) < 0))
185     clib_unix_warning ("close");
186   scm->socket_fd = 0;
187 }
188
189 void
190 vl_socket_client_enable_disable (int enable)
191 {
192   socket_client_main_t *scm = &socket_client_main;
193   scm->socket_enable = enable;
194 }
195
196 static clib_error_t *
197 receive_fd_msg (int socket_fd, int *my_fd)
198 {
199   char msgbuf[16];
200   char ctl[CMSG_SPACE (sizeof (int)) + CMSG_SPACE (sizeof (struct ucred))];
201   struct msghdr mh = { 0 };
202   struct iovec iov[1];
203   ssize_t size;
204   struct ucred *cr = 0;
205   struct cmsghdr *cmsg;
206   pid_t pid __attribute__ ((unused));
207   uid_t uid __attribute__ ((unused));
208   gid_t gid __attribute__ ((unused));
209
210   iov[0].iov_base = msgbuf;
211   iov[0].iov_len = 5;
212   mh.msg_iov = iov;
213   mh.msg_iovlen = 1;
214   mh.msg_control = ctl;
215   mh.msg_controllen = sizeof (ctl);
216
217   memset (ctl, 0, sizeof (ctl));
218
219   /* receive the incoming message */
220   size = recvmsg (socket_fd, &mh, 0);
221   if (size != 5)
222     {
223       return (size == 0) ? clib_error_return (0, "disconnected") :
224         clib_error_return_unix (0, "recvmsg: malformed message (fd %d)",
225                                 socket_fd);
226     }
227
228   cmsg = CMSG_FIRSTHDR (&mh);
229   while (cmsg)
230     {
231       if (cmsg->cmsg_level == SOL_SOCKET)
232         {
233           if (cmsg->cmsg_type == SCM_CREDENTIALS)
234             {
235               cr = (struct ucred *) CMSG_DATA (cmsg);
236               uid = cr->uid;
237               gid = cr->gid;
238               pid = cr->pid;
239             }
240           else if (cmsg->cmsg_type == SCM_RIGHTS)
241             {
242               clib_memcpy (my_fd, CMSG_DATA (cmsg), sizeof (int));
243             }
244         }
245       cmsg = CMSG_NXTHDR (&mh, cmsg);
246     }
247   return 0;
248 }
249
250 static void vl_api_sock_init_shm_reply_t_handler
251   (vl_api_sock_init_shm_reply_t * mp)
252 {
253   socket_client_main_t *scm = &socket_client_main;
254   int my_fd = -1;
255   clib_error_t *error;
256   i32 retval = ntohl (mp->retval);
257   memfd_private_t memfd;
258   api_main_t *am = &api_main;
259   u8 *new_name;
260
261   if (retval)
262     {
263       clib_warning ("failed to init shmem");
264       return;
265     }
266
267   /*
268    * Check the socket for the magic fd
269    */
270   error = receive_fd_msg (scm->socket_fd, &my_fd);
271   if (error)
272     {
273       retval = -99;
274       return;
275     }
276
277   memset (&memfd, 0, sizeof (memfd));
278   memfd.fd = my_fd;
279
280   /* Note: this closes memfd.fd */
281   retval = memfd_slave_init (&memfd);
282   if (retval)
283     clib_warning ("WARNING: segment map returned %d", retval);
284
285   /*
286    * Pivot to the memory client segment that vpp just created
287    */
288   am->vlib_rp = (void *) (memfd.requested_va + MMAP_PAGESIZE);
289   am->shmem_hdr = (void *) am->vlib_rp->user_ctx;
290
291   new_name = format (0, "%v[shm]%c", scm->name, 0);
292   vl_client_install_client_message_handlers ();
293   vl_client_connect_to_vlib_no_map ("pvt", (char *) new_name,
294                                     32 /* input_queue_length */ );
295   vl_socket_client_enable_disable (0);
296   vec_free (new_name);
297 }
298
299 static void
300 vl_api_sockclnt_create_reply_t_handler (vl_api_sockclnt_create_reply_t * mp)
301 {
302   socket_client_main_t *scm = &socket_client_main;
303   if (!mp->response)
304     scm->socket_enable = 1;
305 }
306
307 #define foreach_sock_client_api_msg                             \
308 _(SOCKCLNT_CREATE_REPLY, sockclnt_create_reply)                 \
309 _(SOCK_INIT_SHM_REPLY, sock_init_shm_reply)                     \
310
311 static void
312 noop_handler (void *notused)
313 {
314 }
315
316 void
317 vl_sock_client_install_message_handlers (void)
318 {
319
320 #define _(N,n)                                                  \
321     vl_msg_api_set_handlers(VL_API_##N, #n,                     \
322                             vl_api_##n##_t_handler,             \
323                             noop_handler,                       \
324                             vl_api_##n##_t_endian,              \
325                             vl_api_##n##_t_print,               \
326                             sizeof(vl_api_##n##_t), 1);
327   foreach_sock_client_api_msg;
328 #undef _
329 }
330
331 int
332 vl_socket_client_connect (char *socket_path, char *client_name,
333                           u32 socket_buffer_size)
334 {
335   socket_client_main_t *scm = &socket_client_main;
336   vl_api_sockclnt_create_t *mp;
337   clib_socket_t *sock;
338   clib_error_t *error;
339
340   /* Already connected? */
341   if (scm->socket_fd)
342     return (-2);
343
344   /* bogus call? */
345   if (socket_path == 0 || client_name == 0)
346     return (-3);
347
348   sock = &scm->client_socket;
349   sock->config = socket_path;
350   sock->flags = CLIB_SOCKET_F_IS_CLIENT | CLIB_SOCKET_F_SEQPACKET;
351
352   if ((error = clib_socket_init (sock)))
353     {
354       clib_error_report (error);
355       return (-1);
356     }
357
358   vl_sock_client_install_message_handlers ();
359
360   scm->socket_fd = sock->fd;
361   scm->socket_buffer_size = socket_buffer_size ? socket_buffer_size :
362     SOCKET_CLIENT_DEFAULT_BUFFER_SIZE;
363   vec_validate (scm->socket_tx_buffer, scm->socket_buffer_size - 1);
364   vec_validate (scm->socket_rx_buffer, scm->socket_buffer_size - 1);
365   _vec_len (scm->socket_rx_buffer) = 0;
366   _vec_len (scm->socket_tx_buffer) = 0;
367   scm->name = format (0, "%s", client_name);
368
369   mp = vl_socket_client_msg_alloc (sizeof (*mp));
370   mp->_vl_msg_id = htons (VL_API_SOCKCLNT_CREATE);
371   strncpy ((char *) mp->name, client_name, sizeof (mp->name) - 1);
372   mp->name[sizeof (mp->name) - 1] = 0;
373   mp->context = 0xfeedface;
374
375   if (vl_socket_client_write () <= 0)
376     return (-1);
377
378   if (vl_socket_client_read (1))
379     return (-1);
380
381   clib_time_init (&scm->clib_time);
382   return (0);
383 }
384
385 int
386 vl_socket_client_init_shm (vl_api_shm_elem_config_t * config)
387 {
388   vl_api_sock_init_shm_t *mp;
389   int rv, i;
390   u64 *cfg;
391
392   mp = vl_socket_client_msg_alloc (sizeof (*mp) +
393                                    vec_len (config) * sizeof (u64));
394   memset (mp, 0, sizeof (*mp));
395   mp->_vl_msg_id = clib_host_to_net_u16 (VL_API_SOCK_INIT_SHM);
396   mp->client_index = ~0;
397   mp->requested_size = 64 << 20;
398
399   if (config)
400     {
401       for (i = 0; i < vec_len (config); i++)
402         {
403           cfg = (u64 *) & config[i];
404           mp->configs[i] = *cfg;
405         }
406       mp->nitems = vec_len (config);
407     }
408   rv = vl_socket_client_write ();
409   if (rv <= 0)
410     return rv;
411
412   if (vl_socket_client_read (1))
413     return -1;
414
415   return 0;
416 }
417
418 /*
419  * fd.io coding-style-patch-verification: ON
420  *
421  * Local Variables:
422  * eval: (c-set-style "gnu")
423  * End:
424  */