module Hash2Pub.DHTProtocol ( QueryResponse (..) , queryLocalCache , addCacheEntry , addCacheEntryPure , deleteCacheEntry , deserialiseMessage , markCacheEntryAsVerified , RemoteCacheEntry(..) , toRemoteCacheEntry , remoteNode , Action(..) , ActionPayload(..) , FediChordMessage(..) , maximumParts , sendQueryIdMessage , requestQueryID , requestJoin , queryIdLookupLoop , resolve , mkSendSocket , mkServerSocket , handleIncomingRequest , ackRequest ) where import Control.Concurrent.Async import Control.Concurrent.STM import Control.Concurrent.STM.TBQueue import Control.Concurrent.STM.TQueue import Control.Exception import Control.Monad (foldM, forM, forM_) import qualified Data.ByteString as BS import Data.Either (rights) import Data.Foldable (foldl', foldr') import Data.Functor.Identity import Data.IORef import Data.IP (IPv6, fromHostAddress6, toHostAddress6) import Data.List (sortBy) import qualified Data.Map as Map import Data.Maybe (fromJust, fromMaybe, mapMaybe, maybe) import qualified Data.Set as Set import Data.Time.Clock.POSIX import Network.Socket hiding (recv, recvFrom, send, sendTo) import Network.Socket.ByteString import Safe import System.Random import System.Timeout import Hash2Pub.ASN1Coding import Hash2Pub.FediChordTypes (CacheEntry (..), LocalNodeState (..), NodeCache, NodeID, NodeState (..), RemoteNodeState (..), cacheGetNodeStateUnvalidated, cacheLookup, cacheLookupPred, cacheLookupSucc, localCompare, localCompare, setPredecessors, setSuccessors) import Hash2Pub.ProtocolTypes import Debug.Trace (trace) -- === queries === -- TODO: evaluate more fine-grained argument passing to allow granular locking -- | look up an ID to either claim responsibility for it or return the closest l nodes from the local cache queryLocalCache :: LocalNodeState -> NodeCache -> Int -> NodeID -> QueryResponse queryLocalCache ownState nCache lBestNodes targetID -- as target ID falls between own ID and first predecessor, it is handled by this node | (targetID `localCompare` ownID) `elem` [LT, EQ] && maybe False (\p -> targetID `localCompare` p == GT) (headMay preds) = FOUND . toRemoteNodeState $ ownState -- my interpretation: the "l best next hops" are the l-1 closest preceding nodes and -- the closest succeeding node (like with the p initiated parallel queries | otherwise = FORWARD $ closestSuccessor `Set.union` closestPredecessors where ownID = getNid ownState preds = predecessors ownState closestSuccessor :: Set.Set RemoteCacheEntry closestSuccessor = maybe Set.empty Set.singleton $ toRemoteCacheEntry =<< cacheLookupSucc targetID nCache closestPredecessors :: Set.Set RemoteCacheEntry closestPredecessors = closestPredecessor (lBestNodes-1) $ getNid ownState closestPredecessor :: (Integral n, Show n) => n -> NodeID -> Set.Set RemoteCacheEntry closestPredecessor 0 _ = Set.empty closestPredecessor remainingLookups lastID | remainingLookups < 0 = Set.empty | otherwise = let result = cacheLookupPred lastID nCache in case toRemoteCacheEntry =<< result of Nothing -> Set.empty Just nPred@(RemoteCacheEntry ns ts) -> Set.insert nPred $ closestPredecessor (remainingLookups-1) (nid ns) -- cache operations -- | update or insert a 'RemoteCacheEntry' into the cache, -- converting it to a local 'CacheEntry' addCacheEntry :: RemoteCacheEntry -- ^ a remote cache entry received from network -> NodeCache -- ^ node cache to insert to -> IO NodeCache -- ^ new node cache with the element inserted addCacheEntry entry cache = do now <- getPOSIXTime pure $ addCacheEntryPure now entry cache -- | pure version of 'addCacheEntry' with current time explicitly specified as argument addCacheEntryPure :: POSIXTime -- ^ current time -> RemoteCacheEntry -- ^ a remote cache entry received from network -> NodeCache -- ^ node cache to insert to -> NodeCache -- ^ new node cache with the element inserted addCacheEntryPure now (RemoteCacheEntry ns ts) cache = let -- TODO: limit diffSeconds to some maximum value to prevent malicious nodes from inserting entries valid nearly until eternity timestamp' = if ts <= now then ts else now newCache = Map.insertWith insertCombineFunction (nid ns) (NodeEntry False ns timestamp') cache insertCombineFunction newVal@(NodeEntry newValidationState newNode newTimestamp) oldVal = case oldVal of ProxyEntry n _ -> ProxyEntry n (Just newVal) NodeEntry oldValidationState _ oldTimestamp -> NodeEntry oldValidationState newNode (max oldTimestamp newTimestamp) in newCache -- | delete the node with given ID from cache deleteCacheEntry :: NodeID -- ^ID of the node to be deleted -> NodeCache -- ^cache to delete from -> NodeCache -- ^cache without the specified element deleteCacheEntry = Map.update modifier where modifier (ProxyEntry idPointer _) = Just (ProxyEntry idPointer Nothing) modifier NodeEntry {} = Nothing -- | Mark a cache entry as verified after pinging it, possibly bumping its timestamp. markCacheEntryAsVerified :: Maybe POSIXTime -- ^ the (current) timestamp to be -- given to the entry, or Nothing -> NodeID -- ^ which node to mark -> NodeCache -- ^ current node cache -> NodeCache -- ^ new NodeCache with the updated entry markCacheEntryAsVerified timestamp = Map.adjust adjustFunc where adjustFunc (NodeEntry _ ns ts) = NodeEntry True ns $ fromMaybe ts timestamp adjustFunc (ProxyEntry _ (Just entry)) = adjustFunc entry adjustFunc entry = entry -- | uses the successor and predecessor list of a node as an indicator for whether a -- node has properly joined the DHT isJoined_ :: LocalNodeState -> Bool isJoined_ ns = not . all null $ [successors ns, predecessors ns] -- | the size limit to be used when serialising messages for sending sendMessageSize :: Num i => i sendMessageSize = 1200 -- ====== message send and receive operations ====== -- encode the response to a request that just signals successful receipt ackRequest :: NodeID -> FediChordMessage -> Map.Map Integer BS.ByteString ackRequest ownID req@Request{} = serialiseMessage sendMessageSize $ Response { responseTo = requestID req , senderID = ownID , part = part req , isFinalPart = False , action = action req , payload = Nothing } handleIncomingRequest :: LocalNodeState -- ^ the handling node -> TQueue (BS.ByteString, SockAddr) -- ^ send queue -> Set.Set FediChordMessage -- ^ all parts of the request to handle -> SockAddr -- ^ source address of the request -> IO () handleIncomingRequest ns sendQ msgSet sourceAddr = do -- add nodestate to cache now <- getPOSIXTime aPart <- headMay . Set.elems $ msgSet case aPart of Nothing -> pure () Just aPart' -> queueAddEntries (Identity . RemoteCacheEntry (sender aPart') $ now) ns -- distinguish on whether and how to respond. If responding, pass message to response generating function and write responses to send queue maybe (pure ()) (\respSet -> forM_ (\resp -> atomically $ writeTQueue sendQ (resp, sourceAddr))) (case action aPart' of Ping -> Just respondPing ns msgSet Join -> Just respondJoin ns msgSet -- ToDo: figure out what happens if not joined QueryID -> Just respondQueryID ns msgSet -- only when joined Leave -> if isJoined_ ns then Just respondLeave ns msgSet else Nothing -- only when joined Stabilise -> if isJoined_ ns then Just respondStabilise ns msgSet else Nothing ) -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1. -- TODO: determine request type only from first part, but catch RecSelError on each record access when folding, because otherwise different request type parts can make this crash -- TODO: test case: mixed message types of parts -- ....... response sending ....... -- ....... request sending ....... -- | send a join request and return the joined 'LocalNodeState' including neighbours requestJoin :: NodeState a => a -- ^ currently responsible node to be contacted -> LocalNodeState -- ^ joining NodeState -> IO (Either String LocalNodeState) -- ^ node after join with all its new information requestJoin toJoinOn ownState = bracket (mkSendSocket (getDomain toJoinOn) (getDhtPort toJoinOn)) close (\sock -> do responses <- sendRequestTo 5000 3 (\rid -> Request rid (toRemoteNodeState ownState) 1 True Join (Just JoinRequestPayload)) sock joinedStateUnsorted <- foldM (\nsAcc msg -> case payload msg of Nothing -> pure nsAcc Just msgPl -> do -- add transfered cache entries to global NodeCache queueAddEntries (joinCache msgPl) nsAcc -- add received predecessors and successors let addPreds ns' = setPredecessors (foldr' (:) (predecessors ns') (joinPredecessors msgPl)) ns' addSuccs ns' = setSuccessors (foldr' (:) (successors ns') (joinSuccessors msgPl)) ns' pure $ addSuccs . addPreds $ nsAcc ) -- reset predecessors and successors (setPredecessors [] . setSuccessors [] $ ownState) responses if responses == Set.empty then pure . Left $ "join error: got no response from " <> show (getNid toJoinOn) -- sort successors and predecessors else pure . Right . setSuccessors (sortBy localCompare $ successors joinedStateUnsorted) . setPredecessors (sortBy localCompare $ predecessors joinedStateUnsorted) $ joinedStateUnsorted ) `catch` (\e -> pure . Left $ displayException (e :: IOException)) -- | Send a 'QueryID' 'Request' for getting the node that handles a certain key ID. requestQueryID :: LocalNodeState -- ^ NodeState of the querying node -> NodeID -- ^ target key ID to look up -> IO RemoteNodeState -- ^ the node responsible for handling that key -- 1. do a local lookup for the l closest nodes -- 2. create l sockets -- 3. send a message async concurrently to all l nodes -- 4. collect the results, insert them into cache -- 5. repeat until FOUND (problem: new entries not necessarily already in cache, explicitly compare with closer results) -- TODO: deal with lookup failures requestQueryID ns targetID = do firstCacheSnapshot <- readIORef . nodeCacheRef $ ns queryIdLookupLoop firstCacheSnapshot ns targetID -- | like 'requestQueryID, but allows passing of a custom cache, e.g. for joining queryIdLookupLoop :: NodeCache -> LocalNodeState -> NodeID -> IO RemoteNodeState queryIdLookupLoop cacheSnapshot ns targetID = do let localResult = queryLocalCache ns cacheSnapshot (lNumBestNodes ns) targetID -- FOUND can only be returned if targetID is owned by local node case localResult of FOUND thisNode -> pure thisNode FORWARD nodeSet -> do -- create connected sockets to all query targets and use them for request handling -- ToDo: make attempts and timeout configurable queryThreads <- mapM (\resultNode -> async $ bracket (mkSendSocket (getDomain resultNode) (getDhtPort resultNode)) close (sendQueryIdMessage targetID ns)) $ remoteNode <$> Set.toList nodeSet -- ToDo: process results immediately instead of waiting for the last one to finish, see https://stackoverflow.com/a/38815224/9198613 -- ToDo: exception handling, maybe log them responses <- (mconcat . fmap Set.elems) . rights <$> mapM waitCatch queryThreads -- insert new cache entries both into global cache as well as in local copy, to make sure it is already up to date at next lookup now <- getPOSIXTime newLCache <- foldM (\oldCache resp -> do let entriesToInsert = case queryResult <$> payload resp of Just (FOUND result1) -> [RemoteCacheEntry result1 now] Just (FORWARD resultset) -> Set.elems resultset _ -> [] -- forward entries to global cache queueAddEntries entriesToInsert ns -- insert entries into local cache copy pure $ foldr' ( addCacheEntryPure now ) oldCache entriesToInsert ) cacheSnapshot responses -- check for a FOUND and return it let foundResp = headMay . mapMaybe (\resp -> case queryResult <$> payload resp of Just (FOUND ns') -> Just ns' _ -> Nothing ) $ responses -- if no FOUND, recursively call lookup again maybe (queryIdLookupLoop newLCache ns targetID) pure foundResp sendQueryIdMessage :: NodeID -- ^ target key ID to look up -> LocalNodeState -- ^ node state of the node doing the query -> Socket -- ^ connected socket to use for sending -> IO (Set.Set FediChordMessage) -- ^ responses sendQueryIdMessage targetID ns = sendRequestTo 5000 3 (lookupMessage targetID ns) where lookupMessage targetID ns rID = Request rID (toRemoteNodeState ns) 1 True QueryID (Just $ pl ns targetID) pl ns' targetID' = QueryIDRequestPayload { queryTargetID = targetID', queryLBestNodes = fromIntegral . lNumBestNodes $ ns } -- | Generic function for sending a request over a connected socket and collecting the response. -- Serialises the message and tries to deliver its parts for a number of attempts within a specified timeout. sendRequestTo :: Int -- ^ timeout in seconds -> Int -- ^ number of retries -> (Integer -> FediChordMessage) -- ^ the message to be sent, still needing a requestID -> Socket -- ^ connected socket to use for sending -> IO (Set.Set FediChordMessage) -- ^ responses sendRequestTo timeoutMillis numAttempts msgIncomplete sock = do -- give the message a random request ID randomID <- randomRIO (0, 2^32-1) let requests = serialiseMessage sendMessageSize $ msgIncomplete randomID -- create a queue for passing received response messages back, even after a timeout responseQ <- newTBQueueIO $ 2*maximumParts -- keep room for duplicate packets -- start sendAndAck with timeout attempts numAttempts . timeout timeoutMillis $ sendAndAck responseQ sock requests -- after timeout, check received responses, delete them from unacked message set/ map and rerun senAndAck with that if necessary. recvdParts <- atomically $ flushTBQueue responseQ pure $ Set.fromList recvdParts where sendAndAck :: TBQueue FediChordMessage -- ^ the queue for putting in the received responses -> Socket -- ^ the socket used for sending and receiving for this particular remote node -> Map.Map Integer BS.ByteString -- ^ the remaining unacked request parts -> IO () sendAndAck responseQueue sock remainingSends = do sendMany sock $ Map.elems remainingSends -- if all requests have been acked/ responded to, return prematurely recvLoop responseQueue remainingSends Set.empty Nothing recvLoop :: TBQueue FediChordMessage -- ^ the queue for putting in the received responses -> Map.Map Integer BS.ByteString -- ^ the remaining unacked request parts -> Set.Set Integer -- ^ already received response part numbers -> Maybe Integer -- ^ total number of response parts if already known -> IO () recvLoop responseQueue remainingSends' receivedPartNums totalParts = do -- 65535 is maximum length of UDP packets, as long as -- no IPv6 jumbograms are used response <- deserialiseMessage <$> recv sock 65535 case response of Right msg@Response{} -> do atomically $ writeTBQueue responseQueue msg let newTotalParts = if isFinalPart msg then Just (part msg) else totalParts newRemaining = Map.delete (part msg) remainingSends' newReceivedParts = Set.insert (part msg) receivedPartNums if Map.null newRemaining && maybe False (\p -> Set.size receivedPartNums == fromIntegral p) newTotalParts then pure () else recvLoop responseQueue newRemaining receivedPartNums newTotalParts -- drop errors and invalid messages Left _ -> recvLoop responseQueue remainingSends' receivedPartNums totalParts -- | enqueue a list of RemoteCacheEntries to be added to the global NodeCache queueAddEntries :: Foldable c => c RemoteCacheEntry -> LocalNodeState -> IO () queueAddEntries entries ns = do now <- getPOSIXTime forM_ entries $ \entry -> atomically $ writeTQueue (cacheWriteQueue ns) $ addCacheEntryPure now entry -- | retry an IO action at most *i* times until it delivers a result attempts :: Int -- ^ number of retries *i* -> IO (Maybe a) -- ^ action to retry -> IO (Maybe a) -- ^ result after at most *i* retries attempts 0 _ = pure Nothing attempts i action = do actionResult <- action case actionResult of Nothing -> attempts (i-1) action Just res -> pure $ Just res -- ====== network socket operations ====== -- | resolve a specified host and return the 'AddrInfo' for it. -- If no hostname or IP is specified, the 'AddrInfo' can be used to bind to all -- addresses; -- if no port is specified an arbitrary free port is selected. resolve :: Maybe String -- ^ hostname or IP address to be resolved -> Maybe PortNumber -- ^ port number of either local bind or remote -> IO AddrInfo resolve host port = let hints = defaultHints { addrFamily = AF_INET6, addrSocketType = Datagram , addrFlags = [AI_PASSIVE] } in head <$> getAddrInfo (Just hints) host (show <$> port) -- | create an unconnected UDP Datagram 'Socket' bound to the specified address mkServerSocket :: HostAddress6 -> PortNumber -> IO Socket mkServerSocket ip port = do sockAddr <- addrAddress <$> resolve (Just $ show . fromHostAddress6 $ ip) (Just port) sock <- socket AF_INET6 Datagram defaultProtocol setSocketOption sock IPv6Only 1 bind sock sockAddr pure sock -- | create a UDP datagram socket, connected to a destination. -- The socket gets an arbitrary free local port assigned. mkSendSocket :: String -- ^ destination hostname or IP -> PortNumber -- ^ destination port -> IO Socket -- ^ a socket with an arbitrary source port mkSendSocket dest destPort = do destAddr <- addrAddress <$> resolve (Just dest) (Just destPort) sendSock <- socket AF_INET6 Datagram defaultProtocol setSocketOption sendSock IPv6Only 1 connect sendSock destAddr pure sendSock