diff options
author | Joey Hess <joeyh@joeyh.name> | 2017-04-24 15:24:52 -0400 |
---|---|---|
committer | Joey Hess <joeyh@joeyh.name> | 2017-04-24 16:03:46 -0400 |
commit | 9a8d3bc531647d8b96e66e6daabf2176a1df4afb (patch) | |
tree | 5f198a02e59fbec20b38ad347db37cad97b3ed0d | |
parent | 7b2bcfab392d387b89c3c251f0c9a8b9c0203aa8 (diff) | |
download | debug-me-9a8d3bc531647d8b96e66e6daabf2176a1df4afb.tar.gz |
switch to TMChans so they can be closed when a connection is Done
-rw-r--r-- | Role/Developer.hs | 63 | ||||
-rw-r--r-- | Role/Downloader.hs | 18 | ||||
-rw-r--r-- | Role/User.hs | 41 | ||||
-rw-r--r-- | Role/Watcher.hs | 12 | ||||
-rw-r--r-- | WebSockets.hs | 21 |
5 files changed, 92 insertions, 63 deletions
diff --git a/Role/Developer.hs b/Role/Developer.hs index 0b8fdd9..ffba5c4 100644 --- a/Role/Developer.hs +++ b/Role/Developer.hs @@ -13,16 +13,18 @@ import Pty import Control.Concurrent.Async import Control.Concurrent.STM +import Control.Concurrent.STM.TMChan import System.IO import qualified Data.ByteString as B import qualified Data.Text as T import Data.List +import Data.Maybe import Control.Monad run :: DeveloperOpts -> IO () run = run' developer . debugUrl -run' :: (TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO ()) -> UrlString -> IO () +run' :: (TMChan (Message Entered) -> TMChan (Message Seen) -> SessionID -> IO ()) -> UrlString -> IO () run' runner url = void $ runClientApp app where connect = ConnectMode (T.pack url) @@ -32,7 +34,7 @@ userMessages :: LogMessage -> Maybe (Message Seen) userMessages (User m) = Just m userMessages (Developer _) = Nothing -developer :: TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO () +developer :: TMChan (Message Entered) -> TMChan (Message Seen) -> SessionID -> IO () developer ichan ochan _ = withLogger "debug-me-developer.log" $ \logger -> do devstate <- processSessionStart ochan logger ok <- authUser ichan ochan devstate logger @@ -53,8 +55,8 @@ data DeveloperState = DeveloperState , developerSigVerifier :: SigVerifier } --- | Read things typed by the developer, and forward them to the TChan. -sendTtyInput :: TChan (Message Entered) -> TVar DeveloperState -> Logger -> IO () +-- | Read things typed by the developer, and forward them to the TMChan. +sendTtyInput :: TMChan (Message Entered) -> TVar DeveloperState -> Logger -> IO () sendTtyInput ichan devstate logger = go where go = do @@ -76,7 +78,7 @@ sendTtyInput ichan devstate logger = go } let act = mkSigned (developerSessionKey ds) $ Activity entered (Just $ lastActivity ds) - writeTChan ichan (ActivityMessage act) + writeTMChan ichan (ActivityMessage act) let acth = hash act let ds' = ds { sentSince = sentSince ds ++ [b] @@ -88,31 +90,35 @@ sendTtyInput ichan devstate logger = go logger $ Developer $ ActivityMessage act go --- | Read activity from the TChan and display it to the developer. -sendTtyOutput :: TChan (Message Seen) -> TVar DeveloperState -> Logger -> IO () +-- | Read activity from the TMChan and display it to the developer. +sendTtyOutput :: TMChan (Message Seen) -> TVar DeveloperState -> Logger -> IO () sendTtyOutput ochan devstate logger = go where go = do - (o, msg) <- atomically $ getUserMessage ochan devstate - logger $ User msg - emitOutput o - go + v <- atomically $ getUserMessage ochan devstate + case v of + Nothing -> return () + Just (o, msg) -> do + logger $ User msg + emitOutput o + go -- | Present our session key to the user. -- Wait for them to accept or reject it, while displaying any Seen data -- in the meantime. -authUser :: TChan (Message Entered) -> TChan (Message Seen) -> TVar DeveloperState -> Logger -> IO Bool +authUser :: TMChan (Message Entered) -> TMChan (Message Seen) -> TVar DeveloperState -> Logger -> IO Bool authUser ichan ochan devstate logger = do ds <- atomically $ readTVar devstate pk <- myPublicKey (developerSessionKey ds) let msg = ControlMessage $ mkSigned (developerSessionKey ds) (Control (SessionKey pk)) - atomically $ writeTChan ichan msg + atomically $ writeTMChan ichan msg logger $ Developer msg waitresp pk where waitresp pk = do - (o, msg) <- atomically $ getUserMessage ochan devstate + (o, msg) <- fromMaybe (error "No response from server to our session key") + <$> atomically (getUserMessage ochan devstate) logger $ User msg emitOutput o case o of @@ -142,16 +148,19 @@ emitOutput (GotControl _) = -- | Get messages from user, check their signature, and make sure that they -- are properly chained from past messages, before returning. -getUserMessage :: TChan (Message Seen) -> TVar DeveloperState -> STM (Output, Message Seen) +getUserMessage :: TMChan (Message Seen) -> TVar DeveloperState -> STM (Maybe (Output, Message Seen)) getUserMessage ochan devstate = do - msg <- readTChan ochan - ds <- readTVar devstate - -- Check signature before doing anything else. - if verifySigned (developerSigVerifier ds) msg - then do - o <- process ds msg - return (o, msg) - else getUserMessage ochan devstate + mmsg <- readTMChan ochan + case mmsg of + Nothing -> return Nothing + Just msg -> do + ds <- readTVar devstate + -- Check signature before doing anything else. + if verifySigned (developerSigVerifier ds) msg + then do + o <- process ds msg + return (Just (o, msg)) + else getUserMessage ochan devstate where process ds (ActivityMessage act@(Activity (Seen (Val b)) _ _)) = do let (legal, ds') = isLegalSeen act ds @@ -224,9 +233,10 @@ isLegalSeen act@(Activity (Seen (Val b)) (Just hp) _) ds -- | Start by reading the initial two messages from the user side, -- their session key and the startup message. -processSessionStart :: TChan (Message Seen) -> Logger -> IO (TVar DeveloperState) +processSessionStart :: TMChan (Message Seen) -> Logger -> IO (TVar DeveloperState) processSessionStart ochan logger = do - sessionmsg <- atomically $ readTChan ochan + sessionmsg <- fromMaybe (error "Did not get session initialization message") + <$> atomically (readTMChan ochan) logger $ User sessionmsg sigverifier <- case sessionmsg of ControlMessage c@(Control (SessionKey pk) _) -> @@ -235,7 +245,8 @@ processSessionStart ochan logger = do then return sv else error "Badly signed session initialization message" _ -> error $ "Unexpected session initialization message: " ++ show sessionmsg - startmsg <- atomically $ readTChan ochan + startmsg <- fromMaybe (error "Did not get session startup message") + <$> atomically (readTMChan ochan) logger $ User startmsg starthash <- case startmsg of ActivityMessage act@(Activity (Seen (Val b)) Nothing _) diff --git a/Role/Downloader.hs b/Role/Downloader.hs index d327c8c..55d7b63 100644 --- a/Role/Downloader.hs +++ b/Role/Downloader.hs @@ -6,12 +6,13 @@ import CmdLine import SessionID import Control.Concurrent.STM +import Control.Concurrent.STM.TMChan import Role.Developer (run', processSessionStart, getUserMessage, Output(..)) run :: DownloadOpts -> IO () run = run' downloader . downloadUrl -downloader :: TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO () +downloader :: TMChan (Message Entered) -> TMChan (Message Seen) -> SessionID -> IO () downloader _ichan ochan sid = do let logfile = sessionLogFile "." sid putStrLn $ "Starting download to " ++ logfile @@ -21,9 +22,12 @@ downloader _ichan ochan sid = do go logger st where go logger st = do - (o, msg) <- atomically $ getUserMessage ochan st - _ <- logger $ User msg - case o of - ProtocolError e -> error ("Protocol error: " ++ e) - _ -> return () - go logger st + v <- atomically $ getUserMessage ochan st + case v of + Nothing -> return () + Just (o, msg) -> do + _ <- logger $ User msg + case o of + ProtocolError e -> error ("Protocol error: " ++ e) + _ -> return () + go logger st diff --git a/Role/User.hs b/Role/User.hs index 49c263c..fdf4e53 100644 --- a/Role/User.hs +++ b/Role/User.hs @@ -14,6 +14,7 @@ import SessionID import Control.Concurrent.Async import Control.Concurrent.STM +import Control.Concurrent.STM.TMChan import System.Process import System.Exit import qualified Data.ByteString as B @@ -63,11 +64,11 @@ data UserState = UserState , userSigVerifier :: SigVerifier } -user :: B.ByteString -> Pty -> TChan (Message Seen) -> TChan (Message Entered) -> IO () +user :: B.ByteString -> Pty -> TMChan (Message Seen) -> TMChan (Message Entered) -> IO () user starttxt p ochan ichan = withLogger "debug-me.log" $ \logger -> do -- Start by establishing our session key, and displaying the starttxt. let initialmessage msg = do - atomically $ writeTChan ochan msg + atomically $ writeTMChan ochan msg logger $ User msg sk <- genMySessionKey pk <- myPublicKey sk @@ -100,9 +101,9 @@ forwardTtyInputToPty p = do writePty p b forwardTtyInputToPty p --- | Forward things written to the Pty out the TChan, and also display +-- | Forward things written to the Pty out the TMChan, and also display -- it on their Tty. -sendPtyOutput :: Pty -> TChan (Message Seen) -> TVar UserState -> Logger -> IO () +sendPtyOutput :: Pty -> TMChan (Message Seen) -> TVar UserState -> Logger -> IO () sendPtyOutput p ochan us logger = go where go = do @@ -117,7 +118,7 @@ sendPtyOutput p ochan us logger = go go class SendableToDeveloper t where - sendDeveloper :: TChan (Message Seen) -> TVar UserState -> t -> POSIXTime -> STM (Message Seen) + sendDeveloper :: TMChan (Message Seen) -> TVar UserState -> t -> POSIXTime -> STM (Message Seen) instance SendableToDeveloper Seen where sendDeveloper ochan us seen now = do @@ -127,7 +128,7 @@ instance SendableToDeveloper Seen where mkSigned (userSessionKey st) $ Activity seen (loggedHash prev) let l = mkLog (User msg) now - writeTChan ochan msg + writeTMChan ochan msg writeTVar us $ st { backLog = l :| toList bl } return msg @@ -137,23 +138,24 @@ instance SendableToDeveloper ControlAction where let msg = ControlMessage $ mkSigned (userSessionKey st) (Control c) -- Control messages are not kept in the backlog. - writeTChan ochan msg + writeTMChan ochan msg return msg --- | Read things to be entered from the TChan, verify if they're legal, +-- | Read things to be entered from the TMChan, verify if they're legal, -- and send them to the Pty. -sendPtyInput :: TChan (Message Entered) -> TChan (Message Seen) -> Pty -> TVar UserState -> Logger -> IO () +sendPtyInput :: TMChan (Message Entered) -> TMChan (Message Seen) -> Pty -> TVar UserState -> Logger -> IO () sendPtyInput ichan ochan p us logger = go where go = do now <- getPOSIXTime v <- atomically $ getDeveloperMessage ichan ochan us now case v of - InputMessage msg@(ActivityMessage entered) -> do + Nothing -> return () + Just (InputMessage msg@(ActivityMessage entered)) -> do logger $ Developer msg writePty p $ val $ enteredData $ activity entered go - InputMessage msg@(ControlMessage (Control c _)) -> do + Just (InputMessage msg@(ControlMessage (Control c _))) -> do logger $ Developer msg case c of SessionKey pk -> do @@ -162,10 +164,10 @@ sendPtyInput ichan ochan p us logger = go Rejected r -> error $ "User side received a Rejected: " ++ show r SessionKeyAccepted _ -> error "User side received a SessionKeyAccepted" SessionKeyRejected _ -> error "User side received a SessionKeyRejected" - RejectedMessage rej -> do + Just (RejectedMessage rej) -> do logger $ User rej go - BadlySignedMessage _ -> go + Just (BadlySignedMessage _) -> go data Input = InputMessage (Message Entered) @@ -177,9 +179,14 @@ data Input -- signature of the message is only verified against the key in it), and -- make sure it's legal before returning it. If it's not legal, sends a -- Reject message. -getDeveloperMessage :: TChan (Message Entered) -> TChan (Message Seen) -> TVar UserState -> POSIXTime -> STM Input -getDeveloperMessage ichan ochan us now = do - msg <- readTChan ichan +getDeveloperMessage :: TMChan (Message Entered) -> TMChan (Message Seen) -> TVar UserState -> POSIXTime -> STM (Maybe Input) +getDeveloperMessage ichan ochan us now = maybe + (return Nothing) + (\msg -> Just <$> getDeveloperMessage' msg ochan us now) + =<< readTMChan ichan + +getDeveloperMessage' :: Message Entered -> TMChan (Message Seen) -> TVar UserState -> POSIXTime -> STM Input +getDeveloperMessage' msg ochan us now = do st <- readTVar us case msg of ControlMessage (Control (SessionKey pk) _) -> do @@ -209,7 +216,7 @@ getDeveloperMessage ichan ochan us now = do -- | Check if the public key a developer presented is one we want to use, -- and if so, add it to the userSigVerifier. -checkDeveloperPublicKey :: TChan (Message Seen) -> TVar UserState -> Logger -> PublicKey -> IO () +checkDeveloperPublicKey :: TMChan (Message Seen) -> TVar UserState -> Logger -> PublicKey -> IO () checkDeveloperPublicKey ochan us logger pk = do now <- getPOSIXTime -- TODO check gpg sig.. diff --git a/Role/Watcher.hs b/Role/Watcher.hs index 620733c..6ed1a6b 100644 --- a/Role/Watcher.hs +++ b/Role/Watcher.hs @@ -7,17 +7,21 @@ import CmdLine import SessionID import Control.Concurrent.STM +import Control.Concurrent.STM.TMChan import Role.Developer (run', processSessionStart, getUserMessage, emitOutput) run :: WatchOpts -> IO () run = run' watcher . watchUrl -watcher :: TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO () +watcher :: TMChan (Message Entered) -> TMChan (Message Seen) -> SessionID -> IO () watcher _ichan ochan _ = inRawMode $ do st <- processSessionStart ochan nullLogger go st where go st = do - (o, _msg) <- atomically $ getUserMessage ochan st - emitOutput o - go st + v <- atomically $ getUserMessage ochan st + case v of + Nothing -> return () + Just (o, _msg) -> do + emitOutput o + go st diff --git a/WebSockets.hs b/WebSockets.hs index ea6e251..f3712a9 100644 --- a/WebSockets.hs +++ b/WebSockets.hs @@ -19,6 +19,7 @@ import SessionID import Network.WebSockets hiding (Message) import Control.Concurrent.STM +import Control.Concurrent.STM.TMChan import Control.Concurrent.Async import Control.Exception import GHC.Generics (Generic) @@ -59,7 +60,7 @@ clientApp :: Mode -> (sent -> LogMessage) -> (LogMessage -> Maybe received) - -> (TChan sent -> TChan received -> SessionID -> IO a) + -> (TMChan sent -> TMChan received -> SessionID -> IO a) -> ClientApp a clientApp mode mksent filterreceived a conn = do -- Ping every 30 seconds to avoid timeouts caused by proxies etc. @@ -73,19 +74,19 @@ clientApp mode mksent filterreceived a conn = do _ -> protocolError conn "Did not get expected Ready message from server" where setup = do - schan <- newTChanIO - rchan <- newTChanIO + schan <- newTMChanIO + rchan <- newTMChanIO sthread <- async $ relayFromSocket conn $ \v -> case filterreceived v of Nothing -> return () - Just r -> atomically $ writeTChan rchan r + Just r -> atomically $ writeTMChan rchan r rthread <- async $ relayToSocket conn $ - Just . mksent <$> atomically (readTChan schan) + fmap mksent <$> atomically (readTMChan schan) return (schan, rchan, sthread, rthread) cleanup (_, _, sthread, rthread) = do sendBinaryData conn Done - cancel sthread - cancel rthread + () <- wait sthread + wait rthread go sid (schan, rchan, _, _) = a schan rchan sid relayFromSocket :: Connection -> (LogMessage -> IO ()) -> IO () @@ -97,7 +98,9 @@ relayFromSocket conn sender = go LogMessage msg -> do sender msg go - Done -> return () + Done -> do + print "GOT DONE" + return () WireProtocolError e -> protocolError conn e _ -> protocolError conn "Protocol error" @@ -107,7 +110,7 @@ relayToSocket conn getter = go go = do mmsg <- getter case mmsg of - Nothing -> go + Nothing -> return () Just msg -> do sendBinaryData conn (LogMessage msg) go |