{-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving #-} {-# LANGUAGE FlexibleContexts, FlexibleInstances, ScopedTypeVariables #-} module WebSockets ( connectionOptions, runClientApp, clientApp, protocolError, relayFromSocket, relayToSocket, negotiateWireVersion, WireProtocol(..), Mode(..), ClientSends(..), ServerSends(..), ) where import Types import SessionID import ProtocolBuffers import PrevActivity import Network.WebSockets hiding (Message) import Control.Concurrent.STM import Control.Concurrent.STM.TMChan import Control.Concurrent.Async import Control.Exception import GHC.Generics (Generic) import Data.Aeson (FromJSON, ToJSON) import Data.ProtocolBuffers import qualified Data.Aeson import qualified Data.Serialize import qualified Data.Text as T import qualified Data.ByteString.Lazy as L import Data.List import Data.Monoid import Control.Monad -- | Enable compression. connectionOptions :: ConnectionOptions connectionOptions = defaultConnectionOptions { connectionCompressionOptions = PermessageDeflateCompression defaultPermessageDeflate } -- For some reason, runClient throws ConnectionClosed -- when the server hangs up cleanly. Catch this unwanted exception. -- See https://github.com/jaspervdj/websockets/issues/142 runClientApp :: ClientApp a -> IO (Maybe a) runClientApp app = do rv <- newEmptyTMVarIO let go conn = do r <- app conn atomically $ putTMVar rv r catchJust catchconnclosed (runClientWith "localhost" 8081 "/" connectionOptions [] go) (\_ -> return ()) atomically (tryReadTMVar rv) where catchconnclosed ConnectionClosed = Just () catchconnclosed _ = Nothing -- | Make a client that sends and receives AnyMessages over a websocket. clientApp :: Mode -> RecentActivity -> (sent -> AnyMessage) -> (AnyMessage -> Maybe received) -> (TMChan sent -> TMChan received -> SessionID -> IO a) -> ClientApp a clientApp mode recentactivity mksent filterreceived a conn = do -- Ping every 30 seconds to avoid timeouts caused by proxies etc. forkPingThread conn 30 _v <- negotiateWireVersion conn sendBinaryData conn (SelectMode ClientSends mode) r <- receiveData conn case r of Ready ServerSends sid -> bracket setup cleanup (go sid) WireProtocolError e -> error e _ -> protocolError conn "Did not get expected Ready message from server" where setup = do schan <- newTMChanIO rchan <- newTMChanIO sthread <- async $ relayToSocket conn mksent $ atomically (readTMChan schan) rthread <- async $ do relayFromSocket conn recentactivity (waitTillDrained rchan) $ \v -> do case filterreceived v of Nothing -> return () Just r -> atomically $ writeTMChan rchan r -- Server sent Done, so close channels. atomically $ do closeTMChan schan closeTMChan rchan return (schan, rchan, sthread, rthread) cleanup (schan, _, sthread, rthread) = do sendBinaryData conn Done atomically $ closeTMChan schan -- Wait for any more data from the server. -- These often die with a ConnectionClosed. void $ waitCatch sthread cancel rthread void $ waitCatch rthread go sid (schan, rchan, _, _) = a schan rchan sid waitTillDrained :: TMChan a -> IO () waitTillDrained c = atomically $ do e <- isEmptyTMChan c if e then return () else retry relayFromSocket :: Connection -> RecentActivity -> IO () -> (AnyMessage -> IO ()) -> IO () relayFromSocket conn recentactivity waitprevprocessed sender = go where go = do r <- receiveData conn case r of AnyMessage msg -> do waitprevprocessed msg' <- atomically $ restorePrevActivityHash recentactivity msg sender msg' go Done -> return () WireProtocolError e -> protocolError conn e _ -> protocolError conn "Protocol error" relayToSocket :: Connection -> (received -> AnyMessage) -> IO (Maybe received) -> IO () relayToSocket conn mksent getter = go where go = do mmsg <- getter case mmsg of Nothing -> return () Just msg -> do sendBinaryData conn $ AnyMessage $ removePrevActivityHash $ mksent msg go -- | Framing protocol used over a websocket connection. -- -- This is an asynchronous protocol; both client and server can send -- messages at the same time. -- -- Messages that only one can send are tagged with ClientSends or -- ServerSends. data WireProtocol = Version [WireVersion] | SelectMode ClientSends Mode | Ready ServerSends SessionID | AnyMessage AnyMessage | Done | WireProtocolError String data ServerSends = ServerSends data ClientSends = ClientSends instance WebSocketsData WireProtocol where toLazyByteString (Version v) = "V" <> Data.Aeson.encode v toLazyByteString (SelectMode _ m) = "M" <> Data.Aeson.encode m toLazyByteString (Ready _ sid) = "R" <> Data.Aeson.encode sid toLazyByteString (AnyMessage msg) = "L" <> let pmsg = toProtocolBuffer msg :: AnyMessageP in Data.Serialize.runPutLazy (encodeMessage pmsg) toLazyByteString Done = "D" toLazyByteString (WireProtocolError s) = "E" <> Data.Aeson.encode s fromLazyByteString b = case L.splitAt 1 b of ("V", v) -> maybe (WireProtocolError "invalid JSON in Version") Version (Data.Aeson.decode v) ("M", m) -> maybe (WireProtocolError "invalid JSON in Mode") (SelectMode ClientSends) (Data.Aeson.decode m) ("R", sid) -> maybe (WireProtocolError "invalid JSON in SessionID") (Ready ServerSends) (Data.Aeson.decode sid) ("L", l) -> case Data.Serialize.runGetLazy decodeMessage l of Left err -> WireProtocolError $ "Protocol buffers decode error: " ++ err Right (pmsg :: AnyMessageP) -> AnyMessage (fromProtocolBuffer pmsg) ("D", "") -> Done ("E", s) -> maybe (WireProtocolError "invalid JSON in WireProtocolError") WireProtocolError (Data.Aeson.decode s) _ -> WireProtocolError "received unknown websocket message" fromDataMessage = fromLazyByteString . fromDataMessage protocolError :: Connection -> String -> IO a protocolError conn err = do sendBinaryData conn (WireProtocolError err) sendClose conn Done error err newtype WireVersion = WireVersion T.Text deriving (Show, Eq, Generic, Ord) instance FromJSON WireVersion instance ToJSON WireVersion supportedWireVersions :: [WireVersion] supportedWireVersions = [WireVersion "1"] -- | Send supportedWireVersions and at the same time receive it from -- the remote side. The highest version present in both lists will be used. negotiateWireVersion :: Connection -> IO WireVersion negotiateWireVersion conn = do (_, resp) <- concurrently (sendBinaryData conn $ Version supportedWireVersions) (receiveData conn) case resp of Version remoteversions -> case reverse (intersect (sort supportedWireVersions) (sort remoteversions)) of (v:_) -> return v [] -> protocolError conn $ "Unable to negotiate protocol Version. I support: " ++ show supportedWireVersions ++ " They support: " ++ show remoteversions _ -> protocolError conn "Protocol error, did not receive Version" -- | Modes of operation that can be requested for a websocket connection. data Mode = InitMode T.Text -- ^ Text is unused, but reserved for expansion | ConnectMode T.Text -- ^ Text specifies the SessionID to connect to deriving (Show, Eq, Generic) instance FromJSON Mode instance ToJSON Mode where