clib_socket: add sendmsg / recvmsg with ancillary data support
[vpp.git] / src / vppinfra / socket.c
1 /*
2  * Copyright (c) 2015 Cisco and/or its affiliates.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at:
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 /*
16   Copyright (c) 2001, 2002, 2003, 2005 Eliot Dresselhaus
17
18   Permission is hereby granted, free of charge, to any person obtaining
19   a copy of this software and associated documentation files (the
20   "Software"), to deal in the Software without restriction, including
21   without limitation the rights to use, copy, modify, merge, publish,
22   distribute, sublicense, and/or sell copies of the Software, and to
23   permit persons to whom the Software is furnished to do so, subject to
24   the following conditions:
25
26   The above copyright notice and this permission notice shall be
27   included in all copies or substantial portions of the Software.
28
29   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
30   EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
31   MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
32   NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
33   LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
34   OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
35   WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
36 */
37
38 #include <stdio.h>
39 #include <string.h>             /* strchr */
40 #define __USE_GNU
41 #include <sys/types.h>
42 #include <sys/socket.h>
43 #include <sys/un.h>
44 #include <sys/stat.h>
45 #include <netinet/in.h>
46 #include <arpa/inet.h>
47 #include <netdb.h>
48 #include <unistd.h>
49 #include <fcntl.h>
50
51 #include <vppinfra/mem.h>
52 #include <vppinfra/vec.h>
53 #include <vppinfra/socket.h>
54 #include <vppinfra/format.h>
55 #include <vppinfra/error.h>
56
57 void
58 clib_socket_tx_add_formatted (clib_socket_t * s, char *fmt, ...)
59 {
60   va_list va;
61   va_start (va, fmt);
62   clib_socket_tx_add_va_formatted (s, fmt, &va);
63   va_end (va);
64 }
65
66 /* Return and bind to an unused port. */
67 static word
68 find_free_port (word sock)
69 {
70   word port;
71
72   for (port = IPPORT_USERRESERVED; port < 1 << 16; port++)
73     {
74       struct sockaddr_in a;
75
76       memset (&a, 0, sizeof (a));       /* Warnings be gone */
77
78       a.sin_family = PF_INET;
79       a.sin_addr.s_addr = INADDR_ANY;
80       a.sin_port = htons (port);
81
82       if (bind (sock, (struct sockaddr *) &a, sizeof (a)) >= 0)
83         break;
84     }
85
86   return port < 1 << 16 ? port : -1;
87 }
88
89 /* Convert a config string to a struct sockaddr and length for use
90    with bind or connect. */
91 static clib_error_t *
92 socket_config (char *config,
93                void *addr, socklen_t * addr_len, u32 ip4_default_address)
94 {
95   clib_error_t *error = 0;
96
97   if (!config)
98     config = "";
99
100   /* Anything that begins with a / is a local PF_LOCAL socket. */
101   if (config[0] == '/')
102     {
103       struct sockaddr_un *su = addr;
104       su->sun_family = PF_LOCAL;
105       clib_memcpy (&su->sun_path, config,
106                    clib_min (sizeof (su->sun_path), 1 + strlen (config)));
107       *addr_len = sizeof (su[0]);
108     }
109
110   /* Hostname or hostname:port or port. */
111   else
112     {
113       char *host_name;
114       int port = -1;
115       struct sockaddr_in *sa = addr;
116
117       host_name = 0;
118       port = -1;
119       if (config[0] != 0)
120         {
121           unformat_input_t i;
122
123           unformat_init_string (&i, config, strlen (config));
124           if (unformat (&i, "%s:%d", &host_name, &port)
125               || unformat (&i, "%s:0x%x", &host_name, &port))
126             ;
127           else if (unformat (&i, "%s", &host_name))
128             ;
129           else
130             error = clib_error_return (0, "unknown input `%U'",
131                                        format_unformat_error, &i);
132           unformat_free (&i);
133
134           if (error)
135             goto done;
136         }
137
138       sa->sin_family = PF_INET;
139       *addr_len = sizeof (sa[0]);
140       if (port != -1)
141         sa->sin_port = htons (port);
142       else
143         sa->sin_port = 0;
144
145       if (host_name)
146         {
147           struct in_addr host_addr;
148
149           /* Recognize localhost to avoid host lookup in most common cast. */
150           if (!strcmp (host_name, "localhost"))
151             sa->sin_addr.s_addr = htonl (INADDR_LOOPBACK);
152
153           else if (inet_aton (host_name, &host_addr))
154             sa->sin_addr = host_addr;
155
156           else if (host_name && strlen (host_name) > 0)
157             {
158               struct hostent *host = gethostbyname (host_name);
159               if (!host)
160                 error = clib_error_return (0, "unknown host `%s'", config);
161               else
162                 clib_memcpy (&sa->sin_addr.s_addr, host->h_addr_list[0],
163                              host->h_length);
164             }
165
166           else
167             sa->sin_addr.s_addr = htonl (ip4_default_address);
168
169           vec_free (host_name);
170           if (error)
171             goto done;
172         }
173     }
174
175 done:
176   return error;
177 }
178
179 static clib_error_t *
180 default_socket_write (clib_socket_t * s)
181 {
182   clib_error_t *err = 0;
183   word written = 0;
184   word fd = 0;
185   word tx_len;
186
187   fd = s->fd;
188
189   /* Map standard input to standard output.
190      Typically, fd is a socket for which read/write both work. */
191   if (fd == 0)
192     fd = 1;
193
194   tx_len = vec_len (s->tx_buffer);
195   written = write (fd, s->tx_buffer, tx_len);
196
197   /* Ignore certain errors. */
198   if (written < 0 && !unix_error_is_fatal (errno))
199     written = 0;
200
201   /* A "real" error occurred. */
202   if (written < 0)
203     {
204       err = clib_error_return_unix (0, "write %wd bytes (fd %d, '%s')",
205                                     tx_len, s->fd, s->config);
206       vec_free (s->tx_buffer);
207       goto done;
208     }
209
210   /* Reclaim the transmitted part of the tx buffer on successful writes. */
211   else if (written > 0)
212     {
213       if (written == tx_len)
214         _vec_len (s->tx_buffer) = 0;
215       else
216         vec_delete (s->tx_buffer, written, 0);
217     }
218
219   /* If a non-fatal error occurred AND
220      the buffer is full, then we must free it. */
221   else if (written == 0 && tx_len > 64 * 1024)
222     {
223       vec_free (s->tx_buffer);
224     }
225
226 done:
227   return err;
228 }
229
230 static clib_error_t *
231 default_socket_read (clib_socket_t * sock, int n_bytes)
232 {
233   word fd, n_read;
234   u8 *buf;
235
236   /* RX side of socket is down once end of file is reached. */
237   if (sock->flags & CLIB_SOCKET_F_RX_END_OF_FILE)
238     return 0;
239
240   fd = sock->fd;
241
242   n_bytes = clib_max (n_bytes, 4096);
243   vec_add2 (sock->rx_buffer, buf, n_bytes);
244
245   if ((n_read = read (fd, buf, n_bytes)) < 0)
246     {
247       n_read = 0;
248
249       /* Ignore certain errors. */
250       if (!unix_error_is_fatal (errno))
251         goto non_fatal;
252
253       return clib_error_return_unix (0, "read %d bytes (fd %d, '%s')",
254                                      n_bytes, sock->fd, sock->config);
255     }
256
257   /* Other side closed the socket. */
258   if (n_read == 0)
259     sock->flags |= CLIB_SOCKET_F_RX_END_OF_FILE;
260
261 non_fatal:
262   _vec_len (sock->rx_buffer) += n_read - n_bytes;
263
264   return 0;
265 }
266
267 static clib_error_t *
268 default_socket_close (clib_socket_t * s)
269 {
270   if (close (s->fd) < 0)
271     return clib_error_return_unix (0, "close (fd %d, %s)", s->fd, s->config);
272   return 0;
273 }
274
275 static clib_error_t *
276 default_socket_sendmsg (clib_socket_t * s, void *msg, int msglen,
277                         int fds[], int num_fds)
278 {
279   struct msghdr mh = { 0 };
280   struct iovec iov[1];
281   char ctl[CMSG_SPACE (sizeof (int)) * num_fds];
282   int rv;
283
284   iov[0].iov_base = msg;
285   iov[0].iov_len = msglen;
286   mh.msg_iov = iov;
287   mh.msg_iovlen = 1;
288
289   if (num_fds > 0)
290     {
291       struct cmsghdr *cmsg;
292       memset (&ctl, 0, sizeof (ctl));
293       mh.msg_control = ctl;
294       mh.msg_controllen = sizeof (ctl);
295       cmsg = CMSG_FIRSTHDR (&mh);
296       cmsg->cmsg_len = CMSG_LEN (sizeof (int) * num_fds);
297       cmsg->cmsg_level = SOL_SOCKET;
298       cmsg->cmsg_type = SCM_RIGHTS;
299       memcpy (CMSG_DATA (cmsg), fds, sizeof (int) * num_fds);
300     }
301   rv = sendmsg (s->fd, &mh, 0);
302   if (rv < 0)
303     return clib_error_return_unix (0, "sendmsg");
304   return 0;
305 }
306
307
308 static clib_error_t *
309 default_socket_recvmsg (clib_socket_t * s, void *msg, int msglen,
310                         int fds[], int num_fds)
311 {
312   char ctl[CMSG_SPACE (sizeof (int) * num_fds) +
313            CMSG_SPACE (sizeof (struct ucred))];
314   struct msghdr mh = { 0 };
315   struct iovec iov[1];
316   ssize_t size;
317   struct ucred *cr = 0;
318   struct cmsghdr *cmsg;
319
320   iov[0].iov_base = msg;
321   iov[0].iov_len = msglen;
322   mh.msg_iov = iov;
323   mh.msg_iovlen = 1;
324   mh.msg_control = ctl;
325   mh.msg_controllen = sizeof (ctl);
326
327   memset (ctl, 0, sizeof (ctl));
328
329   /* receive the incoming message */
330   size = recvmsg (s->fd, &mh, 0);
331   if (size != msglen)
332     {
333       return (size == 0) ? clib_error_return (0, "disconnected") :
334         clib_error_return_unix (0, "recvmsg: malformed message (fd %d, '%s')",
335                                 s->fd, s->config);
336     }
337
338   cmsg = CMSG_FIRSTHDR (&mh);
339   while (cmsg)
340     {
341       if (cmsg->cmsg_level == SOL_SOCKET)
342         {
343           if (cmsg->cmsg_type == SCM_CREDENTIALS)
344             {
345               cr = (struct ucred *) CMSG_DATA (cmsg);
346               s->uid = cr->uid;
347               s->gid = cr->gid;
348               s->pid = cr->pid;
349             }
350           else if (cmsg->cmsg_type == SCM_RIGHTS)
351             {
352               clib_memcpy (fds, CMSG_DATA (cmsg), num_fds * sizeof (int));
353             }
354         }
355       cmsg = CMSG_NXTHDR (&mh, cmsg);
356     }
357   return 0;
358 }
359
360 static void
361 socket_init_funcs (clib_socket_t * s)
362 {
363   if (!s->write_func)
364     s->write_func = default_socket_write;
365   if (!s->read_func)
366     s->read_func = default_socket_read;
367   if (!s->close_func)
368     s->close_func = default_socket_close;
369   if (!s->sendmsg_func)
370     s->sendmsg_func = default_socket_sendmsg;
371   if (!s->recvmsg_func)
372     s->recvmsg_func = default_socket_recvmsg;
373 }
374
375 clib_error_t *
376 clib_socket_init (clib_socket_t * s)
377 {
378   union
379   {
380     struct sockaddr sa;
381     struct sockaddr_un su;
382   } addr;
383   socklen_t addr_len = 0;
384   int socket_type;
385   clib_error_t *error = 0;
386   word port;
387
388   error = socket_config (s->config, &addr.sa, &addr_len,
389                          (s->flags & CLIB_SOCKET_F_IS_SERVER
390                           ? INADDR_LOOPBACK : INADDR_ANY));
391   if (error)
392     goto done;
393
394   socket_init_funcs (s);
395
396   socket_type = s->flags & CLIB_SOCKET_F_SEQPACKET ?
397     SOCK_SEQPACKET : SOCK_STREAM;
398
399   s->fd = socket (addr.sa.sa_family, socket_type, 0);
400   if (s->fd < 0)
401     {
402       error = clib_error_return_unix (0, "socket (fd %d, '%s')",
403                                       s->fd, s->config);
404       goto done;
405     }
406
407   port = 0;
408   if (addr.sa.sa_family == PF_INET)
409     port = ((struct sockaddr_in *) &addr)->sin_port;
410
411   if (s->flags & CLIB_SOCKET_F_IS_SERVER)
412     {
413       uword need_bind = 1;
414
415       if (addr.sa.sa_family == PF_INET)
416         {
417           if (port == 0)
418             {
419               port = find_free_port (s->fd);
420               if (port < 0)
421                 {
422                   error = clib_error_return (0, "no free port (fd %d, '%s')",
423                                              s->fd, s->config);
424                   goto done;
425                 }
426               need_bind = 0;
427             }
428         }
429       if (addr.sa.sa_family == PF_LOCAL)
430         unlink (((struct sockaddr_un *) &addr)->sun_path);
431
432       /* Make address available for multiple users. */
433       {
434         int v = 1;
435         if (setsockopt (s->fd, SOL_SOCKET, SO_REUSEADDR, &v, sizeof (v)) < 0)
436           clib_unix_warning ("setsockopt SO_REUSEADDR fails");
437       }
438
439       if (addr.sa.sa_family == PF_LOCAL && s->flags & CLIB_SOCKET_F_PASSCRED)
440         {
441           int x = 1;
442           if (setsockopt (s->fd, SOL_SOCKET, SO_PASSCRED, &x, sizeof (x)) < 0)
443             {
444               error = clib_error_return_unix (0, "setsockopt (SO_PASSCRED, "
445                                               "fd %d, '%s')", s->fd,
446                                               s->config);
447               goto done;
448             }
449         }
450
451       if (need_bind && bind (s->fd, &addr.sa, addr_len) < 0)
452         {
453           error = clib_error_return_unix (0, "bind (fd %d, '%s')",
454                                           s->fd, s->config);
455           goto done;
456         }
457
458       if (listen (s->fd, 5) < 0)
459         {
460           error = clib_error_return_unix (0, "listen (fd %d, '%s')",
461                                           s->fd, s->config);
462           goto done;
463         }
464       if (addr.sa.sa_family == PF_LOCAL
465           && s->flags & CLIB_SOCKET_F_ALLOW_GROUP_WRITE)
466         {
467           struct stat st = { 0 };
468           if (stat (((struct sockaddr_un *) &addr)->sun_path, &st) < 0)
469             {
470               error = clib_error_return_unix (0, "stat (fd %d, '%s')",
471                                               s->fd, s->config);
472               goto done;
473             }
474           st.st_mode |= S_IWGRP;
475           if (chmod (((struct sockaddr_un *) &addr)->sun_path, st.st_mode) <
476               0)
477             {
478               error =
479                 clib_error_return_unix (0, "chmod (fd %d, '%s', mode %o)",
480                                         s->fd, s->config, st.st_mode);
481               goto done;
482             }
483         }
484     }
485   else
486     {
487       if ((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT)
488           && fcntl (s->fd, F_SETFL, O_NONBLOCK) < 0)
489         {
490           error = clib_error_return_unix (0, "fcntl NONBLOCK (fd %d, '%s')",
491                                           s->fd, s->config);
492           goto done;
493         }
494
495       if (connect (s->fd, &addr.sa, addr_len) < 0
496           && !((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT) &&
497                errno == EINPROGRESS))
498         {
499           error = clib_error_return_unix (0, "connect (fd %d, '%s')",
500                                           s->fd, s->config);
501           goto done;
502         }
503     }
504
505   return error;
506
507 done:
508   if (s->fd > 0)
509     close (s->fd);
510   return error;
511 }
512
513 clib_error_t *
514 clib_socket_accept (clib_socket_t * server, clib_socket_t * client)
515 {
516   clib_error_t *err = 0;
517   socklen_t len = 0;
518
519   memset (client, 0, sizeof (client[0]));
520
521   /* Accept the new socket connection. */
522   client->fd = accept (server->fd, 0, 0);
523   if (client->fd < 0)
524     return clib_error_return_unix (0, "accept (fd %d, '%s')",
525                                    server->fd, server->config);
526
527   /* Set the new socket to be non-blocking. */
528   if (fcntl (client->fd, F_SETFL, O_NONBLOCK) < 0)
529     {
530       err = clib_error_return_unix (0, "fcntl O_NONBLOCK (fd %d)",
531                                     client->fd);
532       goto close_client;
533     }
534
535   /* Get peer info. */
536   len = sizeof (client->peer);
537   if (getpeername (client->fd, (struct sockaddr *) &client->peer, &len) < 0)
538     {
539       err = clib_error_return_unix (0, "getpeername (fd %d)", client->fd);
540       goto close_client;
541     }
542
543   client->flags = CLIB_SOCKET_F_IS_CLIENT;
544
545   socket_init_funcs (client);
546   return 0;
547
548 close_client:
549   close (client->fd);
550   return err;
551 }
552
553 /*
554  * fd.io coding-style-patch-verification: ON
555  *
556  * Local Variables:
557  * eval: (c-set-style "gnu")
558  * End:
559  */