diff --git a/src/Network/WebSockets/Connection/Options.hs b/src/Network/WebSockets/Connection/Options.hs index 1255c31..a8f71ac 100644 --- a/src/Network/WebSockets/Connection/Options.hs +++ b/src/Network/WebSockets/Connection/Options.hs @@ -10,6 +10,12 @@ module Network.WebSockets.Connection.Options , SizeLimit (..) , atMostSizeLimit + + , TLSSettings (..) + , defaultTlsSettings + + , CertSettings (..) + , defaultCertSettings ) where @@ -18,6 +24,14 @@ import Data.Int (Int64) import Data.Monoid (Monoid (..)) import Prelude +import qualified Crypto.PubKey.DH as DH +import qualified Data.ByteString as B +import Data.Default.Class (def) +import qualified Data.IORef as IO +import qualified Network.TLS as TLS +import qualified Network.TLS.Extra as TLSExtra +import qualified Network.TLS.SessionManager as SM + -------------------------------------------------------------------------------- -- | Set options for a 'Connection'. Please do not use this constructor @@ -52,6 +66,7 @@ data ConnectionOptions = ConnectionOptions -- compressed messages, as well as the size of the uncompressed messages -- as we are deflating them to ensure we don't use too much memory in any -- case. + , connectionTlsSettings :: !(Maybe TLSSettings) } @@ -70,6 +85,7 @@ defaultConnectionOptions = ConnectionOptions , connectionStrictUnicode = False , connectionFramePayloadSizeLimit = mempty , connectionMessageDataSizeLimit = mempty + , connectionTlsSettings = Nothing } @@ -130,3 +146,89 @@ atMostSizeLimit :: Int64 -> SizeLimit -> Bool atMostSizeLimit _ NoSizeLimit = True atMostSizeLimit s (SizeLimit l) = s <= l {-# INLINE atMostSizeLimit #-} + +-------------------------------------------------------------------------------- +-- | Determines where to load the certificate, chain +-- certificates, and key from. +data CertSettings + = CertFromFile !FilePath ![FilePath] !FilePath + | CertFromMemory !B.ByteString ![B.ByteString] !B.ByteString + | CertFromRef !(IO.IORef B.ByteString) ![IO.IORef B.ByteString] !(IO.IORef B.ByteString) + +-- | The default 'CertSettings'. +defaultCertSettings :: CertSettings +defaultCertSettings = CertFromFile "certificate.pem" [] "key.pem" + +-------------------------------------------------------------------------------- +data TLSSettings = TLSSettings { + certSettings :: CertSettings + -- ^ Where are the certificate, chain certificates, and key + -- loaded from? + -- + -- >>> certSettings defaultTlsSettings + -- tlsSettings "certificate.pem" "key.pem" + , tlsLogging :: TLS.Logging + -- ^ The level of logging to turn on. + -- + -- Default: 'TLS.defaultLogging'. + , tlsAllowedVersions :: [TLS.Version] + -- ^ The TLS versions this server accepts. + -- + -- >>> tlsAllowedVersions defaultTlsSettings + -- [TLS13,TLS12,TLS11,TLS10] + , tlsCiphers :: [TLS.Cipher] + -- ^ The TLS ciphers this server accepts. + -- + -- >>> tlsCiphers defaultTlsSettings + -- [ECDHE-ECDSA-AES256GCM-SHA384,ECDHE-ECDSA-AES128GCM-SHA256,ECDHE-RSA-AES256GCM-SHA384,ECDHE-RSA-AES128GCM-SHA256,DHE-RSA-AES256GCM-SHA384,DHE-RSA-AES128GCM-SHA256,ECDHE-ECDSA-AES256CBC-SHA384,ECDHE-RSA-AES256CBC-SHA384,DHE-RSA-AES256-SHA256,ECDHE-ECDSA-AES256CBC-SHA,ECDHE-RSA-AES256CBC-SHA,DHE-RSA-AES256-SHA1,RSA-AES256GCM-SHA384,RSA-AES256-SHA256,RSA-AES256-SHA1,AES128GCM-SHA256,AES256GCM-SHA384] + , tlsWantClientCert :: Bool + -- ^ Whether or not to demand a certificate from the client. If this + -- is set to True, you must handle received certificates in a server hook + -- or all connections will fail. + -- + -- >>> tlsWantClientCert defaultTlsSettings + -- False + , tlsServerHooks :: TLS.ServerHooks + -- ^ The server-side hooks called by the tls package, including actions + -- to take when a client certificate is received. See the "Network.TLS" + -- module for details. + -- + -- Default: def + , tlsServerDHEParams :: Maybe DH.Params + -- ^ Configuration for ServerDHEParams + -- more function lives in `cryptonite` package + -- + -- Default: Nothing + , tlsSessionManagerConfig :: Maybe SM.Config + -- ^ Configuration for in-memory TLS session manager. + -- If Nothing, 'TLS.noSessionManager' is used. + -- Otherwise, an in-memory TLS session manager is created + -- according to 'Config'. + -- + -- Default: Nothing + , tlsCredentials :: Maybe TLS.Credentials + -- ^ Specifying 'TLS.Credentials' directly. If this value is + -- specified, other fields such as 'certFile' are ignored. + , tlsSessionManager :: Maybe TLS.SessionManager + -- ^ Specifying 'TLS.SessionManager' directly. If this value is + -- specified, 'tlsSessionManagerConfig' is ignored. + } + +defaultTlsSettings :: TLSSettings +defaultTlsSettings = + TLSSettings + { certSettings = defaultCertSettings + , tlsLogging = def + , tlsAllowedVersions = [TLS.TLS13,TLS.TLS12] + , tlsCiphers = ciphers + , tlsWantClientCert = False + , tlsServerHooks = def + , tlsServerDHEParams = Nothing + , tlsSessionManagerConfig = Nothing + , tlsCredentials = Nothing + , tlsSessionManager = Nothing + } + where + -- taken from stunnel example in tls-extra + ciphers :: [TLS.Cipher] + ciphers = TLSExtra.ciphersuite_strong diff --git a/src/Network/WebSockets/Server.hs b/src/Network/WebSockets/Server.hs index fcaeea4..8bf6237 100644 --- a/src/Network/WebSockets/Server.hs +++ b/src/Network/WebSockets/Server.hs @@ -153,7 +153,9 @@ runApp socket opts app = makePendingConnection :: Socket -> ConnectionOptions -> IO PendingConnection makePendingConnection socket opts = do - stream <- Stream.makeSocketStream socket + stream <- case connectionTlsSettings opts of + Nothing -> Stream.makeSocketStream socket + Just tls -> Stream.makeTlsSocketStream tls socket makePendingConnectionFromStream stream opts diff --git a/src/Network/WebSockets/Stream.hs b/src/Network/WebSockets/Stream.hs index d799b7a..ff08ace 100644 --- a/src/Network/WebSockets/Stream.hs +++ b/src/Network/WebSockets/Stream.hs @@ -1,6 +1,7 @@ -------------------------------------------------------------------------------- -- | Lightweight abstraction over an input/output stream. {-# LANGUAGE CPP #-} +{-# LANGUAGE RecordWildCards #-} module Network.WebSockets.Stream ( Stream , makeStream @@ -10,16 +11,23 @@ module Network.WebSockets.Stream , parseBin , write , close + -- * TLS + , makeTlsSocketStream + , streamTlsContext ) where +import Control.Applicative ((<|>)) import Control.Concurrent.MVar (MVar, newEmptyMVar, newMVar, putMVar, takeMVar, withMVar) -import Control.Exception (SomeException, SomeAsyncException, throwIO, catch, try, fromException) +import Control.Exception (SomeException, SomeAsyncException, throwIO, catch, handle, try, fromException) import Control.Monad (forM_) import qualified Data.Attoparsec.ByteString as Atto import qualified Data.Binary.Get as BIN import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL +import Data.Default.Class (def) +import Data.Functor ((<&>)) +import qualified Data.IORef as IO import Data.IORef (IORef, atomicModifyIORef', newIORef, readIORef, writeIORef) @@ -33,7 +41,11 @@ import qualified Network.Socket.ByteString as SB (sendAll) #endif import System.IO.Error (isResourceVanishedError) +import qualified Network.TLS as TLS +import qualified Network.TLS.SessionManager as SM import Network.WebSockets.Types +import Network.WebSockets.Connection.Options +import System.IO.Error (isEOFError) -------------------------------------------------------------------------------- @@ -46,12 +58,12 @@ data StreamState -------------------------------------------------------------------------------- -- | Lightweight abstraction over an input/output stream. data Stream = Stream - { streamIn :: IO (Maybe B.ByteString) - , streamOut :: (Maybe BL.ByteString -> IO ()) - , streamState :: !(IORef StreamState) + { streamIn :: IO (Maybe B.ByteString) + , streamOut :: (Maybe BL.ByteString -> IO ()) + , streamState :: !(IORef StreamState) + , streamTlsContext :: Maybe TLS.Context } - -------------------------------------------------------------------------------- -- | Create a stream from a "receive" and "send" action. The following -- properties apply: @@ -73,7 +85,7 @@ makeStream receive send = do ref <- newIORef (Open B.empty) receiveLock <- newMVar () sendLock <- newMVar () - return $ Stream (receive' ref receiveLock) (send' ref sendLock) ref + return $ Stream (receive' ref receiveLock) (send' ref sendLock) ref Nothing where closeRef :: IORef StreamState -> IO () closeRef ref = atomicModifyIORef' ref $ \state -> case state of @@ -111,7 +123,6 @@ makeStream receive send = do Nothing -> what *> pure () throwIO e - -------------------------------------------------------------------------------- makeSocketStream :: S.Socket -> IO Stream makeSocketStream socket = makeStream receive send @@ -133,6 +144,85 @@ makeSocketStream socket = makeStream receive send forM_ (BL.toChunks bs) (SB.sendAll socket) #endif +loadCredentials :: TLSSettings -> IO TLS.Credentials +loadCredentials TLSSettings{ tlsCredentials = Just creds } = return creds +loadCredentials TLSSettings{..} = case certSettings of + CertFromFile cert chainFiles key -> do + cred <- either error id <$> TLS.credentialLoadX509Chain cert chainFiles key + return $ TLS.Credentials [cred] + CertFromRef certRef chainCertsRef keyRef -> do + cert <- IO.readIORef certRef + chainCerts <- mapM IO.readIORef chainCertsRef + key <- IO.readIORef keyRef + cred <- either error return $ TLS.credentialLoadX509ChainFromMemory cert chainCerts key + return $ TLS.Credentials [cred] + CertFromMemory certMemory chainCertsMemory keyMemory -> do + cred <- either error return $ TLS.credentialLoadX509ChainFromMemory certMemory chainCertsMemory keyMemory + return $ TLS.Credentials [cred] + +makeTlsSocketStream :: TLSSettings -> S.Socket -> IO Stream +makeTlsSocketStream stts socket = do + creds <- loadCredentials stts + mgr <- getSessionManager stts + ctx <- TLS.contextNew socket (params mgr creds) + TLS.contextHookSetLogging ctx (tlsLogging stts) + TLS.handshake ctx + makeStream (receive ctx) (send ctx) <&> + \s -> s { streamTlsContext = Just ctx } + where + receive ctx = handle onEOF go + where + onEOF e + | Just TLS.Error_EOF <- fromException e = pure Nothing + | Just ioe <- fromException e, isEOFError ioe = pure Nothing + | otherwise = throwIO e + go = do + x <- TLS.recvData ctx + if B.null x then + go + else + pure $ Just x + + send _ Nothing = return () + send ctx (Just bs) = + TLS.sendData ctx bs + + params mgr creds = def { -- TLS.ServerParams + TLS.serverWantClientCert = tlsWantClientCert stts + , TLS.serverCACertificates = [] + , TLS.serverDHEParams = tlsServerDHEParams stts + , TLS.serverHooks = hooks + , TLS.serverShared = shared mgr creds + , TLS.serverSupported = supported + , TLS.serverEarlyDataSize = 2018 + } + -- Adding alpn to user's tlsServerHooks. + hooks = (tlsServerHooks stts) + { TLS.onALPNClientSuggest = TLS.onALPNClientSuggest (tlsServerHooks stts) + -- <|> (if settingsHTTP2Enabled set then Just alpn else Nothing) + } + + shared mgr creds = def { + TLS.sharedCredentials = creds + , TLS.sharedSessionManager = mgr + } + supported = def { -- TLS.Supported + TLS.supportedVersions = tlsAllowedVersions stts + , TLS.supportedCiphers = tlsCiphers stts + , TLS.supportedCompressions = [TLS.nullCompression] + , TLS.supportedSecureRenegotiation = True + , TLS.supportedClientInitiatedRenegotiation = False + , TLS.supportedSession = True + , TLS.supportedFallbackScsv = True + , TLS.supportedGroups = [TLS.X25519,TLS.P256,TLS.P384] + } + + getSessionManager :: TLSSettings -> IO TLS.SessionManager + getSessionManager TLSSettings{ tlsSessionManager = Just mgr } = return mgr + getSessionManager stts' = case tlsSessionManagerConfig stts' of + Nothing -> return TLS.noSessionManager + Just config -> SM.newSessionManager config + -------------------------------------------------------------------------------- makeEchoStream :: IO Stream diff --git a/websockets.cabal b/websockets.cabal index 03890cf..cfa3a31 100644 --- a/websockets.cabal +++ b/websockets.cabal @@ -63,13 +63,13 @@ Library Network.WebSockets Network.WebSockets.Client Network.WebSockets.Connection + Network.WebSockets.Connection.Options Network.WebSockets.Connection.PingPong Network.WebSockets.Extensions Network.WebSockets.Stream -- Network.WebSockets.Util.PubSub TODO Other-modules: - Network.WebSockets.Connection.Options Network.WebSockets.Extensions.Description Network.WebSockets.Extensions.PermessageDeflate Network.WebSockets.Extensions.StrictUnicode @@ -89,12 +89,17 @@ Library binary >= 0.8.1 && < 0.11, bytestring >= 0.9 && < 0.13, case-insensitive >= 0.3 && < 1.3, + connection, containers >= 0.3 && < 0.7, + cryptonite, + data-default-class, network >= 2.3 && < 3.2, random >= 1.0.1 && < 1.3, SHA >= 1.5 && < 1.7, streaming-commons >= 0.1 && < 0.3, text >= 0.10 && < 2.2, + tls, + tls-session-manager, entropy >= 0.2.1 && < 0.5 Test-suite websockets-tests