From 3adfdf1ae27cd4b6419ce5be14ffb3712339065a Mon Sep 17 00:00:00 2001 From: Joey Hess Date: Sat, 22 Apr 2017 15:14:03 -0400 Subject: add framing protocol for websockets --- WebSockets.hs | 146 +++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 104 insertions(+), 42 deletions(-) (limited to 'WebSockets.hs') diff --git a/WebSockets.hs b/WebSockets.hs index 0ec0c10..395a707 100644 --- a/WebSockets.hs +++ b/WebSockets.hs @@ -1,19 +1,33 @@ {-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving, FlexibleContexts, FlexibleInstances #-} -module WebSockets where +module WebSockets ( + runClientApp, + clientApp, + protocolError, + relayFromSocket, + relayToSocket, + negotiateWireVersion, + WireProtocol(..), + Mode(..), + ClientSends(..), + ServerSends(..), +) where import Types -import Serialization import SessionID import Network.WebSockets hiding (Message) import Control.Concurrent.STM import Control.Concurrent.Async import Control.Exception +import GHC.Generics (Generic) +import Data.Aeson (FromJSON, ToJSON) import qualified Data.Aeson +import qualified Data.Binary import qualified Data.Text as T +import qualified Data.ByteString.Lazy as L import Data.List -import Data.Maybe +import Data.Monoid runClientApp :: ClientApp () -> IO () runClientApp app = catchJust catchconnclosed @@ -25,41 +39,52 @@ runClientApp app = catchJust catchconnclosed catchconnclosed ConnectionClosed = Just () catchconnclosed _ = Nothing --- | Make a client that sends and receives Messages over a websocket. +-- | Make a client that sends and receives LogMessages over a websocket. clientApp - :: (WebSocketsData (Message sent), WebSocketsData (Message received)) - => Mode - -> (TChan (Message sent) -> TChan (Message received) -> SessionID -> IO a) + :: Mode + -> (sent -> LogMessage) + -> (LogMessage -> Maybe received) + -> (TChan sent -> TChan received -> SessionID -> IO a) -> ClientApp a -clientApp mode a conn = do +clientApp mode mksent filterreceived a conn = do _v <- negotiateWireVersion conn - sendMode conn mode - sid <- receiveData conn - bracket setup cleanup (go sid) + sendBinaryData conn (SelectMode ClientSends mode) + r <- receiveData conn + case r of + Ready ServerSends sid -> bracket setup cleanup (go sid) + WireProtocolError e -> error e + _ -> protocolError conn "Did not get expected Ready message from server" where setup = do schan <- newTChanIO rchan <- newTChanIO - sthread <- async $ relayFromSocket conn $ - atomically . writeTChan rchan + sthread <- async $ relayFromSocket conn $ \v -> + case filterreceived v of + Nothing -> return () + Just r -> atomically $ writeTChan rchan r rthread <- async $ relayToSocket conn $ - Just <$> atomically (readTChan schan) + Just . mksent <$> atomically (readTChan schan) return (schan, rchan, sthread, rthread) cleanup (_, _, sthread, rthread) = do - sendClose conn ("done" :: T.Text) + sendBinaryData conn Done cancel sthread cancel rthread go sid (schan, rchan, _, _) = a schan rchan sid -relayFromSocket :: WebSocketsData (Message received) => Connection -> (Message received -> IO ()) -> IO () +relayFromSocket :: Connection -> (LogMessage -> IO ()) -> IO () relayFromSocket conn sender = go where go = do - msg <- receiveData conn - sender msg - go + r <- receiveData conn + case r of + LogMessage msg -> do + sender msg + go + Done -> return () + WireProtocolError e -> protocolError conn e + _ -> protocolError conn "Protocol error" -relayToSocket :: WebSocketsData (Message sent) => Connection -> (IO (Maybe (Message sent))) -> IO () +relayToSocket :: Connection -> (IO (Maybe LogMessage)) -> IO () relayToSocket conn getter = go where go = do @@ -67,20 +92,65 @@ relayToSocket conn getter = go case mmsg of Nothing -> go Just msg -> do - sendBinaryData conn msg + sendBinaryData conn (LogMessage msg) go +-- | Framing protocol used over a websocket connection. +-- +-- This is an asynchronous protocol; both client and server can send +-- messages at the same time. +-- +-- Messages that only one can send are tagged with ClientSends or +-- ServerSends. +data WireProtocol + = Version [WireVersion] + | SelectMode ClientSends Mode + | Ready ServerSends SessionID + | LogMessage LogMessage + | Done + | WireProtocolError String + +data ServerSends = ServerSends +data ClientSends = ClientSends + +instance WebSocketsData WireProtocol where + toLazyByteString (Version v) = "V" <> Data.Aeson.encode v + toLazyByteString (SelectMode _ m) = "M" <> Data.Aeson.encode m + toLazyByteString (Ready _ sid) = "R" <> Data.Aeson.encode sid + toLazyByteString (LogMessage msg) = "L" <> Data.Binary.encode msg + toLazyByteString Done = "D" + toLazyByteString (WireProtocolError s) = "E" <> Data.Aeson.encode s + fromLazyByteString b = case L.splitAt 1 b of + ("V", v) -> maybe (WireProtocolError "invalid JSON in Version") + Version + (Data.Aeson.decode v) + ("M", m) -> maybe (WireProtocolError "invalid JSON in Mode") + (SelectMode ClientSends) + (Data.Aeson.decode m) + ("R", sid) -> maybe (WireProtocolError "invalid JSON in SessionID") + (Ready ServerSends) + (Data.Aeson.decode sid) + ("L", l) -> case Data.Binary.decodeOrFail l of + Left (_, _, err) -> WireProtocolError $ "Binary decode error: " ++ err + Right (_, _, msg) -> LogMessage msg + ("D", "") -> Done + ("E", s) -> maybe (WireProtocolError "invalid JSON in WireProtocolError") + WireProtocolError + (Data.Aeson.decode s) + _ -> WireProtocolError "received unknown websocket message" + +protocolError :: Connection -> String -> IO a +protocolError conn err = do + sendBinaryData conn (WireProtocolError err) + sendClose conn Done + error err + newtype WireVersion = WireVersion T.Text deriving (Show, Eq, Generic, Ord) instance FromJSON WireVersion instance ToJSON WireVersion -instance WebSocketsData [WireVersion] where - -- fromDataMessage = fromLazyByteString . fromDataMessage - fromLazyByteString = fromMaybe (error "Unknown WireVersion") . Data.Aeson.decode - toLazyByteString = Data.Aeson.encode - supportedWireVersions :: [WireVersion] supportedWireVersions = [WireVersion "1"] @@ -88,12 +158,15 @@ supportedWireVersions = [WireVersion "1"] -- the remote side. The highest version present in both lists will be used. negotiateWireVersion :: Connection -> IO WireVersion negotiateWireVersion conn = do - (_, remoteversions) <- concurrently - (sendTextData conn supportedWireVersions) + (_, resp) <- concurrently + (sendBinaryData conn $ Version supportedWireVersions) (receiveData conn) - case reverse (intersect (sort supportedWireVersions) (sort remoteversions)) of - (v:_) -> return v - [] -> error $ "Unable to negotiate a WireVersion. I support: " ++ show supportedWireVersions ++ " They support: " ++ show remoteversions + case resp of + Version remoteversions -> case reverse (intersect (sort supportedWireVersions) (sort remoteversions)) of + (v:_) -> return v + [] -> protocolError conn $ + "Unable to negotiate protocol Version. I support: " ++ show supportedWireVersions ++ " They support: " ++ show remoteversions + _ -> protocolError conn "Protocol error, did not receive Version" -- | Modes of operation that can be requested for a websocket connection. data Mode @@ -103,14 +176,3 @@ data Mode instance FromJSON Mode instance ToJSON Mode where - -instance WebSocketsData Mode where - -- fromDataMessage = fromLazyByteString . fromDataMessage - fromLazyByteString = fromMaybe (error "Unknown Mode") . Data.Aeson.decode - toLazyByteString = Data.Aeson.encode - -sendMode :: Connection -> Mode -> IO () -sendMode = sendTextData - -getMode :: Connection -> IO Mode -getMode = receiveData -- cgit v1.2.3