refactor request sender ID spoof check to suit k-choices

- mostly refactored the checks into its own function
- now additionally check the vserver number limit
- refactoring to pass that limit to the checking function invocations
- closes #74
This commit is contained in:
Trolli Schmittlauch 2020-09-29 02:59:42 +02:00
parent b2b4fe3dd8
commit bb0fb0919a
3 changed files with 28 additions and 20 deletions

View file

@ -66,6 +66,7 @@ import Data.Maybe (fromJust, fromMaybe, isJust,
isNothing, mapMaybe, maybe) isNothing, mapMaybe, maybe)
import qualified Data.Set as Set import qualified Data.Set as Set
import Data.Time.Clock.POSIX import Data.Time.Clock.POSIX
import Data.Word (Word8)
import Network.Socket hiding (recv, recvFrom, send, import Network.Socket hiding (recv, recvFrom, send,
sendTo) sendTo)
import Network.Socket.ByteString import Network.Socket.ByteString
@ -93,6 +94,7 @@ import Hash2Pub.FediChordTypes (CacheEntry (..),
getKeyID, localCompare, getKeyID, localCompare,
rMapFromList, rMapLookupPred, rMapFromList, rMapLookupPred,
rMapLookupSucc, rMapLookupSucc,
hasValidNodeId,
remainingLoadTarget, remainingLoadTarget,
setPredecessors, setSuccessors) setPredecessors, setSuccessors)
import Hash2Pub.ProtocolTypes import Hash2Pub.ProtocolTypes
@ -267,12 +269,13 @@ extractFirstPayload msgSet = foldr' (\msg plAcc ->
-- | Dispatch incoming requests to the dedicated handling and response function, and enqueue -- | Dispatch incoming requests to the dedicated handling and response function, and enqueue
-- the response to be sent. -- the response to be sent.
handleIncomingRequest :: Service s (RealNodeSTM s) handleIncomingRequest :: Service s (RealNodeSTM s)
=> LocalNodeStateSTM s -- ^ the handling node => Word8 -- ^ maximum number of vservers, because of decision to @dropSpoofedIDs@ in here and not already in @fediMessageHandler@
-> LocalNodeStateSTM s -- ^ the handling node
-> TQueue (BS.ByteString, SockAddr) -- ^ send queue -> TQueue (BS.ByteString, SockAddr) -- ^ send queue
-> Set.Set FediChordMessage -- ^ all parts of the request to handle -> Set.Set FediChordMessage -- ^ all parts of the request to handle
-> SockAddr -- ^ source address of the request -> SockAddr -- ^ source address of the request
-> IO () -> IO ()
handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do handleIncomingRequest vsLimit nsSTM sendQ msgSet sourceAddr = do
ns <- readTVarIO nsSTM ns <- readTVarIO nsSTM
-- add nodestate to cache -- add nodestate to cache
now <- getPOSIXTime now <- getPOSIXTime
@ -287,12 +290,12 @@ handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do
) )
=<< (case action aPart of =<< (case action aPart of
Ping -> Just <$> respondPing nsSTM msgSet Ping -> Just <$> respondPing nsSTM msgSet
Join -> dropSpoofedIDs sourceIP nsSTM msgSet respondJoin Join -> dropSpoofedIDs vsLimit sourceIP nsSTM msgSet respondJoin
-- ToDo: figure out what happens if not joined -- ToDo: figure out what happens if not joined
QueryID -> Just <$> respondQueryID nsSTM msgSet QueryID -> Just <$> respondQueryID nsSTM msgSet
-- only when joined -- only when joined
Leave -> if vsIsJoined ns then dropSpoofedIDs sourceIP nsSTM msgSet respondLeave else pure Nothing Leave -> if vsIsJoined ns then dropSpoofedIDs vsLimit sourceIP nsSTM msgSet respondLeave else pure Nothing
Stabilise -> if vsIsJoined ns then dropSpoofedIDs sourceIP nsSTM msgSet respondStabilise else pure Nothing Stabilise -> if vsIsJoined ns then dropSpoofedIDs vsLimit sourceIP nsSTM msgSet respondStabilise else pure Nothing
QueryLoad -> if vsIsJoined ns then Just <$> respondQueryLoad nsSTM msgSet else pure Nothing QueryLoad -> if vsIsJoined ns then Just <$> respondQueryLoad nsSTM msgSet else pure Nothing
) )
-- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1. -- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1.
@ -303,19 +306,18 @@ handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do
-- | Filter out requests with spoofed node IDs by recomputing the ID using -- | Filter out requests with spoofed node IDs by recomputing the ID using
-- the sender IP. -- the sender IP.
-- For valid (non-spoofed) sender IDs, the passed responder function is invoked. -- For valid (non-spoofed) sender IDs, the passed responder function is invoked.
dropSpoofedIDs :: HostAddress6 -- msg source address dropSpoofedIDs :: Word8 -- ^ maximum number of vservers per node
-> HostAddress6 -- ^ msg source address
-> LocalNodeStateSTM s -> LocalNodeStateSTM s
-> Set.Set FediChordMessage -- message parts of the request -> Set.Set FediChordMessage -- ^ message parts of the request
-> (LocalNodeStateSTM s -> Set.Set FediChordMessage -> IO (Map.Map Integer BS.ByteString)) -- reponder function to be invoked for valid requests -> (LocalNodeStateSTM s -> Set.Set FediChordMessage -> IO (Map.Map Integer BS.ByteString)) -- ^ reponder function to be invoked for valid requests
-> IO (Maybe (Map.Map Integer BS.ByteString)) -> IO (Maybe (Map.Map Integer BS.ByteString))
dropSpoofedIDs addr nsSTM' msgSet' responder = dropSpoofedIDs limVs addr nsSTM' msgSet' responder =
let let
aRequestPart = Set.elemAt 0 msgSet aRequestPart = Set.elemAt 0 msgSet
senderNs = sender aRequestPart senderNs = sender aRequestPart
givenSenderID = getNid senderNs
recomputedID = genNodeID addr (getDomain senderNs) (getVServerID senderNs)
in in
if recomputedID == givenSenderID if hasValidNodeId limVs senderNs addr
then Just <$> responder nsSTM' msgSet' then Just <$> responder nsSTM' msgSet'
else pure Nothing else pure Nothing
@ -779,8 +781,7 @@ requestPing ns target = do
-- recompute ID for each received node and mark as verified in cache -- recompute ID for each received node and mark as verified in cache
now <- getPOSIXTime now <- getPOSIXTime
forM_ responseVss (\vs -> forM_ responseVss (\vs ->
let recomputedID = genNodeID peerAddr (getDomain vs) (getVServerID vs) if hasValidNodeId (confKChoicesMaxVS nodeConf) vs peerAddr
in if recomputedID == getNid vs
then atomically $ writeTQueue (cacheWriteQueue ns) $ addNodeAsVerifiedPure now vs then atomically $ writeTQueue (cacheWriteQueue ns) $ addNodeAsVerifiedPure now vs
else pure () else pure ()
) )

View file

@ -901,6 +901,12 @@ fediMessageHandler sendQ recvQ nodeSTM = do
-- both of them fail -- both of them fail
concurrently_ (requestMapPurge (confResponsePurgeAge nodeConf) requestMap) $ forever $ do concurrently_ (requestMapPurge (confResponsePurgeAge nodeConf) requestMap) $ forever $ do
node <- readTVarIO nodeSTM node <- readTVarIO nodeSTM
-- Messages from invalid (spoofed) sender IDs could already be dropped here
-- or in @dispatchVS@. But as the checking on each possible ID causes an
-- overhead, it is only done for critical operations and the case
-- differentiation is done in @handleIncomingRequest@. Thus the vserver
-- number limit, required for this check, needs to be passed to that function.
let handlerFunc = handleIncomingRequest $ confKChoicesMaxVS nodeConf
-- wait for incoming messages -- wait for incoming messages
(rawMsg, sourceAddr) <- atomically $ readTQueue recvQ (rawMsg, sourceAddr) <- atomically $ readTQueue recvQ
let aMsg = deserialiseMessage rawMsg let aMsg = deserialiseMessage rawMsg
@ -915,7 +921,7 @@ fediMessageHandler sendQ recvQ nodeSTM = do
Nothing -> pure () Nothing -> pure ()
-- if not a multipart message, handle immediately. Response is at the same time an ACK -- if not a multipart message, handle immediately. Response is at the same time an ACK
Just nsSTM | part aRequest == 1 && isFinalPart aRequest -> Just nsSTM | part aRequest == 1 && isFinalPart aRequest ->
forkIO (handleIncomingRequest nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure () forkIO (handlerFunc nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure ()
-- otherwise collect all message parts first before handling the whole request -- otherwise collect all message parts first before handling the whole request
Just nsSTM | otherwise -> do Just nsSTM | otherwise -> do
now <- getPOSIXTime now <- getPOSIXTime
@ -942,7 +948,7 @@ fediMessageHandler sendQ recvQ nodeSTM = do
(RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState (RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState
numParts = Set.size theseParts numParts = Set.size theseParts
if maybe False (numParts ==) (fromIntegral <$> mayMaxParts) if maybe False (numParts ==) (fromIntegral <$> mayMaxParts)
then forkIO (handleIncomingRequest nsSTM sendQ theseParts sourceAddr) >> pure() then forkIO (handlerFunc nsSTM sendQ theseParts sourceAddr) >> pure()
else pure() else pure()
-- Responses should never arrive on the main server port, as they are always -- Responses should never arrive on the main server port, as they are always
-- responses to requests sent from dedicated sockets on another port -- responses to requests sent from dedicated sockets on another port

View file

@ -57,6 +57,7 @@ module Hash2Pub.FediChordTypes
, localCompare , localCompare
, genNodeID , genNodeID
, genNodeIDBS , genNodeIDBS
, hasValidNodeId
, genKeyID , genKeyID
, genKeyIDBS , genKeyIDBS
, byteStringToUInteger , byteStringToUInteger
@ -392,8 +393,8 @@ genNodeID :: HostAddress6 -- ^a node's IPv6 address
genNodeID ip nodeDomain vs = NodeID . byteStringToUInteger $ genNodeIDBS ip nodeDomain vs genNodeID ip nodeDomain vs = NodeID . byteStringToUInteger $ genNodeIDBS ip nodeDomain vs
isValidIdForNode :: Word8 -> RemoteNodeState -> HostAddress6 -> Bool hasValidNodeId :: Word8 -> RemoteNodeState -> HostAddress6 -> Bool
isValidIdForNode numVs rns addr = getNid rns `elem` [genNodeID addr (getDomain rns) v | v <- [0..(numVs-1)] ] hasValidNodeId numVs rns addr = getVServerID rns < numVs && getNid rns == genNodeID addr (getDomain rns) (getVServerID rns)
-- | generates a 256 bit long key identifier, represented as ByteString, for looking up its data on the DHT -- | generates a 256 bit long key identifier, represented as ByteString, for looking up its data on the DHT