1{-# LANGUAGE CPP #-}
2
3#include "HsNet.h"
4##include "HsNetDef.h"
5
6module Network.Socket.Unix (
7    isUnixDomainSocketAvailable
8  , socketPair
9  , sendFd
10  , recvFd
11  , getPeerCredential
12  , getPeerCred
13  , getPeerEid
14  ) where
15
16import System.Posix.Types (Fd(..))
17
18import Network.Socket.Buffer
19import Network.Socket.Imports
20import Network.Socket.Posix.Cmsg
21import Network.Socket.Types
22
23#if defined(HAVE_GETPEEREID)
24import System.IO.Error (catchIOError)
25#endif
26#ifdef HAVE_GETPEEREID
27import Foreign.Marshal.Alloc (alloca)
28#endif
29#ifdef DOMAIN_SOCKET_SUPPORT
30import Foreign.Marshal.Alloc (allocaBytes)
31import Foreign.Marshal.Array (peekArray)
32
33import Network.Socket.Fcntl
34import Network.Socket.Internal
35#endif
36#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
37import Network.Socket.Options
38#endif
39
40-- | Getting process ID, user ID and group ID for UNIX-domain sockets.
41--
42--   This is implemented with SO_PEERCRED on Linux and getpeereid()
43--   on BSD variants. Unfortunately, on some BSD variants
44--   getpeereid() returns unexpected results, rather than an error,
45--   for AF_INET sockets. It is the user's responsibility to make sure
46--   that the socket is a UNIX-domain socket.
47--   Also, on some BSD variants, getpeereid() does not return credentials
48--   for sockets created via 'socketPair', only separately created and then
49--   explicitly connected UNIX-domain sockets work on such systems.
50--
51--   Since 2.7.0.0.
52getPeerCredential :: Socket -> IO (Maybe CUInt, Maybe CUInt, Maybe CUInt)
53#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
54getPeerCredential sock = do
55    (pid, uid, gid) <- getPeerCred sock
56    if uid == maxBound then
57        return (Nothing, Nothing, Nothing)
58      else
59        return (Just pid, Just uid, Just gid)
60#elif defined(HAVE_GETPEEREID)
61getPeerCredential sock =
62    go `catchIOError` \_ -> return (Nothing,Nothing,Nothing)
63  where
64    go = do
65        (uid, gid) <- getPeerEid sock
66        return (Nothing, Just uid, Just gid)
67#else
68getPeerCredential _ = return (Nothing, Nothing, Nothing)
69#endif
70
71-- | Returns the processID, userID and groupID of the peer of
72--   a UNIX-domain socket.
73--
74-- Only available on platforms that support SO_PEERCRED.
75getPeerCred :: Socket -> IO (CUInt, CUInt, CUInt)
76#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
77getPeerCred s = do
78    let opt = SockOpt (#const SOL_SOCKET) (#const SO_PEERCRED)
79    PeerCred cred <- getSockOpt s opt
80    return cred
81
82newtype PeerCred = PeerCred (CUInt, CUInt, CUInt)
83instance Storable PeerCred where
84    sizeOf    _ = (#const sizeof(struct ucred))
85    alignment _ = alignment (0 :: CInt)
86    poke _ _ = return ()
87    peek p = do
88        pid <- (#peek struct ucred, pid) p
89        uid <- (#peek struct ucred, uid) p
90        gid <- (#peek struct ucred, gid) p
91        return $ PeerCred (pid, uid, gid)
92#else
93getPeerCred _ = return (0, 0, 0)
94#endif
95{-# Deprecated getPeerCred "Use getPeerCredential instead" #-}
96
97-- | Returns the userID and groupID of the peer of
98--   a UNIX-domain socket.
99--
100--  Only available on platforms that support getpeereid().
101getPeerEid :: Socket -> IO (CUInt, CUInt)
102#ifdef HAVE_GETPEEREID
103getPeerEid s = do
104  alloca $ \ ptr_uid ->
105    alloca $ \ ptr_gid -> do
106      withFdSocket s $ \fd ->
107        throwSocketErrorIfMinus1Retry_ "Network.Socket.getPeerEid" $
108          c_getpeereid fd ptr_uid ptr_gid
109      uid <- peek ptr_uid
110      gid <- peek ptr_gid
111      return (uid, gid)
112
113foreign import CALLCONV unsafe "getpeereid"
114  c_getpeereid :: CInt -> Ptr CUInt -> Ptr CUInt -> IO CInt
115#else
116getPeerEid _ = return (0, 0)
117#endif
118
119{-# Deprecated getPeerEid "Use getPeerCredential instead" #-}
120
121-- | Whether or not UNIX-domain sockets are available.
122--
123--   Since 2.7.0.0.
124isUnixDomainSocketAvailable :: Bool
125#if defined(DOMAIN_SOCKET_SUPPORT)
126isUnixDomainSocketAvailable = True
127#else
128isUnixDomainSocketAvailable = False
129#endif
130
131data NullSockAddr = NullSockAddr
132
133instance SocketAddress NullSockAddr where
134    sizeOfSocketAddress _ = 0
135    peekSocketAddress _   = return NullSockAddr
136    pokeSocketAddress _ _ = return ()
137
138-- | Send a file descriptor over a UNIX-domain socket.
139--   Use this function in the case where 'isUnixDomainSocketAvailable' is
140--  'True'.
141sendFd :: Socket -> CInt -> IO ()
142#if defined(DOMAIN_SOCKET_SUPPORT)
143sendFd s outfd = void $ allocaBytes dummyBufSize $ \buf -> do
144    let cmsg = encodeCmsg $ Fd outfd
145    sendBufMsg s NullSockAddr [(buf,dummyBufSize)] [cmsg] mempty
146  where
147    dummyBufSize = 1
148#else
149sendFd _ _ = error "Network.Socket.sendFd"
150#endif
151
152-- | Receive a file descriptor over a UNIX-domain socket. Note that the resulting
153--   file descriptor may have to be put into non-blocking mode in order to be
154--   used safely. See 'setNonBlockIfNeeded'.
155--   Use this function in the case where 'isUnixDomainSocketAvailable' is
156--  'True'.
157recvFd :: Socket -> IO CInt
158#if defined(DOMAIN_SOCKET_SUPPORT)
159recvFd s = allocaBytes dummyBufSize $ \buf -> do
160    (NullSockAddr, _, cmsgs, _) <- recvBufMsg s [(buf,dummyBufSize)] 32 mempty
161    case (lookupCmsg CmsgIdFd cmsgs >>= decodeCmsg) :: Maybe Fd of
162      Nothing      -> return (-1)
163      Just (Fd fd) -> return fd
164  where
165    dummyBufSize = 16
166#else
167recvFd _ = error "Network.Socket.recvFd"
168#endif
169
170-- | Build a pair of connected socket objects.
171--   For portability, use this function in the case
172--   where 'isUnixDomainSocketAvailable' is 'True'
173--   and specify 'AF_UNIX' to the first argument.
174socketPair :: Family              -- Family Name (usually AF_UNIX)
175           -> SocketType          -- Socket Type (usually Stream)
176           -> ProtocolNumber      -- Protocol Number
177           -> IO (Socket, Socket) -- unnamed and connected.
178#if defined(DOMAIN_SOCKET_SUPPORT)
179socketPair family stype protocol =
180    allocaBytes (2 * sizeOf (1 :: CInt)) $ \ fdArr -> do
181      let c_stype = packSocketType stype
182      _rc <- throwSocketErrorIfMinus1Retry "Network.Socket.socketpair" $
183                  c_socketpair (packFamily family) c_stype protocol fdArr
184      [fd1,fd2] <- peekArray 2 fdArr
185      setNonBlockIfNeeded fd1
186      setNonBlockIfNeeded fd2
187      s1 <- mkSocket fd1
188      s2 <- mkSocket fd2
189      return (s1, s2)
190
191foreign import ccall unsafe "socketpair"
192  c_socketpair :: CInt -> CInt -> CInt -> Ptr CInt -> IO CInt
193#else
194socketPair _ _ _ = error "Network.Socket.socketPair"
195#endif
196