Hash2Pub/src/Hash2Pub/FediChord.hs

638 lines
30 KiB
Haskell

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{- |
Module : FediChord
Description : An opinionated implementation of the EpiChord DHT by Leong et al.
Copyright : (c) schmittlauch, 2019-2020
License : AGPL-3
Stability : experimental
Modernised EpiChord + k-choices load balancing
-}
module Hash2Pub.FediChord (
NodeID -- abstract, but newtype constructors cannot be hidden
, getNodeID
, toNodeID
, NodeState (..)
, LocalNodeState (..)
, RemoteNodeState (..)
, setSuccessors
, setPredecessors
, NodeCache
, CacheEntry(..)
, cacheGetNodeStateUnvalidated
, initCache
, cacheLookup
, cacheLookupSucc
, cacheLookupPred
, localCompare
, genNodeID
, genNodeIDBS
, genKeyID
, genKeyIDBS
, byteStringToUInteger
, ipAddrAsBS
, bsAsIpAddr
, FediChordConf(..)
, fediChordInit
, fediChordJoin
, fediChordBootstrapJoin
, tryBootstrapJoining
, fediMainThreads
, RealNode (..)
, nodeStateInit
, mkServerSocket
, mkSendSocket
, resolve
, cacheWriter
, joinOnNewEntriesThread
) where
import Control.Applicative ((<|>))
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Concurrent.STM.TQueue
import Control.Concurrent.STM.TVar
import Control.Exception
import Control.Monad (forM_, forever)
import Control.Monad.Except
import Crypto.Hash
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.ByteString.UTF8 as BSU
import Data.Either (rights)
import Data.Foldable (foldr')
import Data.Functor.Identity
import Data.IP (IPv6, fromHostAddress6,
toHostAddress6)
import Data.List ((\\))
import qualified Data.Map.Strict as Map
import Data.Maybe (catMaybes, fromJust, fromMaybe,
isJust, isNothing, mapMaybe)
import qualified Data.Set as Set
import Data.Time.Clock.POSIX
import Data.Typeable (Typeable (..), typeOf)
import Data.Word
import qualified Network.ByteOrder as NetworkBytes
import Network.Socket hiding (recv, recvFrom, send,
sendTo)
import Network.Socket.ByteString
import Safe
import System.Random (randomRIO)
import Hash2Pub.DHTProtocol
import Hash2Pub.FediChordTypes
import Hash2Pub.Utils
import Debug.Trace (trace)
-- | initialise data structures, compute own IDs and bind to listening socket
-- ToDo: load persisted state, thus this function already operates in IO
fediChordInit :: FediChordConf -> IO (Socket, LocalNodeStateSTM)
fediChordInit initConf = do
let realNode = RealNode {
vservers = []
, nodeConfig = initConf
, bootstrapNodes = confBootstrapNodes initConf
}
realNodeSTM <- newTVarIO realNode
initialState <- nodeStateInit realNodeSTM
initialStateSTM <- newTVarIO initialState
serverSock <- mkServerSocket (getIpAddr initialState) (getDhtPort initialState)
pure (serverSock, initialStateSTM)
-- | initialises the 'NodeState' for this local node.
-- Separated from 'fediChordInit' to be usable in tests.
nodeStateInit :: RealNodeSTM -> IO LocalNodeState
nodeStateInit realNodeSTM = do
realNode <- readTVarIO realNodeSTM
cacheSTM <- newTVarIO initCache
q <- atomically newTQueue
let
conf = nodeConfig realNode
vsID = 0
containedState = RemoteNodeState {
domain = confDomain conf
, ipAddr = confIP conf
, nid = genNodeID (confIP conf) (confDomain conf) $ fromInteger vsID
, dhtPort = toEnum $ confDhtPort conf
, servicePort = 0
, vServerID = vsID
}
initialState = LocalNodeState {
nodeState = containedState
, nodeCacheSTM = cacheSTM
, cacheWriteQueue = q
, successors = []
, predecessors = []
, kNeighbours = 3
, lNumBestNodes = 3
, pNumParallelQueries = 2
, jEntriesPerSlice = 2
, parentRealNode = realNodeSTM
}
pure initialState
-- | Join a new node into the DHT, using a provided bootstrap node as initial cache seed
-- for resolving the new node's position.
fediChordBootstrapJoin :: LocalNodeStateSTM -- ^ the local 'NodeState'
-> (String, PortNumber) -- ^ domain and port of a bootstrapping node
-> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a
-- successful join, otherwise an error message
fediChordBootstrapJoin nsSTM bootstrapNode = do
-- can be invoked multiple times with all known bootstrapping nodes until successfully joined
ns <- readTVarIO nsSTM
runExceptT $ do
-- 1. get routed to the currently responsible node
lookupResp <- liftIO $ bootstrapQueryId nsSTM bootstrapNode $ getNid ns
currentlyResponsible <- liftEither lookupResp
liftIO . putStrLn $ "Trying to join on " <> show (getNid currentlyResponsible)
-- 2. then send a join to the currently responsible node
joinResult <- liftIO $ requestJoin currentlyResponsible nsSTM
liftEither joinResult
-- Periodically lookup own ID through a random bootstrapping node to discover and merge separated DHT clusters.
-- Unjoined try joining instead.
convergenceSampleThread :: LocalNodeStateSTM -> IO ()
convergenceSampleThread nsSTM = forever $ do
nsSnap <- readTVarIO nsSTM
parentNode <- readTVarIO $ parentRealNode nsSnap
if isJoined nsSnap
then
runExceptT (do
-- joined node: choose random node, do queryIDLoop, compare result with own responsibility
let bss = bootstrapNodes parentNode
randIndex <- liftIO $ randomRIO (0, length bss - 1)
chosenNode <- maybe (throwError "invalid bootstrapping node index") pure $ atMay bss randIndex
lookupResult <- liftIO $ bootstrapQueryId nsSTM chosenNode (getNid nsSnap)
currentlyResponsible <- liftEither lookupResult
if getNid currentlyResponsible /= getNid nsSnap
-- if mismatch, stabilise on the result, else do nothing
then do
stabResult <- liftIO $ requestStabilise nsSnap currentlyResponsible
(preds, succs) <- liftEither stabResult
-- TODO: verify neighbours before adding, see #55
liftIO . atomically $ do
ns <- readTVar nsSTM
writeTVar nsSTM $ addPredecessors preds ns
else pure ()
) >> pure ()
-- unjoined node: try joining through all bootstrapping nodes
else tryBootstrapJoining nsSTM >> pure ()
let delaySecs = confBootstrapSamplingInterval . nodeConfig $ parentNode
threadDelay $ delaySecs * 10^6
-- | Try joining the DHT through any of the bootstrapping nodes until it succeeds.
tryBootstrapJoining :: LocalNodeStateSTM -> IO (Either String LocalNodeStateSTM)
tryBootstrapJoining nsSTM = do
bss <- atomically $ do
nsSnap <- readTVar nsSTM
realNodeSnap <- readTVar $ parentRealNode nsSnap
pure $ bootstrapNodes realNodeSnap
tryJoining bss
where
tryJoining (bn:bns) = do
j <- fediChordBootstrapJoin nsSTM bn
case j of
Left err -> putStrLn ("join error: " <> err) >> tryJoining bns
Right joined -> pure $ Right joined
tryJoining [] = pure $ Left "Exhausted all bootstrap points for joining."
-- | Look up a key just based on the responses of a single bootstrapping node.
bootstrapQueryId :: LocalNodeStateSTM -> (String, PortNumber) -> NodeID -> IO (Either String RemoteNodeState)
bootstrapQueryId nsSTM (bootstrapHost, bootstrapPort) targetID = do
ns <- readTVarIO nsSTM
srcAddr <- confIP . nodeConfig <$> readTVarIO (parentRealNode ns)
bootstrapResponse <- bracket (mkSendSocket srcAddr bootstrapHost bootstrapPort) close (
-- Initialise an empty cache only with the responses from a bootstrapping node
fmap Right . sendRequestTo 5000 3 (lookupMessage targetID ns Nothing)
)
`catch` (\e -> pure . Left $ "Error at bootstrap QueryId: " <> displayException (e :: IOException))
case bootstrapResponse of
Left err -> pure $ Left err
Right resp
| resp == Set.empty -> pure . Left $ "Bootstrapping node " <> show bootstrapHost <> " gave no response."
| otherwise -> do
now <- getPOSIXTime
-- create new cache with all returned node responses
let bootstrapCache =
-- traverse response parts
foldr' (\resp cacheAcc -> case queryResult <$> payload resp of
Nothing -> cacheAcc
Just (FOUND result1) -> addCacheEntryPure now (RemoteCacheEntry result1 now) cacheAcc
Just (FORWARD resultset) -> foldr' (addCacheEntryPure now) cacheAcc resultset
)
initCache resp
currentlyResponsible <- queryIdLookupLoop bootstrapCache ns 50 $ getNid ns
pure $ Right currentlyResponsible
-- | join a node to the DHT using the global node cache
-- node's position.
fediChordJoin :: LocalNodeStateSTM -- ^ the local 'NodeState'
-> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a
-- successful join, otherwise an error message
fediChordJoin nsSTM = do
ns <- readTVarIO nsSTM
-- 1. get routed to the currently responsible node
currentlyResponsible <- requestQueryID ns $ getNid ns
putStrLn $ "Trying to join on " <> show (getNid currentlyResponsible)
-- 2. then send a join to the currently responsible node
joinResult <- requestJoin currentlyResponsible nsSTM
case joinResult of
Left err -> pure . Left $ "Error joining on " <> err
Right joinedNS -> pure . Right $ joinedNS
-- | Wait for new cache entries to appear and then try joining on them.
-- Exits after successful joining.
joinOnNewEntriesThread :: LocalNodeStateSTM -> IO ()
joinOnNewEntriesThread nsSTM = loop
where
loop = do
nsSnap <- readTVarIO nsSTM
(lookupResult, cache) <- atomically $ do
cache <- readTVar $ nodeCacheSTM nsSnap
case queryLocalCache nsSnap cache 1 (getNid nsSnap) of
-- empty cache, block until cache changes and then retry
(FORWARD s) | Set.null s -> retry
result -> pure (result, cache)
case lookupResult of
-- already joined
FOUND _ -> do
print =<< readTVarIO nsSTM
pure ()
-- otherwise try joining
FORWARD _ -> do
joinResult <- fediChordJoin nsSTM
either
-- on join failure, sleep and retry
-- TODO: make delay configurable
(const $ threadDelay (30 * 10^6) >> loop)
(const $ pure ())
joinResult
emptyset = Set.empty -- because pattern matches don't accept qualified names
-- | cache updater thread that waits for incoming NodeCache update instructions on
-- the node's cacheWriteQueue and then modifies the NodeCache as the single writer.
cacheWriter :: LocalNodeStateSTM -> IO ()
cacheWriter nsSTM =
forever $ atomically $ do
ns <- readTVar nsSTM
cacheModifier <- readTQueue $ cacheWriteQueue ns
modifyTVar' (nodeCacheSTM ns) cacheModifier
-- TODO: make max entry age configurable
maxEntryAge :: POSIXTime
maxEntryAge = 600
-- | Periodically iterate through cache, clean up expired entries and verify unverified ones
cacheVerifyThread :: LocalNodeStateSTM -> IO ()
cacheVerifyThread nsSTM = forever $ do
putStrLn "cache verify run: begin"
-- get cache
(ns, cache) <- atomically $ do
ns <- readTVar nsSTM
cache <- readTVar $ nodeCacheSTM ns
pure (ns, cache)
-- iterate entries:
-- for avoiding too many time syscalls, get current time before iterating.
now <- getPOSIXTime
forM_ (cacheEntries cache) (\(CacheEntry validated node ts) ->
-- case too old: delete (future work: decide whether pinging and resetting timestamp is better)
if (now - ts) > maxEntryAge
then
queueDeleteEntry (getNid node) ns
-- case unverified: try verifying, otherwise delete
else if not validated
then do
-- marking as verified is done by 'requestPing' as well
pong <- requestPing ns node
either (\_->
queueDeleteEntry (getNid node) ns
)
(\vss ->
if node `notElem` vss
then queueDeleteEntry (getNid node) ns
-- after verifying a node, check whether it can be a closer neighbour
else do
if node `isPossiblePredecessor` ns
then atomically $ do
ns' <- readTVar nsSTM
writeTVar nsSTM $ addPredecessors [node] ns'
else pure ()
if node `isPossibleSuccessor` ns
then atomically $ do
ns' <- readTVar nsSTM
writeTVar nsSTM $ addSuccessors [node] ns'
else pure ()
) pong
else pure ()
)
-- check the cache invariant per slice and, if necessary, do a single lookup to the
-- middle of each slice not verifying the invariant
latestNs <- readTVarIO nsSTM
latestCache <- readTVarIO $ nodeCacheSTM latestNs
let nodesToQuery targetID = case queryLocalCache latestNs latestCache (lNumBestNodes latestNs) targetID of
FOUND node -> [node]
FORWARD nodeSet -> remoteNode <$> Set.elems nodeSet
forM_ (checkCacheSliceInvariants latestNs latestCache) (\targetID ->
forkIO $ sendQueryIdMessages targetID latestNs (Just (1 + jEntriesPerSlice latestNs)) (nodesToQuery targetID) >> pure () -- ask for 1 entry more than j because of querying the middle
)
putStrLn "cache verify run: end"
threadDelay $ 10^6 * round maxEntryAge `div` 20
-- | Checks the invariant of at least @jEntries@ per cache slice.
-- If this invariant does not hold, the middle of the slice is returned for
-- making lookups to that ID
checkCacheSliceInvariants :: LocalNodeState
-> NodeCache
-> [NodeID] -- ^ list of middle IDs of slices not
-- ^ fulfilling the invariant
checkCacheSliceInvariants ns
| isJoined ns = checkPredecessorSlice jEntries (getNid ns) startBound lastPred <> checkSuccessorSlice jEntries (getNid ns) startBound lastSucc
| otherwise = const []
where
jEntries = jEntriesPerSlice ns
lastPred = getNid <$> lastMay (predecessors ns)
lastSucc = getNid <$> lastMay (successors ns)
-- start slice boundary: 1/2 key space
startBound = getNid ns + 2^(idBits - 1)
checkSuccessorSlice :: Integral i => i -> NodeID -> NodeID -> Maybe NodeID -> NodeCache -> [NodeID]
checkSuccessorSlice _ _ _ Nothing _ = []
checkSuccessorSlice j ownID upperBound (Just lastSuccID) cache
| (upperBound `localCompare` lastSuccID) == LT = []
| otherwise =
let
diff = getNodeID $ upperBound - ownID
lowerBound = ownID + fromInteger (diff `div` 2)
middleID = lowerBound + fromInteger (diff `div` 4)
lookupResult = Set.map (getNid . remoteNode) $ closestCachePredecessors jEntries upperBound cache
in
-- check whether j entries are in the slice
if length lookupResult == jEntries
&& all (\r -> (r `localCompare` lowerBound) == GT) lookupResult
&& all (\r -> (r `localCompare` upperBound) == LT) lookupResult
then checkSuccessorSlice j ownID (lowerBound - 1) (Just lastSuccID) cache
-- if not enough entries, add the middle of the slice to list
else middleID : checkSuccessorSlice j ownID (lowerBound - 1) (Just lastSuccID) cache
checkPredecessorSlice :: Integral i => i -> NodeID -> NodeID -> Maybe NodeID -> NodeCache -> [NodeID]
checkPredecessorSlice _ _ _ Nothing _ = []
checkPredecessorSlice j ownID lowerBound (Just lastPredID) cache
| (lowerBound `localCompare` lastPredID) == GT = []
| otherwise =
let
diff = getNodeID $ ownID - lowerBound
upperBound = ownID - fromInteger (diff `div` 2)
middleID = lowerBound + fromInteger (diff `div` 4)
lookupResult = Set.map (getNid . remoteNode) $ closestCachePredecessors jEntries upperBound cache
in
-- check whether j entries are in the slice
if length lookupResult == jEntries
&& all (\r -> (r `localCompare` lowerBound) == GT) lookupResult
&& all (\r -> (r `localCompare` upperBound) == LT) lookupResult
then checkPredecessorSlice j ownID (upperBound + 1) (Just lastPredID) cache
-- if not enough entries, add the middle of the slice to list
else middleID : checkPredecessorSlice j ownID (upperBound + 1) (Just lastPredID) cache
-- | Periodically send @StabiliseRequest' s to the closest neighbour nodes, until
-- one responds, and get their neighbours for maintaining the own neighbour lists.
-- If necessary, request new neighbours.
stabiliseThread :: LocalNodeStateSTM -> IO ()
stabiliseThread nsSTM = forever $ do
ns <- readTVarIO nsSTM
putStrLn "stabilise run: begin"
print ns
-- iterate through the same snapshot, collect potential new neighbours
-- and nodes to be deleted, and modify these changes only at the end of
-- each stabilise run.
-- This decision makes iterating through a potentially changing list easier.
-- don't contact all neighbours unless the previous one failed/ Left ed
predStabilise <- stabiliseClosestResponder ns predecessors 1 []
succStabilise <- stabiliseClosestResponder ns predecessors 1 []
let
(predDeletes, predNeighbours) = either (const ([], [])) id predStabilise
(succDeletes, succNeighbours) = either (const ([], [])) id succStabilise
allDeletes = predDeletes <> succDeletes
allNeighbours = predNeighbours <> succNeighbours
-- now actually modify the node state's neighbours
updatedNs <- atomically $ do
newerNsSnap <- readTVar nsSTM
let
-- sorting and taking only k neighbours is taken care of by the
-- setSuccessors/ setPredecessors functions
newPreds = (predecessors newerNsSnap \\ allDeletes) <> allNeighbours
newSuccs = (successors newerNsSnap \\ allDeletes) <> allNeighbours
newNs = setPredecessors newPreds . setSuccessors newSuccs $ newerNsSnap
writeTVar nsSTM newNs
pure newNs
-- delete unresponding nodes from cache as well
mapM_ (atomically . writeTQueue (cacheWriteQueue updatedNs) . deleteCacheEntry . getNid) allDeletes
-- try looking up additional neighbours if list too short
forM_ [(length $ predecessors updatedNs)..(kNeighbours updatedNs)] (\_ -> do
ns' <- readTVarIO nsSTM
nextEntry <- requestQueryID ns' $ pred . getNid $ lastDef (toRemoteNodeState ns') (predecessors ns')
atomically $ do
latestNs <- readTVar nsSTM
writeTVar nsSTM $ addPredecessors [nextEntry] latestNs
)
forM_ [(length $ successors updatedNs)..(kNeighbours updatedNs)] (\_ -> do
ns' <- readTVarIO nsSTM
nextEntry <- requestQueryID ns' $ succ . getNid $ lastDef (toRemoteNodeState ns') (successors ns')
atomically $ do
latestNs <- readTVar nsSTM
writeTVar nsSTM $ addSuccessors [nextEntry] latestNs
)
putStrLn "stabilise run: end"
-- TODO: make delay configurable
threadDelay (60 * 10^6)
where
-- | send a stabilise request to the n-th neighbour
-- (specified by the provided getter function) and on failure retr
-- with the n+1-th neighbour.
-- On success, return 2 lists: The failed nodes and the potential neighbours
-- returned by the queried node.
stabiliseClosestResponder :: LocalNodeState -- ^ own node
-> (LocalNodeState -> [RemoteNodeState]) -- ^ getter function for either predecessors or successors
-> Int -- ^ index of neighbour to query
-> [RemoteNodeState] -- ^ delete accumulator
-> IO (Either String ([RemoteNodeState], [RemoteNodeState])) -- ^ (nodes to be deleted, successfully pinged potential neighbours)
stabiliseClosestResponder ns neighbourGetter neighbourNum deleteAcc
| isNothing (currentNeighbour ns neighbourGetter neighbourNum) = pure $ Left "exhausted all neigbours"
| otherwise = do
let node = fromJust $ currentNeighbour ns neighbourGetter neighbourNum
stabResponse <- requestStabilise ns node
case stabResponse of
-- returning @Left@ signifies the need to try again with next from list
Left err -> stabiliseClosestResponder ns neighbourGetter (neighbourNum+1) (node:deleteAcc)
Right (succs, preds) -> do
-- ping each returned node before actually inserting them
-- send pings in parallel, check whether ID is part of the returned IDs
pingThreads <- mapM (async . checkReachability ns) $ preds <> succs
-- ToDo: exception handling, maybe log them
-- filter out own node
checkedNeighbours <- filter (/= toRemoteNodeState ns) . catMaybes . rights <$> mapM waitCatch pingThreads
pure $ Right (deleteAcc, checkedNeighbours)
currentNeighbour ns neighbourGetter = atMay $ neighbourGetter ns
checkReachability :: LocalNodeState -- ^ this node
-> RemoteNodeState -- ^ node to Ping for reachability
-> IO (Maybe RemoteNodeState) -- ^ if the Pinged node handles the requested node state then that one
checkReachability ns toCheck = do
resp <- requestPing ns toCheck
pure $ either (const Nothing) (\vss ->
if toCheck `elem` vss then Just toCheck else Nothing
) resp
-- | Receives UDP packets and passes them to other threads via the given TQueue.
-- Shall be used as the single receiving thread on the server socket, as multiple
-- threads blocking on the same socket degrades performance.
recvThread :: Socket -- ^ server socket to receive packets from
-> TQueue (BS.ByteString, SockAddr) -- ^ receive queue
-> IO ()
recvThread sock recvQ = forever $ do
packet <- recvFrom sock 65535
atomically $ writeTQueue recvQ packet
-- | Only thread to send data it gets from a TQueue through the server socket.
sendThread :: Socket -- ^ server socket used for sending
-> TQueue (BS.ByteString, SockAddr) -- ^ send queue
-> IO ()
sendThread sock sendQ = forever $ do
(packet, addr) <- atomically $ readTQueue sendQ
sendAllTo sock packet addr
-- | Sets up and manages the main server threads of FediChord
fediMainThreads :: Socket -> LocalNodeStateSTM -> IO ()
fediMainThreads sock nsSTM = do
(\x -> putStrLn $ "launching threads, ns: " <> show x) =<< readTVarIO nsSTM
sendQ <- newTQueueIO
recvQ <- newTQueueIO
-- concurrently launch all handler threads, if one of them throws an exception
-- all get cancelled
concurrently_
(fediMessageHandler sendQ recvQ nsSTM) $
concurrently_ (stabiliseThread nsSTM) $
concurrently_ (cacheVerifyThread nsSTM) $
concurrently_ (convergenceSampleThread nsSTM) $
concurrently_
(sendThread sock sendQ)
(recvThread sock recvQ)
-- defining this here as, for now, the RequestMap is only used by fediMessageHandler.
-- Once that changes, move to FediChordTypes
type RequestMap = Map.Map (SockAddr, Integer) RequestMapEntry
data RequestMapEntry = RequestMapEntry (Set.Set FediChordMessage) (Maybe Integer)
POSIXTime
-- TODO: make purge age configurable
-- | periodically clean up old request parts
responsePurgeAge :: POSIXTime
responsePurgeAge = 60 -- seconds
requestMapPurge :: MVar RequestMap -> IO ()
requestMapPurge mapVar = forever $ do
rMapState <- takeMVar mapVar
now <- getPOSIXTime
putMVar mapVar $ Map.filter (\entry@(RequestMapEntry _ _ ts) ->
now - ts < responsePurgeAge
) rMapState
threadDelay $ round responsePurgeAge * 2 * 10^6
-- | Wait for messages, deserialise them, manage parts and acknowledgement status,
-- and pass them to their specific handling function.
fediMessageHandler :: TQueue (BS.ByteString, SockAddr) -- ^ send queue
-> TQueue (BS.ByteString, SockAddr) -- ^ receive queue
-> LocalNodeStateSTM -- ^ acting NodeState
-> IO ()
fediMessageHandler sendQ recvQ nsSTM = do
-- Read node state just once, assuming that all relevant data for this function does
-- not change.
-- Other functions are passed the nsSTM reference and thus can get the latest state.
nsSnap <- readTVarIO nsSTM
-- handling multipart messages:
-- Request parts can be insert into a map (key: (sender IP against spoofing, request ID), value: timestamp + set of message parts, handle all of them when size of set == parts) before being handled. This map needs to be purged periodically by a separate thread and can be protected by an MVar for fairness.
requestMap <- newMVar (Map.empty :: RequestMap)
-- run receive loop and requestMapPurge concurrently, so that an exception makes
-- both of them fail
concurrently_ (requestMapPurge requestMap) $ forever $ do
-- wait for incoming messages
(rawMsg, sourceAddr) <- atomically $ readTQueue recvQ
let aMsg = deserialiseMessage rawMsg
either (\_ ->
-- drop invalid messages
pure ()
)
(\validMsg ->
case validMsg of
aRequest@Request{}
-- if not a multipart message, handle immediately. Response is at the same time an ACK
| part aRequest == 1 && isFinalPart aRequest ->
forkIO (handleIncomingRequest nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure ()
-- otherwise collect all message parts first before handling the whole request
| otherwise -> do
now <- getPOSIXTime
-- critical locking section of requestMap
rMapState <- takeMVar requestMap
-- insert new message and get set
let
theseMaxParts = if isFinalPart aRequest then Just (part aRequest) else Nothing
thisKey = (sourceAddr, requestID aRequest)
newMapState = Map.insertWith (\
(RequestMapEntry thisMsgSet p' ts) (RequestMapEntry oldMsgSet p'' _) ->
RequestMapEntry (thisMsgSet `Set.union` oldMsgSet) (p' <|> p'') ts
)
thisKey
(RequestMapEntry (Set.singleton aRequest) theseMaxParts now)
rMapState
-- put map back into MVar, end of critical section
putMVar requestMap newMapState
-- ACK the received part
forM_ (ackRequest (getNid nsSnap) aRequest) $
\msg -> atomically $ writeTQueue sendQ (msg, sourceAddr)
-- if all parts received, then handle request.
let
(RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState
numParts = Set.size theseParts
if maybe False (numParts ==) (fromIntegral <$> mayMaxParts)
then forkIO (handleIncomingRequest nsSTM sendQ theseParts sourceAddr) >> pure()
else pure()
-- Responses should never arrive on the main server port, as they are always
-- responses to requests sent from dedicated sockets on another port
_ -> pure ()
)
aMsg
pure ()