{-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving, FlexibleContexts, FlexibleInstances #-} module WebSockets ( runClientApp, clientApp, protocolError, relayFromSocket, relayToSocket, negotiateWireVersion, WireProtocol(..), Mode(..), ClientSends(..), ServerSends(..), ) where import Types import SessionID import Network.WebSockets hiding (Message) import Control.Concurrent.STM import Control.Concurrent.Async import Control.Exception import GHC.Generics (Generic) import Data.Aeson (FromJSON, ToJSON) import qualified Data.Aeson import qualified Data.Binary import qualified Data.Text as T import qualified Data.ByteString.Lazy as L import Data.List import Data.Monoid runClientApp :: ClientApp () -> IO () runClientApp app = catchJust catchconnclosed (runClient "localhost" 8081 "/" app) (\_ -> return ()) where -- For some reason, runClient throws ConnectionClosed -- when the server hangs up cleanly. Catch this unwanted exception. catchconnclosed ConnectionClosed = Just () catchconnclosed _ = Nothing -- | Make a client that sends and receives LogMessages over a websocket. clientApp :: Mode -> (sent -> LogMessage) -> (LogMessage -> Maybe received) -> (TChan sent -> TChan received -> SessionID -> IO a) -> ClientApp a clientApp mode mksent filterreceived a conn = do _v <- negotiateWireVersion conn sendBinaryData conn (SelectMode ClientSends mode) r <- receiveData conn case r of Ready ServerSends sid -> bracket setup cleanup (go sid) WireProtocolError e -> error e _ -> protocolError conn "Did not get expected Ready message from server" where setup = do schan <- newTChanIO rchan <- newTChanIO sthread <- async $ relayFromSocket conn $ \v -> case filterreceived v of Nothing -> return () Just r -> atomically $ writeTChan rchan r rthread <- async $ relayToSocket conn $ Just . mksent <$> atomically (readTChan schan) return (schan, rchan, sthread, rthread) cleanup (_, _, sthread, rthread) = do sendBinaryData conn Done cancel sthread cancel rthread go sid (schan, rchan, _, _) = a schan rchan sid relayFromSocket :: Connection -> (LogMessage -> IO ()) -> IO () relayFromSocket conn sender = go where go = do r <- receiveData conn case r of LogMessage msg -> do sender msg go Done -> return () WireProtocolError e -> protocolError conn e _ -> protocolError conn "Protocol error" relayToSocket :: Connection -> (IO (Maybe LogMessage)) -> IO () relayToSocket conn getter = go where go = do mmsg <- getter case mmsg of Nothing -> go Just msg -> do sendBinaryData conn (LogMessage msg) go -- | Framing protocol used over a websocket connection. -- -- This is an asynchronous protocol; both client and server can send -- messages at the same time. -- -- Messages that only one can send are tagged with ClientSends or -- ServerSends. data WireProtocol = Version [WireVersion] | SelectMode ClientSends Mode | Ready ServerSends SessionID | LogMessage LogMessage | Done | WireProtocolError String data ServerSends = ServerSends data ClientSends = ClientSends instance WebSocketsData WireProtocol where toLazyByteString (Version v) = "V" <> Data.Aeson.encode v toLazyByteString (SelectMode _ m) = "M" <> Data.Aeson.encode m toLazyByteString (Ready _ sid) = "R" <> Data.Aeson.encode sid toLazyByteString (LogMessage msg) = "L" <> Data.Binary.encode msg toLazyByteString Done = "D" toLazyByteString (WireProtocolError s) = "E" <> Data.Aeson.encode s fromLazyByteString b = case L.splitAt 1 b of ("V", v) -> maybe (WireProtocolError "invalid JSON in Version") Version (Data.Aeson.decode v) ("M", m) -> maybe (WireProtocolError "invalid JSON in Mode") (SelectMode ClientSends) (Data.Aeson.decode m) ("R", sid) -> maybe (WireProtocolError "invalid JSON in SessionID") (Ready ServerSends) (Data.Aeson.decode sid) ("L", l) -> case Data.Binary.decodeOrFail l of Left (_, _, err) -> WireProtocolError $ "Binary decode error: " ++ err Right (_, _, msg) -> LogMessage msg ("D", "") -> Done ("E", s) -> maybe (WireProtocolError "invalid JSON in WireProtocolError") WireProtocolError (Data.Aeson.decode s) _ -> WireProtocolError "received unknown websocket message" protocolError :: Connection -> String -> IO a protocolError conn err = do sendBinaryData conn (WireProtocolError err) sendClose conn Done error err newtype WireVersion = WireVersion T.Text deriving (Show, Eq, Generic, Ord) instance FromJSON WireVersion instance ToJSON WireVersion supportedWireVersions :: [WireVersion] supportedWireVersions = [WireVersion "1"] -- | Send supportedWireVersions and at the same time receive it from -- the remote side. The highest version present in both lists will be used. negotiateWireVersion :: Connection -> IO WireVersion negotiateWireVersion conn = do (_, resp) <- concurrently (sendBinaryData conn $ Version supportedWireVersions) (receiveData conn) case resp of Version remoteversions -> case reverse (intersect (sort supportedWireVersions) (sort remoteversions)) of (v:_) -> return v [] -> protocolError conn $ "Unable to negotiate protocol Version. I support: " ++ show supportedWireVersions ++ " They support: " ++ show remoteversions _ -> protocolError conn "Protocol error, did not receive Version" -- | Modes of operation that can be requested for a websocket connection. data Mode = InitMode T.Text -- ^ Text is unused, but reserved for expansion | ConnectMode T.Text -- ^ Text specifies the SessionID to connect to deriving (Show, Eq, Generic) instance FromJSON Mode instance ToJSON Mode where