From a5f677919c2db47149e545165c9cacbf2c6b07b4 Mon Sep 17 00:00:00 2001 From: Joey Hess Date: Fri, 21 Apr 2017 18:52:58 -0400 Subject: client now fully working --- WebSockets.hs | 111 +++++++++++++++++++++++++++++++--------------------------- 1 file changed, 60 insertions(+), 51 deletions(-) (limited to 'WebSockets.hs') diff --git a/WebSockets.hs b/WebSockets.hs index 1f18b14..8816b6b 100644 --- a/WebSockets.hs +++ b/WebSockets.hs @@ -1,30 +1,34 @@ -{-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving, FlexibleContexts #-} +{-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving, FlexibleContexts, FlexibleInstances #-} module WebSockets where import Types import Serialization +import Network.WebSockets hiding (Message) import Control.Concurrent.STM import Control.Concurrent.Async import Control.Exception import qualified Data.Aeson import qualified Data.Binary -import qualified Network.WebSockets as WS import qualified Data.Text as T +import qualified Data.ByteString.Lazy as L import Data.List import Data.Maybe -runClientApp :: WS.ClientApp a -> IO a -runClientApp = WS.runClient "localhost" 8080 "/" +runClientApp :: ClientApp a -> IO a +runClientApp = runClient "localhost" 8081 "/" -- | Make a client that sends and receives Messages over a websocket. clientApp - :: (Binary (Message sent), Binary (Message received)) + :: (WebSocketsData (Message sent), WebSocketsData (Message received)) => Mode -> (TChan (Message sent) -> TChan (Message received) -> IO a) - -> WS.ClientApp a -clientApp mode a conn = bracket setup cleanup go + -> ClientApp a +clientApp mode a conn = do + vs <- negotiateWireVersion conn + sendMode conn mode + bracket setup cleanup go where setup = do schan <- newTChanIO @@ -37,39 +41,25 @@ clientApp mode a conn = bracket setup cleanup go cleanup (_, _, sthread, rthread) = do cancel sthread cancel rthread - go (schan, rchan, _, _) = do - print "sendWireVersions start" - print "negotiateWireVersion start" - _ <- negotiateWireVersion conn - --sendWireVersions conn - print "negotiateWireVersion done" - sendMode conn mode - print "sendmode now done" - a schan rchan - -relayFromSocket :: Binary (Message received) => WS.Connection -> (Message received -> IO ()) -> IO () -relayFromSocket conn send = go + go (schan, rchan, _, _) = a schan rchan + +relayFromSocket :: WebSocketsData (Message received) => Connection -> (Message received -> IO ()) -> IO () +relayFromSocket conn sender = go where go = do - dm <- WS.receiveDataMessage conn - case dm of - WS.Binary b -> case Data.Binary.decodeOrFail b of - Right (_, _, msg) -> do - send msg - go - Left (_, _, err) -> error $ "Deserialization error: " ++ err - WS.Text _ -> error "Unexpected Text received on websocket" - -relayToSocket :: Binary (Message sent) => WS.Connection -> (IO (Maybe (Message sent))) -> IO () -relayToSocket conn get = go + msg <- receiveData conn + sender msg + go + +relayToSocket :: WebSocketsData (Message sent) => Connection -> (IO (Maybe (Message sent))) -> IO () +relayToSocket conn getter = go where go = do - mmsg <- get + mmsg <- getter case mmsg of Nothing -> return () Just msg -> do - WS.sendDataMessage conn $ WS.Binary $ - Data.Binary.encode msg + sendBinaryData conn msg go newtype WireVersion = WireVersion T.Text @@ -78,36 +68,55 @@ newtype WireVersion = WireVersion T.Text instance FromJSON WireVersion instance ToJSON WireVersion +instance WebSocketsData [WireVersion] where + -- fromDataMessage = fromLazyByteString . fromDataMessage + fromLazyByteString = fromMaybe (error "Unknown WireVersion") . Data.Aeson.decode + toLazyByteString = Data.Aeson.encode + supportedWireVersions :: [WireVersion] supportedWireVersions = [WireVersion "1"] -sendWireVersions :: WS.Connection -> IO () -sendWireVersions conn = WS.sendTextData conn (Data.Aeson.encode supportedWireVersions) - -- | Send supportedWireVersions and at the same time receive it from -- the remote side. The highest version present in both lists will be used. -negotiateWireVersion :: WS.Connection -> IO WireVersion +negotiateWireVersion :: Connection -> IO WireVersion negotiateWireVersion conn = do - remoteversions <- WS.receiveData conn - print ("got versions" :: String) - case Data.Aeson.decode remoteversions of - Nothing -> error "Protocol error: WireVersion list was not sent" - Just l -> case reverse (intersect (sort supportedWireVersions) (sort l)) of - (v:_) -> return v - [] -> error $ "Unable to negotiate a WireVersion. I support: " ++ show supportedWireVersions ++ " They support: " ++ show l + (_, remoteversions) <- concurrently + (sendTextData conn supportedWireVersions) + (receiveData conn) + print ("got versions" :: String, remoteversions) + case reverse (intersect (sort supportedWireVersions) (sort remoteversions)) of + (v:_) -> return v + [] -> error $ "Unable to negotiate a WireVersion. I support: " ++ show supportedWireVersions ++ " They support: " ++ show remoteversions -- | Modes of operation that can be requested for a websocket connection. data Mode - = InitMode T.Text - | ConnectMode T.Text + = InitMode T.Text -- ^ Text is unused, but reserved for expansion + | ConnectMode T.Text -- ^ Text specifies the SessionID to connect to deriving (Show, Eq, Generic) instance FromJSON Mode instance ToJSON Mode where -sendMode :: WS.Connection -> Mode -> IO () -sendMode conn mode = WS.sendTextData conn (Data.Aeson.encode mode) +instance WebSocketsData Mode where + -- fromDataMessage = fromLazyByteString . fromDataMessage + fromLazyByteString = fromMaybe (error "Unknown Mode") . Data.Aeson.decode + toLazyByteString = Data.Aeson.encode + +sendMode :: Connection -> Mode -> IO () +sendMode = sendTextData + +getMode :: Connection -> IO Mode +getMode = receiveData + +instance WebSocketsData (Message Seen) where + fromLazyByteString = decodeBinaryMessage + toLazyByteString = Data.Binary.encode + +instance WebSocketsData (Message Entered) where + fromLazyByteString = decodeBinaryMessage + toLazyByteString = Data.Binary.encode -getMode :: WS.Connection -> IO Mode -getMode conn = fromMaybe (error "Unknown mode") . Data.Aeson.decode - <$> WS.receiveData conn +decodeBinaryMessage :: Binary (Message a) => L.ByteString -> Message a +decodeBinaryMessage b = case Data.Binary.decodeOrFail b of + Right (_, _, msg) -> msg + Left (_, _, err) -> error $ "Binary decode error: " ++ err -- cgit v1.2.3