{-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving, FlexibleContexts #-} module WebSockets where import Types import Serialization import Control.Concurrent.STM import Control.Concurrent.Async import Control.Exception import qualified Data.Aeson import qualified Data.Binary import qualified Network.WebSockets as WS import qualified Data.Text as T import Data.List import Data.Maybe runClientApp :: WS.ClientApp a -> IO a runClientApp = WS.runClient "localhost" 8080 "/" -- | Make a client that sends and receives Messages over a websocket. clientApp :: (Binary (Message sent), Binary (Message received)) => Mode -> (TChan (Message sent) -> TChan (Message received) -> IO a) -> WS.ClientApp a clientApp mode a conn = bracket setup cleanup go where setup = do schan <- newTChanIO rchan <- newTChanIO sthread <- async $ relayFromSocket conn $ atomically . writeTChan rchan rthread <- async $ relayToSocket conn $ Just <$> atomically (readTChan schan) return (schan, rchan, sthread, rthread) cleanup (_, _, sthread, rthread) = do cancel sthread cancel rthread go (schan, rchan, _, _) = do print "sendWireVersions start" print "negotiateWireVersion start" _ <- negotiateWireVersion conn --sendWireVersions conn print "negotiateWireVersion done" sendMode conn mode print "sendmode now done" a schan rchan relayFromSocket :: Binary (Message received) => WS.Connection -> (Message received -> IO ()) -> IO () relayFromSocket conn send = go where go = do dm <- WS.receiveDataMessage conn case dm of WS.Binary b -> case Data.Binary.decodeOrFail b of Right (_, _, msg) -> do send msg go Left (_, _, err) -> error $ "Deserialization error: " ++ err WS.Text _ -> error "Unexpected Text received on websocket" relayToSocket :: Binary (Message sent) => WS.Connection -> (IO (Maybe (Message sent))) -> IO () relayToSocket conn get = go where go = do mmsg <- get case mmsg of Nothing -> return () Just msg -> do WS.sendDataMessage conn $ WS.Binary $ Data.Binary.encode msg go newtype WireVersion = WireVersion T.Text deriving (Show, Eq, Generic, Ord) instance FromJSON WireVersion instance ToJSON WireVersion supportedWireVersions :: [WireVersion] supportedWireVersions = [WireVersion "1"] sendWireVersions :: WS.Connection -> IO () sendWireVersions conn = WS.sendTextData conn (Data.Aeson.encode supportedWireVersions) -- | Send supportedWireVersions and at the same time receive it from -- the remote side. The highest version present in both lists will be used. negotiateWireVersion :: WS.Connection -> IO WireVersion negotiateWireVersion conn = do remoteversions <- WS.receiveData conn print ("got versions" :: String) case Data.Aeson.decode remoteversions of Nothing -> error "Protocol error: WireVersion list was not sent" Just l -> case reverse (intersect (sort supportedWireVersions) (sort l)) of (v:_) -> return v [] -> error $ "Unable to negotiate a WireVersion. I support: " ++ show supportedWireVersions ++ " They support: " ++ show l -- | Modes of operation that can be requested for a websocket connection. data Mode = InitMode T.Text | ConnectMode T.Text deriving (Show, Eq, Generic) instance FromJSON Mode instance ToJSON Mode where sendMode :: WS.Connection -> Mode -> IO () sendMode conn mode = WS.sendTextData conn (Data.Aeson.encode mode) getMode :: WS.Connection -> IO Mode getMode conn = fromMaybe (error "Unknown mode") . Data.Aeson.decode <$> WS.receiveData conn