vppinfra: add abstract socket & netns fns
[vpp.git] / src / vppinfra / socket.c
1 /*
2  * Copyright (c) 2015 Cisco and/or its affiliates.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at:
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 /*
16   Copyright (c) 2001, 2002, 2003, 2005 Eliot Dresselhaus
17
18   Permission is hereby granted, free of charge, to any person obtaining
19   a copy of this software and associated documentation files (the
20   "Software"), to deal in the Software without restriction, including
21   without limitation the rights to use, copy, modify, merge, publish,
22   distribute, sublicense, and/or sell copies of the Software, and to
23   permit persons to whom the Software is furnished to do so, subject to
24   the following conditions:
25
26   The above copyright notice and this permission notice shall be
27   included in all copies or substantial portions of the Software.
28
29   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
30   EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
31   MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
32   NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
33   LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
34   OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
35   WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
36 */
37
38 #include <stdio.h>
39 #include <string.h>             /* strchr */
40 #define __USE_GNU
41 #define _GNU_SOURCE
42 #include <sys/types.h>
43 #include <sys/socket.h>
44 #include <sys/un.h>
45 #include <sys/stat.h>
46 #include <netinet/in.h>
47 #include <arpa/inet.h>
48 #include <netdb.h>
49 #include <unistd.h>
50 #include <fcntl.h>
51
52 #include <vppinfra/mem.h>
53 #include <vppinfra/vec.h>
54 #include <vppinfra/socket.h>
55 #include <vppinfra/linux/netns.h>
56 #include <vppinfra/format.h>
57 #include <vppinfra/error.h>
58
59 #ifndef __GLIBC__
60 /* IPPORT_USERRESERVED is not part of musl libc. */
61 #define IPPORT_USERRESERVED 5000
62 #endif
63
64 __clib_export void
65 clib_socket_tx_add_formatted (clib_socket_t * s, char *fmt, ...)
66 {
67   va_list va;
68   va_start (va, fmt);
69   clib_socket_tx_add_va_formatted (s, fmt, &va);
70   va_end (va);
71 }
72
73 /* Return and bind to an unused port. */
74 static word
75 find_free_port (word sock)
76 {
77   word port;
78
79   for (port = IPPORT_USERRESERVED; port < 1 << 16; port++)
80     {
81       struct sockaddr_in a;
82
83       clib_memset (&a, 0, sizeof (a));  /* Warnings be gone */
84
85       a.sin_family = PF_INET;
86       a.sin_addr.s_addr = INADDR_ANY;
87       a.sin_port = htons (port);
88
89       if (bind (sock, (struct sockaddr *) &a, sizeof (a)) >= 0)
90         break;
91     }
92
93   return port < 1 << 16 ? port : -1;
94 }
95
96 /* Convert a config string to a struct sockaddr and length for use
97    with bind or connect. */
98 static clib_error_t *
99 socket_config (char *config,
100                void *addr, socklen_t * addr_len, u32 ip4_default_address)
101 {
102   clib_error_t *error = 0;
103
104   if (!config)
105     config = "";
106
107   /* Anything that begins with a / is a local PF_LOCAL socket. */
108   if (config[0] == '/')
109     {
110       struct sockaddr_un *su = addr;
111       su->sun_family = PF_LOCAL;
112       clib_memcpy (&su->sun_path, config,
113                    clib_min (sizeof (su->sun_path), 1 + strlen (config)));
114       *addr_len = sizeof (su[0]);
115     }
116
117   /* Treat everything that starts with @ as an abstract socket. */
118   else if (config[0] == '@')
119     {
120       struct sockaddr_un *su = addr;
121       su->sun_family = PF_LOCAL;
122       clib_memcpy (&su->sun_path, config,
123                    clib_min (sizeof (su->sun_path), 1 + strlen (config)));
124
125       *addr_len = sizeof (su->sun_family) + strlen (config);
126       su->sun_path[0] = '\0';
127     }
128
129   /* Hostname or hostname:port or port. */
130   else
131     {
132       char *host_name;
133       int port = -1;
134       struct sockaddr_in *sa = addr;
135
136       host_name = 0;
137       port = -1;
138       if (config[0] != 0)
139         {
140           unformat_input_t i;
141
142           unformat_init_string (&i, config, strlen (config));
143           if (unformat (&i, "%s:%d", &host_name, &port)
144               || unformat (&i, "%s:0x%x", &host_name, &port))
145             ;
146           else if (unformat (&i, "%s", &host_name))
147             ;
148           else
149             error = clib_error_return (0, "unknown input `%U'",
150                                        format_unformat_error, &i);
151           unformat_free (&i);
152
153           if (error)
154             goto done;
155         }
156
157       sa->sin_family = PF_INET;
158       *addr_len = sizeof (sa[0]);
159       if (port != -1)
160         sa->sin_port = htons (port);
161       else
162         sa->sin_port = 0;
163
164       if (host_name)
165         {
166           struct in_addr host_addr;
167
168           /* Recognize localhost to avoid host lookup in most common cast. */
169           if (!strcmp (host_name, "localhost"))
170             sa->sin_addr.s_addr = htonl (INADDR_LOOPBACK);
171
172           else if (inet_aton (host_name, &host_addr))
173             sa->sin_addr = host_addr;
174
175           else if (host_name && strlen (host_name) > 0)
176             {
177               struct hostent *host = gethostbyname (host_name);
178               if (!host)
179                 error = clib_error_return (0, "unknown host `%s'", config);
180               else
181                 clib_memcpy (&sa->sin_addr.s_addr, host->h_addr_list[0],
182                              host->h_length);
183             }
184
185           else
186             sa->sin_addr.s_addr = htonl (ip4_default_address);
187
188           vec_free (host_name);
189           if (error)
190             goto done;
191         }
192     }
193
194 done:
195   return error;
196 }
197
198 static clib_error_t *
199 default_socket_write (clib_socket_t * s)
200 {
201   clib_error_t *err = 0;
202   word written = 0;
203   word fd = 0;
204   word tx_len;
205
206   fd = s->fd;
207
208   /* Map standard input to standard output.
209      Typically, fd is a socket for which read/write both work. */
210   if (fd == 0)
211     fd = 1;
212
213   tx_len = vec_len (s->tx_buffer);
214   written = write (fd, s->tx_buffer, tx_len);
215
216   /* Ignore certain errors. */
217   if (written < 0 && !unix_error_is_fatal (errno))
218     written = 0;
219
220   /* A "real" error occurred. */
221   if (written < 0)
222     {
223       err = clib_error_return_unix (0, "write %wd bytes (fd %d, '%s')",
224                                     tx_len, s->fd, s->config);
225       vec_free (s->tx_buffer);
226       goto done;
227     }
228
229   /* Reclaim the transmitted part of the tx buffer on successful writes. */
230   else if (written > 0)
231     {
232       if (written == tx_len)
233         _vec_len (s->tx_buffer) = 0;
234       else
235         vec_delete (s->tx_buffer, written, 0);
236     }
237
238   /* If a non-fatal error occurred AND
239      the buffer is full, then we must free it. */
240   else if (written == 0 && tx_len > 64 * 1024)
241     {
242       vec_free (s->tx_buffer);
243     }
244
245 done:
246   return err;
247 }
248
249 static clib_error_t *
250 default_socket_read (clib_socket_t * sock, int n_bytes)
251 {
252   word fd, n_read;
253   u8 *buf;
254
255   /* RX side of socket is down once end of file is reached. */
256   if (sock->flags & CLIB_SOCKET_F_RX_END_OF_FILE)
257     return 0;
258
259   fd = sock->fd;
260
261   n_bytes = clib_max (n_bytes, 4096);
262   vec_add2 (sock->rx_buffer, buf, n_bytes);
263
264   if ((n_read = read (fd, buf, n_bytes)) < 0)
265     {
266       n_read = 0;
267
268       /* Ignore certain errors. */
269       if (!unix_error_is_fatal (errno))
270         goto non_fatal;
271
272       return clib_error_return_unix (0, "read %d bytes (fd %d, '%s')",
273                                      n_bytes, sock->fd, sock->config);
274     }
275
276   /* Other side closed the socket. */
277   if (n_read == 0)
278     sock->flags |= CLIB_SOCKET_F_RX_END_OF_FILE;
279
280 non_fatal:
281   _vec_len (sock->rx_buffer) += n_read - n_bytes;
282
283   return 0;
284 }
285
286 static clib_error_t *
287 default_socket_close (clib_socket_t * s)
288 {
289   if (close (s->fd) < 0)
290     return clib_error_return_unix (0, "close (fd %d, %s)", s->fd, s->config);
291   return 0;
292 }
293
294 static clib_error_t *
295 default_socket_sendmsg (clib_socket_t * s, void *msg, int msglen,
296                         int fds[], int num_fds)
297 {
298   struct msghdr mh = { 0 };
299   struct iovec iov[1];
300   char ctl[CMSG_SPACE (sizeof (int) * num_fds)];
301   int rv;
302
303   iov[0].iov_base = msg;
304   iov[0].iov_len = msglen;
305   mh.msg_iov = iov;
306   mh.msg_iovlen = 1;
307
308   if (num_fds > 0)
309     {
310       struct cmsghdr *cmsg;
311       clib_memset (&ctl, 0, sizeof (ctl));
312       mh.msg_control = ctl;
313       mh.msg_controllen = sizeof (ctl);
314       cmsg = CMSG_FIRSTHDR (&mh);
315       cmsg->cmsg_len = CMSG_LEN (sizeof (int) * num_fds);
316       cmsg->cmsg_level = SOL_SOCKET;
317       cmsg->cmsg_type = SCM_RIGHTS;
318       memcpy (CMSG_DATA (cmsg), fds, sizeof (int) * num_fds);
319     }
320   rv = sendmsg (s->fd, &mh, 0);
321   if (rv < 0)
322     return clib_error_return_unix (0, "sendmsg");
323   return 0;
324 }
325
326
327 static clib_error_t *
328 default_socket_recvmsg (clib_socket_t * s, void *msg, int msglen,
329                         int fds[], int num_fds)
330 {
331 #ifdef __linux__
332   char ctl[CMSG_SPACE (sizeof (int) * num_fds) +
333            CMSG_SPACE (sizeof (struct ucred))];
334   struct ucred *cr = 0;
335 #else
336   char ctl[CMSG_SPACE (sizeof (int) * num_fds)];
337 #endif
338   struct msghdr mh = { 0 };
339   struct iovec iov[1];
340   ssize_t size;
341   struct cmsghdr *cmsg;
342
343   iov[0].iov_base = msg;
344   iov[0].iov_len = msglen;
345   mh.msg_iov = iov;
346   mh.msg_iovlen = 1;
347   mh.msg_control = ctl;
348   mh.msg_controllen = sizeof (ctl);
349
350   clib_memset (ctl, 0, sizeof (ctl));
351
352   /* receive the incoming message */
353   size = recvmsg (s->fd, &mh, 0);
354   if (size != msglen)
355     {
356       return (size == 0) ? clib_error_return (0, "disconnected") :
357         clib_error_return_unix (0, "recvmsg: malformed message (fd %d, '%s')",
358                                 s->fd, s->config);
359     }
360
361   cmsg = CMSG_FIRSTHDR (&mh);
362   while (cmsg)
363     {
364       if (cmsg->cmsg_level == SOL_SOCKET)
365         {
366 #ifdef __linux__
367           if (cmsg->cmsg_type == SCM_CREDENTIALS)
368             {
369               cr = (struct ucred *) CMSG_DATA (cmsg);
370               s->uid = cr->uid;
371               s->gid = cr->gid;
372               s->pid = cr->pid;
373             }
374           else
375 #endif
376           if (cmsg->cmsg_type == SCM_RIGHTS)
377             {
378               clib_memcpy_fast (fds, CMSG_DATA (cmsg),
379                                 num_fds * sizeof (int));
380             }
381         }
382       cmsg = CMSG_NXTHDR (&mh, cmsg);
383     }
384   return 0;
385 }
386
387 static void
388 socket_init_funcs (clib_socket_t * s)
389 {
390   if (!s->write_func)
391     s->write_func = default_socket_write;
392   if (!s->read_func)
393     s->read_func = default_socket_read;
394   if (!s->close_func)
395     s->close_func = default_socket_close;
396   if (!s->sendmsg_func)
397     s->sendmsg_func = default_socket_sendmsg;
398   if (!s->recvmsg_func)
399     s->recvmsg_func = default_socket_recvmsg;
400 }
401
402 __clib_export clib_error_t *
403 clib_socket_init (clib_socket_t * s)
404 {
405   union
406   {
407     struct sockaddr sa;
408     struct sockaddr_un su;
409   } addr;
410   socklen_t addr_len = 0;
411   int socket_type, rv;
412   clib_error_t *error = 0;
413   word port;
414
415   error = socket_config (s->config, &addr.sa, &addr_len,
416                          (s->flags & CLIB_SOCKET_F_IS_SERVER
417                           ? INADDR_LOOPBACK : INADDR_ANY));
418   if (error)
419     goto done;
420
421   socket_init_funcs (s);
422
423   socket_type = s->flags & CLIB_SOCKET_F_SEQPACKET ?
424     SOCK_SEQPACKET : SOCK_STREAM;
425
426   s->fd = socket (addr.sa.sa_family, socket_type, 0);
427   if (s->fd < 0)
428     {
429       error = clib_error_return_unix (0, "socket (fd %d, '%s')",
430                                       s->fd, s->config);
431       goto done;
432     }
433
434   port = 0;
435   if (addr.sa.sa_family == PF_INET)
436     port = ((struct sockaddr_in *) &addr)->sin_port;
437
438   if (s->flags & CLIB_SOCKET_F_IS_SERVER)
439     {
440       uword need_bind = 1;
441
442       if (addr.sa.sa_family == PF_INET)
443         {
444           if (port == 0)
445             {
446               port = find_free_port (s->fd);
447               if (port < 0)
448                 {
449                   error = clib_error_return (0, "no free port (fd %d, '%s')",
450                                              s->fd, s->config);
451                   goto done;
452                 }
453               need_bind = 0;
454             }
455         }
456       if (addr.sa.sa_family == PF_LOCAL &&
457           ((struct sockaddr_un *) &addr)->sun_path[0] != 0)
458         unlink (((struct sockaddr_un *) &addr)->sun_path);
459
460       /* Make address available for multiple users. */
461       {
462         int v = 1;
463         if (setsockopt (s->fd, SOL_SOCKET, SO_REUSEADDR, &v, sizeof (v)) < 0)
464           clib_unix_warning ("setsockopt SO_REUSEADDR fails");
465       }
466
467 #if __linux__
468       if (addr.sa.sa_family == PF_LOCAL && s->flags & CLIB_SOCKET_F_PASSCRED)
469         {
470           int x = 1;
471           if (setsockopt (s->fd, SOL_SOCKET, SO_PASSCRED, &x, sizeof (x)) < 0)
472             {
473               error = clib_error_return_unix (0, "setsockopt (SO_PASSCRED, "
474                                               "fd %d, '%s')", s->fd,
475                                               s->config);
476               goto done;
477             }
478         }
479 #endif
480
481       if (need_bind && bind (s->fd, &addr.sa, addr_len) < 0)
482         {
483           error = clib_error_return_unix (0, "bind (fd %d, '%s')",
484                                           s->fd, s->config);
485           goto done;
486         }
487
488       if (listen (s->fd, 5) < 0)
489         {
490           error = clib_error_return_unix (0, "listen (fd %d, '%s')",
491                                           s->fd, s->config);
492           goto done;
493         }
494       if (addr.sa.sa_family == PF_LOCAL &&
495           s->flags & CLIB_SOCKET_F_ALLOW_GROUP_WRITE &&
496           ((struct sockaddr_un *) &addr)->sun_path[0] != 0)
497         {
498           struct stat st = { 0 };
499           if (stat (((struct sockaddr_un *) &addr)->sun_path, &st) < 0)
500             {
501               error = clib_error_return_unix (0, "stat (fd %d, '%s')",
502                                               s->fd, s->config);
503               goto done;
504             }
505           st.st_mode |= S_IWGRP;
506           if (chmod (((struct sockaddr_un *) &addr)->sun_path, st.st_mode) <
507               0)
508             {
509               error =
510                 clib_error_return_unix (0, "chmod (fd %d, '%s', mode %o)",
511                                         s->fd, s->config, st.st_mode);
512               goto done;
513             }
514         }
515     }
516   else
517     {
518       if ((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT)
519           && fcntl (s->fd, F_SETFL, O_NONBLOCK) < 0)
520         {
521           error = clib_error_return_unix (0, "fcntl NONBLOCK (fd %d, '%s')",
522                                           s->fd, s->config);
523           goto done;
524         }
525
526       while ((rv = connect (s->fd, &addr.sa, addr_len)) < 0
527              && errno == EAGAIN)
528         ;
529       if (rv < 0 && !((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT) &&
530                       errno == EINPROGRESS))
531         {
532           error = clib_error_return_unix (0, "connect (fd %d, '%s')",
533                                           s->fd, s->config);
534           goto done;
535         }
536       /* Connect was blocking so set fd to non-blocking now unless
537        * blocking mode explicitly requested. */
538       if (!(s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT) &&
539           !(s->flags & CLIB_SOCKET_F_BLOCKING) &&
540           fcntl (s->fd, F_SETFL, O_NONBLOCK) < 0)
541         {
542           error = clib_error_return_unix (0, "fcntl NONBLOCK2 (fd %d, '%s')",
543                                           s->fd, s->config);
544           goto done;
545         }
546     }
547
548   return error;
549
550 done:
551   if (s->fd > 0)
552     close (s->fd);
553   return error;
554 }
555
556 __clib_export clib_error_t *
557 clib_socket_init_netns (clib_socket_t *s, u8 *namespace)
558 {
559   if (namespace == NULL || namespace[0] == 0)
560     return clib_socket_init (s);
561
562   clib_error_t *error;
563   int old_netns_fd, nfd;
564
565   old_netns_fd = clib_netns_open (NULL /* self */);
566   if ((nfd = clib_netns_open (namespace)) == -1)
567     {
568       error = clib_error_return_unix (0, "clib_netns_open '%s'", namespace);
569       goto done;
570     }
571
572   if (clib_setns (nfd) == -1)
573     {
574       error = clib_error_return_unix (0, "setns '%s'", namespace);
575       goto done;
576     }
577
578   error = clib_socket_init (s);
579
580 done:
581   if (clib_setns (old_netns_fd) == -1)
582     clib_warning ("Cannot set old ns");
583   close (old_netns_fd);
584
585   return error;
586 }
587
588 __clib_export clib_error_t *
589 clib_socket_accept (clib_socket_t * server, clib_socket_t * client)
590 {
591   clib_error_t *err = 0;
592   socklen_t len = 0;
593
594   clib_memset (client, 0, sizeof (client[0]));
595
596   /* Accept the new socket connection. */
597   client->fd = accept (server->fd, 0, 0);
598   if (client->fd < 0)
599     return clib_error_return_unix (0, "accept (fd %d, '%s')",
600                                    server->fd, server->config);
601
602   /* Set the new socket to be non-blocking. */
603   if (fcntl (client->fd, F_SETFL, O_NONBLOCK) < 0)
604     {
605       err = clib_error_return_unix (0, "fcntl O_NONBLOCK (fd %d)",
606                                     client->fd);
607       goto close_client;
608     }
609
610   /* Get peer info. */
611   len = sizeof (client->peer);
612   if (getpeername (client->fd, (struct sockaddr *) &client->peer, &len) < 0)
613     {
614       err = clib_error_return_unix (0, "getpeername (fd %d)", client->fd);
615       goto close_client;
616     }
617
618   client->flags = CLIB_SOCKET_F_IS_CLIENT;
619
620   socket_init_funcs (client);
621   return 0;
622
623 close_client:
624   close (client->fd);
625   return err;
626 }
627
628 /*
629  * fd.io coding-style-patch-verification: ON
630  *
631  * Local Variables:
632  * eval: (c-set-style "gnu")
633  * End:
634  */