29b2a945cb99031bed6fd888950a96b8931703cf
[vpp.git] / src / vppinfra / socket.c
1 /*
2  * Copyright (c) 2015 Cisco and/or its affiliates.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at:
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 /*
16   Copyright (c) 2001, 2002, 2003, 2005 Eliot Dresselhaus
17
18   Permission is hereby granted, free of charge, to any person obtaining
19   a copy of this software and associated documentation files (the
20   "Software"), to deal in the Software without restriction, including
21   without limitation the rights to use, copy, modify, merge, publish,
22   distribute, sublicense, and/or sell copies of the Software, and to
23   permit persons to whom the Software is furnished to do so, subject to
24   the following conditions:
25
26   The above copyright notice and this permission notice shall be
27   included in all copies or substantial portions of the Software.
28
29   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
30   EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
31   MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
32   NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
33   LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
34   OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
35   WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
36 */
37
38 #include <stdio.h>
39 #include <string.h>             /* strchr */
40 #define __USE_GNU
41 #include <sys/types.h>
42 #include <sys/socket.h>
43 #include <sys/un.h>
44 #include <sys/stat.h>
45 #include <netinet/in.h>
46 #include <arpa/inet.h>
47 #include <netdb.h>
48 #include <unistd.h>
49 #include <fcntl.h>
50
51 #include <vppinfra/mem.h>
52 #include <vppinfra/vec.h>
53 #include <vppinfra/socket.h>
54 #include <vppinfra/format.h>
55 #include <vppinfra/error.h>
56
57 void
58 clib_socket_tx_add_formatted (clib_socket_t * s, char *fmt, ...)
59 {
60   va_list va;
61   va_start (va, fmt);
62   clib_socket_tx_add_va_formatted (s, fmt, &va);
63   va_end (va);
64 }
65
66 /* Return and bind to an unused port. */
67 static word
68 find_free_port (word sock)
69 {
70   word port;
71
72   for (port = IPPORT_USERRESERVED; port < 1 << 16; port++)
73     {
74       struct sockaddr_in a;
75
76       memset (&a, 0, sizeof (a));       /* Warnings be gone */
77
78       a.sin_family = PF_INET;
79       a.sin_addr.s_addr = INADDR_ANY;
80       a.sin_port = htons (port);
81
82       if (bind (sock, (struct sockaddr *) &a, sizeof (a)) >= 0)
83         break;
84     }
85
86   return port < 1 << 16 ? port : -1;
87 }
88
89 /* Convert a config string to a struct sockaddr and length for use
90    with bind or connect. */
91 static clib_error_t *
92 socket_config (char *config,
93                void *addr, socklen_t * addr_len, u32 ip4_default_address)
94 {
95   clib_error_t *error = 0;
96
97   if (!config)
98     config = "";
99
100   /* Anything that begins with a / is a local PF_LOCAL socket. */
101   if (config[0] == '/')
102     {
103       struct sockaddr_un *su = addr;
104       su->sun_family = PF_LOCAL;
105       clib_memcpy (&su->sun_path, config,
106                    clib_min (sizeof (su->sun_path), 1 + strlen (config)));
107       *addr_len = sizeof (su[0]);
108     }
109
110   /* Hostname or hostname:port or port. */
111   else
112     {
113       char *host_name;
114       int port = -1;
115       struct sockaddr_in *sa = addr;
116
117       host_name = 0;
118       port = -1;
119       if (config[0] != 0)
120         {
121           unformat_input_t i;
122
123           unformat_init_string (&i, config, strlen (config));
124           if (unformat (&i, "%s:%d", &host_name, &port)
125               || unformat (&i, "%s:0x%x", &host_name, &port))
126             ;
127           else if (unformat (&i, "%s", &host_name))
128             ;
129           else
130             error = clib_error_return (0, "unknown input `%U'",
131                                        format_unformat_error, &i);
132           unformat_free (&i);
133
134           if (error)
135             goto done;
136         }
137
138       sa->sin_family = PF_INET;
139       *addr_len = sizeof (sa[0]);
140       if (port != -1)
141         sa->sin_port = htons (port);
142       else
143         sa->sin_port = 0;
144
145       if (host_name)
146         {
147           struct in_addr host_addr;
148
149           /* Recognize localhost to avoid host lookup in most common cast. */
150           if (!strcmp (host_name, "localhost"))
151             sa->sin_addr.s_addr = htonl (INADDR_LOOPBACK);
152
153           else if (inet_aton (host_name, &host_addr))
154             sa->sin_addr = host_addr;
155
156           else if (host_name && strlen (host_name) > 0)
157             {
158               struct hostent *host = gethostbyname (host_name);
159               if (!host)
160                 error = clib_error_return (0, "unknown host `%s'", config);
161               else
162                 clib_memcpy (&sa->sin_addr.s_addr, host->h_addr_list[0],
163                              host->h_length);
164             }
165
166           else
167             sa->sin_addr.s_addr = htonl (ip4_default_address);
168
169           vec_free (host_name);
170           if (error)
171             goto done;
172         }
173     }
174
175 done:
176   return error;
177 }
178
179 static clib_error_t *
180 default_socket_write (clib_socket_t * s)
181 {
182   clib_error_t *err = 0;
183   word written = 0;
184   word fd = 0;
185   word tx_len;
186
187   fd = s->fd;
188
189   /* Map standard input to standard output.
190      Typically, fd is a socket for which read/write both work. */
191   if (fd == 0)
192     fd = 1;
193
194   tx_len = vec_len (s->tx_buffer);
195   written = write (fd, s->tx_buffer, tx_len);
196
197   /* Ignore certain errors. */
198   if (written < 0 && !unix_error_is_fatal (errno))
199     written = 0;
200
201   /* A "real" error occurred. */
202   if (written < 0)
203     {
204       err = clib_error_return_unix (0, "write %wd bytes (fd %d, '%s')",
205                                     tx_len, s->fd, s->config);
206       vec_free (s->tx_buffer);
207       goto done;
208     }
209
210   /* Reclaim the transmitted part of the tx buffer on successful writes. */
211   else if (written > 0)
212     {
213       if (written == tx_len)
214         _vec_len (s->tx_buffer) = 0;
215       else
216         vec_delete (s->tx_buffer, written, 0);
217     }
218
219   /* If a non-fatal error occurred AND
220      the buffer is full, then we must free it. */
221   else if (written == 0 && tx_len > 64 * 1024)
222     {
223       vec_free (s->tx_buffer);
224     }
225
226 done:
227   return err;
228 }
229
230 static clib_error_t *
231 default_socket_read (clib_socket_t * sock, int n_bytes)
232 {
233   word fd, n_read;
234   u8 *buf;
235
236   /* RX side of socket is down once end of file is reached. */
237   if (sock->flags & CLIB_SOCKET_F_RX_END_OF_FILE)
238     return 0;
239
240   fd = sock->fd;
241
242   n_bytes = clib_max (n_bytes, 4096);
243   vec_add2 (sock->rx_buffer, buf, n_bytes);
244
245   if ((n_read = read (fd, buf, n_bytes)) < 0)
246     {
247       n_read = 0;
248
249       /* Ignore certain errors. */
250       if (!unix_error_is_fatal (errno))
251         goto non_fatal;
252
253       return clib_error_return_unix (0, "read %d bytes (fd %d, '%s')",
254                                      n_bytes, sock->fd, sock->config);
255     }
256
257   /* Other side closed the socket. */
258   if (n_read == 0)
259     sock->flags |= CLIB_SOCKET_F_RX_END_OF_FILE;
260
261 non_fatal:
262   _vec_len (sock->rx_buffer) += n_read - n_bytes;
263
264   return 0;
265 }
266
267 static clib_error_t *
268 default_socket_close (clib_socket_t * s)
269 {
270   if (close (s->fd) < 0)
271     return clib_error_return_unix (0, "close (fd %d, %s)", s->fd, s->config);
272   return 0;
273 }
274
275 static clib_error_t *
276 default_socket_sendmsg (clib_socket_t * s, void *msg, int msglen,
277                         int fds[], int num_fds)
278 {
279   struct msghdr mh = { 0 };
280   struct iovec iov[1];
281   char ctl[CMSG_SPACE (sizeof (int)) * num_fds];
282   int rv;
283
284   iov[0].iov_base = msg;
285   iov[0].iov_len = msglen;
286   mh.msg_iov = iov;
287   mh.msg_iovlen = 1;
288
289   if (num_fds > 0)
290     {
291       struct cmsghdr *cmsg;
292       memset (&ctl, 0, sizeof (ctl));
293       mh.msg_control = ctl;
294       mh.msg_controllen = sizeof (ctl);
295       cmsg = CMSG_FIRSTHDR (&mh);
296       cmsg->cmsg_len = CMSG_LEN (sizeof (int) * num_fds);
297       cmsg->cmsg_level = SOL_SOCKET;
298       cmsg->cmsg_type = SCM_RIGHTS;
299       memcpy (CMSG_DATA (cmsg), fds, sizeof (int) * num_fds);
300     }
301   rv = sendmsg (s->fd, &mh, 0);
302   if (rv < 0)
303     return clib_error_return_unix (0, "sendmsg");
304   return 0;
305 }
306
307
308 static clib_error_t *
309 default_socket_recvmsg (clib_socket_t * s, void *msg, int msglen,
310                         int fds[], int num_fds)
311 {
312 #ifdef __linux__
313   char ctl[CMSG_SPACE (sizeof (int) * num_fds) +
314            CMSG_SPACE (sizeof (struct ucred))];
315   struct ucred *cr = 0;
316 #else
317   char ctl[CMSG_SPACE (sizeof (int) * num_fds)];
318 #endif
319   struct msghdr mh = { 0 };
320   struct iovec iov[1];
321   ssize_t size;
322   struct cmsghdr *cmsg;
323
324   iov[0].iov_base = msg;
325   iov[0].iov_len = msglen;
326   mh.msg_iov = iov;
327   mh.msg_iovlen = 1;
328   mh.msg_control = ctl;
329   mh.msg_controllen = sizeof (ctl);
330
331   memset (ctl, 0, sizeof (ctl));
332
333   /* receive the incoming message */
334   size = recvmsg (s->fd, &mh, 0);
335   if (size != msglen)
336     {
337       return (size == 0) ? clib_error_return (0, "disconnected") :
338         clib_error_return_unix (0, "recvmsg: malformed message (fd %d, '%s')",
339                                 s->fd, s->config);
340     }
341
342   cmsg = CMSG_FIRSTHDR (&mh);
343   while (cmsg)
344     {
345       if (cmsg->cmsg_level == SOL_SOCKET)
346         {
347 #ifdef __linux__
348           if (cmsg->cmsg_type == SCM_CREDENTIALS)
349             {
350               cr = (struct ucred *) CMSG_DATA (cmsg);
351               s->uid = cr->uid;
352               s->gid = cr->gid;
353               s->pid = cr->pid;
354             }
355           else
356 #endif
357           if (cmsg->cmsg_type == SCM_RIGHTS)
358             {
359               clib_memcpy (fds, CMSG_DATA (cmsg), num_fds * sizeof (int));
360             }
361         }
362       cmsg = CMSG_NXTHDR (&mh, cmsg);
363     }
364   return 0;
365 }
366
367 static void
368 socket_init_funcs (clib_socket_t * s)
369 {
370   if (!s->write_func)
371     s->write_func = default_socket_write;
372   if (!s->read_func)
373     s->read_func = default_socket_read;
374   if (!s->close_func)
375     s->close_func = default_socket_close;
376   if (!s->sendmsg_func)
377     s->sendmsg_func = default_socket_sendmsg;
378   if (!s->recvmsg_func)
379     s->recvmsg_func = default_socket_recvmsg;
380 }
381
382 clib_error_t *
383 clib_socket_init (clib_socket_t * s)
384 {
385   union
386   {
387     struct sockaddr sa;
388     struct sockaddr_un su;
389   } addr;
390   socklen_t addr_len = 0;
391   int socket_type;
392   clib_error_t *error = 0;
393   word port;
394
395   error = socket_config (s->config, &addr.sa, &addr_len,
396                          (s->flags & CLIB_SOCKET_F_IS_SERVER
397                           ? INADDR_LOOPBACK : INADDR_ANY));
398   if (error)
399     goto done;
400
401   socket_init_funcs (s);
402
403   socket_type = s->flags & CLIB_SOCKET_F_SEQPACKET ?
404     SOCK_SEQPACKET : SOCK_STREAM;
405
406   s->fd = socket (addr.sa.sa_family, socket_type, 0);
407   if (s->fd < 0)
408     {
409       error = clib_error_return_unix (0, "socket (fd %d, '%s')",
410                                       s->fd, s->config);
411       goto done;
412     }
413
414   port = 0;
415   if (addr.sa.sa_family == PF_INET)
416     port = ((struct sockaddr_in *) &addr)->sin_port;
417
418   if (s->flags & CLIB_SOCKET_F_IS_SERVER)
419     {
420       uword need_bind = 1;
421
422       if (addr.sa.sa_family == PF_INET)
423         {
424           if (port == 0)
425             {
426               port = find_free_port (s->fd);
427               if (port < 0)
428                 {
429                   error = clib_error_return (0, "no free port (fd %d, '%s')",
430                                              s->fd, s->config);
431                   goto done;
432                 }
433               need_bind = 0;
434             }
435         }
436       if (addr.sa.sa_family == PF_LOCAL)
437         unlink (((struct sockaddr_un *) &addr)->sun_path);
438
439       /* Make address available for multiple users. */
440       {
441         int v = 1;
442         if (setsockopt (s->fd, SOL_SOCKET, SO_REUSEADDR, &v, sizeof (v)) < 0)
443           clib_unix_warning ("setsockopt SO_REUSEADDR fails");
444       }
445
446 #if __linux__
447       if (addr.sa.sa_family == PF_LOCAL && s->flags & CLIB_SOCKET_F_PASSCRED)
448         {
449           int x = 1;
450           if (setsockopt (s->fd, SOL_SOCKET, SO_PASSCRED, &x, sizeof (x)) < 0)
451             {
452               error = clib_error_return_unix (0, "setsockopt (SO_PASSCRED, "
453                                               "fd %d, '%s')", s->fd,
454                                               s->config);
455               goto done;
456             }
457         }
458 #endif
459
460       if (need_bind && bind (s->fd, &addr.sa, addr_len) < 0)
461         {
462           error = clib_error_return_unix (0, "bind (fd %d, '%s')",
463                                           s->fd, s->config);
464           goto done;
465         }
466
467       if (listen (s->fd, 5) < 0)
468         {
469           error = clib_error_return_unix (0, "listen (fd %d, '%s')",
470                                           s->fd, s->config);
471           goto done;
472         }
473       if (addr.sa.sa_family == PF_LOCAL
474           && s->flags & CLIB_SOCKET_F_ALLOW_GROUP_WRITE)
475         {
476           struct stat st = { 0 };
477           if (stat (((struct sockaddr_un *) &addr)->sun_path, &st) < 0)
478             {
479               error = clib_error_return_unix (0, "stat (fd %d, '%s')",
480                                               s->fd, s->config);
481               goto done;
482             }
483           st.st_mode |= S_IWGRP;
484           if (chmod (((struct sockaddr_un *) &addr)->sun_path, st.st_mode) <
485               0)
486             {
487               error =
488                 clib_error_return_unix (0, "chmod (fd %d, '%s', mode %o)",
489                                         s->fd, s->config, st.st_mode);
490               goto done;
491             }
492         }
493     }
494   else
495     {
496       if ((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT)
497           && fcntl (s->fd, F_SETFL, O_NONBLOCK) < 0)
498         {
499           error = clib_error_return_unix (0, "fcntl NONBLOCK (fd %d, '%s')",
500                                           s->fd, s->config);
501           goto done;
502         }
503
504       if (connect (s->fd, &addr.sa, addr_len) < 0
505           && !((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT) &&
506                errno == EINPROGRESS))
507         {
508           error = clib_error_return_unix (0, "connect (fd %d, '%s')",
509                                           s->fd, s->config);
510           goto done;
511         }
512     }
513
514   return error;
515
516 done:
517   if (s->fd > 0)
518     close (s->fd);
519   return error;
520 }
521
522 clib_error_t *
523 clib_socket_accept (clib_socket_t * server, clib_socket_t * client)
524 {
525   clib_error_t *err = 0;
526   socklen_t len = 0;
527
528   memset (client, 0, sizeof (client[0]));
529
530   /* Accept the new socket connection. */
531   client->fd = accept (server->fd, 0, 0);
532   if (client->fd < 0)
533     return clib_error_return_unix (0, "accept (fd %d, '%s')",
534                                    server->fd, server->config);
535
536   /* Set the new socket to be non-blocking. */
537   if (fcntl (client->fd, F_SETFL, O_NONBLOCK) < 0)
538     {
539       err = clib_error_return_unix (0, "fcntl O_NONBLOCK (fd %d)",
540                                     client->fd);
541       goto close_client;
542     }
543
544   /* Get peer info. */
545   len = sizeof (client->peer);
546   if (getpeername (client->fd, (struct sockaddr *) &client->peer, &len) < 0)
547     {
548       err = clib_error_return_unix (0, "getpeername (fd %d)", client->fd);
549       goto close_client;
550     }
551
552   client->flags = CLIB_SOCKET_F_IS_CLIENT;
553
554   socket_init_funcs (client);
555   return 0;
556
557 close_client:
558   close (client->fd);
559   return err;
560 }
561
562 /*
563  * fd.io coding-style-patch-verification: ON
564  *
565  * Local Variables:
566  * eval: (c-set-style "gnu")
567  * End:
568  */