1{-# LANGUAGE DeriveDataTypeable #-}
2{-# LANGUAGE ViewPatterns #-}
3{-# LANGUAGE NoImplicitPrelude #-}
4{-# LANGUAGE CPP #-}
5-- |
6-- Module      : Network.Socks5.Command
7-- License     : BSD-style
8-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
9-- Stability   : experimental
10-- Portability : unknown
11--
12module Network.Socks5.Command
13    ( establish
14    , Connect(..)
15    , Command(..)
16    , connectIPV4
17    , connectIPV6
18    , connectDomainName
19    -- * lowlevel interface
20    , rpc
21    , rpc_
22    , sendSerialized
23    , waitSerialized
24    ) where
25
26import Basement.Compat.Base
27import Data.ByteString (ByteString)
28import qualified Data.ByteString as B
29import qualified Data.ByteString.Char8 as BC
30import qualified Prelude
31import Data.Serialize
32
33import Network.Socket (Socket, PortNumber, HostAddress, HostAddress6)
34import Network.Socket.ByteString
35
36import Network.Socks5.Types
37import Network.Socks5.Wire
38
39establish :: SocksVersion -> Socket -> [SocksMethod] -> IO SocksMethod
40establish SocksVer5 socket methods = do
41    sendAll socket (encode $ SocksHello methods)
42    getSocksHelloResponseMethod <$> runGetDone get (recv socket 4096)
43
44newtype Connect = Connect SocksAddress deriving (Show,Eq,Ord)
45
46class Command a where
47    toRequest   :: a -> SocksRequest
48    fromRequest :: SocksRequest -> Maybe a
49
50instance Command SocksRequest where
51    toRequest   = id
52    fromRequest = Just
53
54instance Command Connect where
55    toRequest (Connect (SocksAddress ha port)) = SocksRequest
56            { requestCommand  = SocksCommandConnect
57            , requestDstAddr  = ha
58            , requestDstPort  = Prelude.fromIntegral port
59            }
60    fromRequest req
61        | requestCommand req /= SocksCommandConnect = Nothing
62        | otherwise = Just $ Connect $ SocksAddress (requestDstAddr req) (requestDstPort req)
63
64connectIPV4 :: Socket -> HostAddress -> PortNumber -> IO (HostAddress, PortNumber)
65connectIPV4 socket hostaddr port = onReply <$> rpc_ socket (Connect $ SocksAddress (SocksAddrIPV4 hostaddr) port)
66    where onReply (SocksAddrIPV4 h, p) = (h, p)
67          onReply _                    = error "ipv4 requested, got something different"
68
69connectIPV6 :: Socket -> HostAddress6 -> PortNumber -> IO (HostAddress6, PortNumber)
70connectIPV6 socket hostaddr6 port = onReply <$> rpc_ socket (Connect $ SocksAddress (SocksAddrIPV6 hostaddr6) port)
71    where onReply (SocksAddrIPV6 h, p) = (h, p)
72          onReply _                    = error "ipv6 requested, got something different"
73
74-- TODO: FQDN should only be ascii, maybe putting a "fqdn" data type
75-- in front to make sure and make the BC.pack safe.
76connectDomainName :: Socket -> [Char] -> PortNumber -> IO (SocksHostAddress, PortNumber)
77connectDomainName socket fqdn port = rpc_ socket $ Connect $ SocksAddress (SocksAddrDomainName $ BC.pack fqdn) port
78
79sendSerialized :: Serialize a => Socket -> a -> IO ()
80sendSerialized sock a = sendAll sock $ encode a
81
82waitSerialized :: Serialize a => Socket -> IO a
83waitSerialized sock = runGetDone get (getMore sock)
84
85rpc :: Command a => Socket -> a -> IO (Either SocksError (SocksHostAddress, PortNumber))
86rpc socket req = do
87    sendSerialized socket (toRequest req)
88    onReply <$> runGetDone get (getMore socket)
89    where onReply res@(responseReply -> reply) =
90                case reply of
91                    SocksReplySuccess -> Right (responseBindAddr res, Prelude.fromIntegral $ responseBindPort res)
92                    SocksReplyError e -> Left e
93
94rpc_ :: Command a => Socket -> a -> IO (SocksHostAddress, PortNumber)
95rpc_ socket req = rpc socket req >>= either throwIO return
96
97-- this function expect all the data to be consumed. this is fine for intertwined message,
98-- but might not be a good idea for multi messages from one party.
99runGetDone :: Serialize a => Get a -> IO ByteString -> IO a
100runGetDone getter ioget = ioget >>= return . runGetPartial getter >>= r where
101#if MIN_VERSION_cereal(0,4,0)
102    r (Fail s _)     = error s
103#else
104    r (Fail s)       = error s
105#endif
106    r (Partial cont) = ioget >>= r . cont
107    r (Done a b)
108        | not $ B.null b = error "got too many bytes while receiving data"
109        | otherwise      = return a
110
111getMore :: Socket -> IO ByteString
112getMore socket = recv socket 4096
113