1{-# LANGUAGE CPP #-}
2
3#include "HsNet.h"
4##include "HsNetDef.h"
5
6module Network.Socket.Unix (
7    isUnixDomainSocketAvailable
8  , socketPair
9  , sendFd
10  , recvFd
11  , getPeerCredential
12  , getPeerCred
13  , getPeerEid
14  ) where
15
16import Network.Socket.Imports
17import Network.Socket.Types
18
19#if defined(HAVE_GETPEEREID)
20import System.IO.Error (catchIOError)
21#endif
22#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
23import Foreign.Marshal.Utils (with)
24#endif
25#ifdef HAVE_GETPEEREID
26import Foreign.Marshal.Alloc (alloca)
27#endif
28#ifdef DOMAIN_SOCKET_SUPPORT
29import Control.Monad (void)
30import Foreign.Marshal.Alloc (allocaBytes)
31import Foreign.Marshal.Array (peekArray)
32import Foreign.Ptr (Ptr)
33import Foreign.Storable (Storable(..))
34
35import Network.Socket.Fcntl
36import Network.Socket.Internal
37#endif
38#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
39import Network.Socket.Options (c_getsockopt)
40#endif
41
42-- | Getting process ID, user ID and group ID for UNIX-domain sockets.
43--
44--   This is implemented with SO_PEERCRED on Linux and getpeereid()
45--   on BSD variants. Unfortunately, on some BSD variants
46--   getpeereid() returns unexpected results, rather than an error,
47--   for AF_INET sockets. It is the user's responsibility to make sure
48--   that the socket is a UNIX-domain socket.
49--   Also, on some BSD variants, getpeereid() does not return credentials
50--   for sockets created via 'socketPair', only separately created and then
51--   explicitly connected UNIX-domain sockets work on such systems.
52--
53--   Since 2.7.0.0.
54getPeerCredential :: Socket -> IO (Maybe CUInt, Maybe CUInt, Maybe CUInt)
55#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
56getPeerCredential sock = do
57    (pid, uid, gid) <- getPeerCred sock
58    if uid == maxBound then
59        return (Nothing, Nothing, Nothing)
60      else
61        return (Just pid, Just uid, Just gid)
62#elif defined(HAVE_GETPEEREID)
63getPeerCredential sock =
64    go `catchIOError` \_ -> return (Nothing,Nothing,Nothing)
65  where
66    go = do
67        (uid, gid) <- getPeerEid sock
68        return (Nothing, Just uid, Just gid)
69#else
70getPeerCredential _ = return (Nothing, Nothing, Nothing)
71#endif
72
73-- | Returns the processID, userID and groupID of the peer of
74--   a UNIX-domain socket.
75--
76-- Only available on platforms that support SO_PEERCRED.
77getPeerCred :: Socket -> IO (CUInt, CUInt, CUInt)
78#ifdef HAVE_STRUCT_UCRED_SO_PEERCRED
79getPeerCred s = do
80  let sz = (#const sizeof(struct ucred))
81  withFdSocket s $ \fd -> allocaBytes sz $ \ ptr_cr ->
82   with (fromIntegral sz) $ \ ptr_sz -> do
83     _ <- ($) throwSocketErrorIfMinus1Retry "Network.Socket.getPeerCred" $
84       c_getsockopt fd (#const SOL_SOCKET) (#const SO_PEERCRED) ptr_cr ptr_sz
85     pid <- (#peek struct ucred, pid) ptr_cr
86     uid <- (#peek struct ucred, uid) ptr_cr
87     gid <- (#peek struct ucred, gid) ptr_cr
88     return (pid, uid, gid)
89#else
90getPeerCred _ = return (0, 0, 0)
91#endif
92{-# Deprecated getPeerCred "Use getPeerCredential instead" #-}
93
94-- | Returns the userID and groupID of the peer of
95--   a UNIX-domain socket.
96--
97--  Only available on platforms that support getpeereid().
98getPeerEid :: Socket -> IO (CUInt, CUInt)
99#ifdef HAVE_GETPEEREID
100getPeerEid s = do
101  alloca $ \ ptr_uid ->
102    alloca $ \ ptr_gid -> do
103      withFdSocket s $ \fd ->
104        throwSocketErrorIfMinus1Retry_ "Network.Socket.getPeerEid" $
105          c_getpeereid fd ptr_uid ptr_gid
106      uid <- peek ptr_uid
107      gid <- peek ptr_gid
108      return (uid, gid)
109
110foreign import CALLCONV unsafe "getpeereid"
111  c_getpeereid :: CInt -> Ptr CUInt -> Ptr CUInt -> IO CInt
112#else
113getPeerEid _ = return (0, 0)
114#endif
115
116{-# Deprecated getPeerEid "Use getPeerCredential instead" #-}
117
118-- | Whether or not UNIX-domain sockets are available.
119--
120--   Since 2.7.0.0.
121isUnixDomainSocketAvailable :: Bool
122#if defined(DOMAIN_SOCKET_SUPPORT)
123isUnixDomainSocketAvailable = True
124#else
125isUnixDomainSocketAvailable = False
126#endif
127
128-- | Send a file descriptor over a UNIX-domain socket.
129--   Use this function in the case where 'isUnixDomainSocketAvailable' is
130--  'True'.
131sendFd :: Socket -> CInt -> IO ()
132#if defined(DOMAIN_SOCKET_SUPPORT)
133sendFd s outfd = void $ do
134  withFdSocket s $ \fd ->
135    throwSocketErrorWaitWrite s "Network.Socket.sendFd" $ c_sendFd fd outfd
136foreign import ccall SAFE_ON_WIN "sendFd" c_sendFd :: CInt -> CInt -> IO CInt
137#else
138sendFd _ _ = error "Network.Socket.sendFd"
139#endif
140
141-- | Receive a file descriptor over a UNIX-domain socket. Note that the resulting
142--   file descriptor may have to be put into non-blocking mode in order to be
143--   used safely. See 'setNonBlockIfNeeded'.
144--   Use this function in the case where 'isUnixDomainSocketAvailable' is
145--  'True'.
146recvFd :: Socket -> IO CInt
147#if defined(DOMAIN_SOCKET_SUPPORT)
148recvFd s = do
149  withFdSocket s $ \fd ->
150    throwSocketErrorWaitRead s "Network.Socket.recvFd" $ c_recvFd fd
151foreign import ccall SAFE_ON_WIN "recvFd" c_recvFd :: CInt -> IO CInt
152#else
153recvFd _ = error "Network.Socket.recvFd"
154#endif
155
156-- | Build a pair of connected socket objects.
157--   For portability, use this function in the case
158--   where 'isUnixDomainSocketAvailable' is 'True'
159--   and specify 'AF_UNIX' to the first argument.
160socketPair :: Family              -- Family Name (usually AF_UNIX)
161           -> SocketType          -- Socket Type (usually Stream)
162           -> ProtocolNumber      -- Protocol Number
163           -> IO (Socket, Socket) -- unnamed and connected.
164#if defined(DOMAIN_SOCKET_SUPPORT)
165socketPair family stype protocol =
166    allocaBytes (2 * sizeOf (1 :: CInt)) $ \ fdArr -> do
167      c_stype <- packSocketTypeOrThrow "socketPair" stype
168      _rc <- throwSocketErrorIfMinus1Retry "Network.Socket.socketpair" $
169                  c_socketpair (packFamily family) c_stype protocol fdArr
170      [fd1,fd2] <- peekArray 2 fdArr
171      setNonBlockIfNeeded fd1
172      setNonBlockIfNeeded fd2
173      s1 <- mkSocket fd1
174      s2 <- mkSocket fd2
175      return (s1, s2)
176
177foreign import ccall unsafe "socketpair"
178  c_socketpair :: CInt -> CInt -> CInt -> Ptr CInt -> IO CInt
179#else
180socketPair _ _ _ = error "Network.Socket.socketPair"
181#endif
182