From 8d349212b4cb84a62ed657ac014af7a7de3470b2 Mon Sep 17 00:00:00 2001 From: Trolli Schmittlauch Date: Wed, 1 Jul 2020 18:22:28 +0200 Subject: [PATCH] prevent cache invariant querying when not joined --- src/Hash2Pub/DHTProtocol.hs | 9 +++++---- src/Hash2Pub/FediChord.hs | 11 ++++++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/Hash2Pub/DHTProtocol.hs b/src/Hash2Pub/DHTProtocol.hs index 83e32d4..230f7df 100644 --- a/src/Hash2Pub/DHTProtocol.hs +++ b/src/Hash2Pub/DHTProtocol.hs @@ -32,6 +32,7 @@ module Hash2Pub.DHTProtocol , ackRequest , isPossibleSuccessor , isPossiblePredecessor + , isJoined , closestCachePredecessors ) where @@ -213,8 +214,8 @@ markCacheEntryAsVerified timestamp nid = RingMap . Map.adjust adjustFunc nid . g -- | uses the successor and predecessor list of a node as an indicator for whether a -- node has properly joined the DHT -isJoined_ :: LocalNodeState -> Bool -isJoined_ ns = not . all null $ [successors ns, predecessors ns] +isJoined :: LocalNodeState -> Bool +isJoined ns = not . all null $ [successors ns, predecessors ns] -- | the size limit to be used when serialising messages for sending sendMessageSize :: Num i => i @@ -260,8 +261,8 @@ handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do -- ToDo: figure out what happens if not joined QueryID -> Just <$> respondQueryID nsSTM msgSet -- only when joined - Leave -> if isJoined_ ns then Just <$> respondLeave nsSTM msgSet else pure Nothing - Stabilise -> if isJoined_ ns then Just <$> respondStabilise nsSTM msgSet else pure Nothing + Leave -> if isJoined ns then Just <$> respondLeave nsSTM msgSet else pure Nothing + Stabilise -> if isJoined ns then Just <$> respondStabilise nsSTM msgSet else pure Nothing ) -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1. diff --git a/src/Hash2Pub/FediChord.hs b/src/Hash2Pub/FediChord.hs index 95617fc..061a74f 100644 --- a/src/Hash2Pub/FediChord.hs +++ b/src/Hash2Pub/FediChord.hs @@ -253,7 +253,9 @@ checkCacheSliceInvariants :: LocalNodeState -> NodeCache -> [NodeID] -- ^ list of middle IDs of slices not -- ^ fulfilling the invariant -checkCacheSliceInvariants ns = checkPredecessorSlice jEntries (getNid ns) startBound lastPred <> checkSuccessorSlice jEntries (getNid ns) startBound lastSucc +checkCacheSliceInvariants ns + | isJoined ns = checkPredecessorSlice jEntries (getNid ns) startBound lastPred <> checkSuccessorSlice jEntries (getNid ns) startBound lastSucc + | otherwise = const [] where jEntries = jEntriesPerSlice ns lastPred = getNid <$> lastMay (predecessors ns) @@ -340,7 +342,7 @@ stabiliseThread nsSTM = forever $ do -- try looking up additional neighbours if list too short forM_ [(length $ predecessors updatedNs)..(kNeighbours updatedNs)] (\_ -> do ns' <- readTVarIO nsSTM - nextEntry <- requestQueryID ns' $ pred . getNid $ atDef (toRemoteNodeState ns') (predecessors ns') (-1) + nextEntry <- requestQueryID ns' $ pred . getNid $ lastDef (toRemoteNodeState ns') (predecessors ns') atomically $ do latestNs <- readTVar nsSTM writeTVar nsSTM $ addPredecessors [nextEntry] latestNs @@ -348,7 +350,7 @@ stabiliseThread nsSTM = forever $ do forM_ [(length $ successors updatedNs)..(kNeighbours updatedNs)] (\_ -> do ns' <- readTVarIO nsSTM - nextEntry <- requestQueryID ns' $ succ . getNid $ atDef (toRemoteNodeState ns') (successors ns') (-1) + nextEntry <- requestQueryID ns' $ succ . getNid $ lastDef (toRemoteNodeState ns') (successors ns') atomically $ do latestNs <- readTVar nsSTM writeTVar nsSTM $ addSuccessors [nextEntry] latestNs @@ -460,6 +462,9 @@ fediMessageHandler :: TQueue (BS.ByteString, SockAddr) -- ^ send queue -> LocalNodeStateSTM -- ^ acting NodeState -> IO () fediMessageHandler sendQ recvQ nsSTM = do + -- Read node state just once, assuming that all relevant data for this function does + -- not change. + -- Other functions are passed the nsSTM reference and thus can get the latest state. nsSnap <- readTVarIO nsSTM -- handling multipart messages: -- Request parts can be insert into a map (key: (sender IP against spoofing, request ID), value: timestamp + set of message parts, handle all of them when size of set == parts) before being handled. This map needs to be purged periodically by a separate thread and can be protected by an MVar for fairness.