protect concurrent node state access with STM
- for allowing concurrent access to predecessors and successors, the whole LocalNodeState is passed wrapped into an STM TVar - this allows keeping the tests for the mostly pure data type, compared to protecting only the successor and predecessor list contributes to #28
This commit is contained in:
parent
f42dfb2137
commit
dc2e399d64
|
@ -53,7 +53,8 @@ import System.Timeout
|
|||
|
||||
import Hash2Pub.ASN1Coding
|
||||
import Hash2Pub.FediChordTypes (CacheEntry (..),
|
||||
LocalNodeState (..), NodeCache,
|
||||
LocalNodeState (..),
|
||||
LocalNodeStateSTM, NodeCache,
|
||||
NodeID, NodeState (..),
|
||||
RemoteNodeState (..),
|
||||
cacheGetNodeStateUnvalidated,
|
||||
|
@ -169,67 +170,93 @@ ackRequest ownID req@Request{} = serialiseMessage sendMessageSize $ Response {
|
|||
}
|
||||
|
||||
|
||||
handleIncomingRequest :: LocalNodeState -- ^ the handling node
|
||||
handleIncomingRequest :: LocalNodeStateSTM -- ^ the handling node
|
||||
-> TQueue (BS.ByteString, SockAddr) -- ^ send queue
|
||||
-> Set.Set FediChordMessage -- ^ all parts of the request to handle
|
||||
-> SockAddr -- ^ source address of the request
|
||||
-> IO ()
|
||||
handleIncomingRequest ns sendQ msgSet sourceAddr = do
|
||||
handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do
|
||||
ns <- readTVarIO nsSTM
|
||||
-- add nodestate to cache
|
||||
now <- getPOSIXTime
|
||||
aPart <- headMay . Set.elems $ msgSet
|
||||
case aPart of
|
||||
case headMay . Set.elems $ msgSet of
|
||||
Nothing -> pure ()
|
||||
Just aPart' ->
|
||||
queueAddEntries (Identity . RemoteCacheEntry (sender aPart') $ now) ns
|
||||
Just aPart -> do
|
||||
queueAddEntries (Identity $ RemoteCacheEntry (sender aPart) now) ns
|
||||
-- distinguish on whether and how to respond. If responding, pass message to response generating function and write responses to send queue
|
||||
maybe (pure ()) (\respSet ->
|
||||
forM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr)))
|
||||
(case action aPart' of
|
||||
Ping -> Just respondPing ns msgSet
|
||||
Join -> Just respondJoin ns msgSet
|
||||
-- ToDo: figure out what happens if not joined
|
||||
QueryID -> Just respondQueryID ns msgSet
|
||||
-- only when joined
|
||||
Leave -> if isJoined_ ns then Just respondLeave ns msgSet else Nothing
|
||||
-- only when joined
|
||||
Stabilise -> if isJoined_ ns then Just respondStabilise ns msgSet else Nothing
|
||||
)
|
||||
-- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1.
|
||||
|
||||
-- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash
|
||||
-- TODO: test case: mixed message types of parts
|
||||
|
||||
-- ....... response sending .......
|
||||
maybe (pure ()) (
|
||||
mapM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr))
|
||||
)
|
||||
(case action aPart of
|
||||
_ -> Just Map.empty) -- placeholder
|
||||
-- Ping -> Just respondPing nsSTM msgSet
|
||||
-- Join -> Just respondJoin nsSTM msgSet
|
||||
-- -- ToDo: figure out what happens if not joined
|
||||
-- QueryID -> Just respondQueryID nsSTM msgSet
|
||||
-- -- only when joined
|
||||
-- Leave -> if isJoined_ ns then Just respondLeave nsSTM msgSet else Nothing
|
||||
-- -- only when joined
|
||||
-- Stabilise -> if isJoined_ ns then Just respondStabilise nsSTM msgSet else Nothing
|
||||
-- )
|
||||
-- -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1.
|
||||
--
|
||||
-- -- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash
|
||||
-- -- TODO: test case: mixed message types of parts
|
||||
--
|
||||
---- ....... response sending .......
|
||||
--
|
||||
---- this modifies node state, so locking and IO seems to be necessary.
|
||||
---- Still try to keep as much code as possible pure
|
||||
--respondJoin :: LocalNodeStateSTM -> Set.Set FediChordMessage -> Map Integer BS.ByteString
|
||||
--respondJoin nsSTM msgSet =
|
||||
-- -- check whether the joining node falls into our responsibility
|
||||
-- -- if yes, adjust own predecessors/ successors and return those in a response
|
||||
-- -- if no: empty response or send a QueryID forwards response?
|
||||
-- -- TODO: notify service layer to copy over data now handled by the new joined node
|
||||
|
||||
-- ....... request sending .......
|
||||
|
||||
-- | send a join request and return the joined 'LocalNodeState' including neighbours
|
||||
requestJoin :: NodeState a => a -- ^ currently responsible node to be contacted
|
||||
-> LocalNodeState -- ^ joining NodeState
|
||||
-> IO (Either String LocalNodeState) -- ^ node after join with all its new information
|
||||
requestJoin toJoinOn ownState =
|
||||
-> LocalNodeStateSTM -- ^ joining NodeState
|
||||
-> IO (Either String LocalNodeStateSTM) -- ^ node after join with all its new information
|
||||
requestJoin toJoinOn ownStateSTM =
|
||||
bracket (mkSendSocket (getDomain toJoinOn) (getDhtPort toJoinOn)) close (\sock -> do
|
||||
-- extract own state for getting request information
|
||||
ownState <- readTVarIO ownStateSTM
|
||||
responses <- sendRequestTo 5000 3 (\rid -> Request rid (toRemoteNodeState ownState) 1 True Join (Just JoinRequestPayload)) sock
|
||||
joinedStateUnsorted <- foldM
|
||||
(\nsAcc msg -> case payload msg of
|
||||
Nothing -> pure nsAcc
|
||||
Just msgPl -> do
|
||||
-- add transfered cache entries to global NodeCache
|
||||
queueAddEntries (joinCache msgPl) nsAcc
|
||||
-- add received predecessors and successors
|
||||
let
|
||||
addPreds ns' = setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns'
|
||||
addSuccs ns' = setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns'
|
||||
pure $ addSuccs . addPreds $ nsAcc
|
||||
)
|
||||
-- reset predecessors and successors
|
||||
(setPredecessors [] . setSuccessors [] $ ownState)
|
||||
responses
|
||||
(cacheInsertQ, joinedState) <- atomically $ do
|
||||
stateSnap <- readTVar ownStateSTM
|
||||
let
|
||||
(cacheInsertQ, joinedStateUnsorted) = foldl'
|
||||
(\(insertQ, nsAcc) msg ->
|
||||
let
|
||||
insertQ' = maybe insertQ (\msgPl ->
|
||||
-- collect list of insertion statements into global cache
|
||||
queueAddEntries (joinCache msgPl) : insertQ
|
||||
) $ payload msg
|
||||
-- add received predecessors and successors
|
||||
addPreds ns' = maybe ns' (\msgPl ->
|
||||
setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns'
|
||||
) $ payload msg
|
||||
addSuccs ns' = maybe ns' (\msgPl ->
|
||||
setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns'
|
||||
) $ payload msg
|
||||
in
|
||||
(insertQ', addSuccs . addPreds $ nsAcc)
|
||||
)
|
||||
-- reset predecessors and successors
|
||||
([], setPredecessors [] . setSuccessors [] $ ownState)
|
||||
responses
|
||||
-- sort successors and predecessors
|
||||
newState = setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted
|
||||
writeTVar ownStateSTM newState
|
||||
pure (cacheInsertQ, newState)
|
||||
-- execute the cache insertions
|
||||
mapM_ (\f -> f joinedState) cacheInsertQ
|
||||
if responses == Set.empty
|
||||
then pure . Left $ "join error: got no response from " <> show (getNid toJoinOn)
|
||||
-- sort successors and predecessors
|
||||
else pure . Right . setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted
|
||||
else pure $ Right ownStateSTM
|
||||
)
|
||||
`catch` (\e -> pure . Left $ displayException (e :: IOException))
|
||||
|
||||
|
|
|
@ -64,6 +64,7 @@ import Control.Concurrent
|
|||
import Control.Concurrent.Async
|
||||
import Control.Concurrent.STM
|
||||
import Control.Concurrent.STM.TQueue
|
||||
import Control.Concurrent.STM.TVar
|
||||
import Control.Monad (forM_, forever)
|
||||
import Crypto.Hash
|
||||
import qualified Data.ByteArray as BA
|
||||
|
@ -84,11 +85,12 @@ import Debug.Trace (trace)
|
|||
|
||||
-- | initialise data structures, compute own IDs and bind to listening socket
|
||||
-- ToDo: load persisted state, thus this function already operates in IO
|
||||
fediChordInit :: FediChordConf -> IO (Socket, LocalNodeState)
|
||||
fediChordInit :: FediChordConf -> IO (Socket, LocalNodeStateSTM)
|
||||
fediChordInit conf = do
|
||||
initialState <- nodeStateInit conf
|
||||
initialStateSTM <- newTVarIO initialState
|
||||
serverSock <- mkServerSocket (getIpAddr initialState) (getDhtPort initialState)
|
||||
pure (serverSock, initialState)
|
||||
pure (serverSock, initialStateSTM)
|
||||
|
||||
-- | initialises the 'NodeState' for this local node.
|
||||
-- Separated from 'fediChordInit' to be usable in tests.
|
||||
|
@ -120,15 +122,16 @@ nodeStateInit conf = do
|
|||
|
||||
-- | Join a new node into the DHT, using a provided bootstrap node as initial cache seed
|
||||
-- for resolving the new node's position.
|
||||
fediChordBootstrapJoin :: LocalNodeState -- ^ the local 'NodeState'
|
||||
fediChordBootstrapJoin :: LocalNodeStateSTM -- ^ the local 'NodeState'
|
||||
-> (String, PortNumber) -- ^ domain and port of a bootstrapping node
|
||||
-> IO (Either String LocalNodeState) -- ^ the joined 'NodeState' after a
|
||||
-> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a
|
||||
-- successful join, otherwise an error message
|
||||
fediChordBootstrapJoin ns (joinHost, joinPort) =
|
||||
fediChordBootstrapJoin nsSTM (joinHost, joinPort) =
|
||||
-- can be invoked multiple times with all known bootstrapping nodes until successfully joined
|
||||
bracket (mkSendSocket joinHost joinPort) close (\sock -> do
|
||||
-- 1. get routed to placement of own ID until FOUND:
|
||||
-- Initialise an empty cache only with the responses from a bootstrapping node
|
||||
ns <- readTVarIO nsSTM
|
||||
bootstrapResponse <- sendQueryIdMessage (getNid ns) ns sock
|
||||
if bootstrapResponse == Set.empty
|
||||
then pure . Left $ "Bootstrapping node " <> show joinHost <> " gave no response."
|
||||
|
@ -143,7 +146,7 @@ fediChordBootstrapJoin ns (joinHost, joinPort) =
|
|||
Just (FORWARD resultset) -> foldr' (addCacheEntryPure now) cacheAcc resultset
|
||||
)
|
||||
initCache bootstrapResponse
|
||||
fediChordJoin bootstrapCache ns
|
||||
fediChordJoin bootstrapCache nsSTM
|
||||
)
|
||||
`catch` (\e -> pure . Left $ "Error at bootstrap joining: " <> displayException (e :: IOException))
|
||||
|
||||
|
@ -151,15 +154,16 @@ fediChordBootstrapJoin ns (joinHost, joinPort) =
|
|||
-- node's position.
|
||||
fediChordJoin :: NodeCache -- ^ a snapshot of the NodeCache to
|
||||
-- use for ID lookup
|
||||
-> LocalNodeState -- ^ the local 'NodeState'
|
||||
-> IO (Either String LocalNodeState) -- ^ the joined 'NodeState' after a
|
||||
-> LocalNodeStateSTM -- ^ the local 'NodeState'
|
||||
-> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a
|
||||
-- successful join, otherwise an error message
|
||||
fediChordJoin cacheSnapshot ns = do
|
||||
fediChordJoin cacheSnapshot nsSTM = do
|
||||
ns <- readTVarIO nsSTM
|
||||
-- get routed to the currently responsible node, based on the response
|
||||
-- from the bootstrapping node
|
||||
currentlyResponsible <- queryIdLookupLoop cacheSnapshot ns $ getNid ns
|
||||
-- 2. then send a join to the currently responsible node
|
||||
joinResult <- requestJoin currentlyResponsible ns
|
||||
joinResult <- requestJoin currentlyResponsible nsSTM
|
||||
case joinResult of
|
||||
Left err -> pure . Left $ "Error joining on " <> err
|
||||
Right joinedNS -> pure . Right $ joinedNS
|
||||
|
@ -167,8 +171,9 @@ fediChordJoin cacheSnapshot ns = do
|
|||
|
||||
-- | cache updater thread that waits for incoming NodeCache update instructions on
|
||||
-- the node's cacheWriteQueue and then modifies the NodeCache as the single writer.
|
||||
cacheWriter :: LocalNodeState -> IO ()
|
||||
cacheWriter ns = do
|
||||
cacheWriter :: LocalNodeStateSTM -> IO ()
|
||||
cacheWriter nsSTM = do
|
||||
ns <- readTVarIO nsSTM
|
||||
let writeQueue' = cacheWriteQueue ns
|
||||
forever $ do
|
||||
f <- atomically $ readTQueue writeQueue'
|
||||
|
@ -196,14 +201,14 @@ sendThread sock sendQ = forever $ do
|
|||
sendAllTo sock packet addr
|
||||
|
||||
-- | Sets up and manages the main server threads of FediChord
|
||||
fediMainThreads :: Socket -> LocalNodeState -> IO ()
|
||||
fediMainThreads sock ns = do
|
||||
fediMainThreads :: Socket -> LocalNodeStateSTM -> IO ()
|
||||
fediMainThreads sock nsSTM = do
|
||||
sendQ <- newTQueueIO
|
||||
recvQ <- newTQueueIO
|
||||
-- concurrently launch all handler threads, if one of them throws an exception
|
||||
-- all get cancelled
|
||||
concurrently_
|
||||
(fediMessageHandler sendQ recvQ ns) $
|
||||
(fediMessageHandler sendQ recvQ nsSTM) $
|
||||
concurrently
|
||||
(sendThread sock sendQ)
|
||||
(recvThread sock recvQ)
|
||||
|
@ -236,9 +241,10 @@ requestMapPurge mapVar = forever $ do
|
|||
-- and pass them to their specific handling function.
|
||||
fediMessageHandler :: TQueue (BS.ByteString, SockAddr) -- ^ send queue
|
||||
-> TQueue (BS.ByteString, SockAddr) -- ^ receive queue
|
||||
-> LocalNodeState -- ^ acting NodeState
|
||||
-> LocalNodeStateSTM -- ^ acting NodeState
|
||||
-> IO ()
|
||||
fediMessageHandler sendQ recvQ ns = do
|
||||
fediMessageHandler sendQ recvQ nsSTM = do
|
||||
nsSnap <- readTVarIO nsSTM
|
||||
-- handling multipart messages:
|
||||
-- Request parts can be insert into a map (key: (sender IP against spoofing, request ID), value: timestamp + set of message parts, handle all of them when size of set == parts) before being handled. This map needs to be purged periodically by a separate thread and can be protected by an MVar for fairness.
|
||||
requestMap <- newMVar (Map.empty :: RequestMap)
|
||||
|
@ -257,7 +263,7 @@ fediMessageHandler sendQ recvQ ns = do
|
|||
aRequest@Request{}
|
||||
-- if not a multipart message, handle immediately. Response is at the same time an ACK
|
||||
| part aRequest == 1 && isFinalPart aRequest ->
|
||||
forkIO (handleIncomingRequest ns sendQ (Set.singleton aRequest) sourceAddr) >> pure ()
|
||||
forkIO (handleIncomingRequest nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure ()
|
||||
-- otherwise collect all message parts first before handling the whole request
|
||||
| otherwise -> do
|
||||
now <- getPOSIXTime
|
||||
|
@ -277,14 +283,14 @@ fediMessageHandler sendQ recvQ ns = do
|
|||
-- put map back into MVar, end of critical section
|
||||
putMVar requestMap newMapState
|
||||
-- ACK the received part
|
||||
forM_ (ackRequest (getNid ns) aRequest) $
|
||||
forM_ (ackRequest (getNid nsSnap) aRequest) $
|
||||
\msg -> atomically $ writeTQueue sendQ (msg, sourceAddr)
|
||||
-- if all parts received, then handle request.
|
||||
let
|
||||
(RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState
|
||||
numParts = Set.size theseParts
|
||||
if maybe False (numParts ==) (fromIntegral <$> mayMaxParts)
|
||||
then forkIO (handleIncomingRequest ns sendQ theseParts sourceAddr) >> pure()
|
||||
then forkIO (handleIncomingRequest nsSTM sendQ theseParts sourceAddr) >> pure()
|
||||
else pure()
|
||||
-- Responses should never arrive on the main server port, as they are always
|
||||
-- responses to requests sent from dedicated sockets on another port
|
||||
|
|
|
@ -9,6 +9,7 @@ module Hash2Pub.FediChordTypes (
|
|||
, toNodeID
|
||||
, NodeState (..)
|
||||
, LocalNodeState (..)
|
||||
, LocalNodeStateSTM
|
||||
, RemoteNodeState (..)
|
||||
, setSuccessors
|
||||
, setPredecessors
|
||||
|
@ -40,6 +41,7 @@ import Network.Socket
|
|||
-- for hashing and ID conversion
|
||||
import Control.Concurrent.STM
|
||||
import Control.Concurrent.STM.TQueue
|
||||
import Control.Concurrent.STM.TVar
|
||||
import Control.Monad (forever)
|
||||
import Crypto.Hash
|
||||
import qualified Data.ByteArray as BA
|
||||
|
@ -144,6 +146,8 @@ data LocalNodeState = LocalNodeState
|
|||
}
|
||||
deriving (Show, Eq)
|
||||
|
||||
type LocalNodeStateSTM = TVar LocalNodeState
|
||||
|
||||
-- | class for various NodeState representations, providing
|
||||
-- getters and setters for common values
|
||||
class NodeState a where
|
||||
|
|
|
@ -2,9 +2,11 @@ module Main where
|
|||
|
||||
import Control.Concurrent
|
||||
import Control.Concurrent.Async
|
||||
import Control.Concurrent.STM
|
||||
import Control.Concurrent.STM.TVar
|
||||
import Control.Exception
|
||||
import Data.Either
|
||||
import Data.IP (IPv6, toHostAddress6)
|
||||
import Data.IP (IPv6, toHostAddress6)
|
||||
import System.Environment
|
||||
|
||||
import Hash2Pub.FediChord
|
||||
|
@ -16,7 +18,7 @@ main = do
|
|||
conf <- readConfig
|
||||
-- ToDo: load persisted caches, bootstrapping nodes …
|
||||
(serverSock, thisNode) <- fediChordInit conf
|
||||
print thisNode
|
||||
print =<< readTVarIO thisNode
|
||||
print serverSock
|
||||
-- currently no masking is necessary, as there is nothing to clean up
|
||||
cacheWriterThread <- forkIO $ cacheWriter thisNode
|
||||
|
@ -38,7 +40,7 @@ main = do
|
|||
)
|
||||
(\joinedNS -> do
|
||||
-- launch main eventloop with successfully joined state
|
||||
putStrLn ("successful join at " <> (show . getNid $ joinedNS))
|
||||
putStrLn "successful join"
|
||||
wait =<< async (fediMainThreads serverSock thisNode)
|
||||
)
|
||||
joinedState
|
||||
|
|
Loading…
Reference in a new issue