protect concurrent node state access with STM
- for allowing concurrent access to predecessors and successors, the whole LocalNodeState is passed wrapped into an STM TVar - this allows keeping the tests for the mostly pure data type, compared to protecting only the successor and predecessor list contributes to #28
This commit is contained in:
		
							parent
							
								
									f42dfb2137
								
							
						
					
					
						commit
						dc2e399d64
					
				
					 4 changed files with 107 additions and 68 deletions
				
			
		|  | @ -53,7 +53,8 @@ import           System.Timeout | |||
| 
 | ||||
| import           Hash2Pub.ASN1Coding | ||||
| import           Hash2Pub.FediChordTypes        (CacheEntry (..), | ||||
|                                                  LocalNodeState (..), NodeCache, | ||||
|                                                  LocalNodeState (..), | ||||
|                                                  LocalNodeStateSTM, NodeCache, | ||||
|                                                  NodeID, NodeState (..), | ||||
|                                                  RemoteNodeState (..), | ||||
|                                                  cacheGetNodeStateUnvalidated, | ||||
|  | @ -169,67 +170,93 @@ ackRequest ownID req@Request{} = serialiseMessage sendMessageSize $ Response { | |||
|                                                                              } | ||||
| 
 | ||||
| 
 | ||||
| handleIncomingRequest :: LocalNodeState                     -- ^ the handling node | ||||
| handleIncomingRequest :: LocalNodeStateSTM                     -- ^ the handling node | ||||
|                       -> TQueue (BS.ByteString, SockAddr)   -- ^ send queue | ||||
|                       -> Set.Set FediChordMessage           -- ^ all parts of the request to handle | ||||
|                       -> SockAddr                           -- ^ source address of the request | ||||
|                       -> IO () | ||||
| handleIncomingRequest ns sendQ msgSet sourceAddr = do | ||||
| handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do | ||||
|     ns <- readTVarIO nsSTM | ||||
|     -- add nodestate to cache | ||||
|     now <- getPOSIXTime | ||||
|     aPart <- headMay . Set.elems $ msgSet | ||||
|     case aPart of | ||||
|     case headMay . Set.elems $ msgSet of | ||||
|       Nothing -> pure () | ||||
|       Just aPart' -> | ||||
|         queueAddEntries (Identity . RemoteCacheEntry (sender aPart') $ now) ns | ||||
|       Just aPart -> do | ||||
|         queueAddEntries (Identity $ RemoteCacheEntry (sender aPart) now) ns | ||||
|         -- distinguish on whether and how to respond. If responding, pass message to response generating function and write responses to send queue | ||||
|         maybe (pure ()) (\respSet -> | ||||
|             forM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr))) | ||||
|             (case action aPart' of | ||||
|                 Ping -> Just  respondPing ns msgSet | ||||
|                 Join -> Just respondJoin ns msgSet | ||||
|                 -- ToDo: figure out what happens if not joined | ||||
|                 QueryID -> Just respondQueryID ns msgSet | ||||
|                 -- only when joined | ||||
|                 Leave -> if isJoined_ ns then Just respondLeave ns msgSet else Nothing | ||||
|                 -- only when joined | ||||
|                 Stabilise -> if isJoined_ ns then Just respondStabilise ns msgSet else Nothing | ||||
|             ) | ||||
|       -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1. | ||||
| 
 | ||||
|       -- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash | ||||
|       -- TODO: test case: mixed message types of parts | ||||
| 
 | ||||
| -- ....... response sending ....... | ||||
|         maybe (pure ()) ( | ||||
|             mapM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr)) | ||||
|                         ) | ||||
|             (case action aPart of | ||||
|                   _ -> Just Map.empty) -- placeholder | ||||
| --                Ping -> Just  respondPing nsSTM msgSet | ||||
| --                Join -> Just respondJoin nsSTM msgSet | ||||
| --                -- ToDo: figure out what happens if not joined | ||||
| --                QueryID -> Just respondQueryID nsSTM msgSet | ||||
| --                -- only when joined | ||||
| --                Leave -> if isJoined_ ns then Just respondLeave nsSTM msgSet else Nothing | ||||
| --                -- only when joined | ||||
| --                Stabilise -> if isJoined_ ns then Just respondStabilise nsSTM msgSet else Nothing | ||||
| --            ) | ||||
| --      -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1. | ||||
| -- | ||||
| --      -- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash | ||||
| --      -- TODO: test case: mixed message types of parts | ||||
| -- | ||||
| ---- ....... response sending ....... | ||||
| -- | ||||
| ---- this modifies node state, so locking and IO seems to be necessary. | ||||
| ---- Still try to keep as much code as possible pure | ||||
| --respondJoin :: LocalNodeStateSTM -> Set.Set FediChordMessage -> Map Integer BS.ByteString | ||||
| --respondJoin nsSTM msgSet = | ||||
| --    -- check whether the joining node falls into our responsibility | ||||
| --    -- if yes, adjust own predecessors/ successors and return those in a response | ||||
| --    -- if no: empty response or send a QueryID forwards response? | ||||
| --    -- TODO: notify service layer to copy over data now handled by the new joined node | ||||
| 
 | ||||
| -- ....... request sending ....... | ||||
| 
 | ||||
| -- | send a join request and return the joined 'LocalNodeState' including neighbours | ||||
| requestJoin :: NodeState a => a             -- ^ currently responsible node to be contacted | ||||
|             -> LocalNodeState               -- ^ joining NodeState | ||||
|             -> IO (Either String LocalNodeState)    -- ^ node after join with all its new information | ||||
| requestJoin toJoinOn ownState = | ||||
|             -> LocalNodeStateSTM               -- ^ joining NodeState | ||||
|             -> IO (Either String LocalNodeStateSTM)    -- ^ node after join with all its new information | ||||
| requestJoin toJoinOn ownStateSTM = | ||||
|     bracket (mkSendSocket (getDomain toJoinOn) (getDhtPort toJoinOn)) close (\sock -> do | ||||
|         -- extract own state for getting request information | ||||
|         ownState <- readTVarIO ownStateSTM | ||||
|         responses <- sendRequestTo 5000 3 (\rid -> Request rid (toRemoteNodeState ownState) 1 True Join (Just JoinRequestPayload)) sock | ||||
|         joinedStateUnsorted <- foldM | ||||
|             (\nsAcc msg -> case payload msg of | ||||
|               Nothing -> pure nsAcc | ||||
|               Just msgPl -> do | ||||
|                 -- add transfered cache entries to global NodeCache | ||||
|                 queueAddEntries (joinCache msgPl) nsAcc | ||||
|                 -- add received predecessors and successors | ||||
|                 let | ||||
|                     addPreds ns' = setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns' | ||||
|                     addSuccs ns' = setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns' | ||||
|                 pure $ addSuccs . addPreds $ nsAcc | ||||
|                 ) | ||||
|             -- reset predecessors and successors | ||||
|             (setPredecessors [] . setSuccessors [] $ ownState) | ||||
|             responses | ||||
|         (cacheInsertQ, joinedState) <- atomically $ do | ||||
|             stateSnap <- readTVar ownStateSTM | ||||
|             let | ||||
|                 (cacheInsertQ, joinedStateUnsorted) = foldl' | ||||
|                     (\(insertQ, nsAcc) msg -> | ||||
|                       let | ||||
|                         insertQ' = maybe insertQ (\msgPl -> | ||||
|                             -- collect list of insertion statements into global cache | ||||
|                             queueAddEntries (joinCache msgPl) : insertQ | ||||
|                                                  ) $ payload msg | ||||
|                         -- add received predecessors and successors | ||||
|                         addPreds ns' = maybe ns' (\msgPl -> | ||||
|                             setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns' | ||||
|                                                  ) $ payload msg | ||||
|                         addSuccs ns' = maybe ns' (\msgPl -> | ||||
|                             setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns' | ||||
|                                                  ) $ payload msg | ||||
|                         in | ||||
|                         (insertQ', addSuccs . addPreds $ nsAcc) | ||||
|                     ) | ||||
|                     -- reset predecessors and successors | ||||
|                     ([], setPredecessors [] . setSuccessors [] $ ownState) | ||||
|                     responses | ||||
|                 -- sort successors and predecessors | ||||
|                 newState = setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted | ||||
|             writeTVar ownStateSTM newState | ||||
|             pure (cacheInsertQ, newState) | ||||
|         -- execute the cache insertions | ||||
|         mapM_ (\f -> f joinedState) cacheInsertQ | ||||
|         if responses == Set.empty | ||||
|            then pure . Left $ "join error: got no response from " <> show (getNid toJoinOn) | ||||
|             -- sort successors and predecessors | ||||
|            else pure . Right . setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted | ||||
|            else pure $ Right ownStateSTM | ||||
|                                                                             ) | ||||
|         `catch` (\e -> pure . Left $ displayException (e :: IOException)) | ||||
| 
 | ||||
|  |  | |||
|  | @ -64,6 +64,7 @@ import           Control.Concurrent | |||
| import           Control.Concurrent.Async | ||||
| import           Control.Concurrent.STM | ||||
| import           Control.Concurrent.STM.TQueue | ||||
| import           Control.Concurrent.STM.TVar | ||||
| import           Control.Monad                 (forM_, forever) | ||||
| import           Crypto.Hash | ||||
| import qualified Data.ByteArray                as BA | ||||
|  | @ -84,11 +85,12 @@ import           Debug.Trace                   (trace) | |||
| 
 | ||||
| -- | initialise data structures, compute own IDs and bind to listening socket | ||||
| -- ToDo: load persisted state, thus this function already operates in IO | ||||
| fediChordInit :: FediChordConf -> IO (Socket, LocalNodeState) | ||||
| fediChordInit :: FediChordConf -> IO (Socket, LocalNodeStateSTM) | ||||
| fediChordInit conf = do | ||||
|     initialState <- nodeStateInit conf | ||||
|     initialStateSTM <- newTVarIO initialState | ||||
|     serverSock <- mkServerSocket (getIpAddr initialState) (getDhtPort initialState) | ||||
|     pure (serverSock, initialState) | ||||
|     pure (serverSock, initialStateSTM) | ||||
| 
 | ||||
| -- | initialises the 'NodeState' for this local node. | ||||
| -- Separated from 'fediChordInit' to be usable in tests. | ||||
|  | @ -120,15 +122,16 @@ nodeStateInit conf = do | |||
| 
 | ||||
| -- | Join a new node into the DHT, using a provided bootstrap node as initial cache seed | ||||
| -- for resolving the new node's position. | ||||
| fediChordBootstrapJoin :: LocalNodeState              -- ^ the local 'NodeState' | ||||
| fediChordBootstrapJoin :: LocalNodeStateSTM              -- ^ the local 'NodeState' | ||||
|                         -> (String, PortNumber)   -- ^ domain and port of a bootstrapping node | ||||
|                         -> IO (Either String LocalNodeState) -- ^ the joined 'NodeState' after a | ||||
|                         -> IO (Either String LocalNodeStateSTM) -- ^ the joined 'NodeState' after a | ||||
|                                             -- successful join, otherwise an error message | ||||
| fediChordBootstrapJoin ns (joinHost, joinPort) = | ||||
| fediChordBootstrapJoin nsSTM (joinHost, joinPort) = | ||||
|     -- can be invoked multiple times with all known bootstrapping nodes until successfully joined | ||||
|     bracket (mkSendSocket joinHost joinPort) close (\sock -> do | ||||
|         -- 1. get routed to placement of own ID until FOUND: | ||||
|         -- Initialise an empty cache only with the responses from a bootstrapping node | ||||
|         ns <- readTVarIO nsSTM | ||||
|         bootstrapResponse <- sendQueryIdMessage (getNid ns) ns sock | ||||
|         if bootstrapResponse == Set.empty | ||||
|            then pure . Left $ "Bootstrapping node " <> show joinHost <> " gave no response." | ||||
|  | @ -143,7 +146,7 @@ fediChordBootstrapJoin ns (joinHost, joinPort) = | |||
|                            Just (FORWARD resultset) -> foldr' (addCacheEntryPure now) cacheAcc resultset | ||||
|                               ) | ||||
|                            initCache bootstrapResponse | ||||
|                fediChordJoin bootstrapCache ns | ||||
|                fediChordJoin bootstrapCache nsSTM | ||||
|                                                    ) | ||||
|        `catch` (\e -> pure . Left $ "Error at bootstrap joining: " <> displayException (e :: IOException)) | ||||
| 
 | ||||
|  | @ -151,15 +154,16 @@ fediChordBootstrapJoin ns (joinHost, joinPort) = | |||
| -- node's position. | ||||
| fediChordJoin :: NodeCache                          -- ^ a snapshot of the NodeCache to | ||||
|                                                     -- use for ID lookup | ||||
|               -> LocalNodeState                     -- ^ the local 'NodeState' | ||||
|               -> IO (Either String LocalNodeState)  -- ^ the joined 'NodeState' after a | ||||
|               -> LocalNodeStateSTM                    -- ^ the local 'NodeState' | ||||
|               -> IO (Either String LocalNodeStateSTM)  -- ^ the joined 'NodeState' after a | ||||
|                                                     -- successful join, otherwise an error message | ||||
| fediChordJoin cacheSnapshot ns = do | ||||
| fediChordJoin cacheSnapshot nsSTM = do | ||||
|     ns <- readTVarIO nsSTM | ||||
|     -- get routed to the currently responsible node, based on the response | ||||
|     -- from the bootstrapping node | ||||
|     currentlyResponsible <- queryIdLookupLoop cacheSnapshot ns $ getNid ns | ||||
|     -- 2. then send a join to the currently responsible node | ||||
|     joinResult <- requestJoin currentlyResponsible ns | ||||
|     joinResult <- requestJoin currentlyResponsible nsSTM | ||||
|     case joinResult of | ||||
|       Left err       -> pure . Left $ "Error joining on " <> err | ||||
|       Right joinedNS -> pure . Right $ joinedNS | ||||
|  | @ -167,8 +171,9 @@ fediChordJoin cacheSnapshot ns = do | |||
| 
 | ||||
| -- | cache updater thread that waits for incoming NodeCache update instructions on | ||||
| -- the node's cacheWriteQueue and then modifies the NodeCache as the single writer. | ||||
| cacheWriter :: LocalNodeState -> IO () | ||||
| cacheWriter ns = do | ||||
| cacheWriter :: LocalNodeStateSTM -> IO () | ||||
| cacheWriter nsSTM = do | ||||
|     ns <- readTVarIO nsSTM | ||||
|     let writeQueue' = cacheWriteQueue ns | ||||
|     forever $ do | ||||
|       f <- atomically $ readTQueue writeQueue' | ||||
|  | @ -196,14 +201,14 @@ sendThread sock sendQ = forever $ do | |||
|     sendAllTo sock packet addr | ||||
| 
 | ||||
| -- | Sets up and manages the main server threads of FediChord | ||||
| fediMainThreads :: Socket -> LocalNodeState -> IO () | ||||
| fediMainThreads sock ns = do | ||||
| fediMainThreads :: Socket -> LocalNodeStateSTM -> IO () | ||||
| fediMainThreads sock nsSTM = do | ||||
|     sendQ <- newTQueueIO | ||||
|     recvQ <- newTQueueIO | ||||
|     -- concurrently launch all handler threads, if one of them throws an exception | ||||
|     -- all get cancelled | ||||
|     concurrently_ | ||||
|         (fediMessageHandler sendQ recvQ ns) $ | ||||
|         (fediMessageHandler sendQ recvQ nsSTM) $ | ||||
|         concurrently | ||||
|             (sendThread sock sendQ) | ||||
|             (recvThread sock recvQ) | ||||
|  | @ -236,9 +241,10 @@ requestMapPurge mapVar = forever $ do | |||
| -- and pass them to their specific handling function. | ||||
| fediMessageHandler :: TQueue (BS.ByteString, SockAddr)  -- ^ send queue | ||||
|                    -> TQueue (BS.ByteString, SockAddr)  -- ^ receive queue | ||||
|                    -> LocalNodeState                    -- ^ acting NodeState | ||||
|                    -> LocalNodeStateSTM                    -- ^ acting NodeState | ||||
|                    -> IO () | ||||
| fediMessageHandler sendQ recvQ ns = do | ||||
| fediMessageHandler sendQ recvQ nsSTM = do | ||||
|     nsSnap <- readTVarIO nsSTM | ||||
|     -- handling multipart messages: | ||||
|     -- Request parts can be insert into a map (key: (sender IP against spoofing, request ID), value: timestamp + set of message parts, handle all of them when size of set == parts) before being handled. This map needs to be purged periodically by a separate thread and can be protected by an MVar for fairness. | ||||
|     requestMap <- newMVar (Map.empty :: RequestMap) | ||||
|  | @ -257,7 +263,7 @@ fediMessageHandler sendQ recvQ ns = do | |||
|               aRequest@Request{} | ||||
|                 -- if not a multipart message, handle immediately. Response is at the same time an ACK | ||||
|                 | part aRequest == 1 && isFinalPart aRequest -> | ||||
|                   forkIO (handleIncomingRequest ns sendQ (Set.singleton aRequest) sourceAddr) >> pure () | ||||
|                   forkIO (handleIncomingRequest nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure () | ||||
|                 -- otherwise collect all message parts first before handling the whole request | ||||
|                 | otherwise -> do | ||||
|                   now <- getPOSIXTime | ||||
|  | @ -277,14 +283,14 @@ fediMessageHandler sendQ recvQ ns = do | |||
|                   -- put map back into MVar, end of critical section | ||||
|                   putMVar requestMap newMapState | ||||
|                   -- ACK the received part | ||||
|                   forM_ (ackRequest (getNid ns) aRequest) $ | ||||
|                   forM_ (ackRequest (getNid nsSnap) aRequest) $ | ||||
|                       \msg -> atomically $ writeTQueue sendQ (msg, sourceAddr) | ||||
|                   -- if all parts received, then handle request. | ||||
|                   let | ||||
|                     (RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState | ||||
|                     numParts = Set.size theseParts | ||||
|                   if maybe False (numParts ==) (fromIntegral <$> mayMaxParts) | ||||
|                      then forkIO (handleIncomingRequest ns sendQ theseParts sourceAddr) >> pure() | ||||
|                      then forkIO (handleIncomingRequest nsSTM sendQ theseParts sourceAddr) >> pure() | ||||
|                      else pure() | ||||
|               -- Responses should never arrive on the main server port, as they are always | ||||
|               -- responses to requests sent from dedicated sockets on another port | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ module Hash2Pub.FediChordTypes ( | |||
|   , toNodeID | ||||
|   , NodeState (..) | ||||
|   , LocalNodeState (..) | ||||
|   , LocalNodeStateSTM | ||||
|   , RemoteNodeState (..) | ||||
|   , setSuccessors | ||||
|   , setPredecessors | ||||
|  | @ -40,6 +41,7 @@ import           Network.Socket | |||
| -- for hashing and ID conversion | ||||
| import           Control.Concurrent.STM | ||||
| import           Control.Concurrent.STM.TQueue | ||||
| import           Control.Concurrent.STM.TVar | ||||
| import           Control.Monad                 (forever) | ||||
| import           Crypto.Hash | ||||
| import qualified Data.ByteArray                as BA | ||||
|  | @ -144,6 +146,8 @@ data LocalNodeState = LocalNodeState | |||
|     } | ||||
|     deriving (Show, Eq) | ||||
| 
 | ||||
| type LocalNodeStateSTM = TVar LocalNodeState | ||||
| 
 | ||||
| -- | class for various NodeState representations, providing | ||||
| -- getters and setters for common values | ||||
| class NodeState a where | ||||
|  |  | |||
|  | @ -2,9 +2,11 @@ module Main where | |||
| 
 | ||||
| import           Control.Concurrent | ||||
| import           Control.Concurrent.Async | ||||
| import           Control.Concurrent.STM | ||||
| import           Control.Concurrent.STM.TVar | ||||
| import           Control.Exception | ||||
| import           Data.Either | ||||
| import           Data.IP                  (IPv6, toHostAddress6) | ||||
| import           Data.IP                     (IPv6, toHostAddress6) | ||||
| import           System.Environment | ||||
| 
 | ||||
| import           Hash2Pub.FediChord | ||||
|  | @ -16,7 +18,7 @@ main = do | |||
|     conf <- readConfig | ||||
|     -- ToDo: load persisted caches, bootstrapping nodes … | ||||
|     (serverSock, thisNode) <- fediChordInit conf | ||||
|     print thisNode | ||||
|     print =<< readTVarIO thisNode | ||||
|     print serverSock | ||||
|     -- currently no masking is necessary, as there is nothing to clean up | ||||
|     cacheWriterThread <- forkIO $ cacheWriter thisNode | ||||
|  | @ -38,7 +40,7 @@ main = do | |||
|            ) | ||||
|            (\joinedNS -> do | ||||
|         -- launch main eventloop with successfully joined state | ||||
|         putStrLn ("successful join at " <> (show . getNid $ joinedNS)) | ||||
|         putStrLn "successful join" | ||||
|         wait =<< async (fediMainThreads serverSock thisNode) | ||||
|            ) | ||||
|         joinedState | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue