diff --git a/src/Hash2Pub/ASN1Coding.hs b/src/Hash2Pub/ASN1Coding.hs index c2a5cc4..65f5e21 100644 --- a/src/Hash2Pub/ASN1Coding.hs +++ b/src/Hash2Pub/ASN1Coding.hs @@ -206,7 +206,7 @@ encodeNodeState ns = [ , OctetString (ipAddrAsBS $ getIpAddr ns) , IntVal (toInteger . getDhtPort $ ns) , IntVal (toInteger . getServicePort $ ns) - , IntVal (getVServerID ns) + , IntVal (toInteger $ getVServerID ns) , End Sequence ] @@ -370,7 +370,7 @@ parseNodeState = onNextContainer Sequence $ do , domain = domain' , dhtPort = dhtPort' , servicePort = servicePort' - , vServerID = vServer' + , vServerID = fromInteger vServer' , ipAddr = ip' } diff --git a/src/Hash2Pub/DHTProtocol.hs b/src/Hash2Pub/DHTProtocol.hs index a51b117..50a2ec8 100644 --- a/src/Hash2Pub/DHTProtocol.hs +++ b/src/Hash2Pub/DHTProtocol.hs @@ -66,6 +66,7 @@ import Data.Maybe (fromJust, fromMaybe, isJust, isNothing, mapMaybe, maybe) import qualified Data.Set as Set import Data.Time.Clock.POSIX +import Data.Word (Word8) import Network.Socket hiding (recv, recvFrom, send, sendTo) import Network.Socket.ByteString @@ -93,6 +94,7 @@ import Hash2Pub.FediChordTypes (CacheEntry (..), getKeyID, localCompare, rMapFromList, rMapLookupPred, rMapLookupSucc, + hasValidNodeId, remainingLoadTarget, setPredecessors, setSuccessors) import Hash2Pub.ProtocolTypes @@ -267,12 +269,13 @@ extractFirstPayload msgSet = foldr' (\msg plAcc -> -- | Dispatch incoming requests to the dedicated handling and response function, and enqueue -- the response to be sent. handleIncomingRequest :: Service s (RealNodeSTM s) - => LocalNodeStateSTM s -- ^ the handling node + => Word8 -- ^ maximum number of vservers, because of decision to @dropSpoofedIDs@ in here and not already in @fediMessageHandler@ + -> LocalNodeStateSTM s -- ^ the handling node -> TQueue (BS.ByteString, SockAddr) -- ^ send queue -> Set.Set FediChordMessage -- ^ all parts of the request to handle -> SockAddr -- ^ source address of the request -> IO () -handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do +handleIncomingRequest vsLimit nsSTM sendQ msgSet sourceAddr = do ns <- readTVarIO nsSTM -- add nodestate to cache now <- getPOSIXTime @@ -287,12 +290,12 @@ handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do ) =<< (case action aPart of Ping -> Just <$> respondPing nsSTM msgSet - Join -> dropSpoofedIDs sourceIP nsSTM msgSet respondJoin + Join -> dropSpoofedIDs vsLimit sourceIP nsSTM msgSet respondJoin -- ToDo: figure out what happens if not joined QueryID -> Just <$> respondQueryID nsSTM msgSet -- only when joined - Leave -> if vsIsJoined ns then dropSpoofedIDs sourceIP nsSTM msgSet respondLeave else pure Nothing - Stabilise -> if vsIsJoined ns then dropSpoofedIDs sourceIP nsSTM msgSet respondStabilise else pure Nothing + Leave -> if vsIsJoined ns then dropSpoofedIDs vsLimit sourceIP nsSTM msgSet respondLeave else pure Nothing + Stabilise -> if vsIsJoined ns then dropSpoofedIDs vsLimit sourceIP nsSTM msgSet respondStabilise else pure Nothing QueryLoad -> if vsIsJoined ns then Just <$> respondQueryLoad nsSTM msgSet else pure Nothing ) -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1. @@ -303,19 +306,18 @@ handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do -- | Filter out requests with spoofed node IDs by recomputing the ID using -- the sender IP. -- For valid (non-spoofed) sender IDs, the passed responder function is invoked. - dropSpoofedIDs :: HostAddress6 -- msg source address + dropSpoofedIDs :: Word8 -- ^ maximum number of vservers per node + -> HostAddress6 -- ^ msg source address -> LocalNodeStateSTM s - -> Set.Set FediChordMessage -- message parts of the request - -> (LocalNodeStateSTM s -> Set.Set FediChordMessage -> IO (Map.Map Integer BS.ByteString)) -- reponder function to be invoked for valid requests + -> Set.Set FediChordMessage -- ^ message parts of the request + -> (LocalNodeStateSTM s -> Set.Set FediChordMessage -> IO (Map.Map Integer BS.ByteString)) -- ^ reponder function to be invoked for valid requests -> IO (Maybe (Map.Map Integer BS.ByteString)) - dropSpoofedIDs addr nsSTM' msgSet' responder = + dropSpoofedIDs limVs addr nsSTM' msgSet' responder = let aRequestPart = Set.elemAt 0 msgSet senderNs = sender aRequestPart - givenSenderID = getNid senderNs - recomputedID = genNodeID addr (getDomain senderNs) (fromInteger $ getVServerID senderNs) in - if recomputedID == givenSenderID + if hasValidNodeId limVs senderNs addr then Just <$> responder nsSTM' msgSet' else pure Nothing @@ -779,10 +781,9 @@ requestPing ns target = do -- recompute ID for each received node and mark as verified in cache now <- getPOSIXTime forM_ responseVss (\vs -> - let recomputedID = genNodeID peerAddr (getDomain vs) (fromInteger $ getVServerID vs) - in if recomputedID == getNid vs - then atomically $ writeTQueue (cacheWriteQueue ns) $ addNodeAsVerifiedPure now vs - else pure () + if hasValidNodeId (confKChoicesMaxVS nodeConf) vs peerAddr + then atomically $ writeTQueue (cacheWriteQueue ns) $ addNodeAsVerifiedPure now vs + else pure () ) pure $ if null responseVss then Left "no active vServer IDs returned, ignoring node" diff --git a/src/Hash2Pub/FediChord.hs b/src/Hash2Pub/FediChord.hs index d4a94a6..488c92d 100644 --- a/src/Hash2Pub/FediChord.hs +++ b/src/Hash2Pub/FediChord.hs @@ -151,7 +151,7 @@ fediChordInit initConf serviceRunner = do -- this function. fediChordJoinNewVs :: (MonadError String m, MonadIO m, Service s (RealNodeSTM s)) => RealNodeSTM s -- ^ parent real node - -> Integer -- ^ vserver ID + -> Word8 -- ^ vserver ID -> RemoteNodeState -- ^ target node to join on -> m (NodeID, LocalNodeStateSTM s) -- ^ on success: (vserver ID, TVar of vserver) fediChordJoinNewVs nodeSTM vsId target = do @@ -164,7 +164,7 @@ fediChordJoinNewVs nodeSTM vsId target = do -- | initialises the 'NodeState' for this local node. -- Separated from 'fediChordInit' to be usable in tests. -nodeStateInit :: Service s (RealNodeSTM s) => RealNodeSTM s -> Integer -> IO (LocalNodeState s) +nodeStateInit :: Service s (RealNodeSTM s) => RealNodeSTM s -> Word8 -> IO (LocalNodeState s) nodeStateInit realNodeSTM vsID' = do realNode <- readTVarIO realNodeSTM let @@ -173,7 +173,7 @@ nodeStateInit realNodeSTM vsID' = do containedState = RemoteNodeState { domain = confDomain conf , ipAddr = confIP conf - , nid = genNodeID (confIP conf) (confDomain conf) $ fromInteger vsID + , nid = genNodeID (confIP conf) (confDomain conf) vsID , dhtPort = toEnum $ confDhtPort conf , servicePort = getListeningPortFromService $ nodeService realNode , vServerID = vsID @@ -257,7 +257,7 @@ kChoicesVsJoin queryVsSTM bootstrapNode capacity activeVss nodeSTM remainingTarg activeVsSet = HMap.keysSet activeVss -- tuples of node IDs and vserver IDs, because vserver IDs are needed for -- LocalNodeState creation - nonJoinedIDs = filter (not . flip HSet.member activeVsSet . fst) [ (genNodeID (confIP conf) (confDomain conf) (fromInteger v), v) | v <- [0..pred (confKChoicesMaxVS conf)]] + nonJoinedIDs = filter (not . flip HSet.member activeVsSet . fst) [ (genNodeID (confIP conf) (confDomain conf) v, v) | v <- [0..pred (confKChoicesMaxVS conf)]] queryVs <- liftIO $ readTVarIO queryVsSTM -- query load of all possible segments @@ -411,7 +411,7 @@ bootstrapQueryId nsSTM (bootstrapHost, bootstrapPort) targetID = do SockAddrInet6 _ _ bootstrapIP _ -> pure bootstrapIP _ -> throwError $ "Expected an IPv6 address, but got " <> show bootstrapAddr let possibleJoinIDs = - [ genNodeID bootstrapIP bootstrapHost (fromInteger v) | v <- [0..pred ( + [ genNodeID bootstrapIP bootstrapHost v | v <- [0..pred ( if confEnableKChoices nodeConf then confKChoicesMaxVS nodeConf else 1)]] tryQuery ns srcAddr nodeConf possibleJoinIDs where @@ -901,6 +901,12 @@ fediMessageHandler sendQ recvQ nodeSTM = do -- both of them fail concurrently_ (requestMapPurge (confResponsePurgeAge nodeConf) requestMap) $ forever $ do node <- readTVarIO nodeSTM + -- Messages from invalid (spoofed) sender IDs could already be dropped here + -- or in @dispatchVS@. But as the checking on each possible ID causes an + -- overhead, it is only done for critical operations and the case + -- differentiation is done in @handleIncomingRequest@. Thus the vserver + -- number limit, required for this check, needs to be passed to that function. + let handlerFunc = handleIncomingRequest $ confKChoicesMaxVS nodeConf -- wait for incoming messages (rawMsg, sourceAddr) <- atomically $ readTQueue recvQ let aMsg = deserialiseMessage rawMsg @@ -915,7 +921,7 @@ fediMessageHandler sendQ recvQ nodeSTM = do Nothing -> pure () -- if not a multipart message, handle immediately. Response is at the same time an ACK Just nsSTM | part aRequest == 1 && isFinalPart aRequest -> - forkIO (handleIncomingRequest nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure () + forkIO (handlerFunc nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure () -- otherwise collect all message parts first before handling the whole request Just nsSTM | otherwise -> do now <- getPOSIXTime @@ -942,7 +948,7 @@ fediMessageHandler sendQ recvQ nodeSTM = do (RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState numParts = Set.size theseParts if maybe False (numParts ==) (fromIntegral <$> mayMaxParts) - then forkIO (handleIncomingRequest nsSTM sendQ theseParts sourceAddr) >> pure() + then forkIO (handlerFunc nsSTM sendQ theseParts sourceAddr) >> pure() else pure() -- Responses should never arrive on the main server port, as they are always -- responses to requests sent from dedicated sockets on another port diff --git a/src/Hash2Pub/FediChordTypes.hs b/src/Hash2Pub/FediChordTypes.hs index a20c156..2b9dbad 100644 --- a/src/Hash2Pub/FediChordTypes.hs +++ b/src/Hash2Pub/FediChordTypes.hs @@ -57,6 +57,7 @@ module Hash2Pub.FediChordTypes , localCompare , genNodeID , genNodeIDBS + , hasValidNodeId , genKeyID , genKeyIDBS , byteStringToUInteger @@ -190,7 +191,7 @@ data RemoteNodeState = RemoteNodeState -- ^ port of the DHT itself , servicePort :: PortNumber -- ^ port of the service provided on top of the DHT - , vServerID :: Integer + , vServerID :: Word8 -- ^ ID of this vserver } deriving (Show, Eq) @@ -235,14 +236,14 @@ class NodeState a where getIpAddr :: a -> HostAddress6 getDhtPort :: a -> PortNumber getServicePort :: a -> PortNumber - getVServerID :: a -> Integer + getVServerID :: a -> Word8 -- setters for common properties setNid :: NodeID -> a -> a setDomain :: String -> a -> a setIpAddr :: HostAddress6 -> a -> a setDhtPort :: PortNumber -> a -> a setServicePort :: PortNumber -> a -> a - setVServerID :: Integer -> a -> a + setVServerID :: Word8 -> a -> a toRemoteNodeState :: a -> RemoteNodeState instance NodeState RemoteNodeState where @@ -391,6 +392,11 @@ genNodeID :: HostAddress6 -- ^a node's IPv6 address -> NodeID -- ^the generated @NodeID@ genNodeID ip nodeDomain vs = NodeID . byteStringToUInteger $ genNodeIDBS ip nodeDomain vs + +hasValidNodeId :: Word8 -> RemoteNodeState -> HostAddress6 -> Bool +hasValidNodeId numVs rns addr = getVServerID rns < numVs && getNid rns == genNodeID addr (getDomain rns) (getVServerID rns) + + -- | generates a 256 bit long key identifier, represented as ByteString, for looking up its data on the DHT genKeyIDBS :: String -- ^the key string -> BS.ByteString -- ^the key ID represented as a @ByteString@ @@ -451,7 +457,7 @@ data FediChordConf = FediChordConf -- ^ fraction of capacity above which a node considers itself overloaded , confKChoicesUnderload :: Double -- ^ fraction of capacity below which a node considers itself underloaded - , confKChoicesMaxVS :: Integer + , confKChoicesMaxVS :: Word8 -- ^ upper limit of vserver index κ } deriving (Show, Eq)