protect concurrent node state access with STM

- for allowing concurrent access to predecessors and successors, the
  whole LocalNodeState is passed wrapped into an STM TVar
- this allows keeping the tests for the mostly pure data type, compared
  to protecting only the successor and predecessor list

contributes to #28
This commit is contained in:
Trolli Schmittlauch 2020-06-04 22:29:11 +02:00
parent f42dfb2137
commit dc2e399d64
4 changed files with 107 additions and 68 deletions

View file

@ -53,7 +53,8 @@ import System.Timeout
import Hash2Pub.ASN1Coding
import Hash2Pub.FediChordTypes (CacheEntry (..),
LocalNodeState (..), NodeCache,
LocalNodeState (..),
LocalNodeStateSTM, NodeCache,
NodeID, NodeState (..),
RemoteNodeState (..),
cacheGetNodeStateUnvalidated,
@ -169,67 +170,93 @@ ackRequest ownID req@Request{} = serialiseMessage sendMessageSize $ Response {
}
handleIncomingRequest :: LocalNodeState -- ^ the handling node
handleIncomingRequest :: LocalNodeStateSTM -- ^ the handling node
-> TQueue (BS.ByteString, SockAddr) -- ^ send queue
-> Set.Set FediChordMessage -- ^ all parts of the request to handle
-> SockAddr -- ^ source address of the request
-> IO ()
handleIncomingRequest ns sendQ msgSet sourceAddr = do
handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do
ns <- readTVarIO nsSTM
-- add nodestate to cache
now <- getPOSIXTime
aPart <- headMay . Set.elems $ msgSet
case aPart of
case headMay . Set.elems $ msgSet of
Nothing -> pure ()
Just aPart' ->
queueAddEntries (Identity . RemoteCacheEntry (sender aPart') $ now) ns
Just aPart -> do
queueAddEntries (Identity $ RemoteCacheEntry (sender aPart) now) ns
-- distinguish on whether and how to respond. If responding, pass message to response generating function and write responses to send queue
maybe (pure ()) (\respSet ->
forM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr)))
(case action aPart' of
Ping -> Just respondPing ns msgSet
Join -> Just respondJoin ns msgSet
-- ToDo: figure out what happens if not joined
QueryID -> Just respondQueryID ns msgSet
-- only when joined
Leave -> if isJoined_ ns then Just respondLeave ns msgSet else Nothing
-- only when joined
Stabilise -> if isJoined_ ns then Just respondStabilise ns msgSet else Nothing
)
-- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1.
-- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash
-- TODO: test case: mixed message types of parts
-- ....... response sending .......
maybe (pure ()) (
mapM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr))
)
(case action aPart of
_ -> Just Map.empty) -- placeholder
-- Ping -> Just respondPing nsSTM msgSet
-- Join -> Just respondJoin nsSTM msgSet
-- -- ToDo: figure out what happens if not joined
-- QueryID -> Just respondQueryID nsSTM msgSet
-- -- only when joined
-- Leave -> if isJoined_ ns then Just respondLeave nsSTM msgSet else Nothing
-- -- only when joined
-- Stabilise -> if isJoined_ ns then Just respondStabilise nsSTM msgSet else Nothing
-- )
-- -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1.
--
-- -- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash
-- -- TODO: test case: mixed message types of parts
--
---- ....... response sending .......
--
---- this modifies node state, so locking and IO seems to be necessary.
---- Still try to keep as much code as possible pure
--respondJoin :: LocalNodeStateSTM -> Set.Set FediChordMessage -> Map Integer BS.ByteString
--respondJoin nsSTM msgSet =
-- -- check whether the joining node falls into our responsibility
-- -- if yes, adjust own predecessors/ successors and return those in a response
-- -- if no: empty response or send a QueryID forwards response?
-- -- TODO: notify service layer to copy over data now handled by the new joined node
-- ....... request sending .......
-- | send a join request and return the joined 'LocalNodeState' including neighbours
requestJoin :: NodeState a => a -- ^ currently responsible node to be contacted
-> LocalNodeState -- ^ joining NodeState
-> IO (Either String LocalNodeState) -- ^ node after join with all its new information
requestJoin toJoinOn ownState =
-> LocalNodeStateSTM -- ^ joining NodeState
-> IO (Either String LocalNodeStateSTM) -- ^ node after join with all its new information
requestJoin toJoinOn ownStateSTM =
bracket (mkSendSocket (getDomain toJoinOn) (getDhtPort toJoinOn)) close (\sock -> do
-- extract own state for getting request information
ownState <- readTVarIO ownStateSTM
responses <- sendRequestTo 5000 3 (\rid -> Request rid (toRemoteNodeState ownState) 1 True Join (Just JoinRequestPayload)) sock
joinedStateUnsorted <- foldM
(\nsAcc msg -> case payload msg of
Nothing -> pure nsAcc
Just msgPl -> do
-- add transfered cache entries to global NodeCache
queueAddEntries (joinCache msgPl) nsAcc
-- add received predecessors and successors
let
addPreds ns' = setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns'
addSuccs ns' = setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns'
pure $ addSuccs . addPreds $ nsAcc
)
-- reset predecessors and successors
(setPredecessors [] . setSuccessors [] $ ownState)
responses
(cacheInsertQ, joinedState) <- atomically $ do
stateSnap <- readTVar ownStateSTM
let
(cacheInsertQ, joinedStateUnsorted) = foldl'
(\(insertQ, nsAcc) msg ->
let
insertQ' = maybe insertQ (\msgPl ->
-- collect list of insertion statements into global cache
queueAddEntries (joinCache msgPl) : insertQ
) $ payload msg
-- add received predecessors and successors
addPreds ns' = maybe ns' (\msgPl ->
setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns'
) $ payload msg
addSuccs ns' = maybe ns' (\msgPl ->
setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns'
) $ payload msg
in
(insertQ', addSuccs . addPreds $ nsAcc)
)
-- reset predecessors and successors
([], setPredecessors [] . setSuccessors [] $ ownState)
responses
-- sort successors and predecessors
newState = setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted
writeTVar ownStateSTM newState
pure (cacheInsertQ, newState)
-- execute the cache insertions
mapM_ (\f -> f joinedState) cacheInsertQ
if responses == Set.empty
then pure . Left $ "join error: got no response from " <> show (getNid toJoinOn)
-- sort successors and predecessors
else pure . Right . setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted
else pure $ Right ownStateSTM
)
`catch` (\e -> pure . Left $ displayException (e :: IOException))

View file

@ -64,6 +64,7 @@ import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Concurrent.STM.TQueue
import Control.Concurrent.STM.TVar
import Control.Monad (forM_, forever)
import Crypto.Hash
import qualified Data.ByteArray as BA
@ -84,11 +85,12 @@ import Debug.Trace (trace)
-- | initialise data structures, compute own IDs and bind to listening socket
-- ToDo: load persisted state, thus this function already operates in IO
fediChordInit :: FediChordConf -> IO (Socket, LocalNodeState)
fediChordInit :: FediChordConf -> IO (Socket, LocalNodeStateSTM)
fediChordInit conf = do
initialState <- nodeStateInit conf
initialStateSTM <- newTVarIO initialState
serverSock <- mkServerSocket (getIpAddr initialState) (getDhtPort initialState)
pure (serverSock, initialState)
pure (serverSock, initialStateSTM)
-- | initialises the 'NodeState' for this local node.
-- Separated from 'fediChordInit' to be usable in tests.
@ -120,15 +122,16 @@ nodeStateInit conf = do
-- | Join a new node into the DHT, using a provided bootstrap node as initial cache seed
-- for resolving the new node's position.
fediChordBootstrapJoin :: LocalNodeState -- ^ the local 'NodeState'
fediChordBootstrapJoin :: LocalNodeStateSTM -- ^ the local 'NodeState'
-> (String, PortNumber) -- ^ domain and port of a bootstrapping node
-> IO (Either String LocalNodeState) -- ^ the joined 'NodeState' after a
-> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a
-- successful join, otherwise an error message
fediChordBootstrapJoin ns (joinHost, joinPort) =
fediChordBootstrapJoin nsSTM (joinHost, joinPort) =
-- can be invoked multiple times with all known bootstrapping nodes until successfully joined
bracket (mkSendSocket joinHost joinPort) close (\sock -> do
-- 1. get routed to placement of own ID until FOUND:
-- Initialise an empty cache only with the responses from a bootstrapping node
ns <- readTVarIO nsSTM
bootstrapResponse <- sendQueryIdMessage (getNid ns) ns sock
if bootstrapResponse == Set.empty
then pure . Left $ "Bootstrapping node " <> show joinHost <> " gave no response."
@ -143,7 +146,7 @@ fediChordBootstrapJoin ns (joinHost, joinPort) =
Just (FORWARD resultset) -> foldr' (addCacheEntryPure now) cacheAcc resultset
)
initCache bootstrapResponse
fediChordJoin bootstrapCache ns
fediChordJoin bootstrapCache nsSTM
)
`catch` (\e -> pure . Left $ "Error at bootstrap joining: " <> displayException (e :: IOException))
@ -151,15 +154,16 @@ fediChordBootstrapJoin ns (joinHost, joinPort) =
-- node's position.
fediChordJoin :: NodeCache -- ^ a snapshot of the NodeCache to
-- use for ID lookup
-> LocalNodeState -- ^ the local 'NodeState'
-> IO (Either String LocalNodeState) -- ^ the joined 'NodeState' after a
-> LocalNodeStateSTM -- ^ the local 'NodeState'
-> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a
-- successful join, otherwise an error message
fediChordJoin cacheSnapshot ns = do
fediChordJoin cacheSnapshot nsSTM = do
ns <- readTVarIO nsSTM
-- get routed to the currently responsible node, based on the response
-- from the bootstrapping node
currentlyResponsible <- queryIdLookupLoop cacheSnapshot ns $ getNid ns
-- 2. then send a join to the currently responsible node
joinResult <- requestJoin currentlyResponsible ns
joinResult <- requestJoin currentlyResponsible nsSTM
case joinResult of
Left err -> pure . Left $ "Error joining on " <> err
Right joinedNS -> pure . Right $ joinedNS
@ -167,8 +171,9 @@ fediChordJoin cacheSnapshot ns = do
-- | cache updater thread that waits for incoming NodeCache update instructions on
-- the node's cacheWriteQueue and then modifies the NodeCache as the single writer.
cacheWriter :: LocalNodeState -> IO ()
cacheWriter ns = do
cacheWriter :: LocalNodeStateSTM -> IO ()
cacheWriter nsSTM = do
ns <- readTVarIO nsSTM
let writeQueue' = cacheWriteQueue ns
forever $ do
f <- atomically $ readTQueue writeQueue'
@ -196,14 +201,14 @@ sendThread sock sendQ = forever $ do
sendAllTo sock packet addr
-- | Sets up and manages the main server threads of FediChord
fediMainThreads :: Socket -> LocalNodeState -> IO ()
fediMainThreads sock ns = do
fediMainThreads :: Socket -> LocalNodeStateSTM -> IO ()
fediMainThreads sock nsSTM = do
sendQ <- newTQueueIO
recvQ <- newTQueueIO
-- concurrently launch all handler threads, if one of them throws an exception
-- all get cancelled
concurrently_
(fediMessageHandler sendQ recvQ ns) $
(fediMessageHandler sendQ recvQ nsSTM) $
concurrently
(sendThread sock sendQ)
(recvThread sock recvQ)
@ -236,9 +241,10 @@ requestMapPurge mapVar = forever $ do
-- and pass them to their specific handling function.
fediMessageHandler :: TQueue (BS.ByteString, SockAddr) -- ^ send queue
-> TQueue (BS.ByteString, SockAddr) -- ^ receive queue
-> LocalNodeState -- ^ acting NodeState
-> LocalNodeStateSTM -- ^ acting NodeState
-> IO ()
fediMessageHandler sendQ recvQ ns = do
fediMessageHandler sendQ recvQ nsSTM = do
nsSnap <- readTVarIO nsSTM
-- handling multipart messages:
-- Request parts can be insert into a map (key: (sender IP against spoofing, request ID), value: timestamp + set of message parts, handle all of them when size of set == parts) before being handled. This map needs to be purged periodically by a separate thread and can be protected by an MVar for fairness.
requestMap <- newMVar (Map.empty :: RequestMap)
@ -257,7 +263,7 @@ fediMessageHandler sendQ recvQ ns = do
aRequest@Request{}
-- if not a multipart message, handle immediately. Response is at the same time an ACK
| part aRequest == 1 && isFinalPart aRequest ->
forkIO (handleIncomingRequest ns sendQ (Set.singleton aRequest) sourceAddr) >> pure ()
forkIO (handleIncomingRequest nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure ()
-- otherwise collect all message parts first before handling the whole request
| otherwise -> do
now <- getPOSIXTime
@ -277,14 +283,14 @@ fediMessageHandler sendQ recvQ ns = do
-- put map back into MVar, end of critical section
putMVar requestMap newMapState
-- ACK the received part
forM_ (ackRequest (getNid ns) aRequest) $
forM_ (ackRequest (getNid nsSnap) aRequest) $
\msg -> atomically $ writeTQueue sendQ (msg, sourceAddr)
-- if all parts received, then handle request.
let
(RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState
numParts = Set.size theseParts
if maybe False (numParts ==) (fromIntegral <$> mayMaxParts)
then forkIO (handleIncomingRequest ns sendQ theseParts sourceAddr) >> pure()
then forkIO (handleIncomingRequest nsSTM sendQ theseParts sourceAddr) >> pure()
else pure()
-- Responses should never arrive on the main server port, as they are always
-- responses to requests sent from dedicated sockets on another port

View file

@ -9,6 +9,7 @@ module Hash2Pub.FediChordTypes (
, toNodeID
, NodeState (..)
, LocalNodeState (..)
, LocalNodeStateSTM
, RemoteNodeState (..)
, setSuccessors
, setPredecessors
@ -40,6 +41,7 @@ import Network.Socket
-- for hashing and ID conversion
import Control.Concurrent.STM
import Control.Concurrent.STM.TQueue
import Control.Concurrent.STM.TVar
import Control.Monad (forever)
import Crypto.Hash
import qualified Data.ByteArray as BA
@ -144,6 +146,8 @@ data LocalNodeState = LocalNodeState
}
deriving (Show, Eq)
type LocalNodeStateSTM = TVar LocalNodeState
-- | class for various NodeState representations, providing
-- getters and setters for common values
class NodeState a where

View file

@ -2,9 +2,11 @@ module Main where
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Concurrent.STM.TVar
import Control.Exception
import Data.Either
import Data.IP (IPv6, toHostAddress6)
import Data.IP (IPv6, toHostAddress6)
import System.Environment
import Hash2Pub.FediChord
@ -16,7 +18,7 @@ main = do
conf <- readConfig
-- ToDo: load persisted caches, bootstrapping nodes …
(serverSock, thisNode) <- fediChordInit conf
print thisNode
print =<< readTVarIO thisNode
print serverSock
-- currently no masking is necessary, as there is nothing to clean up
cacheWriterThread <- forkIO $ cacheWriter thisNode
@ -38,7 +40,7 @@ main = do
)
(\joinedNS -> do
-- launch main eventloop with successfully joined state
putStrLn ("successful join at " <> (show . getNid $ joinedNS))
putStrLn "successful join"
wait =<< async (fediMainThreads serverSock thisNode)
)
joinedState