memif: fix crash caused by zero pkt len in memif and clear dirty cache while interfac...
[vpp.git] / src / plugins / memif / socket.c
1 /*
2  *------------------------------------------------------------------
3  * Copyright (c) 2016 Cisco and/or its affiliates.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *------------------------------------------------------------------
16  */
17
18 #define _GNU_SOURCE
19 #include <stdint.h>
20 #include <net/if.h>
21 #include <sys/types.h>
22 #include <fcntl.h>
23 #include <sys/ioctl.h>
24 #include <sys/socket.h>
25 #include <sys/un.h>
26 #include <sys/uio.h>
27 #include <sys/mman.h>
28 #include <sys/prctl.h>
29 #include <sys/eventfd.h>
30 #include <inttypes.h>
31 #include <limits.h>
32
33 #include <vlib/vlib.h>
34 #include <vlib/unix/unix.h>
35 #include <vnet/plugin/plugin.h>
36 #include <vnet/ethernet/ethernet.h>
37 #include <vpp/app/version.h>
38
39 #include <memif/memif.h>
40 #include <memif/private.h>
41
42 void
43 memif_socket_close (clib_socket_t ** s)
44 {
45   memif_file_del_by_index ((*s)->private_data);
46   clib_mem_free (*s);
47   *s = 0;
48 }
49
50 static u8 *
51 memif_str2vec (uint8_t * str, int len)
52 {
53   u8 *s = 0;
54   int i;
55
56   if (str[0] == 0)
57     return s;
58
59   for (i = 0; i < len; i++)
60     {
61       vec_add1 (s, str[i]);
62       if (str[i] == 0)
63         return s;
64     }
65   vec_add1 (s, 0);
66
67   return s;
68 }
69
70 static void
71 memif_msg_enq_ack (memif_if_t * mif)
72 {
73   memif_msg_fifo_elt_t *e;
74   clib_fifo_add2 (mif->msg_queue, e);
75
76   e->msg.type = MEMIF_MSG_TYPE_ACK;
77   e->fd = -1;
78 }
79
80 static clib_error_t *
81 memif_msg_enq_hello (clib_socket_t * sock)
82 {
83   u8 *s;
84   memif_msg_t msg = { 0 };
85   memif_msg_hello_t *h = &msg.hello;
86   msg.type = MEMIF_MSG_TYPE_HELLO;
87   h->min_version = MEMIF_VERSION;
88   h->max_version = MEMIF_VERSION;
89   h->max_m2s_ring = MEMIF_MAX_M2S_RING;
90   h->max_s2m_ring = MEMIF_MAX_M2S_RING;
91   h->max_region = MEMIF_MAX_REGION;
92   h->max_log2_ring_size = MEMIF_MAX_LOG2_RING_SIZE;
93   s = format (0, "VPP %s%c", VPP_BUILD_VER, 0);
94   strncpy ((char *) h->name, (char *) s, sizeof (h->name) - 1);
95   vec_free (s);
96   return clib_socket_sendmsg (sock, &msg, sizeof (memif_msg_t), 0, 0);
97 }
98
99 static void
100 memif_msg_enq_init (memif_if_t * mif)
101 {
102   u8 *s;
103   memif_msg_fifo_elt_t *e;
104   clib_fifo_add2 (mif->msg_queue, e);
105   memif_msg_init_t *i = &e->msg.init;
106
107   e->msg.type = MEMIF_MSG_TYPE_INIT;
108   e->fd = -1;
109   i->version = MEMIF_VERSION;
110   i->id = mif->id;
111   i->mode = mif->mode;
112   s = format (0, "VPP %s%c", VPP_BUILD_VER, 0);
113   strncpy ((char *) i->name, (char *) s, sizeof (i->name) - 1);
114   if (mif->secret)
115     strncpy ((char *) i->secret, (char *) mif->secret,
116              sizeof (i->secret) - 1);
117   vec_free (s);
118 }
119
120 static void
121 memif_msg_enq_add_region (memif_if_t * mif, u8 region)
122 {
123   memif_msg_fifo_elt_t *e;
124   clib_fifo_add2 (mif->msg_queue, e);
125   memif_msg_add_region_t *ar = &e->msg.add_region;
126
127   e->msg.type = MEMIF_MSG_TYPE_ADD_REGION;
128   e->fd = mif->regions[region].fd;
129   ar->index = region;
130   ar->size = mif->regions[region].region_size;
131 }
132
133 static void
134 memif_msg_enq_add_ring (memif_if_t * mif, u8 index, u8 direction)
135 {
136   memif_msg_fifo_elt_t *e;
137   clib_fifo_add2 (mif->msg_queue, e);
138   memif_msg_add_ring_t *ar = &e->msg.add_ring;
139   memif_queue_t *mq;
140
141   ASSERT ((mif->flags & MEMIF_IF_FLAG_IS_SLAVE) != 0);
142
143   e->msg.type = MEMIF_MSG_TYPE_ADD_RING;
144
145   if (direction == MEMIF_RING_M2S)
146     mq = vec_elt_at_index (mif->rx_queues, index);
147   else
148     mq = vec_elt_at_index (mif->tx_queues, index);
149
150   e->fd = mq->int_fd;
151   ar->index = index;
152   ar->region = mq->region;
153   ar->offset = mq->offset;
154   ar->log2_ring_size = mq->log2_ring_size;
155   ar->flags = (direction == MEMIF_RING_S2M) ? MEMIF_MSG_ADD_RING_FLAG_S2M : 0;
156 }
157
158 static void
159 memif_msg_enq_connect (memif_if_t * mif)
160 {
161   memif_msg_fifo_elt_t *e;
162   clib_fifo_add2 (mif->msg_queue, e);
163   memif_msg_connect_t *c = &e->msg.connect;
164   u8 *s;
165
166   e->msg.type = MEMIF_MSG_TYPE_CONNECT;
167   e->fd = -1;
168   s = format (0, "%U%c", format_memif_device_name, mif->dev_instance, 0);
169   strncpy ((char *) c->if_name, (char *) s, sizeof (c->if_name) - 1);
170   vec_free (s);
171 }
172
173 static void
174 memif_msg_enq_connected (memif_if_t * mif)
175 {
176   memif_msg_fifo_elt_t *e;
177   clib_fifo_add2 (mif->msg_queue, e);
178   memif_msg_connected_t *c = &e->msg.connected;
179   u8 *s;
180
181   e->msg.type = MEMIF_MSG_TYPE_CONNECTED;
182   e->fd = -1;
183   s = format (0, "%U%c", format_memif_device_name, mif->dev_instance, 0);
184   strncpy ((char *) c->if_name, (char *) s, sizeof (c->if_name) - 1);
185   vec_free (s);
186 }
187
188 clib_error_t *
189 memif_msg_send_disconnect (memif_if_t * mif, clib_error_t * err)
190 {
191   memif_msg_t msg = { 0 };
192   msg.type = MEMIF_MSG_TYPE_DISCONNECT;
193   memif_msg_disconnect_t *d = &msg.disconnect;
194
195   d->code = err->code;
196   strncpy ((char *) d->string, (char *) err->what, sizeof (d->string) - 1);
197
198   return clib_socket_sendmsg (mif->sock, &msg, sizeof (memif_msg_t), 0, 0);
199 }
200
201 static clib_error_t *
202 memif_msg_receive_hello (memif_if_t * mif, memif_msg_t * msg)
203 {
204   memif_msg_hello_t *h = &msg->hello;
205
206   if (msg->hello.min_version > MEMIF_VERSION ||
207       msg->hello.max_version < MEMIF_VERSION)
208     return clib_error_return (0, "incompatible protocol version");
209
210   mif->run.num_s2m_rings = clib_min (h->max_s2m_ring + 1,
211                                      mif->cfg.num_s2m_rings);
212   mif->run.num_m2s_rings = clib_min (h->max_m2s_ring + 1,
213                                      mif->cfg.num_m2s_rings);
214   mif->run.log2_ring_size = clib_min (h->max_log2_ring_size,
215                                       mif->cfg.log2_ring_size);
216   mif->run.buffer_size = mif->cfg.buffer_size;
217
218   mif->remote_name = memif_str2vec (h->name, sizeof (h->name));
219
220   return 0;
221 }
222
223 static clib_error_t *
224 memif_msg_receive_init (memif_if_t ** mifp, memif_msg_t * msg,
225                         clib_socket_t * sock, uword socket_file_index)
226 {
227   memif_main_t *mm = &memif_main;
228   memif_socket_file_t *msf =
229     vec_elt_at_index (mm->socket_files, socket_file_index);
230   memif_msg_init_t *i = &msg->init;
231   memif_if_t *mif, tmp;
232   clib_error_t *err;
233   uword *p;
234
235   if (i->version != MEMIF_VERSION)
236     {
237       memif_file_del_by_index (sock->private_data);
238       return clib_error_return (0, "unsupported version");
239     }
240
241   p = mhash_get (&msf->dev_instance_by_id, &i->id);
242
243   if (!p)
244     {
245       err = clib_error_return (0, "unmatched interface id");
246       goto error;
247     }
248
249   mif = vec_elt_at_index (mm->interfaces, p[0]);
250
251   if (mif->flags & MEMIF_IF_FLAG_IS_SLAVE)
252     {
253       err = clib_error_return (0, "cannot connect to slave");
254       goto error;
255     }
256
257   if (mif->sock)
258     {
259       err = clib_error_return (0, "already connected");
260       goto error;
261     }
262
263   if (i->mode != mif->mode)
264     {
265       err = clib_error_return (0, "mode mismatch");
266       goto error;
267     }
268
269   mif->sock = sock;
270   hash_set (msf->dev_instance_by_fd, mif->sock->fd, mif->dev_instance);
271   mif->remote_name = memif_str2vec (i->name, sizeof (i->name));
272   *mifp = mif;
273
274   if (mif->secret)
275     {
276       u8 *s;
277       int r;
278       s = memif_str2vec (i->secret, sizeof (i->secret));
279       if (s == 0)
280         return clib_error_return (0, "secret required");
281
282       r = vec_cmp (s, mif->secret);
283       vec_free (s);
284
285       if (r)
286         return clib_error_return (0, "incorrect secret");
287     }
288
289   return 0;
290
291 error:
292   tmp.sock = sock;
293   memif_msg_send_disconnect (&tmp, err);
294   memif_socket_close (&sock);
295   return err;
296 }
297
298 static clib_error_t *
299 memif_msg_receive_add_region (memif_if_t * mif, memif_msg_t * msg, int fd)
300 {
301   memif_msg_add_region_t *ar = &msg->add_region;
302   memif_region_t *mr;
303   if (fd < 0)
304     return clib_error_return (0, "missing memory region fd");
305
306   if (ar->index != vec_len (mif->regions))
307     return clib_error_return (0, "unexpected region index");
308
309   if (ar->index > MEMIF_MAX_REGION)
310     return clib_error_return (0, "too many regions");
311
312   vec_validate_aligned (mif->regions, ar->index, CLIB_CACHE_LINE_BYTES);
313   mr = vec_elt_at_index (mif->regions, ar->index);
314   mr->fd = fd;
315   mr->region_size = ar->size;
316
317   return 0;
318 }
319
320 static clib_error_t *
321 memif_msg_receive_add_ring (memif_if_t * mif, memif_msg_t * msg, int fd)
322 {
323   memif_msg_add_ring_t *ar = &msg->add_ring;
324   memif_queue_t *mq;
325
326   if (fd < 0)
327     return clib_error_return (0, "missing ring interrupt fd");
328
329   if (ar->flags & MEMIF_MSG_ADD_RING_FLAG_S2M)
330     {
331       if (ar->index != vec_len (mif->rx_queues))
332         return clib_error_return (0, "unexpected ring index");
333
334       if (ar->index > MEMIF_MAX_S2M_RING)
335         return clib_error_return (0, "too many rings");
336
337       vec_validate_aligned (mif->rx_queues, ar->index, CLIB_CACHE_LINE_BYTES);
338       mq = vec_elt_at_index (mif->rx_queues, ar->index);
339       mif->run.num_s2m_rings = vec_len (mif->rx_queues);
340     }
341   else
342     {
343       if (ar->index != vec_len (mif->tx_queues))
344         return clib_error_return (0, "unexpected ring index");
345
346       if (ar->index > MEMIF_MAX_M2S_RING)
347         return clib_error_return (0, "too many rings");
348
349       vec_validate_aligned (mif->tx_queues, ar->index, CLIB_CACHE_LINE_BYTES);
350       mq = vec_elt_at_index (mif->tx_queues, ar->index);
351       mif->run.num_m2s_rings = vec_len (mif->tx_queues);
352     }
353
354   // clear previous cache data if interface reconncected
355   memset (mq, 0, sizeof (memif_queue_t));
356   mq->int_fd = fd;
357   mq->int_clib_file_index = ~0;
358   mq->log2_ring_size = ar->log2_ring_size;
359   mq->region = ar->region;
360   mq->offset = ar->offset;
361   mq->type =
362     (ar->flags & MEMIF_MSG_ADD_RING_FLAG_S2M) ? MEMIF_RING_S2M :
363     MEMIF_RING_M2S;
364
365   return 0;
366 }
367
368 static clib_error_t *
369 memif_msg_receive_connect (memif_if_t * mif, memif_msg_t * msg)
370 {
371   clib_error_t *err;
372   memif_msg_connect_t *c = &msg->connect;
373
374   if ((err = memif_connect (mif)))
375     return err;
376
377   mif->remote_if_name = memif_str2vec (c->if_name, sizeof (c->if_name));
378
379   return 0;
380 }
381
382 static clib_error_t *
383 memif_msg_receive_connected (memif_if_t * mif, memif_msg_t * msg)
384 {
385   clib_error_t *err;
386   memif_msg_connected_t *c = &msg->connected;
387
388   if ((err = memif_connect (mif)))
389     return err;
390
391   mif->remote_if_name = memif_str2vec (c->if_name, sizeof (c->if_name));
392   return 0;
393 }
394
395 static clib_error_t *
396 memif_msg_receive_disconnect (memif_if_t * mif, memif_msg_t * msg)
397 {
398   memif_msg_disconnect_t *d = &msg->disconnect;
399
400   mif->remote_disc_string = memif_str2vec (d->string, sizeof (d->string));
401   return clib_error_return (0, "disconnect received");
402 }
403
404 static clib_error_t *
405 memif_msg_receive (memif_if_t ** mifp, clib_socket_t * sock, clib_file_t * uf)
406 {
407   memif_msg_t msg = { 0 };
408   clib_error_t *err = 0;
409   int fd = -1;
410   int i;
411   memif_if_t *mif = *mifp;
412
413   err = clib_socket_recvmsg (sock, &msg, sizeof (memif_msg_t), &fd, 1);
414   if (err)
415     return err;
416
417   if (mif == 0 && msg.type != MEMIF_MSG_TYPE_INIT)
418     {
419       memif_socket_close (&sock);
420       return clib_error_return (0, "unexpected message received");
421     }
422
423   DBG ("Message type %u received", msg.type);
424   /* process the message based on its type */
425   switch (msg.type)
426     {
427     case MEMIF_MSG_TYPE_ACK:
428       break;
429
430     case MEMIF_MSG_TYPE_HELLO:
431       if ((err = memif_msg_receive_hello (mif, &msg)))
432         return err;
433       if ((err = memif_init_regions_and_queues (mif)))
434         return err;
435       memif_msg_enq_init (mif);
436       memif_msg_enq_add_region (mif, 0);
437       vec_foreach_index (i, mif->tx_queues)
438         memif_msg_enq_add_ring (mif, i, MEMIF_RING_S2M);
439       vec_foreach_index (i, mif->rx_queues)
440         memif_msg_enq_add_ring (mif, i, MEMIF_RING_M2S);
441       memif_msg_enq_connect (mif);
442       break;
443
444     case MEMIF_MSG_TYPE_INIT:
445       if ((err = memif_msg_receive_init (mifp, &msg, sock, uf->private_data)))
446         return err;
447       mif = *mifp;
448       vec_reset_length (uf->description);
449       uf->description = format (uf->description, "%U ctl",
450                                 format_memif_device_name, mif->dev_instance);
451       memif_msg_enq_ack (mif);
452       break;
453
454     case MEMIF_MSG_TYPE_ADD_REGION:
455       if ((err = memif_msg_receive_add_region (mif, &msg, fd)))
456         return err;
457       memif_msg_enq_ack (mif);
458       break;
459
460     case MEMIF_MSG_TYPE_ADD_RING:
461       if ((err = memif_msg_receive_add_ring (mif, &msg, fd)))
462         return err;
463       memif_msg_enq_ack (mif);
464       break;
465
466     case MEMIF_MSG_TYPE_CONNECT:
467       if ((err = memif_msg_receive_connect (mif, &msg)))
468         return err;
469       memif_msg_enq_connected (mif);
470       break;
471
472     case MEMIF_MSG_TYPE_CONNECTED:
473       if ((err = memif_msg_receive_connected (mif, &msg)))
474         return err;
475       break;
476
477     case MEMIF_MSG_TYPE_DISCONNECT:
478       if ((err = memif_msg_receive_disconnect (mif, &msg)))
479         return err;
480       break;
481
482     default:
483       err = clib_error_return (0, "unknown message type (0x%x)", msg.type);
484       return err;
485     }
486
487   if (clib_fifo_elts (mif->msg_queue))
488     clib_file_set_data_available_to_write (&file_main,
489                                            mif->sock->private_data, 1);
490   return 0;
491 }
492
493 clib_error_t *
494 memif_master_conn_fd_read_ready (clib_file_t * uf)
495 {
496   memif_main_t *mm = &memif_main;
497   memif_socket_file_t *msf =
498     pool_elt_at_index (mm->socket_files, uf->private_data);
499   uword *p;
500   memif_if_t *mif = 0;
501   clib_socket_t *sock = 0;
502   clib_error_t *err = 0;
503
504   p = hash_get (msf->dev_instance_by_fd, uf->file_descriptor);
505   if (p)
506     {
507       mif = vec_elt_at_index (mm->interfaces, p[0]);
508       sock = mif->sock;
509     }
510   else
511     {
512       /* This is new connection, remove index from pending vector */
513       int i;
514       vec_foreach_index (i, msf->pending_clients)
515         if (msf->pending_clients[i]->fd == uf->file_descriptor)
516         {
517           sock = msf->pending_clients[i];
518           vec_del1 (msf->pending_clients, i);
519           break;
520         }
521       ASSERT (sock != 0);
522     }
523   err = memif_msg_receive (&mif, sock, uf);
524   if (err)
525     {
526       memif_disconnect (mif, err);
527       clib_error_free (err);
528     }
529   return 0;
530 }
531
532 clib_error_t *
533 memif_slave_conn_fd_read_ready (clib_file_t * uf)
534 {
535   memif_main_t *mm = &memif_main;
536   clib_error_t *err;
537   memif_if_t *mif = vec_elt_at_index (mm->interfaces, uf->private_data);
538   err = memif_msg_receive (&mif, mif->sock, uf);
539   if (err)
540     {
541       memif_disconnect (mif, err);
542       clib_error_free (err);
543     }
544   return 0;
545 }
546
547 static clib_error_t *
548 memif_conn_fd_write_ready (clib_file_t * uf, memif_if_t * mif)
549 {
550   memif_msg_fifo_elt_t *e;
551   clib_fifo_sub2 (mif->msg_queue, e);
552   clib_file_set_data_available_to_write (&file_main,
553                                          mif->sock->private_data, 0);
554   return clib_socket_sendmsg (mif->sock, &e->msg, sizeof (memif_msg_t),
555                               &e->fd, e->fd > -1 ? 1 : 0);
556 }
557
558 clib_error_t *
559 memif_master_conn_fd_write_ready (clib_file_t * uf)
560 {
561   memif_main_t *mm = &memif_main;
562   memif_socket_file_t *msf =
563     pool_elt_at_index (mm->socket_files, uf->private_data);
564   uword *p;
565   memif_if_t *mif;
566
567   p = hash_get (msf->dev_instance_by_fd, uf->file_descriptor);
568   if (!p)
569     return 0;
570
571   mif = vec_elt_at_index (mm->interfaces, p[0]);
572   return memif_conn_fd_write_ready (uf, mif);
573 }
574
575 clib_error_t *
576 memif_slave_conn_fd_write_ready (clib_file_t * uf)
577 {
578   memif_main_t *mm = &memif_main;
579   memif_if_t *mif = vec_elt_at_index (mm->interfaces, uf->private_data);
580   return memif_conn_fd_write_ready (uf, mif);
581 }
582
583 clib_error_t *
584 memif_slave_conn_fd_error (clib_file_t * uf)
585 {
586   memif_main_t *mm = &memif_main;
587   memif_if_t *mif = vec_elt_at_index (mm->interfaces, uf->private_data);
588   clib_error_t *err;
589
590   err = clib_error_return (0, "connection fd error");
591   memif_disconnect (mif, err);
592   clib_error_free (err);
593
594   return 0;
595 }
596
597 clib_error_t *
598 memif_master_conn_fd_error (clib_file_t * uf)
599 {
600   memif_main_t *mm = &memif_main;
601   memif_socket_file_t *msf =
602     pool_elt_at_index (mm->socket_files, uf->private_data);
603   uword *p;
604
605
606   p = hash_get (msf->dev_instance_by_fd, uf->file_descriptor);
607   if (p)
608     {
609       memif_if_t *mif;
610       clib_error_t *err;
611       mif = vec_elt_at_index (mm->interfaces, p[0]);
612       err = clib_error_return (0, "connection fd error");
613       memif_disconnect (mif, err);
614       clib_error_free (err);
615     }
616   else
617     {
618       int i;
619       vec_foreach_index (i, msf->pending_clients)
620         if (msf->pending_clients[i]->fd == uf->file_descriptor)
621         {
622           clib_socket_t *s = msf->pending_clients[i];
623           memif_socket_close (&s);
624           vec_del1 (msf->pending_clients, i);
625           return 0;
626         }
627     }
628
629   clib_warning ("Error on unknown file descriptor %d", uf->file_descriptor);
630   memif_file_del (uf);
631   return 0;
632 }
633
634
635 clib_error_t *
636 memif_conn_fd_accept_ready (clib_file_t * uf)
637 {
638   memif_main_t *mm = &memif_main;
639   memif_socket_file_t *msf =
640     pool_elt_at_index (mm->socket_files, uf->private_data);
641   clib_file_t template = { 0 };
642   clib_error_t *err;
643   clib_socket_t *client;
644
645   client = clib_mem_alloc (sizeof (clib_socket_t));
646   memset (client, 0, sizeof (clib_socket_t));
647   err = clib_socket_accept (msf->sock, client);
648   if (err)
649     goto error;
650
651   template.read_function = memif_master_conn_fd_read_ready;
652   template.write_function = memif_master_conn_fd_write_ready;
653   template.error_function = memif_master_conn_fd_error;
654   template.file_descriptor = client->fd;
655   template.private_data = uf->private_data;
656   template.description = format (0, "memif in conn on %s", msf->filename);
657
658   memif_file_add (&client->private_data, &template);
659
660   err = memif_msg_enq_hello (client);
661   if (err)
662     {
663       clib_socket_close (client);
664       goto error;
665     }
666
667   vec_add1 (msf->pending_clients, client);
668
669   return 0;
670
671 error:
672   clib_error_report (err);
673   clib_mem_free (client);
674   return err;
675 }
676
677 /*
678  * fd.io coding-style-patch-verification: ON
679  *
680  * Local Variables:
681  * eval: (c-set-style "gnu")
682  * End:
683  */