From 3adfdf1ae27cd4b6419ce5be14ffb3712339065a Mon Sep 17 00:00:00 2001 From: Joey Hess Date: Sat, 22 Apr 2017 15:14:03 -0400 Subject: add framing protocol for websockets --- CmdLine.hs | 8 +-- Log.hs | 15 ------ Role/Developer.hs | 15 ++++-- Role/Downloader.hs | 6 +-- Role/User.hs | 6 ++- Role/Watcher.hs | 6 +-- Server.hs | 86 +++++++++++++++++++------------ SessionID.hs | 21 +++----- TODO | 6 ++- Types.hs | 32 ++++++------ WebSockets.hs | 146 ++++++++++++++++++++++++++++++++++++++--------------- 11 files changed, 212 insertions(+), 135 deletions(-) diff --git a/CmdLine.hs b/CmdLine.hs index 663c63e..f00f0be 100644 --- a/CmdLine.hs +++ b/CmdLine.hs @@ -21,16 +21,18 @@ data UserOpts = UserOpts { cmdToRun :: Maybe (String, [String]) } +type UrlString = String + data DeveloperOpts = DeveloperOpts - { debugUrl :: String + { debugUrl :: UrlString } data DownloadOpts = DownloadOpts - { downloadUrl :: String + { downloadUrl :: UrlString } data WatchOpts = WatchOpts - { watchUrl :: String + { watchUrl :: UrlString } data GraphvizOpts = GraphvizOpts diff --git a/Log.hs b/Log.hs index eb7bf3c..948ab19 100644 --- a/Log.hs +++ b/Log.hs @@ -34,21 +34,6 @@ instance DataSize Log where instance ToJSON Log instance FromJSON Log -data LogMessage - = User (Message Seen) - | Developer (Message Entered) - deriving (Show, Generic) - -instance DataSize LogMessage where - dataSize (User a) = dataSize a - dataSize (Developer a) = dataSize a - -instance ToJSON LogMessage where - toJSON = genericToJSON sumOptions - toEncoding = genericToEncoding sumOptions -instance FromJSON LogMessage where - parseJSON = genericParseJSON sumOptions - mkLog :: LogMessage -> POSIXTime -> Log mkLog m now = Log { loggedMessage = m diff --git a/Role/Developer.hs b/Role/Developer.hs index 89f6ea9..4248591 100644 --- a/Role/Developer.hs +++ b/Role/Developer.hs @@ -19,14 +19,23 @@ import qualified Data.Text as T import Data.List run :: DeveloperOpts -> IO () -run os = runClientApp $ clientApp (ConnectMode (T.pack (debugUrl os))) developer +run = run' developer . debugUrl + +run' :: (TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO ()) -> UrlString -> IO () +run' runner url = runClientApp $ clientApp connect Developer userMessages runner + where + connect = ConnectMode (T.pack url) + +userMessages :: LogMessage -> Maybe (Message Seen) +userMessages (User m) = Just m +userMessages (Developer _) = Nothing developer :: TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO () -developer ichan ochan _ = inRawMode $ withLogger "debug-me-developer.log" $ \logger -> do +developer ichan ochan _ = withLogger "debug-me-developer.log" $ \logger -> do devstate <- processSessionStart ochan logger ok <- authUser ichan ochan devstate logger if ok - then do + then inRawMode $ do _ <- sendTtyInput ichan devstate logger `concurrently` sendTtyOutput ochan devstate logger return () diff --git a/Role/Downloader.hs b/Role/Downloader.hs index 3981227..d327c8c 100644 --- a/Role/Downloader.hs +++ b/Role/Downloader.hs @@ -3,15 +3,13 @@ module Role.Downloader where import Types import Log import CmdLine -import WebSockets import SessionID import Control.Concurrent.STM -import qualified Data.Text as T -import Role.Developer (processSessionStart, getUserMessage, Output(..)) +import Role.Developer (run', processSessionStart, getUserMessage, Output(..)) run :: DownloadOpts -> IO () -run os = runClientApp $ clientApp (ConnectMode (T.pack (downloadUrl os))) downloader +run = run' downloader . downloadUrl downloader :: TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO () downloader _ichan ochan sid = do diff --git a/Role/User.hs b/Role/User.hs index daaaa71..1d1702e 100644 --- a/Role/User.hs +++ b/Role/User.hs @@ -33,7 +33,7 @@ run os = do putStr "Connecting to debug-me server..." hFlush stdout esv <- newEmptyTMVarIO - runClientApp $ clientApp (InitMode mempty) $ \ichan ochan sid -> do + runClientApp $ clientApp (InitMode mempty) User developerMessages $ \ichan ochan sid -> do let url = sessionIDUrl sid "localhost" 8081 putStrLn "" putStrLn "Others can connect to this session and help you debug by running:" @@ -47,6 +47,10 @@ run os = do sessionDone fromMaybe (ExitFailure 101) <$> atomically (tryReadTMVar esv) +developerMessages :: LogMessage -> Maybe (Message Entered) +developerMessages (Developer m) = Just m +developerMessages (User _) = Nothing + shellCommand :: UserOpts -> IO (String, [String]) shellCommand os = case cmdToRun os of Just v -> return v diff --git a/Role/Watcher.hs b/Role/Watcher.hs index fddd59f..620733c 100644 --- a/Role/Watcher.hs +++ b/Role/Watcher.hs @@ -4,15 +4,13 @@ import Types import Log import Pty import CmdLine -import WebSockets import SessionID import Control.Concurrent.STM -import qualified Data.Text as T -import Role.Developer (processSessionStart, getUserMessage, emitOutput) +import Role.Developer (run', processSessionStart, getUserMessage, emitOutput) run :: WatchOpts -> IO () -run os = runClientApp $ clientApp (ConnectMode (T.pack (watchUrl os))) watcher +run = run' watcher . watchUrl watcher :: TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO () watcher _ichan ochan _ = inRawMode $ do diff --git a/Server.hs b/Server.hs index c2589f1..c1a302a 100644 --- a/Server.hs +++ b/Server.hs @@ -2,6 +2,7 @@ module Server where +import Types import CmdLine import WebSockets import SessionID @@ -22,6 +23,7 @@ import qualified Data.Map as M import qualified Data.Text as T import Data.Time.Clock.POSIX import System.IO +import System.Directory type ServerState = M.Map SessionID Session @@ -92,16 +94,18 @@ websocketApp :: ServerOpts -> TVar ServerState -> WS.ServerApp websocketApp o ssv pending_conn = do conn <- WS.acceptRequest pending_conn _v <- negotiateWireVersion conn - theirmode <- getMode conn - case theirmode of - InitMode _ -> user o ssv conn - ConnectMode t -> case mkSessionID (T.unpack t) of - Nothing -> error "Invalid session id!" - Just sid -> developer o ssv sid conn + r <- receiveData conn + case r of + SelectMode ClientSends (InitMode _) -> user o ssv conn + SelectMode ClientSends (ConnectMode t) -> + case mkSessionID (T.unpack t) of + Nothing -> protocolError conn "Invalid session id!" + Just sid -> developer o ssv sid conn + _ -> protocolError conn "Expected SelectMode" user :: ServerOpts -> TVar ServerState -> WS.Connection -> IO () user o ssv conn = withSessionID (serverDirectory o) $ \(loghv, sid) -> do - sendTextData conn sid + sendBinaryData conn (Ready ServerSends sid) bracket (setup sid loghv) (cleanup sid) go where setup sid loghv = do @@ -109,28 +113,32 @@ user o ssv conn = withSessionID (serverDirectory o) $ \(loghv, sid) -> do atomically $ modifyTVar' ssv $ M.insert sid session return session - cleanup sid session = atomically $ do - closeSession session - modifyTVar' ssv $ M.delete sid + cleanup sid session = do + atomically $ do + closeSession session + modifyTVar' ssv $ M.delete sid go session = do userchan <- atomically $ listenSession session _ <- relaytouser userchan - `concurrently` relayfromuser session + `race` relayfromuser session return () -- Relay all messages from the user's websocket to the -- session broadcast channel. + -- (The user is allowed to send Developer messages too.. perhaps + -- they got them from a developer connected to them some other + -- way.) relayfromuser session = relayFromSocket conn $ \msg -> do - l <- mkLog (User msg) <$> getPOSIXTime + l <- mkLog msg <$> getPOSIXTime writeSession session l - -- Relay developer messages from the channel to the user's websocket. + -- Relay Developer messages from the channel to the user's websocket. relaytouser userchan = relayToSocket conn $ do v <- atomically $ readTMChan userchan return $ case v of Just l -> case loggedMessage l of - Developer m -> Just m + Developer m -> Just (Developer m) User _ -> Nothing Nothing -> Nothing @@ -139,29 +147,39 @@ developer o ssv sid conn = bracket setup cleanup go where setup = atomically $ M.lookup sid <$> readTVar ssv cleanup _ = return () - go Nothing = error "Invalid session id!" + go Nothing = do + exists <- doesFileExist $ + sessionLogFile (serverDirectory o) sid + if exists + then do + sendBinaryData conn (Ready ServerSends sid) + replayBacklog o sid conn + sendBinaryData conn Done + else protocolError conn "Unknown session ID" go (Just session) = do - -- Sending the SessionID to the developer is redundant, but - -- is done to make the protocol startup sequence the same as - -- it is for the user. - sendTextData conn sid - devchan <- replayBacklog o sid session conn + sendBinaryData conn (Ready ServerSends sid) + devchan <- replayBacklogAndListen o sid session conn _ <- relayfromdeveloper session `concurrently` relaytodeveloper devchan return () - -- Relay all messages from the developer's websocket to the - -- broadcast channel. - relayfromdeveloper session = relayFromSocket conn $ \msg -> do - l <- mkLog (Developer msg) <$> getPOSIXTime - writeSession session l + -- Relay all Developer amessages from the developer's websocket + -- to the broadcast channel. + relayfromdeveloper session = relayFromSocket conn $ \msg -> case msg of + Developer _ -> do + l <- mkLog msg <$> getPOSIXTime + writeSession session l + User _ -> return () -- developer cannot send User messages -- Relay user messages from the channel to the developer's websocket. relaytodeveloper devchan = relayToSocket conn $ do v <- atomically $ readTMChan devchan return $ case v of Just l -> case loggedMessage l of - User m -> Just m + User m -> Just (User m) + -- TODO: Relay messages from other + -- developers, without looping back + -- the developer's own messages. Developer _ -> Nothing Nothing -> Nothing @@ -174,13 +192,15 @@ developer o ssv sid conn = bracket setup cleanup go -- -- Note that the session may appear to freeze for other users while -- this is running. -replayBacklog :: ServerOpts -> SessionID -> Session -> WS.Connection -> IO (TMChan Log) -replayBacklog o sid session conn = preventWriteWhile session o sid $ do +replayBacklogAndListen :: ServerOpts -> SessionID -> Session -> WS.Connection -> IO (TMChan Log) +replayBacklogAndListen o sid session conn = + preventWriteWhile session o sid $ do + replayBacklog o sid conn + atomically $ listenSession session + +replayBacklog :: ServerOpts -> SessionID -> WS.Connection -> IO () +replayBacklog o sid conn = do ls <- streamLog (sessionLogFile (serverDirectory o) sid) forM_ ls $ \l -> case loggedMessage <$> l of - Right (User m) -> sendBinaryData conn m - Right (Developer _) -> return () - -- This should not happen, since writes to the log - -- are blocked. Unless there's a disk error.. + Right m -> sendBinaryData conn (LogMessage m) Left _ -> return () - atomically $ listenSession session diff --git a/SessionID.hs b/SessionID.hs index 8bf8f7d..449f58c 100644 --- a/SessionID.hs +++ b/SessionID.hs @@ -14,9 +14,6 @@ import System.FilePath import System.IO import System.Directory import Network.Wai.Handler.Warp (Port) -import Network.WebSockets hiding (Message) -import qualified Data.Aeson -import Data.Maybe import Data.List import Data.UUID import Data.UUID.V4 @@ -28,17 +25,15 @@ import Control.Exception newtype SessionID = SessionID FilePath deriving (Show, Eq, Ord, Generic) +-- | Custom JSON deserialization so we can check smart constructor +-- to verify it's legal. +instance FromJSON SessionID where + parseJSON v = verify =<< genericParseJSON defaultOptions v + where + verify (SessionID unverified) = + maybe (fail "illegal SessionID") return + (mkSessionID unverified) instance ToJSON SessionID -instance FromJSON SessionID - -instance WebSocketsData SessionID where - -- fromDataMessage = fromLazyByteString . fromDataMessage - fromLazyByteString b = - -- Down't trust a legal SessionID to be deserialized; - -- use smart constructor to verify it's legal. - let SessionID unverified = fromMaybe (error "bad SessionID serialization") (Data.Aeson.decode b) - in fromMaybe (error "illegal SessionID") (mkSessionID unverified) - toLazyByteString = Data.Aeson.encode -- | Smart constructor that enforces legal SessionID contents. -- diff --git a/TODO b/TODO index 8affef0..2562b22 100644 --- a/TODO +++ b/TODO @@ -22,12 +22,16 @@ multiple developers, as each time a developer gets an Activity Seen, they can update their state to use the Activity Entered that it points to. +* When Role.Developer.processSessionStart throws an error, it's caught + somewhere, and the process exits quietly with exit code 0. +* The "debug me session is done" is only shown to the user; + it ought to be included in the session log. * --watch and --download only get Seen messages, not Entered messages, because the server does not send Developer messages to them. To fix, need a way to avoid looping Entered messages sent by a developer back to themselves. * Improve error message when developer fails to connect due to the session - ID being invalid or expored. + ID being invalid or expired. * Use protobuf for serialization, to make non-haskell implementations easier? * Leave the prevMessage out of Activity serialization to save BW. diff --git a/Types.hs b/Types.hs index 76a30a2..04855f4 100644 --- a/Types.hs +++ b/Types.hs @@ -15,9 +15,6 @@ module Types ( import Val import Memory import Serialization -import Network.WebSockets (WebSocketsData(..)) -import qualified Data.Binary -import qualified Data.ByteString.Lazy as L -- | Things that the developer sees. data Seen = Seen @@ -139,6 +136,22 @@ newtype GpgSig = GpgSig Val instance DataSize GpgSig where dataSize (GpgSig s) = dataSize s +data LogMessage + = User (Message Seen) + | Developer (Message Entered) + deriving (Show, Generic) + +instance DataSize LogMessage where + dataSize (User a) = dataSize a + dataSize (Developer a) = dataSize a + +instance Binary LogMessage +instance ToJSON LogMessage where + toJSON = genericToJSON sumOptions + toEncoding = genericToEncoding sumOptions +instance FromJSON LogMessage where + parseJSON = genericParseJSON sumOptions + instance Binary Seen instance ToJSON Seen instance FromJSON Seen @@ -194,16 +207,3 @@ instance ToJSON ControlAction where toEncoding = genericToEncoding sumOptions instance FromJSON ControlAction where parseJSON = genericParseJSON sumOptions - -instance WebSocketsData (Message Seen) where - fromLazyByteString = decodeBinaryMessage - toLazyByteString = Data.Binary.encode - -instance WebSocketsData (Message Entered) where - fromLazyByteString = decodeBinaryMessage - toLazyByteString = Data.Binary.encode - -decodeBinaryMessage :: Binary (Message a) => L.ByteString -> Message a -decodeBinaryMessage b = case Data.Binary.decodeOrFail b of - Right (_, _, msg) -> msg - Left (_, _, err) -> error $ "Binary decode error: " ++ err diff --git a/WebSockets.hs b/WebSockets.hs index 0ec0c10..395a707 100644 --- a/WebSockets.hs +++ b/WebSockets.hs @@ -1,19 +1,33 @@ {-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving, FlexibleContexts, FlexibleInstances #-} -module WebSockets where +module WebSockets ( + runClientApp, + clientApp, + protocolError, + relayFromSocket, + relayToSocket, + negotiateWireVersion, + WireProtocol(..), + Mode(..), + ClientSends(..), + ServerSends(..), +) where import Types -import Serialization import SessionID import Network.WebSockets hiding (Message) import Control.Concurrent.STM import Control.Concurrent.Async import Control.Exception +import GHC.Generics (Generic) +import Data.Aeson (FromJSON, ToJSON) import qualified Data.Aeson +import qualified Data.Binary import qualified Data.Text as T +import qualified Data.ByteString.Lazy as L import Data.List -import Data.Maybe +import Data.Monoid runClientApp :: ClientApp () -> IO () runClientApp app = catchJust catchconnclosed @@ -25,41 +39,52 @@ runClientApp app = catchJust catchconnclosed catchconnclosed ConnectionClosed = Just () catchconnclosed _ = Nothing --- | Make a client that sends and receives Messages over a websocket. +-- | Make a client that sends and receives LogMessages over a websocket. clientApp - :: (WebSocketsData (Message sent), WebSocketsData (Message received)) - => Mode - -> (TChan (Message sent) -> TChan (Message received) -> SessionID -> IO a) + :: Mode + -> (sent -> LogMessage) + -> (LogMessage -> Maybe received) + -> (TChan sent -> TChan received -> SessionID -> IO a) -> ClientApp a -clientApp mode a conn = do +clientApp mode mksent filterreceived a conn = do _v <- negotiateWireVersion conn - sendMode conn mode - sid <- receiveData conn - bracket setup cleanup (go sid) + sendBinaryData conn (SelectMode ClientSends mode) + r <- receiveData conn + case r of + Ready ServerSends sid -> bracket setup cleanup (go sid) + WireProtocolError e -> error e + _ -> protocolError conn "Did not get expected Ready message from server" where setup = do schan <- newTChanIO rchan <- newTChanIO - sthread <- async $ relayFromSocket conn $ - atomically . writeTChan rchan + sthread <- async $ relayFromSocket conn $ \v -> + case filterreceived v of + Nothing -> return () + Just r -> atomically $ writeTChan rchan r rthread <- async $ relayToSocket conn $ - Just <$> atomically (readTChan schan) + Just . mksent <$> atomically (readTChan schan) return (schan, rchan, sthread, rthread) cleanup (_, _, sthread, rthread) = do - sendClose conn ("done" :: T.Text) + sendBinaryData conn Done cancel sthread cancel rthread go sid (schan, rchan, _, _) = a schan rchan sid -relayFromSocket :: WebSocketsData (Message received) => Connection -> (Message received -> IO ()) -> IO () +relayFromSocket :: Connection -> (LogMessage -> IO ()) -> IO () relayFromSocket conn sender = go where go = do - msg <- receiveData conn - sender msg - go + r <- receiveData conn + case r of + LogMessage msg -> do + sender msg + go + Done -> return () + WireProtocolError e -> protocolError conn e + _ -> protocolError conn "Protocol error" -relayToSocket :: WebSocketsData (Message sent) => Connection -> (IO (Maybe (Message sent))) -> IO () +relayToSocket :: Connection -> (IO (Maybe LogMessage)) -> IO () relayToSocket conn getter = go where go = do @@ -67,20 +92,65 @@ relayToSocket conn getter = go case mmsg of Nothing -> go Just msg -> do - sendBinaryData conn msg + sendBinaryData conn (LogMessage msg) go +-- | Framing protocol used over a websocket connection. +-- +-- This is an asynchronous protocol; both client and server can send +-- messages at the same time. +-- +-- Messages that only one can send are tagged with ClientSends or +-- ServerSends. +data WireProtocol + = Version [WireVersion] + | SelectMode ClientSends Mode + | Ready ServerSends SessionID + | LogMessage LogMessage + | Done + | WireProtocolError String + +data ServerSends = ServerSends +data ClientSends = ClientSends + +instance WebSocketsData WireProtocol where + toLazyByteString (Version v) = "V" <> Data.Aeson.encode v + toLazyByteString (SelectMode _ m) = "M" <> Data.Aeson.encode m + toLazyByteString (Ready _ sid) = "R" <> Data.Aeson.encode sid + toLazyByteString (LogMessage msg) = "L" <> Data.Binary.encode msg + toLazyByteString Done = "D" + toLazyByteString (WireProtocolError s) = "E" <> Data.Aeson.encode s + fromLazyByteString b = case L.splitAt 1 b of + ("V", v) -> maybe (WireProtocolError "invalid JSON in Version") + Version + (Data.Aeson.decode v) + ("M", m) -> maybe (WireProtocolError "invalid JSON in Mode") + (SelectMode ClientSends) + (Data.Aeson.decode m) + ("R", sid) -> maybe (WireProtocolError "invalid JSON in SessionID") + (Ready ServerSends) + (Data.Aeson.decode sid) + ("L", l) -> case Data.Binary.decodeOrFail l of + Left (_, _, err) -> WireProtocolError $ "Binary decode error: " ++ err + Right (_, _, msg) -> LogMessage msg + ("D", "") -> Done + ("E", s) -> maybe (WireProtocolError "invalid JSON in WireProtocolError") + WireProtocolError + (Data.Aeson.decode s) + _ -> WireProtocolError "received unknown websocket message" + +protocolError :: Connection -> String -> IO a +protocolError conn err = do + sendBinaryData conn (WireProtocolError err) + sendClose conn Done + error err + newtype WireVersion = WireVersion T.Text deriving (Show, Eq, Generic, Ord) instance FromJSON WireVersion instance ToJSON WireVersion -instance WebSocketsData [WireVersion] where - -- fromDataMessage = fromLazyByteString . fromDataMessage - fromLazyByteString = fromMaybe (error "Unknown WireVersion") . Data.Aeson.decode - toLazyByteString = Data.Aeson.encode - supportedWireVersions :: [WireVersion] supportedWireVersions = [WireVersion "1"] @@ -88,12 +158,15 @@ supportedWireVersions = [WireVersion "1"] -- the remote side. The highest version present in both lists will be used. negotiateWireVersion :: Connection -> IO WireVersion negotiateWireVersion conn = do - (_, remoteversions) <- concurrently - (sendTextData conn supportedWireVersions) + (_, resp) <- concurrently + (sendBinaryData conn $ Version supportedWireVersions) (receiveData conn) - case reverse (intersect (sort supportedWireVersions) (sort remoteversions)) of - (v:_) -> return v - [] -> error $ "Unable to negotiate a WireVersion. I support: " ++ show supportedWireVersions ++ " They support: " ++ show remoteversions + case resp of + Version remoteversions -> case reverse (intersect (sort supportedWireVersions) (sort remoteversions)) of + (v:_) -> return v + [] -> protocolError conn $ + "Unable to negotiate protocol Version. I support: " ++ show supportedWireVersions ++ " They support: " ++ show remoteversions + _ -> protocolError conn "Protocol error, did not receive Version" -- | Modes of operation that can be requested for a websocket connection. data Mode @@ -103,14 +176,3 @@ data Mode instance FromJSON Mode instance ToJSON Mode where - -instance WebSocketsData Mode where - -- fromDataMessage = fromLazyByteString . fromDataMessage - fromLazyByteString = fromMaybe (error "Unknown Mode") . Data.Aeson.decode - toLazyByteString = Data.Aeson.encode - -sendMode :: Connection -> Mode -> IO () -sendMode = sendTextData - -getMode :: Connection -> IO Mode -getMode = receiveData -- cgit v1.2.3