From dc2e399d6480ccb323b72637d70e3055510a411c Mon Sep 17 00:00:00 2001 From: Trolli Schmittlauch Date: Thu, 4 Jun 2020 22:29:11 +0200 Subject: [PATCH] protect concurrent node state access with STM - for allowing concurrent access to predecessors and successors, the whole LocalNodeState is passed wrapped into an STM TVar - this allows keeping the tests for the mostly pure data type, compared to protecting only the successor and predecessor list contributes to #28 --- src/Hash2Pub/DHTProtocol.hs | 117 ++++++++++++++++++++------------- src/Hash2Pub/FediChord.hs | 46 +++++++------ src/Hash2Pub/FediChordTypes.hs | 4 ++ src/Hash2Pub/Main.hs | 8 ++- 4 files changed, 107 insertions(+), 68 deletions(-) diff --git a/src/Hash2Pub/DHTProtocol.hs b/src/Hash2Pub/DHTProtocol.hs index 9f1e4c7..8857597 100644 --- a/src/Hash2Pub/DHTProtocol.hs +++ b/src/Hash2Pub/DHTProtocol.hs @@ -53,7 +53,8 @@ import System.Timeout import Hash2Pub.ASN1Coding import Hash2Pub.FediChordTypes (CacheEntry (..), - LocalNodeState (..), NodeCache, + LocalNodeState (..), + LocalNodeStateSTM, NodeCache, NodeID, NodeState (..), RemoteNodeState (..), cacheGetNodeStateUnvalidated, @@ -169,67 +170,93 @@ ackRequest ownID req@Request{} = serialiseMessage sendMessageSize $ Response { } -handleIncomingRequest :: LocalNodeState -- ^ the handling node +handleIncomingRequest :: LocalNodeStateSTM -- ^ the handling node -> TQueue (BS.ByteString, SockAddr) -- ^ send queue -> Set.Set FediChordMessage -- ^ all parts of the request to handle -> SockAddr -- ^ source address of the request -> IO () -handleIncomingRequest ns sendQ msgSet sourceAddr = do +handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do + ns <- readTVarIO nsSTM -- add nodestate to cache now <- getPOSIXTime - aPart <- headMay . Set.elems $ msgSet - case aPart of + case headMay . Set.elems $ msgSet of Nothing -> pure () - Just aPart' -> - queueAddEntries (Identity . RemoteCacheEntry (sender aPart') $ now) ns + Just aPart -> do + queueAddEntries (Identity $ RemoteCacheEntry (sender aPart) now) ns -- distinguish on whether and how to respond. If responding, pass message to response generating function and write responses to send queue - maybe (pure ()) (\respSet -> - forM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr))) - (case action aPart' of - Ping -> Just respondPing ns msgSet - Join -> Just respondJoin ns msgSet - -- ToDo: figure out what happens if not joined - QueryID -> Just respondQueryID ns msgSet - -- only when joined - Leave -> if isJoined_ ns then Just respondLeave ns msgSet else Nothing - -- only when joined - Stabilise -> if isJoined_ ns then Just respondStabilise ns msgSet else Nothing - ) - -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1. - - -- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash - -- TODO: test case: mixed message types of parts - --- ....... response sending ....... + maybe (pure ()) ( + mapM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr)) + ) + (case action aPart of + _ -> Just Map.empty) -- placeholder +-- Ping -> Just respondPing nsSTM msgSet +-- Join -> Just respondJoin nsSTM msgSet +-- -- ToDo: figure out what happens if not joined +-- QueryID -> Just respondQueryID nsSTM msgSet +-- -- only when joined +-- Leave -> if isJoined_ ns then Just respondLeave nsSTM msgSet else Nothing +-- -- only when joined +-- Stabilise -> if isJoined_ ns then Just respondStabilise nsSTM msgSet else Nothing +-- ) +-- -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1. +-- +-- -- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash +-- -- TODO: test case: mixed message types of parts +-- +---- ....... response sending ....... +-- +---- this modifies node state, so locking and IO seems to be necessary. +---- Still try to keep as much code as possible pure +--respondJoin :: LocalNodeStateSTM -> Set.Set FediChordMessage -> Map Integer BS.ByteString +--respondJoin nsSTM msgSet = +-- -- check whether the joining node falls into our responsibility +-- -- if yes, adjust own predecessors/ successors and return those in a response +-- -- if no: empty response or send a QueryID forwards response? +-- -- TODO: notify service layer to copy over data now handled by the new joined node -- ....... request sending ....... -- | send a join request and return the joined 'LocalNodeState' including neighbours requestJoin :: NodeState a => a -- ^ currently responsible node to be contacted - -> LocalNodeState -- ^ joining NodeState - -> IO (Either String LocalNodeState) -- ^ node after join with all its new information -requestJoin toJoinOn ownState = + -> LocalNodeStateSTM -- ^ joining NodeState + -> IO (Either String LocalNodeStateSTM) -- ^ node after join with all its new information +requestJoin toJoinOn ownStateSTM = bracket (mkSendSocket (getDomain toJoinOn) (getDhtPort toJoinOn)) close (\sock -> do + -- extract own state for getting request information + ownState <- readTVarIO ownStateSTM responses <- sendRequestTo 5000 3 (\rid -> Request rid (toRemoteNodeState ownState) 1 True Join (Just JoinRequestPayload)) sock - joinedStateUnsorted <- foldM - (\nsAcc msg -> case payload msg of - Nothing -> pure nsAcc - Just msgPl -> do - -- add transfered cache entries to global NodeCache - queueAddEntries (joinCache msgPl) nsAcc - -- add received predecessors and successors - let - addPreds ns' = setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns' - addSuccs ns' = setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns' - pure $ addSuccs . addPreds $ nsAcc - ) - -- reset predecessors and successors - (setPredecessors [] . setSuccessors [] $ ownState) - responses + (cacheInsertQ, joinedState) <- atomically $ do + stateSnap <- readTVar ownStateSTM + let + (cacheInsertQ, joinedStateUnsorted) = foldl' + (\(insertQ, nsAcc) msg -> + let + insertQ' = maybe insertQ (\msgPl -> + -- collect list of insertion statements into global cache + queueAddEntries (joinCache msgPl) : insertQ + ) $ payload msg + -- add received predecessors and successors + addPreds ns' = maybe ns' (\msgPl -> + setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns' + ) $ payload msg + addSuccs ns' = maybe ns' (\msgPl -> + setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns' + ) $ payload msg + in + (insertQ', addSuccs . addPreds $ nsAcc) + ) + -- reset predecessors and successors + ([], setPredecessors [] . setSuccessors [] $ ownState) + responses + -- sort successors and predecessors + newState = setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted + writeTVar ownStateSTM newState + pure (cacheInsertQ, newState) + -- execute the cache insertions + mapM_ (\f -> f joinedState) cacheInsertQ if responses == Set.empty then pure . Left $ "join error: got no response from " <> show (getNid toJoinOn) - -- sort successors and predecessors - else pure . Right . setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted + else pure $ Right ownStateSTM ) `catch` (\e -> pure . Left $ displayException (e :: IOException)) diff --git a/src/Hash2Pub/FediChord.hs b/src/Hash2Pub/FediChord.hs index c8b2b2e..43de152 100644 --- a/src/Hash2Pub/FediChord.hs +++ b/src/Hash2Pub/FediChord.hs @@ -64,6 +64,7 @@ import Control.Concurrent import Control.Concurrent.Async import Control.Concurrent.STM import Control.Concurrent.STM.TQueue +import Control.Concurrent.STM.TVar import Control.Monad (forM_, forever) import Crypto.Hash import qualified Data.ByteArray as BA @@ -84,11 +85,12 @@ import Debug.Trace (trace) -- | initialise data structures, compute own IDs and bind to listening socket -- ToDo: load persisted state, thus this function already operates in IO -fediChordInit :: FediChordConf -> IO (Socket, LocalNodeState) +fediChordInit :: FediChordConf -> IO (Socket, LocalNodeStateSTM) fediChordInit conf = do initialState <- nodeStateInit conf + initialStateSTM <- newTVarIO initialState serverSock <- mkServerSocket (getIpAddr initialState) (getDhtPort initialState) - pure (serverSock, initialState) + pure (serverSock, initialStateSTM) -- | initialises the 'NodeState' for this local node. -- Separated from 'fediChordInit' to be usable in tests. @@ -120,15 +122,16 @@ nodeStateInit conf = do -- | Join a new node into the DHT, using a provided bootstrap node as initial cache seed -- for resolving the new node's position. -fediChordBootstrapJoin :: LocalNodeState -- ^ the local 'NodeState' +fediChordBootstrapJoin :: LocalNodeStateSTM -- ^ the local 'NodeState' -> (String, PortNumber) -- ^ domain and port of a bootstrapping node - -> IO (Either String LocalNodeState) -- ^ the joined 'NodeState' after a + -> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a -- successful join, otherwise an error message -fediChordBootstrapJoin ns (joinHost, joinPort) = +fediChordBootstrapJoin nsSTM (joinHost, joinPort) = -- can be invoked multiple times with all known bootstrapping nodes until successfully joined bracket (mkSendSocket joinHost joinPort) close (\sock -> do -- 1. get routed to placement of own ID until FOUND: -- Initialise an empty cache only with the responses from a bootstrapping node + ns <- readTVarIO nsSTM bootstrapResponse <- sendQueryIdMessage (getNid ns) ns sock if bootstrapResponse == Set.empty then pure . Left $ "Bootstrapping node " <> show joinHost <> " gave no response." @@ -143,7 +146,7 @@ fediChordBootstrapJoin ns (joinHost, joinPort) = Just (FORWARD resultset) -> foldr' (addCacheEntryPure now) cacheAcc resultset ) initCache bootstrapResponse - fediChordJoin bootstrapCache ns + fediChordJoin bootstrapCache nsSTM ) `catch` (\e -> pure . Left $ "Error at bootstrap joining: " <> displayException (e :: IOException)) @@ -151,15 +154,16 @@ fediChordBootstrapJoin ns (joinHost, joinPort) = -- node's position. fediChordJoin :: NodeCache -- ^ a snapshot of the NodeCache to -- use for ID lookup - -> LocalNodeState -- ^ the local 'NodeState' - -> IO (Either String LocalNodeState) -- ^ the joined 'NodeState' after a + -> LocalNodeStateSTM -- ^ the local 'NodeState' + -> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a -- successful join, otherwise an error message -fediChordJoin cacheSnapshot ns = do +fediChordJoin cacheSnapshot nsSTM = do + ns <- readTVarIO nsSTM -- get routed to the currently responsible node, based on the response -- from the bootstrapping node currentlyResponsible <- queryIdLookupLoop cacheSnapshot ns $ getNid ns -- 2. then send a join to the currently responsible node - joinResult <- requestJoin currentlyResponsible ns + joinResult <- requestJoin currentlyResponsible nsSTM case joinResult of Left err -> pure . Left $ "Error joining on " <> err Right joinedNS -> pure . Right $ joinedNS @@ -167,8 +171,9 @@ fediChordJoin cacheSnapshot ns = do -- | cache updater thread that waits for incoming NodeCache update instructions on -- the node's cacheWriteQueue and then modifies the NodeCache as the single writer. -cacheWriter :: LocalNodeState -> IO () -cacheWriter ns = do +cacheWriter :: LocalNodeStateSTM -> IO () +cacheWriter nsSTM = do + ns <- readTVarIO nsSTM let writeQueue' = cacheWriteQueue ns forever $ do f <- atomically $ readTQueue writeQueue' @@ -196,14 +201,14 @@ sendThread sock sendQ = forever $ do sendAllTo sock packet addr -- | Sets up and manages the main server threads of FediChord -fediMainThreads :: Socket -> LocalNodeState -> IO () -fediMainThreads sock ns = do +fediMainThreads :: Socket -> LocalNodeStateSTM -> IO () +fediMainThreads sock nsSTM = do sendQ <- newTQueueIO recvQ <- newTQueueIO -- concurrently launch all handler threads, if one of them throws an exception -- all get cancelled concurrently_ - (fediMessageHandler sendQ recvQ ns) $ + (fediMessageHandler sendQ recvQ nsSTM) $ concurrently (sendThread sock sendQ) (recvThread sock recvQ) @@ -236,9 +241,10 @@ requestMapPurge mapVar = forever $ do -- and pass them to their specific handling function. fediMessageHandler :: TQueue (BS.ByteString, SockAddr) -- ^ send queue -> TQueue (BS.ByteString, SockAddr) -- ^ receive queue - -> LocalNodeState -- ^ acting NodeState + -> LocalNodeStateSTM -- ^ acting NodeState -> IO () -fediMessageHandler sendQ recvQ ns = do +fediMessageHandler sendQ recvQ nsSTM = do + nsSnap <- readTVarIO nsSTM -- handling multipart messages: -- Request parts can be insert into a map (key: (sender IP against spoofing, request ID), value: timestamp + set of message parts, handle all of them when size of set == parts) before being handled. This map needs to be purged periodically by a separate thread and can be protected by an MVar for fairness. requestMap <- newMVar (Map.empty :: RequestMap) @@ -257,7 +263,7 @@ fediMessageHandler sendQ recvQ ns = do aRequest@Request{} -- if not a multipart message, handle immediately. Response is at the same time an ACK | part aRequest == 1 && isFinalPart aRequest -> - forkIO (handleIncomingRequest ns sendQ (Set.singleton aRequest) sourceAddr) >> pure () + forkIO (handleIncomingRequest nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure () -- otherwise collect all message parts first before handling the whole request | otherwise -> do now <- getPOSIXTime @@ -277,14 +283,14 @@ fediMessageHandler sendQ recvQ ns = do -- put map back into MVar, end of critical section putMVar requestMap newMapState -- ACK the received part - forM_ (ackRequest (getNid ns) aRequest) $ + forM_ (ackRequest (getNid nsSnap) aRequest) $ \msg -> atomically $ writeTQueue sendQ (msg, sourceAddr) -- if all parts received, then handle request. let (RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState numParts = Set.size theseParts if maybe False (numParts ==) (fromIntegral <$> mayMaxParts) - then forkIO (handleIncomingRequest ns sendQ theseParts sourceAddr) >> pure() + then forkIO (handleIncomingRequest nsSTM sendQ theseParts sourceAddr) >> pure() else pure() -- Responses should never arrive on the main server port, as they are always -- responses to requests sent from dedicated sockets on another port diff --git a/src/Hash2Pub/FediChordTypes.hs b/src/Hash2Pub/FediChordTypes.hs index 7e3565d..2feea08 100644 --- a/src/Hash2Pub/FediChordTypes.hs +++ b/src/Hash2Pub/FediChordTypes.hs @@ -9,6 +9,7 @@ module Hash2Pub.FediChordTypes ( , toNodeID , NodeState (..) , LocalNodeState (..) + , LocalNodeStateSTM , RemoteNodeState (..) , setSuccessors , setPredecessors @@ -40,6 +41,7 @@ import Network.Socket -- for hashing and ID conversion import Control.Concurrent.STM import Control.Concurrent.STM.TQueue +import Control.Concurrent.STM.TVar import Control.Monad (forever) import Crypto.Hash import qualified Data.ByteArray as BA @@ -144,6 +146,8 @@ data LocalNodeState = LocalNodeState } deriving (Show, Eq) +type LocalNodeStateSTM = TVar LocalNodeState + -- | class for various NodeState representations, providing -- getters and setters for common values class NodeState a where diff --git a/src/Hash2Pub/Main.hs b/src/Hash2Pub/Main.hs index 554585f..fc9299d 100644 --- a/src/Hash2Pub/Main.hs +++ b/src/Hash2Pub/Main.hs @@ -2,9 +2,11 @@ module Main where import Control.Concurrent import Control.Concurrent.Async +import Control.Concurrent.STM +import Control.Concurrent.STM.TVar import Control.Exception import Data.Either -import Data.IP (IPv6, toHostAddress6) +import Data.IP (IPv6, toHostAddress6) import System.Environment import Hash2Pub.FediChord @@ -16,7 +18,7 @@ main = do conf <- readConfig -- ToDo: load persisted caches, bootstrapping nodes … (serverSock, thisNode) <- fediChordInit conf - print thisNode + print =<< readTVarIO thisNode print serverSock -- currently no masking is necessary, as there is nothing to clean up cacheWriterThread <- forkIO $ cacheWriter thisNode @@ -38,7 +40,7 @@ main = do ) (\joinedNS -> do -- launch main eventloop with successfully joined state - putStrLn ("successful join at " <> (show . getNid $ joinedNS)) + putStrLn "successful join" wait =<< async (fediMainThreads serverSock thisNode) ) joinedState