refactor request sender ID spoof check to suit k-choices
- mostly refactored the checks into its own function - now additionally check the vserver number limit - refactoring to pass that limit to the checking function invocations - closes #74
This commit is contained in:
parent
b2b4fe3dd8
commit
bb0fb0919a
|
@ -66,6 +66,7 @@ import Data.Maybe (fromJust, fromMaybe, isJust,
|
|||
isNothing, mapMaybe, maybe)
|
||||
import qualified Data.Set as Set
|
||||
import Data.Time.Clock.POSIX
|
||||
import Data.Word (Word8)
|
||||
import Network.Socket hiding (recv, recvFrom, send,
|
||||
sendTo)
|
||||
import Network.Socket.ByteString
|
||||
|
@ -93,6 +94,7 @@ import Hash2Pub.FediChordTypes (CacheEntry (..),
|
|||
getKeyID, localCompare,
|
||||
rMapFromList, rMapLookupPred,
|
||||
rMapLookupSucc,
|
||||
hasValidNodeId,
|
||||
remainingLoadTarget,
|
||||
setPredecessors, setSuccessors)
|
||||
import Hash2Pub.ProtocolTypes
|
||||
|
@ -267,12 +269,13 @@ extractFirstPayload msgSet = foldr' (\msg plAcc ->
|
|||
-- | Dispatch incoming requests to the dedicated handling and response function, and enqueue
|
||||
-- the response to be sent.
|
||||
handleIncomingRequest :: Service s (RealNodeSTM s)
|
||||
=> LocalNodeStateSTM s -- ^ the handling node
|
||||
=> Word8 -- ^ maximum number of vservers, because of decision to @dropSpoofedIDs@ in here and not already in @fediMessageHandler@
|
||||
-> LocalNodeStateSTM s -- ^ the handling node
|
||||
-> TQueue (BS.ByteString, SockAddr) -- ^ send queue
|
||||
-> Set.Set FediChordMessage -- ^ all parts of the request to handle
|
||||
-> SockAddr -- ^ source address of the request
|
||||
-> IO ()
|
||||
handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do
|
||||
handleIncomingRequest vsLimit nsSTM sendQ msgSet sourceAddr = do
|
||||
ns <- readTVarIO nsSTM
|
||||
-- add nodestate to cache
|
||||
now <- getPOSIXTime
|
||||
|
@ -287,12 +290,12 @@ handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do
|
|||
)
|
||||
=<< (case action aPart of
|
||||
Ping -> Just <$> respondPing nsSTM msgSet
|
||||
Join -> dropSpoofedIDs sourceIP nsSTM msgSet respondJoin
|
||||
Join -> dropSpoofedIDs vsLimit sourceIP nsSTM msgSet respondJoin
|
||||
-- ToDo: figure out what happens if not joined
|
||||
QueryID -> Just <$> respondQueryID nsSTM msgSet
|
||||
-- only when joined
|
||||
Leave -> if vsIsJoined ns then dropSpoofedIDs sourceIP nsSTM msgSet respondLeave else pure Nothing
|
||||
Stabilise -> if vsIsJoined ns then dropSpoofedIDs sourceIP nsSTM msgSet respondStabilise else pure Nothing
|
||||
Leave -> if vsIsJoined ns then dropSpoofedIDs vsLimit sourceIP nsSTM msgSet respondLeave else pure Nothing
|
||||
Stabilise -> if vsIsJoined ns then dropSpoofedIDs vsLimit sourceIP nsSTM msgSet respondStabilise else pure Nothing
|
||||
QueryLoad -> if vsIsJoined ns then Just <$> respondQueryLoad nsSTM msgSet else pure Nothing
|
||||
)
|
||||
-- for single part request, response starts with part number 1. For multipart requests, response starts with part number n+1.
|
||||
|
@ -303,19 +306,18 @@ handleIncomingRequest nsSTM sendQ msgSet sourceAddr = do
|
|||
-- | Filter out requests with spoofed node IDs by recomputing the ID using
|
||||
-- the sender IP.
|
||||
-- For valid (non-spoofed) sender IDs, the passed responder function is invoked.
|
||||
dropSpoofedIDs :: HostAddress6 -- msg source address
|
||||
dropSpoofedIDs :: Word8 -- ^ maximum number of vservers per node
|
||||
-> HostAddress6 -- ^ msg source address
|
||||
-> LocalNodeStateSTM s
|
||||
-> Set.Set FediChordMessage -- message parts of the request
|
||||
-> (LocalNodeStateSTM s -> Set.Set FediChordMessage -> IO (Map.Map Integer BS.ByteString)) -- reponder function to be invoked for valid requests
|
||||
-> Set.Set FediChordMessage -- ^ message parts of the request
|
||||
-> (LocalNodeStateSTM s -> Set.Set FediChordMessage -> IO (Map.Map Integer BS.ByteString)) -- ^ reponder function to be invoked for valid requests
|
||||
-> IO (Maybe (Map.Map Integer BS.ByteString))
|
||||
dropSpoofedIDs addr nsSTM' msgSet' responder =
|
||||
dropSpoofedIDs limVs addr nsSTM' msgSet' responder =
|
||||
let
|
||||
aRequestPart = Set.elemAt 0 msgSet
|
||||
senderNs = sender aRequestPart
|
||||
givenSenderID = getNid senderNs
|
||||
recomputedID = genNodeID addr (getDomain senderNs) (getVServerID senderNs)
|
||||
in
|
||||
if recomputedID == givenSenderID
|
||||
if hasValidNodeId limVs senderNs addr
|
||||
then Just <$> responder nsSTM' msgSet'
|
||||
else pure Nothing
|
||||
|
||||
|
@ -779,8 +781,7 @@ requestPing ns target = do
|
|||
-- recompute ID for each received node and mark as verified in cache
|
||||
now <- getPOSIXTime
|
||||
forM_ responseVss (\vs ->
|
||||
let recomputedID = genNodeID peerAddr (getDomain vs) (getVServerID vs)
|
||||
in if recomputedID == getNid vs
|
||||
if hasValidNodeId (confKChoicesMaxVS nodeConf) vs peerAddr
|
||||
then atomically $ writeTQueue (cacheWriteQueue ns) $ addNodeAsVerifiedPure now vs
|
||||
else pure ()
|
||||
)
|
||||
|
|
|
@ -901,6 +901,12 @@ fediMessageHandler sendQ recvQ nodeSTM = do
|
|||
-- both of them fail
|
||||
concurrently_ (requestMapPurge (confResponsePurgeAge nodeConf) requestMap) $ forever $ do
|
||||
node <- readTVarIO nodeSTM
|
||||
-- Messages from invalid (spoofed) sender IDs could already be dropped here
|
||||
-- or in @dispatchVS@. But as the checking on each possible ID causes an
|
||||
-- overhead, it is only done for critical operations and the case
|
||||
-- differentiation is done in @handleIncomingRequest@. Thus the vserver
|
||||
-- number limit, required for this check, needs to be passed to that function.
|
||||
let handlerFunc = handleIncomingRequest $ confKChoicesMaxVS nodeConf
|
||||
-- wait for incoming messages
|
||||
(rawMsg, sourceAddr) <- atomically $ readTQueue recvQ
|
||||
let aMsg = deserialiseMessage rawMsg
|
||||
|
@ -915,7 +921,7 @@ fediMessageHandler sendQ recvQ nodeSTM = do
|
|||
Nothing -> pure ()
|
||||
-- if not a multipart message, handle immediately. Response is at the same time an ACK
|
||||
Just nsSTM | part aRequest == 1 && isFinalPart aRequest ->
|
||||
forkIO (handleIncomingRequest nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure ()
|
||||
forkIO (handlerFunc nsSTM sendQ (Set.singleton aRequest) sourceAddr) >> pure ()
|
||||
-- otherwise collect all message parts first before handling the whole request
|
||||
Just nsSTM | otherwise -> do
|
||||
now <- getPOSIXTime
|
||||
|
@ -942,7 +948,7 @@ fediMessageHandler sendQ recvQ nodeSTM = do
|
|||
(RequestMapEntry theseParts mayMaxParts _) = fromJust $ Map.lookup thisKey newMapState
|
||||
numParts = Set.size theseParts
|
||||
if maybe False (numParts ==) (fromIntegral <$> mayMaxParts)
|
||||
then forkIO (handleIncomingRequest nsSTM sendQ theseParts sourceAddr) >> pure()
|
||||
then forkIO (handlerFunc nsSTM sendQ theseParts sourceAddr) >> pure()
|
||||
else pure()
|
||||
-- Responses should never arrive on the main server port, as they are always
|
||||
-- responses to requests sent from dedicated sockets on another port
|
||||
|
|
|
@ -57,6 +57,7 @@ module Hash2Pub.FediChordTypes
|
|||
, localCompare
|
||||
, genNodeID
|
||||
, genNodeIDBS
|
||||
, hasValidNodeId
|
||||
, genKeyID
|
||||
, genKeyIDBS
|
||||
, byteStringToUInteger
|
||||
|
@ -392,8 +393,8 @@ genNodeID :: HostAddress6 -- ^a node's IPv6 address
|
|||
genNodeID ip nodeDomain vs = NodeID . byteStringToUInteger $ genNodeIDBS ip nodeDomain vs
|
||||
|
||||
|
||||
isValidIdForNode :: Word8 -> RemoteNodeState -> HostAddress6 -> Bool
|
||||
isValidIdForNode numVs rns addr = getNid rns `elem` [genNodeID addr (getDomain rns) v | v <- [0..(numVs-1)] ]
|
||||
hasValidNodeId :: Word8 -> RemoteNodeState -> HostAddress6 -> Bool
|
||||
hasValidNodeId numVs rns addr = getVServerID rns < numVs && getNid rns == genNodeID addr (getDomain rns) (getVServerID rns)
|
||||
|
||||
|
||||
-- | generates a 256 bit long key identifier, represented as ByteString, for looking up its data on the DHT
|
||||
|
|
Loading…
Reference in a new issue