Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prevent pad/unpad failures on large/small messages #547

Merged
merged 6 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/Simplex/Messaging/Crypto.hs
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,8 @@ data CryptoError
| -- | message is larger that allowed padded length minus 2 (to prepend message length)
-- (or required un-padded length is larger than the message length)
CryptoLargeMsgError
| -- | padded message is shorter than 2 bytes
CryptoInvalidMsgError
| -- | failure parsing message header
CryptoHeaderError String
| -- | no sending chain key in ratchet state
Expand Down Expand Up @@ -802,18 +804,21 @@ decryptAEAD aesKey ivBytes ad msg (AuthTag authTag) = do
aead <- initAEAD @AES256 aesKey ivBytes
liftEither . unPad =<< maybeError AESDecryptError (AES.aeadSimpleDecrypt aead ad msg authTag)

maxMsgLen :: Int
maxMsgLen = 2 ^ (16 :: Int) - 3

pad :: ByteString -> Int -> Either CryptoError ByteString
pad msg paddedLen
| padLen >= 0 = Right $ encodeWord16 (fromIntegral len) <> msg <> B.replicate padLen '#'
| len <= maxMsgLen && padLen >= 0 = Right $ encodeWord16 (fromIntegral len) <> msg <> B.replicate padLen '#'
| otherwise = Left CryptoLargeMsgError
where
len = B.length msg
padLen = paddedLen - len - 2

unPad :: ByteString -> Either CryptoError ByteString
unPad padded
| B.length rest >= len = Right $ B.take len rest
| otherwise = Left CryptoLargeMsgError
| B.length lenWrd == 2 && B.length rest >= len = Right $ B.take len rest
| otherwise = Left CryptoInvalidMsgError
where
(lenWrd, rest) = B.splitAt 2 padded
len = fromIntegral $ decodeWord16 lenWrd
Expand Down
39 changes: 39 additions & 0 deletions tests/CoreTests/CryptoTests.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Redundant if" #-}
module CoreTests.CryptoTests (cryptoTests) where

import qualified Simplex.Messaging.Crypto as C
import Test.Hspec
import Test.Hspec.QuickCheck (modifyMaxSuccess)
import Test.QuickCheck
import qualified Data.Text as T
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Text.Encoding (encodeUtf8, decodeUtf8)

cryptoTests :: Spec
cryptoTests = modifyMaxSuccess (const 10000) $ do
describe "padding / unpadding" $ do
it "should pad / unpad string" . property $ \(s, paddedLen) ->
let b = encodeUtf8 $ T.pack s
len = B.length b
padded = C.pad b paddedLen
in if len < 2 ^ (16 :: Int) - 3 && len <= paddedLen - 2
then (fmap (T.unpack . decodeUtf8) . C.unPad =<< padded) == Right s
else padded == Left C.CryptoLargeMsgError
it "pad should fail on large string" $ do
C.pad "abc" 5 `shouldBe` Right "\000\003abc"
C.pad "abc" 4 `shouldBe` Left C.CryptoLargeMsgError
(C.unPad =<< C.pad (str 65533) 65535) `shouldBe` Right (str 65533)
C.pad (str 65534) 65536 `shouldBe` Left C.CryptoLargeMsgError
C.pad (str 65535) 65537 `shouldBe` Left C.CryptoLargeMsgError
it "unpad should fail on invalid string" $ do
C.unPad "\000\000" `shouldBe` Right ""
C.unPad "\000" `shouldBe` Left C.CryptoInvalidMsgError
C.unPad "" `shouldBe` Left C.CryptoInvalidMsgError
it "unpad should fail on shorter string" $ do
C.unPad "\000\003abc" `shouldBe` Right "abc"
C.unPad "\000\003ab" `shouldBe` Left C.CryptoInvalidMsgError
where
str n = LB.toStrict $ LB.take n $ LB.repeat 'a'
2 changes: 2 additions & 0 deletions tests/Test.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE TypeApplications #-}

import AgentTests (agentTests)
import CoreTests.CryptoTests
import CoreTests.EncodingTests
import CoreTests.ProtocolErrorTests
import CoreTests.VersionRangeTests
Expand All @@ -22,6 +23,7 @@ main = do
describe "Encoding tests" encodingTests
describe "Protocol error tests" protocolErrorTests
describe "Version range" versionRangeTests
describe "Encryption tests" cryptoTests
describe "SMP server via TLS" $ serverTests (transport @TLS)
describe "SMP server via WebSockets" $ serverTests (transport @WS)
describe "Notifications server" $ ntfServerTests (transport @TLS)
Expand Down