protect concurrent node state access with STM

- for allowing concurrent access to predecessors and successors, the
  whole LocalNodeState is passed wrapped into an STM TVar
- this allows keeping the tests for the mostly pure data type, compared
  to protecting only the successor and predecessor list

contributes to #28
This commit is contained in:
Trolli Schmittlauch 2020-06-04 22:29:11 +02:00
parent f42dfb2137
commit dc2e399d64
4 changed files with 107 additions and 68 deletions

View file

@ -53,7 +53,8 @@ import System.Timeout
import Hash2Pub.ASN1Coding import Hash2Pub.ASN1Coding
import Hash2Pub.FediChordTypes (CacheEntry (..), import Hash2Pub.FediChordTypes (CacheEntry (..),
LocalNodeState (..), NodeCache, LocalNodeState (..),
LocalNodeStateSTM, NodeCache,
NodeID, NodeState (..), NodeID, NodeState (..),
RemoteNodeState (..), RemoteNodeState (..),
cacheGetNodeStateUnvalidated, cacheGetNodeStateUnvalidated,
@ -169,67 +170,93 @@ ackRequest ownID req@Request{} = serialiseMessage sendMessageSize $ Response {
} }
handleIncomingRequest :: LocalNodeState -- ^ the handling node handleIncomingRequest :: LocalNodeStateSTM -- ^ the handling node
-> TQueue (BS.ByteString, SockAddr) -- ^ send queue -> TQueue (BS.ByteString, SockAddr) -- ^ send queue
-> Set.Set FediChordMessage -- ^ all parts of the request to handle -> Set.Set FediChordMessage -- ^ all parts of the request to handle
-> SockAddr -- ^ source address of the request -> SockAddr -- ^ source address of the request
-> IO () -> IO ()
handleIncomingRequest ns sendQ msgSet sourceAddr = do handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do
ns <- readTVarIO nsSTM
-- add nodestate to cache -- add nodestate to cache
now <- getPOSIXTime now <- getPOSIXTime
aPart <- headMay . Set.elems $ msgSet case headMay . Set.elems $ msgSet of
case aPart of
Nothing -> pure () Nothing -> pure ()
Just aPart' -> Just aPart -> do
queueAddEntries (Identity . RemoteCacheEntry (sender aPart') $ now) ns queueAddEntries (Identity $ RemoteCacheEntry (sender aPart) now) ns
-- distinguish on whether and how to respond. If responding, pass message to response generating function and write responses to send queue -- distinguish on whether and how to respond. If responding, pass message to response generating function and write responses to send queue
maybe (pure ()) (\respSet -> maybe (pure ()) (
forM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr))) mapM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr))
(case action aPart' of
Ping -> Just respondPing ns msgSet
Join -> Just respondJoin ns msgSet
-- ToDo: figure out what happens if not joined
QueryID -> Just respondQueryID ns msgSet
-- only when joined
Leave -> if isJoined_ ns then Just respondLeave ns msgSet else Nothing
-- only when joined
Stabilise -> if isJoined_ ns then Just respondStabilise ns msgSet else Nothing
) )
-- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1. (case action aPart of
_ -> Just Map.empty) -- placeholder
-- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash -- Ping -> Just respondPing nsSTM msgSet
-- TODO: test case: mixed message types of parts -- Join -> Just respondJoin nsSTM msgSet
-- -- ToDo: figure out what happens if not joined
-- ....... response sending ....... -- QueryID -> Just respondQueryID nsSTM msgSet
-- -- only when joined
-- Leave -> if isJoined_ ns then Just respondLeave nsSTM msgSet else Nothing
-- -- only when joined
-- Stabilise -> if isJoined_ ns then Just respondStabilise nsSTM msgSet else Nothing
-- )
-- -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1.
--
-- -- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash
-- -- TODO: test case: mixed message types of parts
--
---- ....... response sending .......
--
---- this modifies node state, so locking and IO seems to be necessary.
---- Still try to keep as much code as possible pure
--respondJoin :: LocalNodeStateSTM -> Set.Set FediChordMessage -> Map Integer BS.ByteString
--respondJoin nsSTM msgSet =
-- -- check whether the joining node falls into our responsibility
-- -- if yes, adjust own predecessors/ successors and return those in a response
-- -- if no: empty response or send a QueryID forwards response?
-- -- TODO: notify service layer to copy over data now handled by the new joined node
-- ....... request sending ....... -- ....... request sending .......
-- | send a join request and return the joined 'LocalNodeState' including neighbours -- | send a join request and return the joined 'LocalNodeState' including neighbours
requestJoin :: NodeState a => a -- ^ currently responsible node to be contacted requestJoin :: NodeState a => a -- ^ currently responsible node to be contacted
-> LocalNodeState -- ^ joining NodeState -> LocalNodeStateSTM -- ^ joining NodeState
-> IO (Either String LocalNodeState) -- ^ node after join with all its new information -> IO (Either String LocalNodeStateSTM) -- ^ node after join with all its new information
requestJoin toJoinOn ownState = requestJoin toJoinOn ownStateSTM =
bracket (mkSendSocket (getDomain toJoinOn) (getDhtPort toJoinOn)) close (\sock -> do bracket (mkSendSocket (getDomain toJoinOn) (getDhtPort toJoinOn)) close (\sock -> do
-- extract own state for getting request information
ownState <- readTVarIO ownStateSTM
responses <- sendRequestTo 5000 3 (\rid -> Request rid (toRemoteNodeState ownState) 1 True Join (Just JoinRequestPayload)) sock responses <- sendRequestTo 5000 3 (\rid -> Request rid (toRemoteNodeState ownState) 1 True Join (Just JoinRequestPayload)) sock
joinedStateUnsorted <- foldM (cacheInsertQ, joinedState) <- atomically $ do
(\nsAcc msg -> case payload msg of stateSnap <- readTVar ownStateSTM
Nothing -> pure nsAcc
Just msgPl -> do
-- add transfered cache entries to global NodeCache
queueAddEntries (joinCache msgPl) nsAcc
-- add received predecessors and successors
let let
addPreds ns' = setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns' (cacheInsertQ, joinedStateUnsorted) = foldl'
addSuccs ns' = setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns' (\(insertQ, nsAcc) msg ->
pure $ addSuccs . addPreds $ nsAcc let
insertQ' = maybe insertQ (\msgPl ->
-- collect list of insertion statements into global cache
queueAddEntries (joinCache msgPl) : insertQ
) $ payload msg
-- add received predecessors and successors
addPreds ns' = maybe ns' (\msgPl ->
setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns'
) $ payload msg
addSuccs ns' = maybe ns' (\msgPl ->
setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns'
) $ payload msg
in
(insertQ', addSuccs . addPreds $ nsAcc)
) )
-- reset predecessors and successors -- reset predecessors and successors
(setPredecessors [] . setSuccessors [] $ ownState) ([], setPredecessors [] . setSuccessors [] $ ownState)
responses responses
-- sort successors and predecessors
newState = setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted
writeTVar ownStateSTM newState
pure (cacheInsertQ, newState)
-- execute the cache insertions
mapM_ (\f -> f joinedState) cacheInsertQ
if responses == Set.empty if responses == Set.empty
then pure . Left $ "join error: got no response from " <> show (getNid toJoinOn) then pure . Left $ "join error: got no response from " <> show (getNid toJoinOn)
-- sort successors and predecessors else pure $ Right ownStateSTM
else pure . Right . setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted
) )
`catch` (\e -> pure . Left $ displayException (e :: IOException)) `catch` (\e -> pure . Left $ displayException (e :: IOException))

View file

@ -64,6 +64,7 @@ import Control.Concurrent
import Control.Concurrent.Async import Control.Concurrent.Async
import Control.Concurrent.STM import Control.Concurrent.STM
import Control.Concurrent.STM.TQueue import Control.Concurrent.STM.TQueue
import Control.Concurrent.STM.TVar
import Control.Monad (forM_, forever) import Control.Monad (forM_, forever)
import Crypto.Hash import Crypto.Hash
import qualified Data.ByteArray as BA import qualified Data.ByteArray as BA
@ -84,11 +85,12 @@ import Debug.Trace (trace)
-- | initialise data structures, compute own IDs and bind to listening socket -- | initialise data structures, compute own IDs and bind to listening socket
-- ToDo: load persisted state, thus this function already operates in IO -- ToDo: load persisted state, thus this function already operates in IO
fediChordInit :: FediChordConf -> IO (Socket, LocalNodeState) fediChordInit :: FediChordConf -> IO (Socket, LocalNodeStateSTM)
fediChordInit conf = do fediChordInit conf = do
initialState <- nodeStateInit conf initialState <- nodeStateInit conf
initialStateSTM <- newTVarIO initialState
serverSock <- mkServerSocket (getIpAddr initialState) (getDhtPort initialState) serverSock <- mkServerSocket (getIpAddr initialState) (getDhtPort initialState)
pure (serverSock, initialState) pure (serverSock, initialStateSTM)
-- | initialises the 'NodeState' for this local node. -- | initialises the 'NodeState' for this local node.
-- Separated from 'fediChordInit' to be usable in tests. -- Separated from 'fediChordInit' to be usable in tests.
@ -120,15 +122,16 @@ nodeStateInit conf = do
-- | Join a new node into the DHT, using a provided bootstrap node as initial cache seed -- | Join a new node into the DHT, using a provided bootstrap node as initial cache seed
-- for resolving the new node's position. -- for resolving the new node's position.
fediChordBootstrapJoin :: LocalNodeState -- ^ the local 'NodeState' fediChordBootstrapJoin :: LocalNodeStateSTM -- ^ the local 'NodeState'
-> (String, PortNumber) -- ^ domain and port of a bootstrapping node -> (String, PortNumber) -- ^ domain and port of a bootstrapping node
-> IO (Either String LocalNodeState) -- ^ the joined 'NodeState' after a -> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a
-- successful join, otherwise an error message -- successful join, otherwise an error message
fediChordBootstrapJoin ns (joinHost, joinPort) = fediChordBootstrapJoin nsSTM (joinHost, joinPort) =
-- can be invoked multiple times with all known bootstrapping nodes until successfully joined -- can be invoked multiple times with all known bootstrapping nodes until successfully joined
bracket (mkSendSocket joinHost joinPort) close (\sock -> do bracket (mkSendSocket joinHost joinPort) close (\sock -> do
-- 1. get routed to placement of own ID until FOUND: -- 1. get routed to placement of own ID until FOUND:
-- Initialise an empty cache only with the responses from a bootstrapping node -- Initialise an empty cache only with the responses from a bootstrapping node
ns <- readTVarIO nsSTM
bootstrapResponse <- sendQueryIdMessage (getNid ns) ns sock bootstrapResponse <- sendQueryIdMessage (getNid ns) ns sock
if bootstrapResponse == Set.empty if bootstrapResponse == Set.empty
then pure . Left $ "Bootstrapping node " <> show joinHost <> " gave no response." then pure . Left $ "Bootstrapping node " <> show joinHost <> " gave no response."
@ -143,7 +146,7 @@ fediChordBootstrapJoin ns (joinHost, joinPort) =
Just (FORWARD resultset) -> foldr' (addCacheEntryPure now) cacheAcc resultset Just (FORWARD resultset) -> foldr' (addCacheEntryPure now) cacheAcc resultset
) )
initCache bootstrapResponse initCache bootstrapResponse
fediChordJoin bootstrapCache ns fediChordJoin bootstrapCache nsSTM
) )
`catch` (\e -> pure . Left $ "Error at bootstrap joining: " <> displayException (e :: IOException)) `catch` (\e -> pure . Left $ "Error at bootstrap joining: " <> displayException (e :: IOException))
@ -151,15 +154,16 @@ fediChordBootstrapJoin ns (joinHost, joinPort) =
-- node's position. -- node's position.
fediChordJoin :: NodeCache -- ^ a snapshot of the NodeCache to fediChordJoin :: NodeCache -- ^ a snapshot of the NodeCache to
-- use for ID lookup -- use for ID lookup
-> LocalNodeState -- ^ the local 'NodeState' -> LocalNodeStateSTM -- ^ the local 'NodeState'
-> IO (Either String LocalNodeState) -- ^ the joined 'NodeState' after a -> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a
-- successful join, otherwise an error message -- successful join, otherwise an error message
fediChordJoin cacheSnapshot ns = do fediChordJoin cacheSnapshot nsSTM = do
ns <- readTVarIO nsSTM
-- get routed to the currently responsible node, based on the response -- get routed to the currently responsible node, based on the response
-- from the bootstrapping node -- from the bootstrapping node
currentlyResponsible <- queryIdLookupLoop cacheSnapshot ns $ getNid ns currentlyResponsible <- queryIdLookupLoop cacheSnapshot ns $ getNid ns
-- 2. then send a join to the currently responsible node -- 2. then send a join to the currently responsible node
joinResult <- requestJoin currentlyResponsible ns joinResult <- requestJoin currentlyResponsible nsSTM
case joinResult of case joinResult of
Left err -> pure . Left $ "Error joining on " <> err Left err -> pure . Left $ "Error joining on " <> err
Right joinedNS -> pure . Right $ joinedNS Right joinedNS -> pure . Right $ joinedNS
@ -167,8 +171,9 @@ fediChordJoin cacheSnapshot ns = do
-- | cache updater thread that waits for incoming NodeCache update instructions on -- | cache updater thread that waits for incoming NodeCache update instructions on
-- the node's cacheWriteQueue and then modifies the NodeCache as the single writer. -- the node's cacheWriteQueue and then modifies the NodeCache as the single writer.
cacheWriter :: LocalNodeState -> IO () cacheWriter :: LocalNodeStateSTM -> IO ()
cacheWriter ns = do cacheWriter nsSTM = do
ns <- readTVarIO nsSTM
let writeQueue' = cacheWriteQueue ns let writeQueue' = cacheWriteQueue ns
forever $ do forever $ do
f <- atomically $ readTQueue writeQueue' f <- atomically $ readTQueue writeQueue'
@ -196,14 +201,14 @@ sendThread sock sendQ = forever $ do
sendAllTo sock packet addr sendAllTo sock packet addr
-- | Sets up and manages the main server threads of FediChord -- | Sets up and manages the main server threads of FediChord
fediMainThreads :: Socket -> LocalNodeState -> IO () fediMainThreads :: Socket -> LocalNodeStateSTM -> IO ()
fediMainThreads sock ns = do fediMainThreads sock nsSTM = do
sendQ <- newTQueueIO sendQ <- newTQueueIO
recvQ <- newTQueueIO recvQ <- newTQueueIO
-- concurrently launch all handler threads, if one of them throws an exception -- concurrently launch all handler threads, if one of them throws an exception
-- all get cancelled -- all get cancelled
concurrently_ concurrently_
(fediMessageHandler sendQ recvQ ns) $ (fediMessageHandler sendQ recvQ nsSTM) $
concurrently concurrently
(sendThread sock sendQ) (sendThread sock sendQ)
(recvThread sock recvQ) (recvThread sock recvQ)
@ -236,9 +241,10 @@ requestMapPurge mapVar = forever $ do
-- and pass them to their specific handling function. -- and pass them to their specific handling function.
fediMessageHandler :: TQueue (BS.ByteString, SockAddr) -- ^ send queue fediMessageHandler :: TQueue (BS.ByteString, SockAddr) -- ^ send queue
-> TQueue (BS.ByteString, SockAddr) -- ^ receive queue -> TQueue (BS.ByteString, SockAddr) -- ^ receive queue
-> LocalNodeState -- ^ acting NodeState -> LocalNodeStateSTM -- ^ acting NodeState
-> IO () -> IO ()
fediMessageHandler sendQ recvQ ns = do fediMessageHandler sendQ recvQ nsSTM = do
nsSnap <- readTVarIO nsSTM
-- handling multipart messages: -- handling multipart messages:
-- Request parts can be insert into a map (key: (sender IP against spoofing, request ID), value: timestamp + set of message parts, handle all of them when size of set == parts) before being handled. This map needs to be purged periodically by a separate thread and can be protected by an MVar for fairness. -- Request parts can be insert into a map (key: (sender IP against spoofing, request ID), value: timestamp + set of message parts, handle all of them when size of set == parts) before being handled. This map needs to be purged periodically by a separate thread and can be protected by an MVar for fairness.
requestMap <- newMVar (Map.empty :: RequestMap) requestMap <- newMVar (Map.empty :: RequestMap)
@ -257,7 +263,7 @@ fediMessageHandler sendQ recvQ ns = do
aRequest@Request{} aRequest@Request{}
-- if not a multipart message, handle immediately. Response is at the same time an ACK -- if not a multipart message, handle immediately. Response is at the same time an ACK
| part aRequest == 1 && isFinalPart aRequest -> | part aRequest == 1 && isFinalPart aRequest ->
forkIO (handleIncomingRequest ns sendQ (Set.singleton aRequest) sourceAddr) >> pure () forkIO (handleIncomingRequest nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure ()
-- otherwise collect all message parts first before handling the whole request -- otherwise collect all message parts first before handling the whole request
| otherwise -> do | otherwise -> do
now <- getPOSIXTime now <- getPOSIXTime
@ -277,14 +283,14 @@ fediMessageHandler sendQ recvQ ns = do
-- put map back into MVar, end of critical section -- put map back into MVar, end of critical section
putMVar requestMap newMapState putMVar requestMap newMapState
-- ACK the received part -- ACK the received part
forM_ (ackRequest (getNid ns) aRequest) $ forM_ (ackRequest (getNid nsSnap) aRequest) $
\msg -> atomically $ writeTQueue sendQ (msg, sourceAddr) \msg -> atomically $ writeTQueue sendQ (msg, sourceAddr)
-- if all parts received, then handle request. -- if all parts received, then handle request.
let let
(RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState (RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState
numParts = Set.size theseParts numParts = Set.size theseParts
if maybe False (numParts ==) (fromIntegral <$> mayMaxParts) if maybe False (numParts ==) (fromIntegral <$> mayMaxParts)
then forkIO (handleIncomingRequest ns sendQ theseParts sourceAddr) >> pure() then forkIO (handleIncomingRequest nsSTM sendQ theseParts sourceAddr) >> pure()
else pure() else pure()
-- Responses should never arrive on the main server port, as they are always -- Responses should never arrive on the main server port, as they are always
-- responses to requests sent from dedicated sockets on another port -- responses to requests sent from dedicated sockets on another port

View file

@ -9,6 +9,7 @@ module Hash2Pub.FediChordTypes (
, toNodeID , toNodeID
, NodeState (..) , NodeState (..)
, LocalNodeState (..) , LocalNodeState (..)
, LocalNodeStateSTM
, RemoteNodeState (..) , RemoteNodeState (..)
, setSuccessors , setSuccessors
, setPredecessors , setPredecessors
@ -40,6 +41,7 @@ import Network.Socket
-- for hashing and ID conversion -- for hashing and ID conversion
import Control.Concurrent.STM import Control.Concurrent.STM
import Control.Concurrent.STM.TQueue import Control.Concurrent.STM.TQueue
import Control.Concurrent.STM.TVar
import Control.Monad (forever) import Control.Monad (forever)
import Crypto.Hash import Crypto.Hash
import qualified Data.ByteArray as BA import qualified Data.ByteArray as BA
@ -144,6 +146,8 @@ data LocalNodeState = LocalNodeState
} }
deriving (Show, Eq) deriving (Show, Eq)
type LocalNodeStateSTM = TVar LocalNodeState
-- | class for various NodeState representations, providing -- | class for various NodeState representations, providing
-- getters and setters for common values -- getters and setters for common values
class NodeState a where class NodeState a where

View file

@ -2,6 +2,8 @@ module Main where
import Control.Concurrent import Control.Concurrent
import Control.Concurrent.Async import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Concurrent.STM.TVar
import Control.Exception import Control.Exception
import Data.Either import Data.Either
import Data.IP (IPv6, toHostAddress6) import Data.IP (IPv6, toHostAddress6)
@ -16,7 +18,7 @@ main = do
conf <- readConfig conf <- readConfig
-- ToDo: load persisted caches, bootstrapping nodes … -- ToDo: load persisted caches, bootstrapping nodes …
(serverSock, thisNode) <- fediChordInit conf (serverSock, thisNode) <- fediChordInit conf
print thisNode print =<< readTVarIO thisNode
print serverSock print serverSock
-- currently no masking is necessary, as there is nothing to clean up -- currently no masking is necessary, as there is nothing to clean up
cacheWriterThread <- forkIO $ cacheWriter thisNode cacheWriterThread <- forkIO $ cacheWriter thisNode
@ -38,7 +40,7 @@ main = do
) )
(\joinedNS -> do (\joinedNS -> do
-- launch main eventloop with successfully joined state -- launch main eventloop with successfully joined state
putStrLn ("successful join at " <> (show . getNid $ joinedNS)) putStrLn "successful join"
wait =<< async (fediMainThreads serverSock thisNode) wait =<< async (fediMainThreads serverSock thisNode)
) )
joinedState joinedState