summaryrefslogtreecommitdiffhomepage
path: root/WebSockets.hs
blob: 25f2162d146c1d02b6e3cb25d3f35462e8777b4c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
{-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving, FlexibleContexts, FlexibleInstances #-}

module WebSockets where

import Types
import Serialization
import SessionID

import Network.WebSockets hiding (Message)
import Control.Concurrent.STM
import Control.Concurrent.Async
import Control.Exception
import qualified Data.Aeson
import qualified Data.Text as T
import Data.List
import Data.Maybe

runClientApp :: ClientApp a -> IO a
runClientApp = runClient "localhost" 8081 "/"

-- | Make a client that sends and receives Messages over a websocket.
clientApp
	:: (WebSocketsData (Message sent), WebSocketsData (Message received))
	=> Mode
	-> (TChan (Message sent) -> TChan (Message received) -> SessionID -> IO a)
	-> ClientApp a
clientApp mode a conn = do
	_v <- negotiateWireVersion conn
	sendMode conn mode
	sid <- receiveData conn
	bracket setup cleanup (go sid)
  where
	setup = do
		schan <- newTChanIO
		rchan <- newTChanIO
		sthread <- async $ relayFromSocket conn $
			atomically . writeTChan rchan
		rthread <- async $ relayToSocket conn $
			Just <$> atomically (readTChan schan)
		return (schan, rchan, sthread, rthread)
	cleanup (_, _, sthread, rthread) = do
		cancel sthread
		cancel rthread
	go sid (schan, rchan, _, _) = a schan rchan sid

relayFromSocket :: WebSocketsData (Message received) => Connection -> (Message received -> IO ()) -> IO ()
relayFromSocket conn sender = go
  where
	go = do
		msg <- receiveData conn
		sender msg
		go

relayToSocket :: WebSocketsData (Message sent) => Connection -> (IO (Maybe (Message sent))) -> IO ()
relayToSocket conn getter = go
  where
	go = do
		mmsg <- getter
		case mmsg of
			Nothing -> go
			Just msg -> do
				sendBinaryData conn msg
				go

newtype WireVersion = WireVersion T.Text
	deriving (Show, Eq, Generic, Ord)

instance FromJSON WireVersion
instance ToJSON WireVersion

instance WebSocketsData [WireVersion] where
	-- fromDataMessage = fromLazyByteString . fromDataMessage
	fromLazyByteString = fromMaybe (error "Unknown WireVersion") . Data.Aeson.decode
	toLazyByteString = Data.Aeson.encode

supportedWireVersions :: [WireVersion]
supportedWireVersions = [WireVersion "1"]

-- | Send supportedWireVersions and at the same time receive it from
-- the remote side. The highest version present in both lists will be used.
negotiateWireVersion :: Connection -> IO WireVersion
negotiateWireVersion conn = do
	(_, remoteversions) <- concurrently
		(sendTextData conn supportedWireVersions)
		(receiveData conn)
	case reverse (intersect (sort supportedWireVersions) (sort remoteversions)) of
		(v:_) -> return v
		[] -> error $ "Unable to negotiate a WireVersion. I support: " ++ show supportedWireVersions ++ " They support: " ++ show remoteversions

-- | Modes of operation that can be requested for a websocket connection.
data Mode
	= InitMode T.Text -- ^ Text is unused, but reserved for expansion
	| ConnectMode T.Text -- ^ Text specifies the SessionID to connect to
	deriving (Show, Eq, Generic)

instance FromJSON Mode
instance ToJSON Mode where

instance WebSocketsData Mode where
	-- fromDataMessage = fromLazyByteString . fromDataMessage
	fromLazyByteString = fromMaybe (error "Unknown Mode") . Data.Aeson.decode
	toLazyByteString = Data.Aeson.encode

sendMode :: Connection -> Mode -> IO ()
sendMode = sendTextData

getMode :: Connection -> IO Mode
getMode = receiveData