clib_socket: add sendmsg / recvmsg with ancillary data support 85/8385/3
authorDamjan Marion <damarion@cisco.com>
Mon, 11 Sep 2017 14:52:11 +0000 (16:52 +0200)
committerDave Barach <openvpp@barachs.net>
Thu, 14 Sep 2017 11:31:05 +0000 (11:31 +0000)
Change-Id: Ie18580e05ec12291e7026f21ad874e088a712c8e
Signed-off-by: Damjan Marion <damarion@cisco.com>
src/vlib/unix/cli.c
src/vpp/app/vppctl.c
src/vppinfra/socket.c
src/vppinfra/socket.h
src/vppinfra/test_socket.c

index 3936882..1567cc2 100644 (file)
@@ -2664,8 +2664,8 @@ unix_cli_config (vlib_main_t * vm, unformat_input_t * input)
          vec_free (tmp);
        }
 
-      s->flags = SOCKET_IS_SERVER |    /* listen, don't connect */
-       SOCKET_ALLOW_GROUP_WRITE;       /* PF_LOCAL socket only */
+      s->flags = CLIB_SOCKET_F_IS_SERVER |     /* listen, don't connect */
+       CLIB_SOCKET_F_ALLOW_GROUP_WRITE;        /* PF_LOCAL socket only */
       error = clib_socket_init (s);
 
       if (error)
index a8f3eab..980936f 100644 (file)
@@ -158,7 +158,7 @@ main (int argc, char *argv[])
   while (argc--)
     cmd = format (cmd, "%s%c", (argv++)[0], argc ? ' ' : 0);
 
-  s->flags = SOCKET_IS_CLIENT;
+  s->flags = CLIB_SOCKET_F_IS_CLIENT;
 
   error = clib_socket_init (s);
   if (error)
index 37dcbbf..87a9333 100644 (file)
   WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */
 
-#include <sys/un.h>
+#include <stdio.h>
+#include <string.h>            /* strchr */
+#define __USE_GNU
 #include <sys/types.h>
 #include <sys/socket.h>
+#include <sys/un.h>
 #include <sys/stat.h>
 #include <netinet/in.h>
 #include <arpa/inet.h>
 #include <netdb.h>
 #include <unistd.h>
-#include <stdio.h>
 #include <fcntl.h>
-#include <string.h>            /* strchr */
 
 #include <vppinfra/mem.h>
 #include <vppinfra/vec.h>
@@ -233,7 +234,7 @@ default_socket_read (clib_socket_t * sock, int n_bytes)
   u8 *buf;
 
   /* RX side of socket is down once end of file is reached. */
-  if (sock->flags & SOCKET_RX_END_OF_FILE)
+  if (sock->flags & CLIB_SOCKET_F_RX_END_OF_FILE)
     return 0;
 
   fd = sock->fd;
@@ -255,7 +256,7 @@ default_socket_read (clib_socket_t * sock, int n_bytes)
 
   /* Other side closed the socket. */
   if (n_read == 0)
-    sock->flags |= SOCKET_RX_END_OF_FILE;
+    sock->flags |= CLIB_SOCKET_F_RX_END_OF_FILE;
 
 non_fatal:
   _vec_len (sock->rx_buffer) += n_read - n_bytes;
@@ -271,6 +272,91 @@ default_socket_close (clib_socket_t * s)
   return 0;
 }
 
+static clib_error_t *
+default_socket_sendmsg (clib_socket_t * s, void *msg, int msglen,
+                       int fds[], int num_fds)
+{
+  struct msghdr mh = { 0 };
+  struct iovec iov[1];
+  char ctl[CMSG_SPACE (sizeof (int)) * num_fds];
+  int rv;
+
+  iov[0].iov_base = msg;
+  iov[0].iov_len = msglen;
+  mh.msg_iov = iov;
+  mh.msg_iovlen = 1;
+
+  if (num_fds > 0)
+    {
+      struct cmsghdr *cmsg;
+      memset (&ctl, 0, sizeof (ctl));
+      mh.msg_control = ctl;
+      mh.msg_controllen = sizeof (ctl);
+      cmsg = CMSG_FIRSTHDR (&mh);
+      cmsg->cmsg_len = CMSG_LEN (sizeof (int) * num_fds);
+      cmsg->cmsg_level = SOL_SOCKET;
+      cmsg->cmsg_type = SCM_RIGHTS;
+      memcpy (CMSG_DATA (cmsg), fds, sizeof (int) * num_fds);
+    }
+  rv = sendmsg (s->fd, &mh, 0);
+  if (rv < 0)
+    return clib_error_return_unix (0, "sendmsg");
+  return 0;
+}
+
+
+static clib_error_t *
+default_socket_recvmsg (clib_socket_t * s, void *msg, int msglen,
+                       int fds[], int num_fds)
+{
+  char ctl[CMSG_SPACE (sizeof (int) * num_fds) +
+          CMSG_SPACE (sizeof (struct ucred))];
+  struct msghdr mh = { 0 };
+  struct iovec iov[1];
+  ssize_t size;
+  struct ucred *cr = 0;
+  struct cmsghdr *cmsg;
+
+  iov[0].iov_base = msg;
+  iov[0].iov_len = msglen;
+  mh.msg_iov = iov;
+  mh.msg_iovlen = 1;
+  mh.msg_control = ctl;
+  mh.msg_controllen = sizeof (ctl);
+
+  memset (ctl, 0, sizeof (ctl));
+
+  /* receive the incoming message */
+  size = recvmsg (s->fd, &mh, 0);
+  if (size != msglen)
+    {
+      return (size == 0) ? clib_error_return (0, "disconnected") :
+       clib_error_return_unix (0, "recvmsg: malformed message (fd %d, '%s')",
+                               s->fd, s->config);
+    }
+
+  cmsg = CMSG_FIRSTHDR (&mh);
+  while (cmsg)
+    {
+      if (cmsg->cmsg_level == SOL_SOCKET)
+       {
+         if (cmsg->cmsg_type == SCM_CREDENTIALS)
+           {
+             cr = (struct ucred *) CMSG_DATA (cmsg);
+             s->uid = cr->uid;
+             s->gid = cr->gid;
+             s->pid = cr->pid;
+           }
+         else if (cmsg->cmsg_type == SCM_RIGHTS)
+           {
+             clib_memcpy (fds, CMSG_DATA (cmsg), num_fds * sizeof (int));
+           }
+       }
+      cmsg = CMSG_NXTHDR (&mh, cmsg);
+    }
+  return 0;
+}
+
 static void
 socket_init_funcs (clib_socket_t * s)
 {
@@ -280,6 +366,10 @@ socket_init_funcs (clib_socket_t * s)
     s->read_func = default_socket_read;
   if (!s->close_func)
     s->close_func = default_socket_close;
+  if (!s->sendmsg_func)
+    s->sendmsg_func = default_socket_sendmsg;
+  if (!s->recvmsg_func)
+    s->recvmsg_func = default_socket_recvmsg;
 }
 
 clib_error_t *
@@ -291,18 +381,22 @@ clib_socket_init (clib_socket_t * s)
     struct sockaddr_un su;
   } addr;
   socklen_t addr_len = 0;
+  int socket_type;
   clib_error_t *error = 0;
   word port;
 
   error = socket_config (s->config, &addr.sa, &addr_len,
-                        (s->flags & SOCKET_IS_SERVER
+                        (s->flags & CLIB_SOCKET_F_IS_SERVER
                          ? INADDR_LOOPBACK : INADDR_ANY));
   if (error)
     goto done;
 
   socket_init_funcs (s);
 
-  s->fd = socket (addr.sa.sa_family, SOCK_STREAM, 0);
+  socket_type = s->flags & CLIB_SOCKET_F_SEQPACKET ?
+    SOCK_SEQPACKET : SOCK_STREAM;
+
+  s->fd = socket (addr.sa.sa_family, socket_type, 0);
   if (s->fd < 0)
     {
       error = clib_error_return_unix (0, "socket (fd %d, '%s')",
@@ -314,7 +408,7 @@ clib_socket_init (clib_socket_t * s)
   if (addr.sa.sa_family == PF_INET)
     port = ((struct sockaddr_in *) &addr)->sin_port;
 
-  if (s->flags & SOCKET_IS_SERVER)
+  if (s->flags & CLIB_SOCKET_F_IS_SERVER)
     {
       uword need_bind = 1;
 
@@ -342,6 +436,18 @@ clib_socket_init (clib_socket_t * s)
          clib_unix_warning ("setsockopt SO_REUSEADDR fails");
       }
 
+      if (addr.sa.sa_family == PF_LOCAL && s->flags & CLIB_SOCKET_F_PASSCRED)
+       {
+         int x = 1;
+         if (setsockopt (s->fd, SOL_SOCKET, SO_PASSCRED, &x, sizeof (x)) < 0)
+           {
+             error = clib_error_return_unix (0, "setsockopt (SO_PASSCRED, "
+                                             "fd %d, '%s')", s->fd,
+                                             s->config);
+             goto done;
+           }
+       }
+
       if (need_bind && bind (s->fd, &addr.sa, addr_len) < 0)
        {
          error = clib_error_return_unix (0, "bind (fd %d, '%s')",
@@ -356,7 +462,7 @@ clib_socket_init (clib_socket_t * s)
          goto done;
        }
       if (addr.sa.sa_family == PF_LOCAL
-         && s->flags & SOCKET_ALLOW_GROUP_WRITE)
+         && s->flags & CLIB_SOCKET_F_ALLOW_GROUP_WRITE)
        {
          struct stat st = { 0 };
          if (stat (((struct sockaddr_un *) &addr)->sun_path, &st) < 0)
@@ -378,7 +484,7 @@ clib_socket_init (clib_socket_t * s)
     }
   else
     {
-      if ((s->flags & SOCKET_NON_BLOCKING_CONNECT)
+      if ((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT)
          && fcntl (s->fd, F_SETFL, O_NONBLOCK) < 0)
        {
          error = clib_error_return_unix (0, "fcntl NONBLOCK (fd %d, '%s')",
@@ -387,7 +493,7 @@ clib_socket_init (clib_socket_t * s)
        }
 
       if (connect (s->fd, &addr.sa, addr_len) < 0
-         && !((s->flags & SOCKET_NON_BLOCKING_CONNECT) &&
+         && !((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT) &&
               errno == EINPROGRESS))
        {
          error = clib_error_return_unix (0, "connect (fd %d, '%s')",
@@ -434,7 +540,7 @@ clib_socket_accept (clib_socket_t * server, clib_socket_t * client)
       goto close_client;
     }
 
-  client->flags = SOCKET_IS_CLIENT;
+  client->flags = CLIB_SOCKET_F_IS_CLIENT;
 
   socket_init_funcs (client);
   return 0;
index 7503720..4f9e950 100644 (file)
@@ -55,13 +55,14 @@ typedef struct _socket_t
   char *config;
 
   u32 flags;
-#define SOCKET_IS_SERVER (1 << 0)
-#define SOCKET_IS_CLIENT (0 << 0)
-#define SOCKET_NON_BLOCKING_CONNECT (1 << 1)
-#define SOCKET_ALLOW_GROUP_WRITE (1 << 2)
+#define CLIB_SOCKET_F_IS_SERVER (1 << 0)
+#define CLIB_SOCKET_F_IS_CLIENT (0 << 0)
+#define CLIB_SOCKET_F_RX_END_OF_FILE (1 << 2)
+#define CLIB_SOCKET_F_NON_BLOCKING_CONNECT (1 << 3)
+#define CLIB_SOCKET_F_ALLOW_GROUP_WRITE (1 << 4)
+#define CLIB_SOCKET_F_SEQPACKET (1 << 5)
+#define CLIB_SOCKET_F_PASSCRED  (1 << 6)
 
-  /* Read returned end-of-file. */
-#define SOCKET_RX_END_OF_FILE (1 << 2)
 
   /* Transmit buffer.  Holds data waiting to be written. */
   u8 *tx_buffer;
@@ -72,10 +73,19 @@ typedef struct _socket_t
   /* Peer socket we are connected to. */
   struct sockaddr_in peer;
 
+  /* Credentials, populated if CLIB_SOCKET_F_PASSCRED is set */
+  pid_t pid;
+  uid_t uid;
+  gid_t gid;
+
   clib_error_t *(*write_func) (struct _socket_t * sock);
   clib_error_t *(*read_func) (struct _socket_t * sock, int min_bytes);
   clib_error_t *(*close_func) (struct _socket_t * sock);
-  void *private_data;
+  clib_error_t *(*recvmsg_func) (struct _socket_t * s, void *msg, int msglen,
+                                int fds[], int num_fds);
+  clib_error_t *(*sendmsg_func) (struct _socket_t * s, void *msg, int msglen,
+                                int fds[], int num_fds);
+  uword private_data;
 } clib_socket_t;
 
 /* socket config format is host:port.
@@ -89,7 +99,7 @@ clib_error_t *clib_socket_accept (clib_socket_t * server,
 always_inline uword
 clib_socket_is_server (clib_socket_t * sock)
 {
-  return (sock->flags & SOCKET_IS_SERVER) != 0;
+  return (sock->flags & CLIB_SOCKET_F_IS_SERVER) != 0;
 }
 
 always_inline uword
@@ -98,10 +108,17 @@ clib_socket_is_client (clib_socket_t * s)
   return !clib_socket_is_server (s);
 }
 
+always_inline uword
+clib_socket_is_connected (clib_socket_t * sock)
+{
+  return sock->fd > 0;
+}
+
+
 always_inline int
 clib_socket_rx_end_of_file (clib_socket_t * s)
 {
-  return s->flags & SOCKET_RX_END_OF_FILE;
+  return s->flags & CLIB_SOCKET_F_RX_END_OF_FILE;
 }
 
 always_inline void *
@@ -130,6 +147,20 @@ clib_socket_rx (clib_socket_t * s, int n_bytes)
   return s->read_func (s, n_bytes);
 }
 
+always_inline clib_error_t *
+clib_socket_sendmsg (clib_socket_t * s, void *msg, int msglen,
+                    int fds[], int num_fds)
+{
+  return s->sendmsg_func (s, msg, msglen, fds, num_fds);
+}
+
+always_inline clib_error_t *
+clib_socket_recvmsg (clib_socket_t * s, void *msg, int msglen,
+                    int fds[], int num_fds)
+{
+  return s->recvmsg_func (s, msg, msglen, fds, num_fds);
+}
+
 always_inline void
 clib_socket_free (clib_socket_t * s)
 {
index 0b05467..2f25ecc 100644 (file)
@@ -50,15 +50,15 @@ test_socket_main (unformat_input_t * input)
   clib_error_t *error;
 
   s->config = "localhost:22";
-  s->flags = SOCKET_IS_CLIENT;
+  s->flags = CLIB_SOCKET_F_IS_CLIENT;
 
   while (unformat_check_input (input) != UNFORMAT_END_OF_INPUT)
     {
       if (unformat (input, "server %s %=", &config,
-                   &s->flags, SOCKET_IS_SERVER))
+                   &s->flags, CLIB_SOCKET_F_IS_SERVER))
        ;
       else if (unformat (input, "client %s %=", &config,
-                        &s->flags, SOCKET_IS_CLIENT))
+                        &s->flags, CLIB_SOCKET_F_IS_CLIENT))
        ;
       else
        {