From 9a8d3bc531647d8b96e66e6daabf2176a1df4afb Mon Sep 17 00:00:00 2001 From: Joey Hess Date: Mon, 24 Apr 2017 15:24:52 -0400 Subject: switch to TMChans so they can be closed when a connection is Done --- Role/Developer.hs | 63 ++++++++++++++++++++++++++++++++---------------------- Role/Downloader.hs | 18 ++++++++++------ Role/User.hs | 41 ++++++++++++++++++++--------------- Role/Watcher.hs | 12 +++++++---- 4 files changed, 80 insertions(+), 54 deletions(-) (limited to 'Role') diff --git a/Role/Developer.hs b/Role/Developer.hs index 0b8fdd9..ffba5c4 100644 --- a/Role/Developer.hs +++ b/Role/Developer.hs @@ -13,16 +13,18 @@ import Pty import Control.Concurrent.Async import Control.Concurrent.STM +import Control.Concurrent.STM.TMChan import System.IO import qualified Data.ByteString as B import qualified Data.Text as T import Data.List +import Data.Maybe import Control.Monad run :: DeveloperOpts -> IO () run = run' developer . debugUrl -run' :: (TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO ()) -> UrlString -> IO () +run' :: (TMChan (Message Entered) -> TMChan (Message Seen) -> SessionID -> IO ()) -> UrlString -> IO () run' runner url = void $ runClientApp app where connect = ConnectMode (T.pack url) @@ -32,7 +34,7 @@ userMessages :: LogMessage -> Maybe (Message Seen) userMessages (User m) = Just m userMessages (Developer _) = Nothing -developer :: TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO () +developer :: TMChan (Message Entered) -> TMChan (Message Seen) -> SessionID -> IO () developer ichan ochan _ = withLogger "debug-me-developer.log" $ \logger -> do devstate <- processSessionStart ochan logger ok <- authUser ichan ochan devstate logger @@ -53,8 +55,8 @@ data DeveloperState = DeveloperState , developerSigVerifier :: SigVerifier } --- | Read things typed by the developer, and forward them to the TChan. -sendTtyInput :: TChan (Message Entered) -> TVar DeveloperState -> Logger -> IO () +-- | Read things typed by the developer, and forward them to the TMChan. +sendTtyInput :: TMChan (Message Entered) -> TVar DeveloperState -> Logger -> IO () sendTtyInput ichan devstate logger = go where go = do @@ -76,7 +78,7 @@ sendTtyInput ichan devstate logger = go } let act = mkSigned (developerSessionKey ds) $ Activity entered (Just $ lastActivity ds) - writeTChan ichan (ActivityMessage act) + writeTMChan ichan (ActivityMessage act) let acth = hash act let ds' = ds { sentSince = sentSince ds ++ [b] @@ -88,31 +90,35 @@ sendTtyInput ichan devstate logger = go logger $ Developer $ ActivityMessage act go --- | Read activity from the TChan and display it to the developer. -sendTtyOutput :: TChan (Message Seen) -> TVar DeveloperState -> Logger -> IO () +-- | Read activity from the TMChan and display it to the developer. +sendTtyOutput :: TMChan (Message Seen) -> TVar DeveloperState -> Logger -> IO () sendTtyOutput ochan devstate logger = go where go = do - (o, msg) <- atomically $ getUserMessage ochan devstate - logger $ User msg - emitOutput o - go + v <- atomically $ getUserMessage ochan devstate + case v of + Nothing -> return () + Just (o, msg) -> do + logger $ User msg + emitOutput o + go -- | Present our session key to the user. -- Wait for them to accept or reject it, while displaying any Seen data -- in the meantime. -authUser :: TChan (Message Entered) -> TChan (Message Seen) -> TVar DeveloperState -> Logger -> IO Bool +authUser :: TMChan (Message Entered) -> TMChan (Message Seen) -> TVar DeveloperState -> Logger -> IO Bool authUser ichan ochan devstate logger = do ds <- atomically $ readTVar devstate pk <- myPublicKey (developerSessionKey ds) let msg = ControlMessage $ mkSigned (developerSessionKey ds) (Control (SessionKey pk)) - atomically $ writeTChan ichan msg + atomically $ writeTMChan ichan msg logger $ Developer msg waitresp pk where waitresp pk = do - (o, msg) <- atomically $ getUserMessage ochan devstate + (o, msg) <- fromMaybe (error "No response from server to our session key") + <$> atomically (getUserMessage ochan devstate) logger $ User msg emitOutput o case o of @@ -142,16 +148,19 @@ emitOutput (GotControl _) = -- | Get messages from user, check their signature, and make sure that they -- are properly chained from past messages, before returning. -getUserMessage :: TChan (Message Seen) -> TVar DeveloperState -> STM (Output, Message Seen) +getUserMessage :: TMChan (Message Seen) -> TVar DeveloperState -> STM (Maybe (Output, Message Seen)) getUserMessage ochan devstate = do - msg <- readTChan ochan - ds <- readTVar devstate - -- Check signature before doing anything else. - if verifySigned (developerSigVerifier ds) msg - then do - o <- process ds msg - return (o, msg) - else getUserMessage ochan devstate + mmsg <- readTMChan ochan + case mmsg of + Nothing -> return Nothing + Just msg -> do + ds <- readTVar devstate + -- Check signature before doing anything else. + if verifySigned (developerSigVerifier ds) msg + then do + o <- process ds msg + return (Just (o, msg)) + else getUserMessage ochan devstate where process ds (ActivityMessage act@(Activity (Seen (Val b)) _ _)) = do let (legal, ds') = isLegalSeen act ds @@ -224,9 +233,10 @@ isLegalSeen act@(Activity (Seen (Val b)) (Just hp) _) ds -- | Start by reading the initial two messages from the user side, -- their session key and the startup message. -processSessionStart :: TChan (Message Seen) -> Logger -> IO (TVar DeveloperState) +processSessionStart :: TMChan (Message Seen) -> Logger -> IO (TVar DeveloperState) processSessionStart ochan logger = do - sessionmsg <- atomically $ readTChan ochan + sessionmsg <- fromMaybe (error "Did not get session initialization message") + <$> atomically (readTMChan ochan) logger $ User sessionmsg sigverifier <- case sessionmsg of ControlMessage c@(Control (SessionKey pk) _) -> @@ -235,7 +245,8 @@ processSessionStart ochan logger = do then return sv else error "Badly signed session initialization message" _ -> error $ "Unexpected session initialization message: " ++ show sessionmsg - startmsg <- atomically $ readTChan ochan + startmsg <- fromMaybe (error "Did not get session startup message") + <$> atomically (readTMChan ochan) logger $ User startmsg starthash <- case startmsg of ActivityMessage act@(Activity (Seen (Val b)) Nothing _) diff --git a/Role/Downloader.hs b/Role/Downloader.hs index d327c8c..55d7b63 100644 --- a/Role/Downloader.hs +++ b/Role/Downloader.hs @@ -6,12 +6,13 @@ import CmdLine import SessionID import Control.Concurrent.STM +import Control.Concurrent.STM.TMChan import Role.Developer (run', processSessionStart, getUserMessage, Output(..)) run :: DownloadOpts -> IO () run = run' downloader . downloadUrl -downloader :: TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO () +downloader :: TMChan (Message Entered) -> TMChan (Message Seen) -> SessionID -> IO () downloader _ichan ochan sid = do let logfile = sessionLogFile "." sid putStrLn $ "Starting download to " ++ logfile @@ -21,9 +22,12 @@ downloader _ichan ochan sid = do go logger st where go logger st = do - (o, msg) <- atomically $ getUserMessage ochan st - _ <- logger $ User msg - case o of - ProtocolError e -> error ("Protocol error: " ++ e) - _ -> return () - go logger st + v <- atomically $ getUserMessage ochan st + case v of + Nothing -> return () + Just (o, msg) -> do + _ <- logger $ User msg + case o of + ProtocolError e -> error ("Protocol error: " ++ e) + _ -> return () + go logger st diff --git a/Role/User.hs b/Role/User.hs index 49c263c..fdf4e53 100644 --- a/Role/User.hs +++ b/Role/User.hs @@ -14,6 +14,7 @@ import SessionID import Control.Concurrent.Async import Control.Concurrent.STM +import Control.Concurrent.STM.TMChan import System.Process import System.Exit import qualified Data.ByteString as B @@ -63,11 +64,11 @@ data UserState = UserState , userSigVerifier :: SigVerifier } -user :: B.ByteString -> Pty -> TChan (Message Seen) -> TChan (Message Entered) -> IO () +user :: B.ByteString -> Pty -> TMChan (Message Seen) -> TMChan (Message Entered) -> IO () user starttxt p ochan ichan = withLogger "debug-me.log" $ \logger -> do -- Start by establishing our session key, and displaying the starttxt. let initialmessage msg = do - atomically $ writeTChan ochan msg + atomically $ writeTMChan ochan msg logger $ User msg sk <- genMySessionKey pk <- myPublicKey sk @@ -100,9 +101,9 @@ forwardTtyInputToPty p = do writePty p b forwardTtyInputToPty p --- | Forward things written to the Pty out the TChan, and also display +-- | Forward things written to the Pty out the TMChan, and also display -- it on their Tty. -sendPtyOutput :: Pty -> TChan (Message Seen) -> TVar UserState -> Logger -> IO () +sendPtyOutput :: Pty -> TMChan (Message Seen) -> TVar UserState -> Logger -> IO () sendPtyOutput p ochan us logger = go where go = do @@ -117,7 +118,7 @@ sendPtyOutput p ochan us logger = go go class SendableToDeveloper t where - sendDeveloper :: TChan (Message Seen) -> TVar UserState -> t -> POSIXTime -> STM (Message Seen) + sendDeveloper :: TMChan (Message Seen) -> TVar UserState -> t -> POSIXTime -> STM (Message Seen) instance SendableToDeveloper Seen where sendDeveloper ochan us seen now = do @@ -127,7 +128,7 @@ instance SendableToDeveloper Seen where mkSigned (userSessionKey st) $ Activity seen (loggedHash prev) let l = mkLog (User msg) now - writeTChan ochan msg + writeTMChan ochan msg writeTVar us $ st { backLog = l :| toList bl } return msg @@ -137,23 +138,24 @@ instance SendableToDeveloper ControlAction where let msg = ControlMessage $ mkSigned (userSessionKey st) (Control c) -- Control messages are not kept in the backlog. - writeTChan ochan msg + writeTMChan ochan msg return msg --- | Read things to be entered from the TChan, verify if they're legal, +-- | Read things to be entered from the TMChan, verify if they're legal, -- and send them to the Pty. -sendPtyInput :: TChan (Message Entered) -> TChan (Message Seen) -> Pty -> TVar UserState -> Logger -> IO () +sendPtyInput :: TMChan (Message Entered) -> TMChan (Message Seen) -> Pty -> TVar UserState -> Logger -> IO () sendPtyInput ichan ochan p us logger = go where go = do now <- getPOSIXTime v <- atomically $ getDeveloperMessage ichan ochan us now case v of - InputMessage msg@(ActivityMessage entered) -> do + Nothing -> return () + Just (InputMessage msg@(ActivityMessage entered)) -> do logger $ Developer msg writePty p $ val $ enteredData $ activity entered go - InputMessage msg@(ControlMessage (Control c _)) -> do + Just (InputMessage msg@(ControlMessage (Control c _))) -> do logger $ Developer msg case c of SessionKey pk -> do @@ -162,10 +164,10 @@ sendPtyInput ichan ochan p us logger = go Rejected r -> error $ "User side received a Rejected: " ++ show r SessionKeyAccepted _ -> error "User side received a SessionKeyAccepted" SessionKeyRejected _ -> error "User side received a SessionKeyRejected" - RejectedMessage rej -> do + Just (RejectedMessage rej) -> do logger $ User rej go - BadlySignedMessage _ -> go + Just (BadlySignedMessage _) -> go data Input = InputMessage (Message Entered) @@ -177,9 +179,14 @@ data Input -- signature of the message is only verified against the key in it), and -- make sure it's legal before returning it. If it's not legal, sends a -- Reject message. -getDeveloperMessage :: TChan (Message Entered) -> TChan (Message Seen) -> TVar UserState -> POSIXTime -> STM Input -getDeveloperMessage ichan ochan us now = do - msg <- readTChan ichan +getDeveloperMessage :: TMChan (Message Entered) -> TMChan (Message Seen) -> TVar UserState -> POSIXTime -> STM (Maybe Input) +getDeveloperMessage ichan ochan us now = maybe + (return Nothing) + (\msg -> Just <$> getDeveloperMessage' msg ochan us now) + =<< readTMChan ichan + +getDeveloperMessage' :: Message Entered -> TMChan (Message Seen) -> TVar UserState -> POSIXTime -> STM Input +getDeveloperMessage' msg ochan us now = do st <- readTVar us case msg of ControlMessage (Control (SessionKey pk) _) -> do @@ -209,7 +216,7 @@ getDeveloperMessage ichan ochan us now = do -- | Check if the public key a developer presented is one we want to use, -- and if so, add it to the userSigVerifier. -checkDeveloperPublicKey :: TChan (Message Seen) -> TVar UserState -> Logger -> PublicKey -> IO () +checkDeveloperPublicKey :: TMChan (Message Seen) -> TVar UserState -> Logger -> PublicKey -> IO () checkDeveloperPublicKey ochan us logger pk = do now <- getPOSIXTime -- TODO check gpg sig.. diff --git a/Role/Watcher.hs b/Role/Watcher.hs index 620733c..6ed1a6b 100644 --- a/Role/Watcher.hs +++ b/Role/Watcher.hs @@ -7,17 +7,21 @@ import CmdLine import SessionID import Control.Concurrent.STM +import Control.Concurrent.STM.TMChan import Role.Developer (run', processSessionStart, getUserMessage, emitOutput) run :: WatchOpts -> IO () run = run' watcher . watchUrl -watcher :: TChan (Message Entered) -> TChan (Message Seen) -> SessionID -> IO () +watcher :: TMChan (Message Entered) -> TMChan (Message Seen) -> SessionID -> IO () watcher _ichan ochan _ = inRawMode $ do st <- processSessionStart ochan nullLogger go st where go st = do - (o, _msg) <- atomically $ getUserMessage ochan st - emitOutput o - go st + v <- atomically $ getUserMessage ochan st + case v of + Nothing -> return () + Just (o, _msg) -> do + emitOutput o + go st -- cgit v1.2.3