{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE CPP #-}
-- |
-- Module      : Network.Socks5.Command
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.Socks5.Command
    ( establish
    , Connect(..)
    , Command(..)
    , connectIPV4
    , connectIPV6
    , connectDomainName
    -- * lowlevel interface
    , rpc
    , rpc_
    , sendSerialized
    , waitSerialized
    ) where

import Basement.Compat.Base
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Prelude
import Data.Serialize

import Network.Socket (Socket, PortNumber, HostAddress, HostAddress6)
import Network.Socket.ByteString

import Network.Socks5.Types
import Network.Socks5.Wire

establish :: SocksVersion -> Socket -> [SocksMethod] -> IO SocksMethod
establish :: SocksVersion -> Socket -> [SocksMethod] -> IO SocksMethod
establish SocksVersion
SocksVer5 Socket
socket [SocksMethod]
methods = do
    Socket -> ByteString -> IO ()
sendAll Socket
socket (forall a. Serialize a => a -> ByteString
encode forall a b. (a -> b) -> a -> b
$ [SocksMethod] -> SocksHello
SocksHello [SocksMethod]
methods)
    SocksHelloResponse -> SocksMethod
getSocksHelloResponseMethod forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Serialize a => Get a -> IO ByteString -> IO a
runGetDone forall t. Serialize t => Get t
get (Socket -> Int -> IO ByteString
recv Socket
socket Int
4096)

newtype Connect = Connect SocksAddress deriving (Int -> Connect -> ShowS
[Connect] -> ShowS
Connect -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Connect] -> ShowS
$cshowList :: [Connect] -> ShowS
show :: Connect -> String
$cshow :: Connect -> String
showsPrec :: Int -> Connect -> ShowS
$cshowsPrec :: Int -> Connect -> ShowS
Show,Connect -> Connect -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Connect -> Connect -> Bool
$c/= :: Connect -> Connect -> Bool
== :: Connect -> Connect -> Bool
$c== :: Connect -> Connect -> Bool
Eq,Eq Connect
Connect -> Connect -> Bool
Connect -> Connect -> Ordering
Connect -> Connect -> Connect
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Connect -> Connect -> Connect
$cmin :: Connect -> Connect -> Connect
max :: Connect -> Connect -> Connect
$cmax :: Connect -> Connect -> Connect
>= :: Connect -> Connect -> Bool
$c>= :: Connect -> Connect -> Bool
> :: Connect -> Connect -> Bool
$c> :: Connect -> Connect -> Bool
<= :: Connect -> Connect -> Bool
$c<= :: Connect -> Connect -> Bool
< :: Connect -> Connect -> Bool
$c< :: Connect -> Connect -> Bool
compare :: Connect -> Connect -> Ordering
$ccompare :: Connect -> Connect -> Ordering
Ord)

class Command a where
    toRequest   :: a -> SocksRequest
    fromRequest :: SocksRequest -> Maybe a

instance Command SocksRequest where
    toRequest :: SocksRequest -> SocksRequest
toRequest   = forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id
    fromRequest :: SocksRequest -> Maybe SocksRequest
fromRequest = forall a. a -> Maybe a
Just

instance Command Connect where
    toRequest :: Connect -> SocksRequest
toRequest (Connect (SocksAddress SocksHostAddress
ha PortNumber
port)) = SocksRequest
            { requestCommand :: SocksCommand
requestCommand  = SocksCommand
SocksCommandConnect
            , requestDstAddr :: SocksHostAddress
requestDstAddr  = SocksHostAddress
ha
            , requestDstPort :: PortNumber
requestDstPort  = forall a b. (Integral a, Num b) => a -> b
Prelude.fromIntegral PortNumber
port
            }
    fromRequest :: SocksRequest -> Maybe Connect
fromRequest SocksRequest
req
        | SocksRequest -> SocksCommand
requestCommand SocksRequest
req forall a. Eq a => a -> a -> Bool
/= SocksCommand
SocksCommandConnect = forall a. Maybe a
Nothing
        | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ SocksAddress -> Connect
Connect forall a b. (a -> b) -> a -> b
$ SocksHostAddress -> PortNumber -> SocksAddress
SocksAddress (SocksRequest -> SocksHostAddress
requestDstAddr SocksRequest
req) (SocksRequest -> PortNumber
requestDstPort SocksRequest
req)

connectIPV4 :: Socket -> HostAddress -> PortNumber -> IO (HostAddress, PortNumber)
connectIPV4 :: Socket -> HostAddress -> PortNumber -> IO (HostAddress, PortNumber)
connectIPV4 Socket
socket HostAddress
hostaddr PortNumber
port = forall {b}. (SocksHostAddress, b) -> (HostAddress, b)
onReply forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a.
Command a =>
Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ Socket
socket (SocksAddress -> Connect
Connect forall a b. (a -> b) -> a -> b
$ SocksHostAddress -> PortNumber -> SocksAddress
SocksAddress (HostAddress -> SocksHostAddress
SocksAddrIPV4 HostAddress
hostaddr) PortNumber
port)
    where onReply :: (SocksHostAddress, b) -> (HostAddress, b)
onReply (SocksAddrIPV4 HostAddress
h, b
p) = (HostAddress
h, b
p)
          onReply (SocksHostAddress, b)
_                    = forall a. HasCallStack => String -> a
error String
"ipv4 requested, got something different"

connectIPV6 :: Socket -> HostAddress6 -> PortNumber -> IO (HostAddress6, PortNumber)
connectIPV6 :: Socket
-> HostAddress6 -> PortNumber -> IO (HostAddress6, PortNumber)
connectIPV6 Socket
socket HostAddress6
hostaddr6 PortNumber
port = forall {b}. (SocksHostAddress, b) -> (HostAddress6, b)
onReply forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a.
Command a =>
Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ Socket
socket (SocksAddress -> Connect
Connect forall a b. (a -> b) -> a -> b
$ SocksHostAddress -> PortNumber -> SocksAddress
SocksAddress (HostAddress6 -> SocksHostAddress
SocksAddrIPV6 HostAddress6
hostaddr6) PortNumber
port)
    where onReply :: (SocksHostAddress, b) -> (HostAddress6, b)
onReply (SocksAddrIPV6 HostAddress6
h, b
p) = (HostAddress6
h, b
p)
          onReply (SocksHostAddress, b)
_                    = forall a. HasCallStack => String -> a
error String
"ipv6 requested, got something different"

-- TODO: FQDN should only be ascii, maybe putting a "fqdn" data type
-- in front to make sure and make the BC.pack safe.
connectDomainName :: Socket -> [Char] -> PortNumber -> IO (SocksHostAddress, PortNumber)
connectDomainName :: Socket -> String -> PortNumber -> IO (SocksHostAddress, PortNumber)
connectDomainName Socket
socket String
fqdn PortNumber
port = forall a.
Command a =>
Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ Socket
socket forall a b. (a -> b) -> a -> b
$ SocksAddress -> Connect
Connect forall a b. (a -> b) -> a -> b
$ SocksHostAddress -> PortNumber -> SocksAddress
SocksAddress (ByteString -> SocksHostAddress
SocksAddrDomainName forall a b. (a -> b) -> a -> b
$ String -> ByteString
BC.pack String
fqdn) PortNumber
port

sendSerialized :: Serialize a => Socket -> a -> IO ()
sendSerialized :: forall a. Serialize a => Socket -> a -> IO ()
sendSerialized Socket
sock a
a = Socket -> ByteString -> IO ()
sendAll Socket
sock forall a b. (a -> b) -> a -> b
$ forall a. Serialize a => a -> ByteString
encode a
a

waitSerialized :: Serialize a => Socket -> IO a
waitSerialized :: forall a. Serialize a => Socket -> IO a
waitSerialized Socket
sock = forall a. Serialize a => Get a -> IO ByteString -> IO a
runGetDone forall t. Serialize t => Get t
get (Socket -> IO ByteString
getMore Socket
sock)

rpc :: Command a => Socket -> a -> IO (Either SocksError (SocksHostAddress, PortNumber))
rpc :: forall a.
Command a =>
Socket
-> a -> IO (Either SocksError (SocksHostAddress, PortNumber))
rpc Socket
socket a
req = do
    forall a. Serialize a => Socket -> a -> IO ()
sendSerialized Socket
socket (forall a. Command a => a -> SocksRequest
toRequest a
req)
    forall {b}.
Num b =>
SocksResponse -> Either SocksError (SocksHostAddress, b)
onReply forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Serialize a => Get a -> IO ByteString -> IO a
runGetDone forall t. Serialize t => Get t
get (Socket -> IO ByteString
getMore Socket
socket)
    where onReply :: SocksResponse -> Either SocksError (SocksHostAddress, b)
onReply res :: SocksResponse
res@(SocksResponse -> SocksReply
responseReply -> SocksReply
reply) =
                case SocksReply
reply of
                    SocksReply
SocksReplySuccess -> forall a b. b -> Either a b
Right (SocksResponse -> SocksHostAddress
responseBindAddr SocksResponse
res, forall a b. (Integral a, Num b) => a -> b
Prelude.fromIntegral forall a b. (a -> b) -> a -> b
$ SocksResponse -> PortNumber
responseBindPort SocksResponse
res)
                    SocksReplyError SocksError
e -> forall a b. a -> Either a b
Left SocksError
e

rpc_ :: Command a => Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ :: forall a.
Command a =>
Socket -> a -> IO (SocksHostAddress, PortNumber)
rpc_ Socket
socket a
req = forall a.
Command a =>
Socket
-> a -> IO (Either SocksError (SocksHostAddress, PortNumber))
rpc Socket
socket a
req forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e a. Exception e => e -> IO a
throwIO forall (m :: * -> *) a. Monad m => a -> m a
return

-- this function expect all the data to be consumed. this is fine for intertwined message,
-- but might not be a good idea for multi messages from one party.
runGetDone :: Serialize a => Get a -> IO ByteString -> IO a
runGetDone :: forall a. Serialize a => Get a -> IO ByteString -> IO a
runGetDone Get a
getter IO ByteString
ioget = IO ByteString
ioget forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. Monad m => a -> m a
return forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Get a -> ByteString -> Result a
runGetPartial Get a
getter forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {b}. Result b -> IO b
r where
#if MIN_VERSION_cereal(0,4,0)
    r :: Result b -> IO b
r (Fail String
s ByteString
_)     = forall a. HasCallStack => String -> a
error String
s
#else
    r (Fail s)       = error s
#endif
    r (Partial ByteString -> Result b
cont) = IO ByteString
ioget forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result b -> IO b
r forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ByteString -> Result b
cont
    r (Done b
a ByteString
b)
        | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ ByteString -> Bool
B.null ByteString
b = forall a. HasCallStack => String -> a
error String
"got too many bytes while receiving data"
        | Bool
otherwise      = forall (m :: * -> *) a. Monad m => a -> m a
return b
a

getMore :: Socket -> IO ByteString
getMore :: Socket -> IO ByteString
getMore Socket
socket = Socket -> Int -> IO ByteString
recv Socket
socket Int
4096