1{-# LANGUAGE CPP #-}
2
3#include "HsNet.h"
4##include "HsNetDef.h"
5
6module Network.Socket.Unix (
7    isUnixDomainSocketAvailable
8  , socketPair
9  , sendFd
10  , recvFd
11  , getPeerCredential
12  , getPeerCred
13  , getPeerEid
14  ) where
15
16import System.Posix.Types (Fd(..))
17
18import Network.Socket.Buffer
19import Network.Socket.Imports
20#if !defined(mingw32_HOST_OS)
21import Network.Socket.Posix.Cmsg
22#endif
23import Network.Socket.Types
24
25#if defined(HAVE_GETPEEREID)
26import System.IO.Error (catchIOError)
27#endif
28#ifdef HAVE_GETPEEREID
29import Foreign.Marshal.Alloc (alloca)
30#endif
31#ifdef DOMAIN_SOCKET_SUPPORT
32import Foreign.Marshal.Alloc (allocaBytes)
33import Foreign.Marshal.Array (peekArray)
34
35import Network.Socket.Fcntl
36import Network.Socket.Internal
37#endif
38#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
39import Network.Socket.Options
40#endif
41
42-- | Getting process ID, user ID and group ID for UNIX-domain sockets.
43--
44--   This is implemented with SO_PEERCRED on Linux and getpeereid()
45--   on BSD variants. Unfortunately, on some BSD variants
46--   getpeereid() returns unexpected results, rather than an error,
47--   for AF_INET sockets. It is the user's responsibility to make sure
48--   that the socket is a UNIX-domain socket.
49--   Also, on some BSD variants, getpeereid() does not return credentials
50--   for sockets created via 'socketPair', only separately created and then
51--   explicitly connected UNIX-domain sockets work on such systems.
52--
53--   Since 2.7.0.0.
54getPeerCredential :: Socket -> IO (Maybe CUInt, Maybe CUInt, Maybe CUInt)
55#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
56getPeerCredential sock = do
57    (pid, uid, gid) <- getPeerCred sock
58    if uid == maxBound then
59        return (Nothing, Nothing, Nothing)
60      else
61        return (Just pid, Just uid, Just gid)
62#elif defined(HAVE_GETPEEREID)
63getPeerCredential sock =
64    go `catchIOError` \_ -> return (Nothing,Nothing,Nothing)
65  where
66    go = do
67        (uid, gid) <- getPeerEid sock
68        return (Nothing, Just uid, Just gid)
69#else
70getPeerCredential _ = return (Nothing, Nothing, Nothing)
71#endif
72
73-- | Returns the processID, userID and groupID of the peer of
74--   a UNIX-domain socket.
75--
76-- Only available on platforms that support SO_PEERCRED.
77getPeerCred :: Socket -> IO (CUInt, CUInt, CUInt)
78#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
79getPeerCred s = do
80    let opt = SockOpt (#const SOL_SOCKET) (#const SO_PEERCRED)
81    PeerCred cred <- getSockOpt s opt
82    return cred
83
84newtype PeerCred = PeerCred (CUInt, CUInt, CUInt)
85instance Storable PeerCred where
86    sizeOf    _ = (#const sizeof(struct ucred))
87    alignment _ = alignment (0 :: CInt)
88    poke _ _ = return ()
89    peek p = do
90        pid <- (#peek struct ucred, pid) p
91        uid <- (#peek struct ucred, uid) p
92        gid <- (#peek struct ucred, gid) p
93        return $ PeerCred (pid, uid, gid)
94#else
95getPeerCred _ = return (0, 0, 0)
96#endif
97{-# Deprecated getPeerCred "Use getPeerCredential instead" #-}
98
99-- | Returns the userID and groupID of the peer of
100--   a UNIX-domain socket.
101--
102--  Only available on platforms that support getpeereid().
103getPeerEid :: Socket -> IO (CUInt, CUInt)
104#ifdef HAVE_GETPEEREID
105getPeerEid s = do
106  alloca $ \ ptr_uid ->
107    alloca $ \ ptr_gid -> do
108      withFdSocket s $ \fd ->
109        throwSocketErrorIfMinus1Retry_ "Network.Socket.getPeerEid" $
110          c_getpeereid fd ptr_uid ptr_gid
111      uid <- peek ptr_uid
112      gid <- peek ptr_gid
113      return (uid, gid)
114
115foreign import CALLCONV unsafe "getpeereid"
116  c_getpeereid :: CInt -> Ptr CUInt -> Ptr CUInt -> IO CInt
117#else
118getPeerEid _ = return (0, 0)
119#endif
120
121{-# Deprecated getPeerEid "Use getPeerCredential instead" #-}
122
123-- | Whether or not UNIX-domain sockets are available.
124--
125--   Since 2.7.0.0.
126isUnixDomainSocketAvailable :: Bool
127#if defined(DOMAIN_SOCKET_SUPPORT)
128isUnixDomainSocketAvailable = True
129#else
130isUnixDomainSocketAvailable = False
131#endif
132
133data NullSockAddr = NullSockAddr
134
135instance SocketAddress NullSockAddr where
136    sizeOfSocketAddress _ = 0
137    peekSocketAddress _   = return NullSockAddr
138    pokeSocketAddress _ _ = return ()
139
140-- | Send a file descriptor over a UNIX-domain socket.
141--   Use this function in the case where 'isUnixDomainSocketAvailable' is
142--  'True'.
143sendFd :: Socket -> CInt -> IO ()
144#if defined(DOMAIN_SOCKET_SUPPORT)
145sendFd s outfd = void $ allocaBytes dummyBufSize $ \buf -> do
146    let cmsg = encodeCmsg $ Fd outfd
147    sendBufMsg s NullSockAddr [(buf,dummyBufSize)] [cmsg] mempty
148  where
149    dummyBufSize = 1
150#else
151sendFd _ _ = error "Network.Socket.sendFd"
152#endif
153
154-- | Receive a file descriptor over a UNIX-domain socket. Note that the resulting
155--   file descriptor may have to be put into non-blocking mode in order to be
156--   used safely. See 'setNonBlockIfNeeded'.
157--   Use this function in the case where 'isUnixDomainSocketAvailable' is
158--  'True'.
159recvFd :: Socket -> IO CInt
160#if defined(DOMAIN_SOCKET_SUPPORT)
161recvFd s = allocaBytes dummyBufSize $ \buf -> do
162    (NullSockAddr, _, cmsgs, _) <- recvBufMsg s [(buf,dummyBufSize)] 32 mempty
163    case (lookupCmsg CmsgIdFd cmsgs >>= decodeCmsg) :: Maybe Fd of
164      Nothing      -> return (-1)
165      Just (Fd fd) -> return fd
166  where
167    dummyBufSize = 16
168#else
169recvFd _ = error "Network.Socket.recvFd"
170#endif
171
172-- | Build a pair of connected socket objects.
173--   For portability, use this function in the case
174--   where 'isUnixDomainSocketAvailable' is 'True'
175--   and specify 'AF_UNIX' to the first argument.
176socketPair :: Family              -- Family Name (usually AF_UNIX)
177           -> SocketType          -- Socket Type (usually Stream)
178           -> ProtocolNumber      -- Protocol Number
179           -> IO (Socket, Socket) -- unnamed and connected.
180#if defined(DOMAIN_SOCKET_SUPPORT)
181socketPair family stype protocol =
182    allocaBytes (2 * sizeOf (1 :: CInt)) $ \ fdArr -> do
183      let c_stype = packSocketType stype
184      _rc <- throwSocketErrorIfMinus1Retry "Network.Socket.socketpair" $
185                  c_socketpair (packFamily family) c_stype protocol fdArr
186      [fd1,fd2] <- peekArray 2 fdArr
187      setNonBlockIfNeeded fd1
188      setNonBlockIfNeeded fd2
189      s1 <- mkSocket fd1
190      s2 <- mkSocket fd2
191      return (s1, s2)
192
193foreign import ccall unsafe "socketpair"
194  c_socketpair :: CInt -> CInt -> CInt -> Ptr CInt -> IO CInt
195#else
196socketPair _ _ _ = error "Network.Socket.socketPair"
197#endif
198