summaryrefslogtreecommitdiffhomepage
path: root/WebSockets.hs
blob: 1f18b14000b6c7aa163e2d8a2b53f67b48c7214c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
{-# LANGUAGE OverloadedStrings, DeriveGeneric, GeneralizedNewtypeDeriving, FlexibleContexts #-}

module WebSockets where

import Types
import Serialization

import Control.Concurrent.STM
import Control.Concurrent.Async
import Control.Exception
import qualified Data.Aeson
import qualified Data.Binary
import qualified Network.WebSockets as WS
import qualified Data.Text as T
import Data.List
import Data.Maybe

runClientApp :: WS.ClientApp a -> IO a
runClientApp = WS.runClient "localhost" 8080 "/"

-- | Make a client that sends and receives Messages over a websocket.
clientApp
	:: (Binary (Message sent), Binary (Message received))
	=> Mode
	-> (TChan (Message sent) -> TChan (Message received) -> IO a)
	-> WS.ClientApp a
clientApp mode a conn = bracket setup cleanup go
  where
	setup = do
		schan <- newTChanIO
		rchan <- newTChanIO
		sthread <- async $ relayFromSocket conn $
			atomically . writeTChan rchan
		rthread <- async $ relayToSocket conn $
			Just <$> atomically (readTChan schan)
		return (schan, rchan, sthread, rthread)
	cleanup (_, _, sthread, rthread) = do
		cancel sthread
		cancel rthread
	go (schan, rchan, _, _) = do
		print "sendWireVersions start"
		print "negotiateWireVersion start"
		_ <- negotiateWireVersion conn
		--sendWireVersions conn
		print "negotiateWireVersion done"
		sendMode conn mode
		print "sendmode now done"
		a schan rchan

relayFromSocket :: Binary (Message received) => WS.Connection -> (Message received -> IO ()) -> IO ()
relayFromSocket conn send = go
  where
	go = do
		dm <- WS.receiveDataMessage conn
		case dm of
			WS.Binary b -> case Data.Binary.decodeOrFail b of
				Right (_, _, msg) -> do
					send msg
					go
				Left (_, _, err) -> error $ "Deserialization error: " ++ err
			WS.Text _ -> error "Unexpected Text received on websocket"

relayToSocket :: Binary (Message sent) => WS.Connection -> (IO (Maybe (Message sent))) -> IO ()
relayToSocket conn get = go
  where
	go = do
		mmsg <- get
		case mmsg of
			Nothing -> return ()
			Just msg -> do
				WS.sendDataMessage conn $ WS.Binary $
					Data.Binary.encode msg
				go

newtype WireVersion = WireVersion T.Text
	deriving (Show, Eq, Generic, Ord)

instance FromJSON WireVersion
instance ToJSON WireVersion

supportedWireVersions :: [WireVersion]
supportedWireVersions = [WireVersion "1"]

sendWireVersions :: WS.Connection -> IO ()
sendWireVersions conn = WS.sendTextData conn (Data.Aeson.encode supportedWireVersions)

-- | Send supportedWireVersions and at the same time receive it from
-- the remote side. The highest version present in both lists will be used.
negotiateWireVersion :: WS.Connection -> IO WireVersion
negotiateWireVersion conn = do
	remoteversions <- WS.receiveData conn
	print ("got versions" :: String)
	case Data.Aeson.decode remoteversions of
		Nothing -> error "Protocol error: WireVersion list was not sent"
		Just l -> case reverse (intersect (sort supportedWireVersions) (sort l)) of
			(v:_) -> return v
			[] -> error $ "Unable to negotiate a WireVersion. I support: " ++ show supportedWireVersions ++ " They support: " ++ show l

-- | Modes of operation that can be requested for a websocket connection.
data Mode
	= InitMode T.Text
	| ConnectMode T.Text
	deriving (Show, Eq, Generic)

instance FromJSON Mode
instance ToJSON Mode where

sendMode :: WS.Connection -> Mode -> IO ()
sendMode conn mode = WS.sendTextData conn (Data.Aeson.encode mode)

getMode :: WS.Connection -> IO Mode
getMode conn = fromMaybe (error "Unknown mode") . Data.Aeson.decode
	<$> WS.receiveData conn