{- Copyright 2017 Joey Hess - - Licensed under the GNU AGPL version 3 or higher. -} {-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving #-} {-# LANGUAGE FlexibleContexts, FlexibleInstances, ScopedTypeVariables #-} {-# LANGUAGE CPP #-} module WebSockets ( connectionOptions, runClientApp, clientApp, protocolError, relayFromSocket, relayToSocket, negotiateWireVersion, WireProtocol(..), Mode(..), EmailAddress, ClientSends(..), ServerSends(..), ) where import Types import SessionID import ProtocolBuffers import PrevActivity import Network.WebSockets hiding (Message) import Control.Concurrent.STM import Control.Concurrent.STM.TMChan import Control.Concurrent.Async import Control.Exception import GHC.Generics (Generic) import Data.Aeson (FromJSON, ToJSON) import Data.ProtocolBuffers import qualified Data.Aeson import qualified Data.Serialize import qualified Data.Text as T import qualified Data.ByteString.Lazy as L import Data.List import Data.Maybe import Text.Read import Control.Monad import Network.URI import System.IO import Data.Monoid import Prelude -- | Framing protocol used over a websocket connection. -- -- This is an asynchronous protocol; both client and server can send -- messages at the same time. -- -- Messages that only one can send are tagged with ClientSends or -- ServerSends. data WireProtocol = Version [WireVersion] | SelectMode ClientSends Mode | Ready ServerSends SessionID | AnyMessage AnyMessage | Done | WireProtocolError String data ServerSends = ServerSends data ClientSends = ClientSends instance WebSocketsData WireProtocol where toLazyByteString (Version v) = "V" <> Data.Aeson.encode v toLazyByteString (SelectMode _ m) = "M" <> Data.Aeson.encode m toLazyByteString (Ready _ sid) = "R" <> Data.Aeson.encode sid toLazyByteString (AnyMessage msg) = "L" <> let pmsg = toProtocolBuffer msg :: AnyMessageP in Data.Serialize.runPutLazy (encodeMessage pmsg) toLazyByteString Done = "D" toLazyByteString (WireProtocolError s) = "E" <> Data.Aeson.encode s fromLazyByteString b = case L.splitAt 1 b of ("V", v) -> maybe (WireProtocolError "invalid JSON in Version") Version (Data.Aeson.decode v) ("M", m) -> maybe (WireProtocolError "invalid JSON in Mode") (SelectMode ClientSends) (Data.Aeson.decode m) ("R", sid) -> maybe (WireProtocolError "invalid JSON in SessionID") (Ready ServerSends) (Data.Aeson.decode sid) ("L", l) -> case Data.Serialize.runGetLazy decodeMessage l of Left err -> WireProtocolError $ "Protocol buffers decode error: " ++ err Right (pmsg :: AnyMessageP) -> AnyMessage (fromProtocolBuffer pmsg) ("D", "") -> Done ("E", s) -> maybe (WireProtocolError "invalid JSON in WireProtocolError") WireProtocolError (Data.Aeson.decode s) _ -> WireProtocolError "received unknown websocket message" #if MIN_VERSION_websockets(0,11,0) fromDataMessage = fromLazyByteString . fromDataMessage #endif -- | Modes of operation that can be requested for a websocket connection. data Mode = InitMode EmailAddress -- ^ initialize a new debug-me session. | ConnectMode T.Text -- ^ Text specifies the SessionID to connect to deriving (Show, Eq, Generic) instance FromJSON Mode instance ToJSON Mode where newtype WireVersion = WireVersion T.Text deriving (Show, Eq, Generic, Ord) instance FromJSON WireVersion instance ToJSON WireVersion supportedWireVersions :: [WireVersion] supportedWireVersions = [WireVersion "1"] connectionOptions :: ConnectionOptions connectionOptions = defaultConnectionOptions #if MIN_VERSION_websockets(0,11,0) -- Enable compression. { connectionCompressionOptions = PermessageDeflateCompression defaultPermessageDeflate } #endif -- For some reason, runClient throws ConnectionClosed -- when the server hangs up cleanly. Catch this unwanted exception. -- See https://github.com/jaspervdj/websockets/issues/142 runClientApp :: URI -> ClientApp a -> IO (Maybe a) runClientApp serverurl app = do rv <- newEmptyTMVarIO let go conn = do r <- app conn `catch` showerr atomically $ putTMVar rv r catchJust catchconnclosed (runClientWith host port endpoint connectionOptions [] go) (\_ -> return ()) atomically (tryReadTMVar rv) where serverauth = fromMaybe (error "bad server url") (uriAuthority serverurl) host = uriRegName serverauth port = case uriPort serverauth of (':':s) -> fromMaybe 80 (readMaybe s) _ -> 80 endpoint = case uriPath serverurl of [] -> "/" p -> p catchconnclosed ConnectionClosed = Just () catchconnclosed _ = Nothing showerr :: SomeException -> IO a showerr e = do hPutStrLn stderr (show e) throwIO e -- | Make a client that sends and receives AnyMessages over a websocket. clientApp :: Mode -> (sent -> AnyMessage) -> (AnyMessage -> Maybe received) -> (TMChan sent -> TMChan (MissingHashes received) -> SessionID -> IO a) -> ClientApp a clientApp mode mksent filterreceived a conn = do -- Ping every 30 seconds to avoid timeouts caused by proxies etc. forkPingThread conn 30 _v <- negotiateWireVersion conn sendBinaryData conn (SelectMode ClientSends mode) r <- receiveData conn case r of Ready ServerSends sid -> bracket setup cleanup (go sid) WireProtocolError e -> error e _ -> protocolError conn "Did not get expected Ready message from server" where setup = do schan <- newTMChanIO rchan <- newTMChanIO sthread <- async $ relayToSocket conn mksent $ atomically (readTMChan schan) rthread <- async $ do relayFromSocket conn $ \v -> do case filterreceived v of Nothing -> return () Just r -> atomically $ writeTMChan rchan (MissingHashes r) -- Server sent Done, so close channels. atomically $ do closeTMChan schan closeTMChan rchan return (schan, rchan, sthread, rthread) cleanup (schan, _, sthread, rthread) = do sendBinaryData conn Done atomically $ closeTMChan schan -- Wait for any more data from the server. -- These often die with a ConnectionClosed. void $ waitCatch sthread cancel rthread void $ waitCatch rthread go sid (schan, rchan, _, _) = a schan rchan sid relayFromSocket :: Connection -> (AnyMessage -> IO ()) -> IO () relayFromSocket conn sender = go where go = do r <- receiveData conn case r of AnyMessage msg -> do sender msg go Done -> return () WireProtocolError e -> protocolError conn e _ -> protocolError conn "Protocol error" relayToSocket :: Connection -> (received -> AnyMessage) -> IO (Maybe received) -> IO () relayToSocket conn mksent getter = go where go = do mmsg <- getter case mmsg of Nothing -> return () Just msg -> do let MissingHashes wiremsg = removeHashes $ mksent msg sendBinaryData conn $ AnyMessage wiremsg go -- | Send supportedWireVersions and at the same time receive it from -- the remote side. The highest version present in both lists will be used. negotiateWireVersion :: Connection -> IO WireVersion negotiateWireVersion conn = do (_, resp) <- concurrently (sendBinaryData conn $ Version supportedWireVersions) (receiveData conn) case resp of Version remoteversions -> case reverse (intersect (sort supportedWireVersions) (sort remoteversions)) of (v:_) -> return v [] -> protocolError conn $ "Unable to negotiate protocol Version. I support: " ++ show supportedWireVersions ++ " They support: " ++ show remoteversions _ -> protocolError conn "Protocol error, did not receive Version" protocolError :: Connection -> String -> IO a protocolError conn err = do sendBinaryData conn (WireProtocolError err) sendClose conn Done error err