diff --git a/src/Simplex/Messaging/Crypto.hs b/src/Simplex/Messaging/Crypto.hs index 8f2924dd0..c5d727aad 100644 --- a/src/Simplex/Messaging/Crypto.hs +++ b/src/Simplex/Messaging/Crypto.hs @@ -691,6 +691,8 @@ data CryptoError | -- | message is larger that allowed padded length minus 2 (to prepend message length) -- (or required un-padded length is larger than the message length) CryptoLargeMsgError + | -- | padded message is shorter than 2 bytes + CryptoInvalidMsgError | -- | failure parsing message header CryptoHeaderError String | -- | no sending chain key in ratchet state @@ -802,9 +804,12 @@ decryptAEAD aesKey ivBytes ad msg (AuthTag authTag) = do aead <- initAEAD @AES256 aesKey ivBytes liftEither . unPad =<< maybeError AESDecryptError (AES.aeadSimpleDecrypt aead ad msg authTag) +maxMsgLen :: Int +maxMsgLen = 2 ^ (16 :: Int) - 3 + pad :: ByteString -> Int -> Either CryptoError ByteString pad msg paddedLen - | padLen >= 0 = Right $ encodeWord16 (fromIntegral len) <> msg <> B.replicate padLen '#' + | len <= maxMsgLen && padLen >= 0 = Right $ encodeWord16 (fromIntegral len) <> msg <> B.replicate padLen '#' | otherwise = Left CryptoLargeMsgError where len = B.length msg @@ -812,8 +817,8 @@ pad msg paddedLen unPad :: ByteString -> Either CryptoError ByteString unPad padded - | B.length rest >= len = Right $ B.take len rest - | otherwise = Left CryptoLargeMsgError + | B.length lenWrd == 2 && B.length rest >= len = Right $ B.take len rest + | otherwise = Left CryptoInvalidMsgError where (lenWrd, rest) = B.splitAt 2 padded len = fromIntegral $ decodeWord16 lenWrd diff --git a/tests/CoreTests/CryptoTests.hs b/tests/CoreTests/CryptoTests.hs new file mode 100644 index 000000000..7913ae33b --- /dev/null +++ b/tests/CoreTests/CryptoTests.hs @@ -0,0 +1,39 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} +{-# HLINT ignore "Redundant if" #-} +module CoreTests.CryptoTests (cryptoTests) where + +import qualified Simplex.Messaging.Crypto as C +import Test.Hspec +import Test.Hspec.QuickCheck (modifyMaxSuccess) +import Test.QuickCheck +import qualified Data.Text as T +import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy.Char8 as LB +import Data.Text.Encoding (encodeUtf8, decodeUtf8) + +cryptoTests :: Spec +cryptoTests = modifyMaxSuccess (const 10000) $ do + describe "padding / unpadding" $ do + it "should pad / unpad string" . property $ \(s, paddedLen) -> + let b = encodeUtf8 $ T.pack s + len = B.length b + padded = C.pad b paddedLen + in if len < 2 ^ (16 :: Int) - 3 && len <= paddedLen - 2 + then (fmap (T.unpack . decodeUtf8) . C.unPad =<< padded) == Right s + else padded == Left C.CryptoLargeMsgError + it "pad should fail on large string" $ do + C.pad "abc" 5 `shouldBe` Right "\000\003abc" + C.pad "abc" 4 `shouldBe` Left C.CryptoLargeMsgError + (C.unPad =<< C.pad (str 65533) 65535) `shouldBe` Right (str 65533) + C.pad (str 65534) 65536 `shouldBe` Left C.CryptoLargeMsgError + C.pad (str 65535) 65537 `shouldBe` Left C.CryptoLargeMsgError + it "unpad should fail on invalid string" $ do + C.unPad "\000\000" `shouldBe` Right "" + C.unPad "\000" `shouldBe` Left C.CryptoInvalidMsgError + C.unPad "" `shouldBe` Left C.CryptoInvalidMsgError + it "unpad should fail on shorter string" $ do + C.unPad "\000\003abc" `shouldBe` Right "abc" + C.unPad "\000\003ab" `shouldBe` Left C.CryptoInvalidMsgError + where + str n = LB.toStrict $ LB.take n $ LB.repeat 'a' diff --git a/tests/Test.hs b/tests/Test.hs index d67277dd3..61e09a2b7 100644 --- a/tests/Test.hs +++ b/tests/Test.hs @@ -1,6 +1,7 @@ {-# LANGUAGE TypeApplications #-} import AgentTests (agentTests) +import CoreTests.CryptoTests import CoreTests.EncodingTests import CoreTests.ProtocolErrorTests import CoreTests.VersionRangeTests @@ -22,6 +23,7 @@ main = do describe "Encoding tests" encodingTests describe "Protocol error tests" protocolErrorTests describe "Version range" versionRangeTests + describe "Encryption tests" cryptoTests describe "SMP server via TLS" $ serverTests (transport @TLS) describe "SMP server via WebSockets" $ serverTests (transport @WS) describe "Notifications server" $ ntfServerTests (transport @TLS)