diff --git a/integrationTests/chainSimulator/relayedTx/relayedTx_test.go b/integrationTests/chainSimulator/relayedTx/relayedTx_test.go index b7b74a66944..b7426c63f5f 100644 --- a/integrationTests/chainSimulator/relayedTx/relayedTx_test.go +++ b/integrationTests/chainSimulator/relayedTx/relayedTx_test.go @@ -48,6 +48,8 @@ var ( ) func TestRelayedV3WithChainSimulator(t *testing.T) { + t.Run("sender == relayer move balance should consume fee", testRelayedV3RelayedBySenderMoveBalance()) + t.Run("receiver == relayer move balance should consume fee", testRelayedV3RelayedByReceiverMoveBalance()) t.Run("successful intra shard move balance", testRelayedV3MoveBalance(0, 0, false, false)) t.Run("successful intra shard guarded move balance", testRelayedV3MoveBalance(0, 0, false, true)) t.Run("successful intra shard move balance with extra gas", testRelayedV3MoveBalance(0, 0, true, false)) @@ -59,10 +61,12 @@ func TestRelayedV3WithChainSimulator(t *testing.T) { t.Run("intra shard move balance, invalid gas", testRelayedV3MoveInvalidGasLimit(0, 0)) t.Run("cross shard move balance, invalid gas", testRelayedV3MoveInvalidGasLimit(0, 1)) - t.Run("successful intra shard sc call with refunds, existing sender", testRelayedV3ScCall(0, 0, true)) - t.Run("successful intra shard sc call with refunds, new sender", testRelayedV3ScCall(0, 0, false)) - t.Run("successful cross shard sc call with refunds, existing sender", testRelayedV3ScCall(0, 1, true)) - t.Run("successful cross shard sc call with refunds, new sender", testRelayedV3ScCall(0, 1, false)) + t.Run("successful intra shard sc call with refunds, existing sender", testRelayedV3ScCall(0, 0, true, false)) + t.Run("successful intra shard sc call with refunds, existing sender, relayed by sender", testRelayedV3ScCall(0, 0, true, true)) + t.Run("successful intra shard sc call with refunds, new sender", testRelayedV3ScCall(0, 0, false, false)) + t.Run("successful cross shard sc call with refunds, existing sender", testRelayedV3ScCall(0, 1, true, false)) + t.Run("successful cross shard sc call with refunds, existing sender, relayed by sender", testRelayedV3ScCall(0, 1, true, true)) + t.Run("successful cross shard sc call with refunds, new sender", testRelayedV3ScCall(0, 1, false, false)) t.Run("intra shard sc call, invalid gas", testRelayedV3ScCallInvalidGasLimit(0, 0)) t.Run("cross shard sc call, invalid gas", testRelayedV3ScCallInvalidGasLimit(0, 1)) t.Run("intra shard sc call, invalid method", testRelayedV3ScCallInvalidMethod(0, 0)) @@ -279,6 +283,7 @@ func testRelayedV3ScCall( relayerShard uint32, ownerShard uint32, existingSenderWithBalance bool, + relayedBySender bool, ) func(t *testing.T) { return func(t *testing.T) { if testing.Short() { @@ -297,8 +302,13 @@ func testRelayedV3ScCall( initialBalance := big.NewInt(0).Mul(oneEGLD, big.NewInt(10)) relayer, err := cs.GenerateAndMintWalletAddress(relayerShard, initialBalance) require.NoError(t, err) + relayerInitialBalance := initialBalance sender, senderInitialBalance := prepareSender(t, cs, existingSenderWithBalance, relayerShard, initialBalance) + if relayedBySender { + relayer = sender + relayerInitialBalance = senderInitialBalance + } owner, err := cs.GenerateAndMintWalletAddress(ownerShard, initialBalance) require.NoError(t, err) @@ -336,12 +346,14 @@ func testRelayedV3ScCall( // check relayer balance relayerBalanceAfter := getBalance(t, cs, relayer) - relayerFee := big.NewInt(0).Sub(initialBalance, relayerBalanceAfter) + relayerFee := big.NewInt(0).Sub(relayerInitialBalance, relayerBalanceAfter) require.Equal(t, fee.String(), relayerFee.String()) - // check sender balance - senderBalanceAfter := getBalance(t, cs, sender) - require.Equal(t, senderInitialBalance.String(), senderBalanceAfter.String()) + // check sender balance, only if the tx was not relayed by sender + if !relayedBySender { + senderBalanceAfter := getBalance(t, cs, sender) + require.Equal(t, senderInitialBalance.String(), senderBalanceAfter.String()) + } // check owner balance _, feeDeploy, _ := computeTxGasAndFeeBasedOnRefund(resultDeploy, refundDeploy, false, false) @@ -357,6 +369,119 @@ func testRelayedV3ScCall( } } +func testRelayedV3RelayedBySenderMoveBalance() func(t *testing.T) { + return func(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + providedActivationEpoch := uint32(1) + alterConfigsFunc := func(cfg *config.Configs) { + cfg.EpochConfig.EnableEpochs.FixRelayedBaseCostEnableEpoch = providedActivationEpoch + cfg.EpochConfig.EnableEpochs.RelayedTransactionsV3EnableEpoch = providedActivationEpoch + } + + cs := startChainSimulator(t, alterConfigsFunc) + defer cs.Close() + + initialBalance := big.NewInt(0).Mul(oneEGLD, big.NewInt(10)) + + sender, err := cs.GenerateAndMintWalletAddress(0, initialBalance) + require.NoError(t, err) + + // generate one block so the minting has effect + err = cs.GenerateBlocks(1) + require.NoError(t, err) + + senderNonce := uint64(0) + senderBalanceBefore := getBalance(t, cs, sender) + + gasLimit := minGasLimit * 2 + relayedTx := generateRelayedV3Transaction(sender.Bytes, senderNonce, sender.Bytes, sender.Bytes, big.NewInt(0), "", uint64(gasLimit)) + + result, err := cs.SendTxAndGenerateBlockTilTxIsExecuted(relayedTx, maxNumOfBlocksToGenerateWhenExecutingTx) + require.NoError(t, err) + + // check fee fields + initiallyPaidFee, fee, gasUsed := computeTxGasAndFeeBasedOnRefund(result, big.NewInt(0), true, false) + require.Equal(t, initiallyPaidFee.String(), result.InitiallyPaidFee) + require.Equal(t, fee.String(), result.Fee) + require.Equal(t, gasUsed, result.GasUsed) + + // check sender balance + expectedFee := core.SafeMul(uint64(gasLimit), uint64(minGasPrice)) + senderBalanceAfter := getBalance(t, cs, sender) + senderBalanceDiff := big.NewInt(0).Sub(senderBalanceBefore, senderBalanceAfter) + require.Equal(t, expectedFee.String(), senderBalanceDiff.String()) + + // check scrs, should be none + require.Zero(t, len(result.SmartContractResults)) + + // check intra shard logs, should be none + require.Nil(t, result.Logs) + } +} + +func testRelayedV3RelayedByReceiverMoveBalance() func(t *testing.T) { + return func(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + providedActivationEpoch := uint32(1) + alterConfigsFunc := func(cfg *config.Configs) { + cfg.EpochConfig.EnableEpochs.FixRelayedBaseCostEnableEpoch = providedActivationEpoch + cfg.EpochConfig.EnableEpochs.RelayedTransactionsV3EnableEpoch = providedActivationEpoch + } + + cs := startChainSimulator(t, alterConfigsFunc) + defer cs.Close() + + initialBalance := big.NewInt(0).Mul(oneEGLD, big.NewInt(10)) + + sender, err := cs.GenerateAndMintWalletAddress(0, initialBalance) + require.NoError(t, err) + + receiver, err := cs.GenerateAndMintWalletAddress(0, initialBalance) + require.NoError(t, err) + + // generate one block so the minting has effect + err = cs.GenerateBlocks(1) + require.NoError(t, err) + + senderNonce := uint64(0) + receiverBalanceBefore := getBalance(t, cs, receiver) + + gasLimit := minGasLimit * 2 + relayedTx := generateRelayedV3Transaction(sender.Bytes, senderNonce, receiver.Bytes, receiver.Bytes, big.NewInt(0), "", uint64(gasLimit)) + + result, err := cs.SendTxAndGenerateBlockTilTxIsExecuted(relayedTx, maxNumOfBlocksToGenerateWhenExecutingTx) + require.NoError(t, err) + + // check fee fields + initiallyPaidFee, fee, gasUsed := computeTxGasAndFeeBasedOnRefund(result, big.NewInt(0), true, false) + require.Equal(t, initiallyPaidFee.String(), result.InitiallyPaidFee) + require.Equal(t, fee.String(), result.Fee) + require.Equal(t, gasUsed, result.GasUsed) + + // check sender balance + senderBalanceAfter := getBalance(t, cs, sender) + require.Equal(t, senderBalanceAfter.String(), initialBalance.String()) + + // check receiver balance + expectedFee := core.SafeMul(uint64(gasLimit), uint64(minGasPrice)) + receiverBalanceAfter := getBalance(t, cs, receiver) + receiverBalanceDiff := big.NewInt(0).Sub(receiverBalanceBefore, receiverBalanceAfter) + require.Equal(t, receiverBalanceDiff.String(), expectedFee.String()) + + // check scrs, should be none + require.Zero(t, len(result.SmartContractResults)) + + // check intra shard logs, should be none + require.Nil(t, result.Logs) + } +} + func prepareSender( t *testing.T, cs testsChainSimulator.ChainSimulator, diff --git a/process/smartContract/processorV2/processV2.go b/process/smartContract/processorV2/processV2.go index 31513ecea3b..e0d88916a52 100644 --- a/process/smartContract/processorV2/processV2.go +++ b/process/smartContract/processorV2/processV2.go @@ -1943,14 +1943,16 @@ func (sc *scProcessor) processSCPayment(tx data.TransactionHandler, acntSnd stat } fee := sc.economicsFee.ComputeTxFee(tx) - err = feePayer.SubFromBalance(fee) - if err != nil { - return err - } + if !check.IfNil(feePayer) { + err = feePayer.SubFromBalance(fee) + if err != nil { + return err + } - err = sc.saveAccount(feePayer) - if err != nil { - return err + err = sc.saveAccount(feePayer) + if err != nil { + return err + } } err = acntSnd.SubFromBalance(tx.GetValue()) @@ -1971,6 +1973,11 @@ func (sc *scProcessor) getFeePayer(tx data.TransactionHandler, acntSnd state.Use return acntSnd, nil } + relayerIsSender := bytes.Equal(relayedTx.GetRelayerAddr(), tx.GetSndAddr()) + if relayerIsSender { + return acntSnd, nil // do not load the same account twice + } + account, err := sc.getAccountFromAddress(relayedTx.GetRelayerAddr()) if err != nil { return nil, err diff --git a/process/transaction/baseProcess.go b/process/transaction/baseProcess.go index 73af11f5063..0433f8a0f50 100644 --- a/process/transaction/baseProcess.go +++ b/process/transaction/baseProcess.go @@ -239,10 +239,21 @@ func (txProc *baseTxProcessor) checkUserTxOfRelayedV3Values( func (txProc *baseTxProcessor) getFeePayer( tx *transaction.Transaction, - acntSnd state.UserAccountHandler, + senderAccount state.UserAccountHandler, + destinationAccount state.UserAccountHandler, ) (state.UserAccountHandler, bool, error) { if !common.IsRelayedTxV3(tx) { - return acntSnd, false, nil + return senderAccount, false, nil + } + + relayerIsSender := bytes.Equal(tx.RelayerAddr, tx.SndAddr) + if relayerIsSender { + return senderAccount, true, nil // do not load the same account twice + } + + relayerIsDestination := bytes.Equal(tx.RelayerAddr, tx.RcvAddr) + if relayerIsDestination { + return destinationAccount, true, nil // do not load the same account twice } acntRelayer, err := txProc.getAccountFromAddress(tx.RelayerAddr) diff --git a/process/transaction/shardProcess.go b/process/transaction/shardProcess.go index 115e0771534..a9b4d4d68b8 100644 --- a/process/transaction/shardProcess.go +++ b/process/transaction/shardProcess.go @@ -200,14 +200,14 @@ func (txProc *txProcessor) ProcessTransaction(tx *transaction.Transaction) (vmco err = txProc.checkTxValues(tx, acntSnd, acntDst, false) if err != nil { if errors.Is(err, process.ErrInsufficientFunds) { - receiptErr := txProc.executingFailedTransaction(tx, acntSnd, err) + receiptErr := txProc.executingFailedTransaction(tx, acntSnd, acntDst, err) if receiptErr != nil { return 0, receiptErr } } if errors.Is(err, process.ErrUserNameDoesNotMatch) && txProc.enableEpochsHandler.IsFlagEnabled(common.RelayedTransactionsFlag) { - receiptErr := txProc.executingFailedTransaction(tx, acntSnd, err) + receiptErr := txProc.executingFailedTransaction(tx, acntSnd, acntDst, err) if receiptErr != nil { return vmcommon.UserError, receiptErr } @@ -249,7 +249,7 @@ func (txProc *txProcessor) ProcessTransaction(tx *transaction.Transaction) (vmco return txProc.processRelayedTxV2(tx, acntSnd, acntDst) } - return vmcommon.UserError, txProc.executingFailedTransaction(tx, acntSnd, process.ErrWrongTransaction) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, acntSnd, acntDst, process.ErrWrongTransaction) } func (txProc *txProcessor) executeAfterFailedMoveBalanceTransaction( @@ -296,13 +296,14 @@ func (txProc *txProcessor) executeAfterFailedMoveBalanceTransaction( func (txProc *txProcessor) executingFailedTransaction( tx *transaction.Transaction, acntSnd state.UserAccountHandler, + acntDst state.UserAccountHandler, txError error, ) error { if check.IfNil(acntSnd) { return nil } - feePayer, isRelayedV3, err := txProc.getFeePayer(tx, acntSnd) + feePayer, isRelayedV3, err := txProc.getFeePayer(tx, acntSnd, acntDst) if err != nil { return err } @@ -486,7 +487,7 @@ func (txProc *txProcessor) processMoveBalance( isUserTxOfRelayed bool, ) error { - feePayer, _, err := txProc.getFeePayer(tx, acntSrc) + feePayer, _, err := txProc.getFeePayer(tx, acntSrc, acntDst) if err != nil { return nil } @@ -735,18 +736,18 @@ func (txProc *txProcessor) processRelayedTxV2( relayerAcnt, acntDst state.UserAccountHandler, ) (vmcommon.ReturnCode, error) { if !txProc.enableEpochsHandler.IsFlagEnabled(common.RelayedTransactionsV2Flag) { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, process.ErrRelayedTxV2Disabled) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, process.ErrRelayedTxV2Disabled) } if tx.GetValue().Cmp(big.NewInt(0)) != 0 { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, process.ErrRelayedTxV2ZeroVal) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, process.ErrRelayedTxV2ZeroVal) } _, args, err := txProc.argsParser.ParseCallData(string(tx.GetData())) if err != nil { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, err) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, err) } if len(args) != 4 { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, process.ErrInvalidArguments) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, process.ErrInvalidArguments) } userTx := makeUserTxFromRelayedTxV2Args(args) @@ -767,31 +768,31 @@ func (txProc *txProcessor) processRelayedTx( return 0, err } if len(args) != 1 { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, process.ErrInvalidArguments) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, process.ErrInvalidArguments) } if !txProc.enableEpochsHandler.IsFlagEnabled(common.RelayedTransactionsFlag) { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, process.ErrRelayedTxDisabled) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, process.ErrRelayedTxDisabled) } userTx := &transaction.Transaction{} err = txProc.signMarshalizer.Unmarshal(userTx, args[0]) if err != nil { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, err) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, err) } if !bytes.Equal(userTx.SndAddr, tx.RcvAddr) { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, process.ErrRelayedTxBeneficiaryDoesNotMatchReceiver) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, process.ErrRelayedTxBeneficiaryDoesNotMatchReceiver) } if userTx.Value.Cmp(tx.Value) < 0 { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, process.ErrRelayedTxValueHigherThenUserTxValue) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, process.ErrRelayedTxValueHigherThenUserTxValue) } if userTx.GasPrice != tx.GasPrice { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, process.ErrRelayedGasPriceMissmatch) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, process.ErrRelayedGasPriceMissmatch) } remainingGasLimit := tx.GasLimit - txProc.economicsFee.ComputeGasLimit(tx) if userTx.GasLimit != remainingGasLimit { - return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, process.ErrRelayedTxGasLimitMissmatch) + return vmcommon.UserError, txProc.executingFailedTransaction(tx, relayerAcnt, acntDst, process.ErrRelayedTxGasLimitMissmatch) } return txProc.finishExecutionOfRelayedTx(relayerAcnt, acntDst, tx, userTx)