{-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving, FlexibleContexts, FlexibleInstances #-} module WebSockets where import Types import Serialization import SessionID import Network.WebSockets hiding (Message) import Control.Concurrent.STM import Control.Concurrent.Async import Control.Exception import qualified Data.Aeson import qualified Data.Text as T import Data.List import Data.Maybe runClientApp :: ClientApp a -> IO a runClientApp = runClient "localhost" 8081 "/" -- | Make a client that sends and receives Messages over a websocket. clientApp :: (WebSocketsData (Message sent), WebSocketsData (Message received)) => Mode -> (TChan (Message sent) -> TChan (Message received) -> SessionID -> IO a) -> ClientApp a clientApp mode a conn = do _v <- negotiateWireVersion conn sendMode conn mode sid <- receiveData conn bracket setup cleanup (go sid) where setup = do schan <- newTChanIO rchan <- newTChanIO sthread <- async $ relayFromSocket conn $ atomically . writeTChan rchan rthread <- async $ relayToSocket conn $ Just <$> atomically (readTChan schan) return (schan, rchan, sthread, rthread) cleanup (_, _, sthread, rthread) = do cancel sthread cancel rthread go sid (schan, rchan, _, _) = a schan rchan sid relayFromSocket :: WebSocketsData (Message received) => Connection -> (Message received -> IO ()) -> IO () relayFromSocket conn sender = go where go = do msg <- receiveData conn sender msg go relayToSocket :: WebSocketsData (Message sent) => Connection -> (IO (Maybe (Message sent))) -> IO () relayToSocket conn getter = go where go = do mmsg <- getter case mmsg of Nothing -> go Just msg -> do sendBinaryData conn msg go newtype WireVersion = WireVersion T.Text deriving (Show, Eq, Generic, Ord) instance FromJSON WireVersion instance ToJSON WireVersion instance WebSocketsData [WireVersion] where -- fromDataMessage = fromLazyByteString . fromDataMessage fromLazyByteString = fromMaybe (error "Unknown WireVersion") . Data.Aeson.decode toLazyByteString = Data.Aeson.encode supportedWireVersions :: [WireVersion] supportedWireVersions = [WireVersion "1"] -- | Send supportedWireVersions and at the same time receive it from -- the remote side. The highest version present in both lists will be used. negotiateWireVersion :: Connection -> IO WireVersion negotiateWireVersion conn = do (_, remoteversions) <- concurrently (sendTextData conn supportedWireVersions) (receiveData conn) case reverse (intersect (sort supportedWireVersions) (sort remoteversions)) of (v:_) -> return v [] -> error $ "Unable to negotiate a WireVersion. I support: " ++ show supportedWireVersions ++ " They support: " ++ show remoteversions -- | Modes of operation that can be requested for a websocket connection. data Mode = InitMode T.Text -- ^ Text is unused, but reserved for expansion | ConnectMode T.Text -- ^ Text specifies the SessionID to connect to deriving (Show, Eq, Generic) instance FromJSON Mode instance ToJSON Mode where instance WebSocketsData Mode where -- fromDataMessage = fromLazyByteString . fromDataMessage fromLazyByteString = fromMaybe (error "Unknown Mode") . Data.Aeson.decode toLazyByteString = Data.Aeson.encode sendMode :: Connection -> Mode -> IO () sendMode = sendTextData getMode :: Connection -> IO Mode getMode = receiveData