initial commit
[govpp.git] / vendor / golang.org / x / sys / unix / creds_test.go
1 // Copyright 2012 The Go Authors.  All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // +build linux
6
7 package unix_test
8
9 import (
10         "bytes"
11         "net"
12         "os"
13         "syscall"
14         "testing"
15
16         "golang.org/x/sys/unix"
17 )
18
19 // TestSCMCredentials tests the sending and receiving of credentials
20 // (PID, UID, GID) in an ancillary message between two UNIX
21 // sockets. The SO_PASSCRED socket option is enabled on the sending
22 // socket for this to work.
23 func TestSCMCredentials(t *testing.T) {
24         fds, err := unix.Socketpair(unix.AF_LOCAL, unix.SOCK_STREAM, 0)
25         if err != nil {
26                 t.Fatalf("Socketpair: %v", err)
27         }
28         defer unix.Close(fds[0])
29         defer unix.Close(fds[1])
30
31         err = unix.SetsockoptInt(fds[0], unix.SOL_SOCKET, unix.SO_PASSCRED, 1)
32         if err != nil {
33                 t.Fatalf("SetsockoptInt: %v", err)
34         }
35
36         srvFile := os.NewFile(uintptr(fds[0]), "server")
37         defer srvFile.Close()
38         srv, err := net.FileConn(srvFile)
39         if err != nil {
40                 t.Errorf("FileConn: %v", err)
41                 return
42         }
43         defer srv.Close()
44
45         cliFile := os.NewFile(uintptr(fds[1]), "client")
46         defer cliFile.Close()
47         cli, err := net.FileConn(cliFile)
48         if err != nil {
49                 t.Errorf("FileConn: %v", err)
50                 return
51         }
52         defer cli.Close()
53
54         var ucred unix.Ucred
55         if os.Getuid() != 0 {
56                 ucred.Pid = int32(os.Getpid())
57                 ucred.Uid = 0
58                 ucred.Gid = 0
59                 oob := unix.UnixCredentials(&ucred)
60                 _, _, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
61                 if op, ok := err.(*net.OpError); ok {
62                         err = op.Err
63                 }
64                 if sys, ok := err.(*os.SyscallError); ok {
65                         err = sys.Err
66                 }
67                 if err != syscall.EPERM {
68                         t.Fatalf("WriteMsgUnix failed with %v, want EPERM", err)
69                 }
70         }
71
72         ucred.Pid = int32(os.Getpid())
73         ucred.Uid = uint32(os.Getuid())
74         ucred.Gid = uint32(os.Getgid())
75         oob := unix.UnixCredentials(&ucred)
76
77         // this is going to send a dummy byte
78         n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
79         if err != nil {
80                 t.Fatalf("WriteMsgUnix: %v", err)
81         }
82         if n != 0 {
83                 t.Fatalf("WriteMsgUnix n = %d, want 0", n)
84         }
85         if oobn != len(oob) {
86                 t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob))
87         }
88
89         oob2 := make([]byte, 10*len(oob))
90         n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2)
91         if err != nil {
92                 t.Fatalf("ReadMsgUnix: %v", err)
93         }
94         if flags != 0 {
95                 t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags)
96         }
97         if n != 1 {
98                 t.Fatalf("ReadMsgUnix n = %d, want 1 (dummy byte)", n)
99         }
100         if oobn2 != oobn {
101                 // without SO_PASSCRED set on the socket, ReadMsgUnix will
102                 // return zero oob bytes
103                 t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn)
104         }
105         oob2 = oob2[:oobn2]
106         if !bytes.Equal(oob, oob2) {
107                 t.Fatal("ReadMsgUnix oob bytes don't match")
108         }
109
110         scm, err := unix.ParseSocketControlMessage(oob2)
111         if err != nil {
112                 t.Fatalf("ParseSocketControlMessage: %v", err)
113         }
114         newUcred, err := unix.ParseUnixCredentials(&scm[0])
115         if err != nil {
116                 t.Fatalf("ParseUnixCredentials: %v", err)
117         }
118         if *newUcred != ucred {
119                 t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred)
120         }
121 }