diff --git a/packages/transaction-controller/jest.config.js b/packages/transaction-controller/jest.config.js index c313d7948a4..0cc1ff36e40 100644 --- a/packages/transaction-controller/jest.config.js +++ b/packages/transaction-controller/jest.config.js @@ -18,7 +18,7 @@ module.exports = merge(baseConfig, { coverageThreshold: { global: { branches: 91.76, - functions: 93.52, + functions: 93.44, lines: 96.83, statements: 96.82, }, diff --git a/packages/transaction-controller/src/TransactionController.test.ts b/packages/transaction-controller/src/TransactionController.test.ts index c317b4e326e..5644f02c44d 100644 --- a/packages/transaction-controller/src/TransactionController.test.ts +++ b/packages/transaction-controller/src/TransactionController.test.ts @@ -46,6 +46,7 @@ import { MethodDataHelper } from './helpers/MethodDataHelper'; import { MultichainTrackingHelper } from './helpers/MultichainTrackingHelper'; import { PendingTransactionTracker } from './helpers/PendingTransactionTracker'; import { shouldResimulate } from './helpers/ResimulateHelper'; +import { ExtraTransactionsPublishHook } from './hooks/ExtraTransactionsPublishHook'; import type { AllowedActions, AllowedEvents, @@ -66,6 +67,7 @@ import type { GasFeeFlowResponse, SubmitHistoryEntry, InternalAccount, + PublishHook, } from './types'; import { GasFeeEstimateType, @@ -103,6 +105,8 @@ type UnrestrictedMessenger = Messenger< const MOCK_V1_UUID = '9b1deb4d-3b7d-4bad-9bdd-2b0d7b3dcb6d'; const TRANSACTION_HASH_MOCK = '0x123456'; +const DATA_MOCK = '0x12345678'; +const VALUE_MOCK = '0xabcd'; jest.mock('@metamask/eth-query'); jest.mock('./api/accounts-api'); @@ -114,6 +118,7 @@ jest.mock('./helpers/IncomingTransactionHelper'); jest.mock('./helpers/MethodDataHelper'); jest.mock('./helpers/MultichainTrackingHelper'); jest.mock('./helpers/PendingTransactionTracker'); +jest.mock('./hooks/ExtraTransactionsPublishHook'); jest.mock('./utils/batch'); jest.mock('./utils/gas'); jest.mock('./utils/gas-fees'); @@ -1578,6 +1583,7 @@ describe('TransactionController', () => { const expectedInitialSnapshot = { actionId: undefined, + batchId: undefined, chainId: expect.any(String), dappSuggestedGasFees: undefined, deviceConfirmedOn: undefined, @@ -2169,6 +2175,63 @@ describe('TransactionController', () => { ]); }); + it('uses extra transactions publish hook if batch transactions in metadata', async () => { + const { controller } = setupController({ + messengerOptions: { + addTransactionApprovalRequest: { + state: 'approved', + }, + }, + }); + + const publishHook: jest.MockedFn = jest.fn(); + + publishHook.mockResolvedValueOnce({ + transactionHash: TRANSACTION_HASH_MOCK, + }); + + const extraTransactionsPublishHook = jest.mocked( + ExtraTransactionsPublishHook, + ); + + extraTransactionsPublishHook.mockReturnValue({ + getHook: () => publishHook, + } as unknown as ExtraTransactionsPublishHook); + + const { result, transactionMeta } = await controller.addTransaction( + { + from: ACCOUNT_MOCK, + to: ACCOUNT_MOCK, + }, + { + networkClientId: NETWORK_CLIENT_ID_MOCK, + }, + ); + + controller.updateBatchTransactions({ + transactionId: transactionMeta.id, + batchTransactions: [ + { data: DATA_MOCK, to: ACCOUNT_2_MOCK, value: VALUE_MOCK }, + ], + }); + + result.catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + expect(ExtraTransactionsPublishHook).toHaveBeenCalledTimes(1); + expect(ExtraTransactionsPublishHook).toHaveBeenCalledWith({ + addTransactionBatch: expect.any(Function), + transactions: [ + { data: DATA_MOCK, to: ACCOUNT_2_MOCK, value: VALUE_MOCK }, + ], + }); + + expect(publishHook).toHaveBeenCalledTimes(1); + }); + describe('fails', () => { /** * Test template to assert adding and submitting a transaction fails. @@ -5088,6 +5151,42 @@ describe('TransactionController', () => { expect.any(String), ); }); + + it('supports publish hook override per call', async () => { + const publishHookController = jest.fn(); + + const publishHookCall = jest.fn().mockResolvedValueOnce({ + transactionHash: TRANSACTION_HASH_MOCK, + }); + + const { controller } = setupController({ + options: { + hooks: { + publish: publishHookController, + }, + }, + messengerOptions: { + addTransactionApprovalRequest: { + state: 'approved', + }, + }, + }); + + jest.spyOn(mockEthQuery, 'sendRawTransaction'); + + const { result } = await controller.addTransaction(paramsMock, { + networkClientId: NETWORK_CLIENT_ID_MOCK, + publishHook: publishHookCall, + }); + + await result; + + expect(controller.state.transactions[0].hash).toBe(TRANSACTION_HASH_MOCK); + + expect(publishHookCall).toHaveBeenCalledTimes(1); + expect(publishHookController).not.toHaveBeenCalled(); + expect(mockEthQuery.sendRawTransaction).not.toHaveBeenCalled(); + }); }); describe('updateSecurityAlertResponse', () => { diff --git a/packages/transaction-controller/src/TransactionController.ts b/packages/transaction-controller/src/TransactionController.ts index a3a5b46650e..e54621fd435 100644 --- a/packages/transaction-controller/src/TransactionController.ts +++ b/packages/transaction-controller/src/TransactionController.ts @@ -46,7 +46,6 @@ import type { RemoteFeatureFlagControllerGetStateAction } from '@metamask/remote import { errorCodes, rpcErrors, providerErrors } from '@metamask/rpc-errors'; import type { Hex } from '@metamask/utils'; import { add0x, hexToNumber } from '@metamask/utils'; -import { Mutex } from 'async-mutex'; // This package purposefully relies on Node's EventEmitter module. // eslint-disable-next-line import-x/no-nodejs-modules import { EventEmitter } from 'events'; @@ -75,6 +74,7 @@ import { hasSimulationDataChanged, shouldResimulate, } from './helpers/ResimulateHelper'; +import { ExtraTransactionsPublishHook } from './hooks/ExtraTransactionsPublishHook'; import { projectLogger as log } from './logger'; import type { DappSuggestedGasFees, @@ -97,6 +97,8 @@ import type { TransactionBatchRequest, TransactionBatchResult, BatchTransactionParams, + PublishHook, + PublishBatchHook, } from './types'; import { TransactionEnvelopeType, @@ -364,6 +366,7 @@ export type TransactionControllerOptions = { publish?: ( transactionMeta: TransactionMeta, ) => Promise<{ transactionHash: string }>; + publishBatch?: PublishBatchHook; }; }; @@ -631,8 +634,6 @@ export class TransactionController extends BaseController< readonly #methodDataHelper: MethodDataHelper; - private readonly mutex = new Mutex(); - private readonly gasFeeFlows: GasFeeFlow[]; private readonly getSavedGasFees: (chainId: Hex) => SavedGasFees | undefined; @@ -672,6 +673,8 @@ export class TransactionController extends BaseController< readonly #pendingTransactionOptions: PendingTransactionOptions; + readonly #publishBatchHook?: PublishBatchHook; + private readonly signAbortCallbacks: Map void> = new Map(); readonly #trace: TraceCallback; @@ -831,6 +834,7 @@ export class TransactionController extends BaseController< this.securityProviderRequest = securityProviderRequest; this.#incomingTransactionOptions = incomingTransactions; this.#pendingTransactionOptions = pendingTransactions; + this.#publishBatchHook = hooks?.publishBatch; this.#transactionHistoryLimit = transactionHistoryLimit; this.sign = sign; this.#testGasFeeFlows = testGasFeeFlows === true; @@ -997,8 +1001,12 @@ export class TransactionController extends BaseController< getChainId: this.#getChainId.bind(this), getEthQuery: (networkClientId) => this.#getEthQuery({ networkClientId }), getInternalAccounts: this.#getInternalAccounts.bind(this), + getTransaction: (transactionId) => + this.getTransactionOrThrow(transactionId), messenger: this.messagingSystem, + publishBatchHook: this.#publishBatchHook, request, + updateTransaction: this.#updateTransactionInternal.bind(this), }); } @@ -1024,10 +1032,12 @@ export class TransactionController extends BaseController< * @param txParams - Standard parameters for an Ethereum transaction. * @param options - Additional options to control how the transaction is added. * @param options.actionId - Unique ID to prevent duplicate requests. + * @param options.batchId - ID of the batch this transaction belongs to. * @param options.deviceConfirmedOn - An enum to indicate what device confirmed the transaction. * @param options.method - RPC method that requested the transaction. * @param options.nestedTransactions - Params for any nested transactions encoded in the data. * @param options.origin - The origin of the transaction request, such as a dApp hostname. + * @param options.publishHook - Custom logic to publish the transaction. * @param options.requireApproval - Whether the transaction requires approval by the user, defaults to true unless explicitly disabled. * @param options.securityAlertResponse - Response from security validator. * @param options.sendFlowHistory - The sendFlowHistory entries to add. @@ -1043,11 +1053,13 @@ export class TransactionController extends BaseController< txParams: TransactionParams, options: { actionId?: string; + batchId?: string; deviceConfirmedOn?: WalletDevice; method?: string; nestedTransactions?: BatchTransactionParams[]; networkClientId: NetworkClientId; origin?: string; + publishHook?: PublishHook; requireApproval?: boolean | undefined; securityAlertResponse?: SecurityAlertResponse; sendFlowHistory?: SendFlowHistoryEntry[]; @@ -1063,11 +1075,13 @@ export class TransactionController extends BaseController< const { actionId, + batchId, deviceConfirmedOn, method, nestedTransactions, networkClientId, origin, + publishHook, requireApproval, securityAlertResponse, sendFlowHistory, @@ -1129,6 +1143,7 @@ export class TransactionController extends BaseController< : { // Add actionId to txMeta to check if same actionId is seen again actionId, + batchId, chainId, dappSuggestedGasFees, deviceConfirmedOn, @@ -1214,9 +1229,10 @@ export class TransactionController extends BaseController< return { result: this.processApproval(addedTransactionMeta, { + actionId, isExisting: Boolean(existingTransactionMeta), + publishHook, requireApproval, - actionId, traceContext, }), transactionMeta: addedTransactionMeta, @@ -2351,6 +2367,34 @@ export class TransactionController extends BaseController< this.signAbortCallbacks.delete(transactionId); } + /** + * Update the batch transactions associated with a transaction. + * These transactions will be submitted with the main transaction as a batch. + * + * @param request - The request object. + * @param request.transactionId - The ID of the transaction to update. + * @param request.batchTransactions - The new batch transactions. + */ + updateBatchTransactions({ + transactionId, + batchTransactions, + }: { + transactionId: string; + batchTransactions: BatchTransactionParams[]; + }) { + log('Updating batch transactions', { transactionId, batchTransactions }); + + this.#updateTransactionInternal( + { + transactionId, + note: 'TransactionController#updateBatchTransactions - Batch transactions updated', + }, + (transactionMeta) => { + transactionMeta.batchTransactions = batchTransactions; + }, + ); + } + private addMetadata(transactionMeta: TransactionMeta) { validateTxParams(transactionMeta.txParams); this.update((state) => { @@ -2438,22 +2482,25 @@ export class TransactionController extends BaseController< private async processApproval( transactionMeta: TransactionMeta, { + actionId, isExisting = false, + publishHook, requireApproval, shouldShowRequest = true, - actionId, traceContext, }: { + actionId?: string; isExisting?: boolean; + publishHook?: PublishHook; requireApproval?: boolean | undefined; shouldShowRequest?: boolean; - actionId?: string; traceContext?: TraceContext; }, ): Promise { const transactionId = transactionMeta.id; let resultCallbacks: AcceptResultCallbacks | undefined; const { meta, isCompleted } = this.isTransactionCompleted(transactionId); + const finishedPromise = isCompleted ? Promise.resolve(meta) : this.waitForTransactionFinished(transactionId); @@ -2500,6 +2547,7 @@ export class TransactionController extends BaseController< const approvalResult = await this.approveTransaction( transactionId, traceContext, + publishHook, ); if ( approvalResult === ApprovalState.SkippedViaBeforePublishHook && @@ -2544,7 +2592,7 @@ export class TransactionController extends BaseController< switch (finalMeta?.status) { case TransactionStatus.failed: resultCallbacks?.error(finalMeta.error); - throw rpcErrors.internal(finalMeta.error.message); + throw rpcErrors.internal(finalMeta.error.stack); case TransactionStatus.submitted: resultCallbacks?.success(); @@ -2570,17 +2618,21 @@ export class TransactionController extends BaseController< * * @param transactionId - The ID of the transaction to approve. * @param traceContext - The parent context for any new traces. + * @param publishHookOverride - Custom logic to publish the transaction. * @returns The state of the approval. */ private async approveTransaction( transactionId: string, traceContext?: unknown, + publishHookOverride?: PublishHook, ) { - const cleanupTasks = new Array<() => void>(); - cleanupTasks.push(await this.mutex.acquire()); + let clearApprovingTransactionId: (() => void) | undefined; + let clearNonceLock: (() => void) | undefined; let transactionMeta = this.getTransactionOrThrow(transactionId); + log('Approving transaction', transactionMeta); + try { if (!this.sign) { this.failTransaction( @@ -2597,10 +2649,11 @@ export class TransactionController extends BaseController< log('Skipping approval as signing in progress', transactionId); return ApprovalState.NotApproved; } + this.approvingTransactionIds.add(transactionId); - cleanupTasks.push(() => - this.approvingTransactionIds.delete(transactionId), - ); + + clearApprovingTransactionId = () => + this.approvingTransactionIds.delete(transactionId); const [nonce, releaseNonce] = await getNextNonce( transactionMeta, @@ -2611,8 +2664,7 @@ export class TransactionController extends BaseController< ), ); - // must set transaction to submitted/failed before releasing lock - releaseNonce && cleanupTasks.push(releaseNonce); + clearNonceLock = releaseNonce; transactionMeta = this.#updateTransactionInternal( { @@ -2673,10 +2725,26 @@ export class TransactionController extends BaseController< let hash: string | undefined; + clearNonceLock?.(); + clearNonceLock = undefined; + + if (transactionMeta.batchTransactions?.length) { + log('Found batch transactions', transactionMeta.batchTransactions); + + const extraTransactionsPublishHook = new ExtraTransactionsPublishHook({ + addTransactionBatch: this.addTransactionBatch.bind(this), + transactions: transactionMeta.batchTransactions, + }); + + publishHookOverride = extraTransactionsPublishHook.getHook(); + } + await this.#trace( { name: 'Publish', parentContext: traceContext }, async () => { - ({ transactionHash: hash } = await this.publish( + const publishHook = publishHookOverride ?? this.publish; + + ({ transactionHash: hash } = await publishHook( transactionMeta, rawTx, )); @@ -2726,7 +2794,8 @@ export class TransactionController extends BaseController< this.failTransaction(transactionMeta, error); return ApprovalState.NotApproved; } finally { - cleanupTasks.forEach((task) => task()); + clearApprovingTransactionId?.(); + clearNonceLock?.(); } } @@ -3312,14 +3381,14 @@ export class TransactionController extends BaseController< } private getNonceTrackerTransactions( - status: TransactionStatus, + statuses: TransactionStatus[], address: string, chainId: string, ) { return getAndFormatTransactionsForNonceTracker( chainId, address, - status, + statuses, this.state.transactions, ); } @@ -3394,7 +3463,7 @@ export class TransactionController extends BaseController< ), getConfirmedTransactions: this.getNonceTrackerTransactions.bind( this, - TransactionStatus.confirmed, + [TransactionStatus.confirmed], chainId, ), }); @@ -3492,7 +3561,11 @@ export class TransactionController extends BaseController< #getNonceTrackerPendingTransactions(chainId: string, address: string) { const standardPendingTransactions = this.getNonceTrackerTransactions( - TransactionStatus.submitted, + [ + TransactionStatus.approved, + TransactionStatus.signed, + TransactionStatus.submitted, + ], address, chainId, ); diff --git a/packages/transaction-controller/src/hooks/CollectPublishHook.test.ts b/packages/transaction-controller/src/hooks/CollectPublishHook.test.ts new file mode 100644 index 00000000000..fcafa02dace --- /dev/null +++ b/packages/transaction-controller/src/hooks/CollectPublishHook.test.ts @@ -0,0 +1,111 @@ +import { CollectPublishHook } from './CollectPublishHook'; +import type { TransactionMeta } from '..'; +import { flushPromises } from '../../../../tests/helpers'; + +const SIGNED_TX_MOCK = '0x123'; +const SIGNED_TX_2_MOCK = '0x456'; +const TRANSACTION_HASH_MOCK = '0x789'; +const TRANSACTION_HASH_2_MOCK = '0xabc'; +const ERROR_MESSAGE_MOCK = 'Test error'; + +const TRANSACTION_META_MOCK = { + id: '123-456', +} as TransactionMeta; + +describe('CollectPublishHook', () => { + describe('getHook', () => { + it('returns function that resolves ready promise', async () => { + const collectHook = new CollectPublishHook(2); + const publishHook = collectHook.getHook(); + + publishHook(TRANSACTION_META_MOCK, SIGNED_TX_MOCK).catch(() => { + // Intentionally empty + }); + + publishHook(TRANSACTION_META_MOCK, SIGNED_TX_2_MOCK).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + const result = await collectHook.ready(); + + expect(result.signedTransactions).toStrictEqual([ + SIGNED_TX_MOCK, + SIGNED_TX_2_MOCK, + ]); + }); + }); + + describe('success', () => { + it('resolves all publish promises', async () => { + const collectHook = new CollectPublishHook(2); + const publishHook = collectHook.getHook(); + + const publishPromise1 = publishHook( + TRANSACTION_META_MOCK, + SIGNED_TX_MOCK, + ); + + const publishPromise2 = publishHook( + TRANSACTION_META_MOCK, + SIGNED_TX_2_MOCK, + ); + + collectHook.success([TRANSACTION_HASH_MOCK, TRANSACTION_HASH_2_MOCK]); + + const result1 = await publishPromise1; + const result2 = await publishPromise2; + + expect(result1.transactionHash).toBe(TRANSACTION_HASH_MOCK); + expect(result2.transactionHash).toBe(TRANSACTION_HASH_2_MOCK); + }); + + it('throws if transaction hash count does not match hook call count', () => { + const collectHook = new CollectPublishHook(2); + const publishHook = collectHook.getHook(); + + publishHook(TRANSACTION_META_MOCK, SIGNED_TX_MOCK).catch(() => { + // Intentionally empty + }); + + publishHook(TRANSACTION_META_MOCK, SIGNED_TX_2_MOCK).catch(() => { + // Intentionally empty + }); + + expect(() => { + collectHook.success([TRANSACTION_HASH_MOCK]); + }).toThrow('Transaction hash count mismatch'); + }); + }); + + describe('error', () => { + it('rejects all publish promises', async () => { + const collectHook = new CollectPublishHook(2); + const publishHook = collectHook.getHook(); + + const publishPromise1 = publishHook( + TRANSACTION_META_MOCK, + SIGNED_TX_MOCK, + ); + + const publishPromise2 = publishHook( + TRANSACTION_META_MOCK, + SIGNED_TX_2_MOCK, + ); + + publishPromise1.catch(() => { + // Intentionally empty + }); + + publishPromise2.catch(() => { + // Intentionally empty + }); + + collectHook.error(new Error(ERROR_MESSAGE_MOCK)); + + await expect(publishPromise1).rejects.toThrow(ERROR_MESSAGE_MOCK); + await expect(publishPromise2).rejects.toThrow(ERROR_MESSAGE_MOCK); + }); + }); +}); diff --git a/packages/transaction-controller/src/hooks/CollectPublishHook.ts b/packages/transaction-controller/src/hooks/CollectPublishHook.ts new file mode 100644 index 00000000000..3e84f98fd8a --- /dev/null +++ b/packages/transaction-controller/src/hooks/CollectPublishHook.ts @@ -0,0 +1,97 @@ +import type { DeferredPromise, Hex } from '@metamask/utils'; +import { createDeferredPromise, createModuleLogger } from '@metamask/utils'; + +import { projectLogger } from '../logger'; +import type { PublishHook, PublishHookResult, TransactionMeta } from '../types'; + +const log = createModuleLogger(projectLogger, 'collect-publish-hook'); + +export type CollectPublishHookResult = { + signedTransactions: Hex[]; +}; + +/** + * Custom publish logic that collects multiple signed transactions until a specific number is reached. + * Used by batch transactions to publish multiple transactions at once. + */ +export class CollectPublishHook { + readonly #publishPromises: DeferredPromise[]; + + readonly #signedTransactions: Hex[]; + + readonly #transactionCount: number; + + readonly #readyPromise: DeferredPromise; + + constructor(transactionCount: number) { + this.#publishPromises = []; + this.#readyPromise = createDeferredPromise(); + this.#signedTransactions = []; + this.#transactionCount = transactionCount; + } + + /** + * @returns The publish hook function to be passed to `addTransaction`. + */ + getHook(): PublishHook { + return this.#hook.bind(this); + } + + /** + * @returns A promise that resolves when all transactions are signed. + */ + ready(): Promise { + return this.#readyPromise.promise; + } + + /** + * Resolve all publish promises with the provided transaction hashes. + * + * @param transactionHashes - The transaction hashes to pass to the original publish promises. + */ + success(transactionHashes: Hex[]) { + log('Success', { transactionHashes }); + + if (transactionHashes.length !== this.#transactionCount) { + throw new Error('Transaction hash count mismatch'); + } + + for (let i = 0; i < this.#publishPromises.length; i++) { + const publishPromise = this.#publishPromises[i]; + const transactionHash = transactionHashes[i]; + + publishPromise.resolve({ transactionHash }); + } + } + + error(error: unknown) { + log('Error', { error }); + + for (const publishPromise of this.#publishPromises) { + publishPromise.reject(error); + } + } + + #hook( + transactionMeta: TransactionMeta, + signedTx: string, + ): Promise { + this.#signedTransactions.push(signedTx as Hex); + + log('Processing transaction', { transactionMeta, signedTx }); + + const publishPromise = createDeferredPromise(); + + this.#publishPromises.push(publishPromise); + + if (this.#signedTransactions.length === this.#transactionCount) { + log('All transactions signed'); + + this.#readyPromise.resolve({ + signedTransactions: this.#signedTransactions, + }); + } + + return publishPromise.promise; + } +} diff --git a/packages/transaction-controller/src/hooks/ExtraTransactionsPublishHook.test.ts b/packages/transaction-controller/src/hooks/ExtraTransactionsPublishHook.test.ts new file mode 100644 index 00000000000..05e7c32c394 --- /dev/null +++ b/packages/transaction-controller/src/hooks/ExtraTransactionsPublishHook.test.ts @@ -0,0 +1,140 @@ +import { ExtraTransactionsPublishHook } from './ExtraTransactionsPublishHook'; +import type { + BatchTransactionParams, + TransactionController, + TransactionMeta, +} from '..'; + +const SIGNED_TRANSACTION_MOCK = '0xffe'; +const TRANSACTION_HASH_MOCK = '0xeee'; + +const BATCH_TRANSACTION_PARAMS_MOCK: BatchTransactionParams = { + data: '0x123', + to: '0x456', + value: '0x789', +}; + +const BATCH_TRANSACTION_PARAMS_2_MOCK: BatchTransactionParams = { + data: '0x321', + to: '0x654', + value: '0x987', +}; + +const TRANSACTION_META_MOCK = { + id: '123-456', + networkClientId: 'testNetworkClientId', + txParams: { + from: '0xaab', + data: '0xabc', + to: '0xdef', + value: '0xfed', + }, +} as TransactionMeta; + +describe('ExtraTransactionsPublishHook', () => { + it('creates batch transaction', async () => { + const addTransactionBatch: jest.MockedFn< + TransactionController['addTransactionBatch'] + > = jest.fn(); + + const hookInstance = new ExtraTransactionsPublishHook({ + addTransactionBatch, + transactions: [ + BATCH_TRANSACTION_PARAMS_MOCK, + BATCH_TRANSACTION_PARAMS_2_MOCK, + ], + }); + + const hook = hookInstance.getHook(); + + hook(TRANSACTION_META_MOCK, SIGNED_TRANSACTION_MOCK).catch(() => { + // Intentionally empty + }); + + expect(addTransactionBatch).toHaveBeenCalledTimes(1); + expect(addTransactionBatch).toHaveBeenCalledWith({ + from: TRANSACTION_META_MOCK.txParams.from, + networkClientId: TRANSACTION_META_MOCK.networkClientId, + transactions: [ + { + existingTransaction: { + id: TRANSACTION_META_MOCK.id, + onPublish: expect.any(Function), + signedTransaction: SIGNED_TRANSACTION_MOCK, + }, + params: { + data: TRANSACTION_META_MOCK.txParams.data, + to: TRANSACTION_META_MOCK.txParams.to, + value: TRANSACTION_META_MOCK.txParams.value, + }, + }, + { + params: BATCH_TRANSACTION_PARAMS_MOCK, + }, + { + params: BATCH_TRANSACTION_PARAMS_2_MOCK, + }, + ], + useHook: true, + }); + }); + + it('resolves when onPublish callback is called', async () => { + const addTransactionBatch: jest.MockedFn< + TransactionController['addTransactionBatch'] + > = jest.fn(); + + const hookInstance = new ExtraTransactionsPublishHook({ + addTransactionBatch, + transactions: [ + BATCH_TRANSACTION_PARAMS_MOCK, + BATCH_TRANSACTION_PARAMS_2_MOCK, + ], + }); + + const hook = hookInstance.getHook(); + + const hookPromise = hook( + TRANSACTION_META_MOCK, + SIGNED_TRANSACTION_MOCK, + ).catch(() => { + // Intentionally empty + }); + + const onPublish = + addTransactionBatch.mock.calls[0][0].transactions[0].existingTransaction + ?.onPublish; + + onPublish?.({ transactionHash: TRANSACTION_HASH_MOCK }); + + expect(await hookPromise).toStrictEqual({ + transactionHash: TRANSACTION_HASH_MOCK, + }); + }); + + it('rejects if addTransactionBatch throws', async () => { + const addTransactionBatch: jest.MockedFn< + TransactionController['addTransactionBatch'] + > = jest.fn().mockImplementation(() => { + throw new Error('Test error'); + }); + + const hookInstance = new ExtraTransactionsPublishHook({ + addTransactionBatch, + transactions: [ + BATCH_TRANSACTION_PARAMS_MOCK, + BATCH_TRANSACTION_PARAMS_2_MOCK, + ], + }); + + const hook = hookInstance.getHook(); + + const hookPromise = hook(TRANSACTION_META_MOCK, SIGNED_TRANSACTION_MOCK); + + hookPromise.catch(() => { + // Intentionally empty + }); + + await expect(hookPromise).rejects.toThrow('Test error'); + }); +}); diff --git a/packages/transaction-controller/src/hooks/ExtraTransactionsPublishHook.ts b/packages/transaction-controller/src/hooks/ExtraTransactionsPublishHook.ts new file mode 100644 index 00000000000..32bc3e53f86 --- /dev/null +++ b/packages/transaction-controller/src/hooks/ExtraTransactionsPublishHook.ts @@ -0,0 +1,107 @@ +import { + createDeferredPromise, + createModuleLogger, + type Hex, +} from '@metamask/utils'; + +import type { TransactionController } from '..'; +import { projectLogger } from '../logger'; +import type { + BatchTransactionParams, + PublishHook, + PublishHookResult, + TransactionBatchSingleRequest, + TransactionMeta, +} from '../types'; + +const log = createModuleLogger( + projectLogger, + 'extra-transactions-publish-hook', +); + +/** + * Custom publish logic that also publishes additional transactions in an batch. + * Requires the batch to be successful to resolve. + */ +export class ExtraTransactionsPublishHook { + readonly #addTransactionBatch: TransactionController['addTransactionBatch']; + + readonly #transactions: BatchTransactionParams[]; + + constructor({ + addTransactionBatch, + transactions, + }: { + addTransactionBatch: TransactionController['addTransactionBatch']; + transactions: BatchTransactionParams[]; + }) { + this.#addTransactionBatch = addTransactionBatch; + this.#transactions = transactions; + } + + /** + * @returns The publish hook function. + */ + getHook(): PublishHook { + return this.#hook.bind(this); + } + + async #hook( + transactionMeta: TransactionMeta, + signedTx: string, + ): Promise { + log('Publishing transaction as batch', { transactionMeta, signedTx }); + + const { id, networkClientId, txParams } = transactionMeta; + const from = txParams.from as Hex; + const to = txParams.to as Hex | undefined; + const data = txParams.data as Hex | undefined; + const value = txParams.value as Hex | undefined; + const signedTransaction = signedTx as Hex; + const resultPromise = createDeferredPromise(); + + const onPublish = ({ transactionHash }: { transactionHash?: string }) => { + resultPromise.resolve({ transactionHash }); + }; + + const firstParams: BatchTransactionParams = { + data, + to, + value, + }; + + const firstTransaction: TransactionBatchSingleRequest = { + existingTransaction: { + id, + onPublish, + signedTransaction, + }, + params: firstParams, + }; + + const extraTransactions: TransactionBatchSingleRequest[] = + this.#transactions.map((transaction) => ({ + params: transaction, + })); + + const transactions: TransactionBatchSingleRequest[] = [ + firstTransaction, + ...extraTransactions, + ]; + + log('Adding transaction batch', { + from, + networkClientId, + transactions, + }); + + await this.#addTransactionBatch({ + from, + networkClientId, + transactions, + useHook: true, + }); + + return resultPromise.promise; + } +} diff --git a/packages/transaction-controller/src/index.ts b/packages/transaction-controller/src/index.ts index 9b83ae44011..0892b5b40b1 100644 --- a/packages/transaction-controller/src/index.ts +++ b/packages/transaction-controller/src/index.ts @@ -44,6 +44,11 @@ export type { InferTransactionTypeResult, LegacyGasFeeEstimates, Log, + PublishBatchHook, + PublishBatchHookRequest, + PublishBatchHookResult, + PublishHook, + PublishHookResult, SavedGasFees, SecurityAlertResponse, SecurityProviderRequest, diff --git a/packages/transaction-controller/src/types.ts b/packages/transaction-controller/src/types.ts index c5bff26a152..ff77dacbb93 100644 --- a/packages/transaction-controller/src/types.ts +++ b/packages/transaction-controller/src/types.ts @@ -56,6 +56,16 @@ type TransactionMetaBase = { */ baseFeePerGas?: Hex; + /** + * ID of the batch this transaction belongs to. + */ + batchId?: string; + + /** + * Additional transactions that must also be submitted in a batch. + */ + batchTransactions?: BatchTransactionParams[]; + /** * Number of the block where the transaction has been included. */ @@ -1422,6 +1432,21 @@ export type BatchTransactionParams = { * Specification for a single transaction within a batch request. */ export type TransactionBatchSingleRequest = { + /** Data if the transaction already exists. */ + existingTransaction?: { + /** ID of the existing transaction. */ + id: string; + + /** Optional callback to be invoked once the transaction is published. */ + onPublish?: (request: { + /** Hash of the transaction on the network. */ + transactionHash?: string; + }) => void; + + /** Signed transaction data. */ + signedTransaction: Hex; + }; + /** Parameters of the single transaction. */ params: BatchTransactionParams; }; @@ -1445,6 +1470,12 @@ export type TransactionBatchRequest = { /** Transactions to be submitted as part of the batch. */ transactions: TransactionBatchSingleRequest[]; + + /** + * Whether to use the publish batch hook to submit the batch. + * Defaults to false. + */ + useHook?: boolean; }; /** @@ -1454,3 +1485,68 @@ export type TransactionBatchResult = { /** ID of the batch to locate related transactions. */ batchId: string; }; + +/** + * Data returned from custom logic to publish a transaction. + */ +export type PublishHookResult = { + /** + * The hash of the transaction on the network. + */ + transactionHash?: string; +}; + +/** + * Custom logic to publish a transaction. + * + * @param transactionMeta - The metadata of the transaction to publish. + * @param signedTx - The signed transaction data to publish. + * @returns The result of the publish operation. + */ +export type PublishHook = ( + transactionMeta: TransactionMeta, + signedTx: string, +) => Promise; + +/** Single transaction in a publish batch hook request. */ +export type PublishBatchHookTransaction = { + /** ID of the transaction. */ + id?: string; + + /** Parameters of the nested transaction. */ + params: BatchTransactionParams; + + /** Signed transaction data to publish. */ + signedTx: Hex; +}; + +/** + * Data required to call a publish batch hook. + */ +export type PublishBatchHookRequest = { + /** Address of the account to submit the transaction batch. */ + from: Hex; + + /** ID of the network client associated with the transaction batch. */ + networkClientId: string; + + /** Nested transactions to be submitted as part of the batch. */ + transactions: PublishBatchHookTransaction[]; +}; + +/** Result of calling a publish batch hook. */ +export type PublishBatchHookResult = + | { + /** Result data for each transaction in the batch. */ + results: { + /** Hash of the transaction on the network. */ + transactionHash: Hex; + }[]; + } + | undefined; + +/** Custom logic to publish a transaction batch. */ +export type PublishBatchHook = ( + /** Data required to call the hook. */ + request: PublishBatchHookRequest, +) => Promise; diff --git a/packages/transaction-controller/src/utils/batch.test.ts b/packages/transaction-controller/src/utils/batch.test.ts index 13b0ff676ac..95cf8aec089 100644 --- a/packages/transaction-controller/src/utils/batch.test.ts +++ b/packages/transaction-controller/src/utils/batch.test.ts @@ -16,6 +16,8 @@ import { type TransactionControllerMessenger, type TransactionMeta, } from '..'; +import { flushPromises } from '../../../../tests/helpers'; +import type { PublishBatchHook } from '../types'; jest.mock('./eip7702'); jest.mock('./feature-flags'); @@ -39,9 +41,22 @@ const NETWORK_CLIENT_ID_MOCK = 'testNetworkClientId'; const BATCH_ID_MOCK = 'testBatchId'; const GET_ETH_QUERY_MOCK = jest.fn(); const GET_INTERNAL_ACCOUNTS_MOCK = jest.fn().mockReturnValue([]); +const TRANSACTION_ID_MOCK = 'testTransactionId'; +const TRANSACTION_ID_2_MOCK = 'testTransactionId2'; +const TRANSACTION_HASH_MOCK = '0x123'; +const TRANSACTION_HASH_2_MOCK = '0x456'; +const TRANSACTION_SIGNATURE_MOCK = '0xabc'; +const TRANSACTION_SIGNATURE_2_MOCK = '0xdef'; +const ERROR_MESSAGE_MOCK = 'Test error'; const TRANSACTION_META_MOCK = { id: BATCH_ID_MOCK, + txParams: { + from: FROM_MOCK, + to: TO_MOCK, + data: DATA_MOCK, + value: VALUE_MOCK, + }, } as TransactionMeta; describe('Batch Utils', () => { @@ -70,18 +85,24 @@ describe('Batch Utils', () => { AddBatchTransactionOptions['getChainId'] >; + let updateTransactionMock: jest.MockedFn< + AddBatchTransactionOptions['updateTransaction'] + >; + let request: AddBatchTransactionOptions; beforeEach(() => { jest.resetAllMocks(); addTransactionMock = jest.fn(); getChainIdMock = jest.fn(); + updateTransactionMock = jest.fn(); request = { addTransaction: addTransactionMock, getChainId: getChainIdMock, getEthQuery: GET_ETH_QUERY_MOCK, getInternalAccounts: GET_INTERNAL_ACCOUNTS_MOCK, + getTransaction: jest.fn(), messenger: MESSENGER_MOCK, request: { from: FROM_MOCK, @@ -104,161 +125,691 @@ describe('Batch Utils', () => { }, ], }, + updateTransaction: updateTransactionMock, }; }); - it('adds generated EIP-7702 transaction', async () => { - doesChainSupportEIP7702Mock.mockReturnValueOnce(true); + describe('with EIP-7702', () => { + it('adds generated EIP-7702 transaction', async () => { + doesChainSupportEIP7702Mock.mockReturnValueOnce(true); - isAccountUpgradedToEIP7702Mock.mockResolvedValueOnce({ - delegationAddress: undefined, - isSupported: true, - }); + isAccountUpgradedToEIP7702Mock.mockResolvedValueOnce({ + delegationAddress: undefined, + isSupported: true, + }); - addTransactionMock.mockResolvedValueOnce({ - transactionMeta: TRANSACTION_META_MOCK, - result: Promise.resolve(''), + addTransactionMock.mockResolvedValueOnce({ + transactionMeta: TRANSACTION_META_MOCK, + result: Promise.resolve(''), + }); + + generateEIP7702BatchTransactionMock.mockReturnValueOnce({ + to: TO_MOCK, + data: DATA_MOCK, + value: VALUE_MOCK, + }); + + await addTransactionBatch(request); + + expect(addTransactionMock).toHaveBeenCalledTimes(1); + expect(addTransactionMock).toHaveBeenCalledWith( + { + from: FROM_MOCK, + to: TO_MOCK, + data: DATA_MOCK, + value: VALUE_MOCK, + }, + expect.objectContaining({ + networkClientId: NETWORK_CLIENT_ID_MOCK, + requireApproval: true, + }), + ); }); - generateEIP7702BatchTransactionMock.mockReturnValueOnce({ - to: TO_MOCK, - data: DATA_MOCK, - value: VALUE_MOCK, + it('uses type 4 transaction if not upgraded', async () => { + doesChainSupportEIP7702Mock.mockReturnValueOnce(true); + + isAccountUpgradedToEIP7702Mock.mockResolvedValueOnce({ + delegationAddress: undefined, + isSupported: false, + }); + + addTransactionMock.mockResolvedValueOnce({ + transactionMeta: TRANSACTION_META_MOCK, + result: Promise.resolve(''), + }); + + generateEIP7702BatchTransactionMock.mockReturnValueOnce({ + to: TO_MOCK, + data: DATA_MOCK, + value: VALUE_MOCK, + }); + + getEIP7702UpgradeContractAddressMock.mockReturnValueOnce( + CONTRACT_ADDRESS_MOCK, + ); + + await addTransactionBatch(request); + + expect(addTransactionMock).toHaveBeenCalledTimes(1); + expect(addTransactionMock).toHaveBeenCalledWith( + { + from: FROM_MOCK, + to: TO_MOCK, + data: DATA_MOCK, + value: VALUE_MOCK, + type: TransactionEnvelopeType.setCode, + authorizationList: [{ address: CONTRACT_ADDRESS_MOCK }], + }, + expect.objectContaining({ + networkClientId: NETWORK_CLIENT_ID_MOCK, + requireApproval: true, + }), + ); }); - await addTransactionBatch(request); + it('passes nested transactions to add transaction', async () => { + doesChainSupportEIP7702Mock.mockReturnValueOnce(true); - expect(addTransactionMock).toHaveBeenCalledTimes(1); - expect(addTransactionMock).toHaveBeenCalledWith( - { - from: FROM_MOCK, + isAccountUpgradedToEIP7702Mock.mockResolvedValueOnce({ + delegationAddress: undefined, + isSupported: true, + }); + + addTransactionMock.mockResolvedValueOnce({ + transactionMeta: TRANSACTION_META_MOCK, + result: Promise.resolve(''), + }); + + generateEIP7702BatchTransactionMock.mockReturnValueOnce({ to: TO_MOCK, data: DATA_MOCK, value: VALUE_MOCK, - }, - expect.objectContaining({ - networkClientId: NETWORK_CLIENT_ID_MOCK, - requireApproval: true, - }), - ); - }); + }); - it('uses type 4 transaction if not upgraded', async () => { - doesChainSupportEIP7702Mock.mockReturnValueOnce(true); + await addTransactionBatch(request); - isAccountUpgradedToEIP7702Mock.mockResolvedValueOnce({ - delegationAddress: undefined, - isSupported: false, + expect(addTransactionMock).toHaveBeenCalledTimes(1); + expect(addTransactionMock).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + nestedTransactions: [ + { + to: TO_MOCK, + data: DATA_MOCK, + value: VALUE_MOCK, + }, + { + to: TO_MOCK, + data: DATA_MOCK, + value: VALUE_MOCK, + }, + ], + }), + ); }); - addTransactionMock.mockResolvedValueOnce({ - transactionMeta: TRANSACTION_META_MOCK, - result: Promise.resolve(''), + it('throws if chain not supported', async () => { + doesChainSupportEIP7702Mock.mockReturnValueOnce(false); + + await expect(addTransactionBatch(request)).rejects.toThrow( + rpcErrors.internal('Chain does not support EIP-7702'), + ); }); - generateEIP7702BatchTransactionMock.mockReturnValueOnce({ - to: TO_MOCK, - data: DATA_MOCK, - value: VALUE_MOCK, + it('throws if account upgraded to unsupported contract', async () => { + doesChainSupportEIP7702Mock.mockReturnValueOnce(true); + isAccountUpgradedToEIP7702Mock.mockResolvedValueOnce({ + delegationAddress: CONTRACT_ADDRESS_MOCK, + isSupported: false, + }); + + await expect(addTransactionBatch(request)).rejects.toThrow( + rpcErrors.internal('Account upgraded to unsupported contract'), + ); }); - getEIP7702UpgradeContractAddressMock.mockReturnValueOnce( - CONTRACT_ADDRESS_MOCK, - ); + it('throws if account not upgraded and no upgrade address', async () => { + doesChainSupportEIP7702Mock.mockReturnValueOnce(true); - await addTransactionBatch(request); + isAccountUpgradedToEIP7702Mock.mockResolvedValueOnce({ + delegationAddress: undefined, + isSupported: false, + }); - expect(addTransactionMock).toHaveBeenCalledTimes(1); - expect(addTransactionMock).toHaveBeenCalledWith( - { - from: FROM_MOCK, - to: TO_MOCK, - data: DATA_MOCK, - value: VALUE_MOCK, - type: TransactionEnvelopeType.setCode, - authorizationList: [{ address: CONTRACT_ADDRESS_MOCK }], - }, - expect.objectContaining({ - networkClientId: NETWORK_CLIENT_ID_MOCK, - requireApproval: true, - }), - ); + getEIP7702UpgradeContractAddressMock.mockReturnValueOnce(undefined); + + await expect(addTransactionBatch(request)).rejects.toThrow( + rpcErrors.internal('Upgrade contract address not found'), + ); + }); }); - it('passes nested transactions to add transaction', async () => { - doesChainSupportEIP7702Mock.mockReturnValueOnce(true); + describe('with publish batch hook', () => { + it('adds each nested transaction', async () => { + const publishBatchHook = jest.fn(); - isAccountUpgradedToEIP7702Mock.mockResolvedValueOnce({ - delegationAddress: undefined, - isSupported: true, + addTransactionMock.mockResolvedValueOnce({ + transactionMeta: TRANSACTION_META_MOCK, + result: Promise.resolve(''), + }); + + addTransactionBatch({ + ...request, + publishBatchHook, + request: { ...request.request, useHook: true }, + }).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + expect(addTransactionMock).toHaveBeenCalledTimes(2); + expect(addTransactionMock).toHaveBeenCalledWith( + { + data: DATA_MOCK, + from: FROM_MOCK, + to: TO_MOCK, + value: VALUE_MOCK, + }, + { + batchId: expect.any(String), + networkClientId: NETWORK_CLIENT_ID_MOCK, + publishHook: expect.any(Function), + requireApproval: false, + }, + ); }); - addTransactionMock.mockResolvedValueOnce({ - transactionMeta: TRANSACTION_META_MOCK, - result: Promise.resolve(''), + it('calls publish batch hook', async () => { + const publishBatchHook: jest.MockedFn = jest.fn(); + + addTransactionMock + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_MOCK, + }, + result: Promise.resolve(''), + }) + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_2_MOCK, + }, + result: Promise.resolve(''), + }); + + publishBatchHook.mockResolvedValue({ + results: [ + { + transactionHash: TRANSACTION_HASH_MOCK, + }, + { + transactionHash: TRANSACTION_HASH_2_MOCK, + }, + ], + }); + + addTransactionBatch({ + ...request, + publishBatchHook, + request: { ...request.request, useHook: true }, + }).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + const publishHooks = addTransactionMock.mock.calls.map( + ([, options]) => options.publishHook, + ); + + publishHooks[0]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_MOCK, + ).catch(() => { + // Intentionally empty + }); + + publishHooks[1]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_2_MOCK, + ).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + expect(publishBatchHook).toHaveBeenCalledTimes(1); + expect(publishBatchHook).toHaveBeenCalledWith({ + from: FROM_MOCK, + networkClientId: NETWORK_CLIENT_ID_MOCK, + transactions: [ + { + id: TRANSACTION_ID_MOCK, + params: { data: DATA_MOCK, to: TO_MOCK, value: VALUE_MOCK }, + signedTx: TRANSACTION_SIGNATURE_MOCK, + }, + { + id: TRANSACTION_ID_2_MOCK, + params: { data: DATA_MOCK, to: TO_MOCK, value: VALUE_MOCK }, + signedTx: TRANSACTION_SIGNATURE_2_MOCK, + }, + ], + }); + }); + + it('resolves individual publish hooks with transaction hashes from publish batch hook', async () => { + const publishBatchHook: jest.MockedFn = jest.fn(); + + addTransactionMock + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_MOCK, + }, + result: Promise.resolve(''), + }) + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_2_MOCK, + }, + result: Promise.resolve(''), + }); + + publishBatchHook.mockResolvedValue({ + results: [ + { + transactionHash: TRANSACTION_HASH_MOCK, + }, + { + transactionHash: TRANSACTION_HASH_2_MOCK, + }, + ], + }); + + addTransactionBatch({ + ...request, + publishBatchHook, + request: { ...request.request, useHook: true }, + }).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + const publishHooks = addTransactionMock.mock.calls.map( + ([, options]) => options.publishHook, + ); + + const publishHookPromise1 = publishHooks[0]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_MOCK, + ).catch(() => { + // Intentionally empty + }); + + const publishHookPromise2 = publishHooks[1]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_2_MOCK, + ).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + expect(await publishHookPromise1).toStrictEqual({ + transactionHash: TRANSACTION_HASH_MOCK, + }); + + expect(await publishHookPromise2).toStrictEqual({ + transactionHash: TRANSACTION_HASH_2_MOCK, + }); }); - generateEIP7702BatchTransactionMock.mockReturnValueOnce({ - to: TO_MOCK, - data: DATA_MOCK, - value: VALUE_MOCK, + it('handles existing transactions', async () => { + const publishBatchHook: jest.MockedFn = jest.fn(); + const onPublish = jest.fn(); + + addTransactionMock + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_MOCK, + }, + result: Promise.resolve(''), + }) + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_2_MOCK, + }, + result: Promise.resolve(''), + }); + + publishBatchHook.mockResolvedValue({ + results: [ + { + transactionHash: TRANSACTION_HASH_MOCK, + }, + { + transactionHash: TRANSACTION_HASH_2_MOCK, + }, + ], + }); + + addTransactionBatch({ + ...request, + publishBatchHook, + request: { + ...request.request, + transactions: [ + { + ...request.request.transactions[0], + existingTransaction: { + id: TRANSACTION_ID_2_MOCK, + onPublish, + signedTransaction: TRANSACTION_SIGNATURE_2_MOCK, + }, + }, + request.request.transactions[1], + ], + useHook: true, + }, + }).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + const publishHooks = addTransactionMock.mock.calls.map( + ([, options]) => options.publishHook, + ); + + publishHooks[0]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_MOCK, + ).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + expect(addTransactionMock).toHaveBeenCalledTimes(1); + + expect(publishBatchHook).toHaveBeenCalledTimes(1); + expect(publishBatchHook).toHaveBeenCalledWith({ + from: FROM_MOCK, + networkClientId: NETWORK_CLIENT_ID_MOCK, + transactions: [ + { + id: TRANSACTION_ID_2_MOCK, + params: { data: DATA_MOCK, to: TO_MOCK, value: VALUE_MOCK }, + signedTx: TRANSACTION_SIGNATURE_2_MOCK, + }, + { + id: TRANSACTION_ID_MOCK, + params: { data: DATA_MOCK, to: TO_MOCK, value: VALUE_MOCK }, + signedTx: TRANSACTION_SIGNATURE_MOCK, + }, + ], + }); + + expect(onPublish).toHaveBeenCalledTimes(1); + expect(onPublish).toHaveBeenCalledWith({ + transactionHash: TRANSACTION_HASH_MOCK, + }); }); - await addTransactionBatch(request); + it('adds batch ID to existing transaction', async () => { + const publishBatchHook: jest.MockedFn = jest.fn(); + const onPublish = jest.fn(); + const existingTransactionMock = {}; - expect(addTransactionMock).toHaveBeenCalledTimes(1); - expect(addTransactionMock).toHaveBeenCalledWith( - expect.any(Object), - expect.objectContaining({ - nestedTransactions: [ + addTransactionMock + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_MOCK, + }, + result: Promise.resolve(''), + }) + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_2_MOCK, + }, + result: Promise.resolve(''), + }); + + updateTransactionMock.mockImplementation((_id, update) => { + update(existingTransactionMock as TransactionMeta); + }); + + publishBatchHook.mockResolvedValue({ + results: [ { - to: TO_MOCK, - data: DATA_MOCK, - value: VALUE_MOCK, + transactionHash: TRANSACTION_HASH_MOCK, }, { - to: TO_MOCK, - data: DATA_MOCK, - value: VALUE_MOCK, + transactionHash: TRANSACTION_HASH_2_MOCK, }, ], - }), - ); - }); + }); - it('throws if chain not supported', async () => { - doesChainSupportEIP7702Mock.mockReturnValueOnce(false); + addTransactionBatch({ + ...request, + publishBatchHook, + request: { + ...request.request, + transactions: [ + { + ...request.request.transactions[0], + existingTransaction: { + id: TRANSACTION_ID_2_MOCK, + onPublish, + signedTransaction: TRANSACTION_SIGNATURE_2_MOCK, + }, + }, + request.request.transactions[1], + ], + useHook: true, + }, + }).catch(() => { + // Intentionally empty + }); - await expect(addTransactionBatch(request)).rejects.toThrow( - rpcErrors.internal('Chain does not support EIP-7702'), - ); - }); + await flushPromises(); - it('throws if account upgraded to unsupported contract', async () => { - doesChainSupportEIP7702Mock.mockReturnValueOnce(true); - isAccountUpgradedToEIP7702Mock.mockResolvedValueOnce({ - delegationAddress: CONTRACT_ADDRESS_MOCK, - isSupported: false, + const publishHooks = addTransactionMock.mock.calls.map( + ([, options]) => options.publishHook, + ); + + publishHooks[0]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_MOCK, + ).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + expect(updateTransactionMock).toHaveBeenCalledTimes(1); + expect(existingTransactionMock).toStrictEqual({ + batchId: expect.any(String), + }); }); - await expect(addTransactionBatch(request)).rejects.toThrow( - rpcErrors.internal('Account upgraded to unsupported contract'), - ); - }); + it('throws if publish batch hook does not return result', async () => { + const publishBatchHook: jest.MockedFn = jest.fn(); - it('throws if account not upgraded and no upgrade address', async () => { - doesChainSupportEIP7702Mock.mockReturnValueOnce(true); + addTransactionMock + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_MOCK, + }, + result: Promise.resolve(''), + }) + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_2_MOCK, + }, + result: Promise.resolve(''), + }); - isAccountUpgradedToEIP7702Mock.mockResolvedValueOnce({ - delegationAddress: undefined, - isSupported: false, + publishBatchHook.mockResolvedValue(undefined); + + const resultPromise = addTransactionBatch({ + ...request, + publishBatchHook, + request: { ...request.request, useHook: true }, + }); + + resultPromise.catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + const publishHooks = addTransactionMock.mock.calls.map( + ([, options]) => options.publishHook, + ); + + publishHooks[0]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_MOCK, + ).catch(() => { + // Intentionally empty + }); + + publishHooks[1]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_2_MOCK, + ).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + await expect(resultPromise).rejects.toThrow( + 'Publish batch hook did not return a result', + ); }); - getEIP7702UpgradeContractAddressMock.mockReturnValueOnce(undefined); + it('throws if no publish batch hook', async () => { + await expect( + addTransactionBatch({ + ...request, + request: { ...request.request, useHook: true }, + }), + ).rejects.toThrow(rpcErrors.internal('No publish batch hook provided')); + }); - await expect(addTransactionBatch(request)).rejects.toThrow( - rpcErrors.internal('Upgrade contract address not found'), - ); + it('rejects individual publish hooks if batch hook throws', async () => { + const publishBatchHook: jest.MockedFn = jest.fn(); + + addTransactionMock + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_MOCK, + }, + result: Promise.resolve(''), + }) + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_2_MOCK, + }, + result: Promise.resolve(''), + }); + + publishBatchHook.mockImplementationOnce(() => { + throw new Error(ERROR_MESSAGE_MOCK); + }); + + addTransactionBatch({ + ...request, + publishBatchHook, + request: { ...request.request, useHook: true }, + }).catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + const publishHooks = addTransactionMock.mock.calls.map( + ([, options]) => options.publishHook, + ); + + const publishHookPromise1 = publishHooks[0]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_MOCK, + ); + + publishHookPromise1?.catch(() => { + // Intentionally empty + }); + + const publishHookPromise2 = publishHooks[1]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_2_MOCK, + ); + + publishHookPromise2?.catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + await expect(publishHookPromise1).rejects.toThrow(ERROR_MESSAGE_MOCK); + await expect(publishHookPromise2).rejects.toThrow(ERROR_MESSAGE_MOCK); + }); + + it('rejects individual publish hooks if add transaction throws', async () => { + const publishBatchHook: jest.MockedFn = jest.fn(); + + addTransactionMock + .mockResolvedValueOnce({ + transactionMeta: { + ...TRANSACTION_META_MOCK, + id: TRANSACTION_ID_MOCK, + }, + result: Promise.resolve(''), + }) + .mockImplementationOnce(() => { + throw new Error(ERROR_MESSAGE_MOCK); + }); + + addTransactionBatch({ + ...request, + publishBatchHook, + request: { ...request.request, useHook: true }, + }).catch(() => { + // Intentionally empty + }); + + const publishHooks = addTransactionMock.mock.calls.map( + ([, options]) => options.publishHook, + ); + + const publishHookPromise1 = publishHooks[0]?.( + TRANSACTION_META_MOCK, + TRANSACTION_SIGNATURE_MOCK, + ); + + publishHookPromise1?.catch(() => { + // Intentionally empty + }); + + await flushPromises(); + + await expect(publishHookPromise1).rejects.toThrow(ERROR_MESSAGE_MOCK); + }); }); it('validates request', async () => { diff --git a/packages/transaction-controller/src/utils/batch.ts b/packages/transaction-controller/src/utils/batch.ts index 1b6d1870145..b60ea630e60 100644 --- a/packages/transaction-controller/src/utils/batch.ts +++ b/packages/transaction-controller/src/utils/batch.ts @@ -2,6 +2,7 @@ import type EthQuery from '@metamask/eth-query'; import { rpcErrors } from '@metamask/rpc-errors'; import type { Hex } from '@metamask/utils'; import { createModuleLogger } from '@metamask/utils'; +import { v4 as uuid } from 'uuid'; import { doesChainSupportEIP7702, @@ -13,8 +14,20 @@ import { getEIP7702UpgradeContractAddress, } from './feature-flags'; import { validateBatchRequest } from './validation'; -import type { TransactionController, TransactionControllerMessenger } from '..'; +import type { + BatchTransactionParams, + TransactionController, + TransactionControllerMessenger, + TransactionMeta, +} from '..'; +import { CollectPublishHook } from '../hooks/CollectPublishHook'; import { projectLogger } from '../logger'; +import type { + PublishBatchHook, + PublishBatchHookTransaction, + PublishHook, + TransactionBatchSingleRequest, +} from '../types'; import { TransactionEnvelopeType, type TransactionBatchRequest, @@ -28,8 +41,14 @@ type AddTransactionBatchRequest = { getChainId: (networkClientId: string) => Hex; getEthQuery: (networkClientId: string) => EthQuery; getInternalAccounts: () => Hex[]; + getTransaction: (id: string) => TransactionMeta; messenger: TransactionControllerMessenger; + publishBatchHook?: PublishBatchHook; request: TransactionBatchRequest; + updateTransaction: ( + options: { transactionId: string }, + callback: (transactionMeta: TransactionMeta) => void, + ) => void; }; type IsAtomicBatchSupportedRequest = { @@ -62,10 +81,15 @@ export async function addTransactionBatch( request: userRequest, }); - const { from, networkClientId, requireApproval, transactions } = userRequest; + const { from, networkClientId, requireApproval, transactions, useHook } = + userRequest; log('Adding', userRequest); + if (useHook) { + return await addTransactionBatchWithHook(request); + } + const chainId = getChainId(networkClientId); const ethQuery = request.getEthQuery(networkClientId); const isChainSupported = doesChainSupportEIP7702(chainId, messenger); @@ -163,3 +187,179 @@ export async function isAtomicBatchSupported( return chainIds; } + +/** + * Process a batch transaction using a publish batch hook. + * + * @param request - The request object including the user request and necessary callbacks. + * @returns The batch result object including the batch ID. + */ +async function addTransactionBatchWithHook( + request: AddTransactionBatchRequest, +): Promise { + const { publishBatchHook, request: userRequest } = request; + + const { + from, + networkClientId, + transactions: nestedTransactions, + } = userRequest; + + log('Adding transaction batch using hook', userRequest); + + if (!publishBatchHook) { + log('No publish batch hook provided'); + throw new Error('No publish batch hook provided'); + } + + const batchId = uuid(); + const transactionCount = nestedTransactions.length; + const collectHook = new CollectPublishHook(transactionCount); + const publishHook = collectHook.getHook(); + const hookTransactions: Omit[] = []; + + try { + for (const nestedTransaction of nestedTransactions) { + const hookTransaction = await processTransactionWithHook( + batchId, + nestedTransaction, + publishHook, + request, + ); + + hookTransactions.push(hookTransaction); + } + + const { signedTransactions } = await collectHook.ready(); + + const transactions = hookTransactions.map((transaction, index) => ({ + ...transaction, + signedTx: signedTransactions[index], + })); + + log('Calling publish batch hook', { from, networkClientId, transactions }); + + const result = await publishBatchHook({ + from, + networkClientId, + transactions, + }); + + log('Publish batch hook result', result); + + if (!result) { + throw new Error('Publish batch hook did not return a result'); + } + + const transactionHashes = result.results.map( + ({ transactionHash }) => transactionHash, + ); + + collectHook.success(transactionHashes); + + log('Completed batch transaction with hook', transactionHashes); + + return { + batchId, + }; + } catch (error) { + log('Publish batch hook failed', error); + + collectHook.error(error); + + throw error; + } +} + +/** + * Process a single transaction with a publish batch hook. + * + * @param batchId - ID of the transaction batch. + * @param nestedTransaction - The nested transaction request. + * @param publishHook - The publish hook to use for each transaction. + * @param request - The request object including the user request and necessary callbacks. + * @returns The single transaction request to be processed by the publish batch hook. + */ +async function processTransactionWithHook( + batchId: string, + nestedTransaction: TransactionBatchSingleRequest, + publishHook: PublishHook, + request: AddTransactionBatchRequest, +) { + const { existingTransaction, params } = nestedTransaction; + + const { + addTransaction, + getTransaction, + request: userRequest, + updateTransaction, + } = request; + + const { from, networkClientId } = userRequest; + + if (existingTransaction) { + const { id, onPublish, signedTransaction } = existingTransaction; + const transactionMeta = getTransaction(id); + + const data = params.data as Hex | undefined; + const to = params.to as Hex | undefined; + const value = params.value as Hex | undefined; + + const existingParams: BatchTransactionParams = { + data, + to, + value, + }; + + updateTransaction({ transactionId: id }, (_transactionMeta) => { + _transactionMeta.batchId = batchId; + }); + + publishHook(transactionMeta, signedTransaction) + .then(onPublish) + .catch(() => { + // Intentionally empty + }); + + log('Processed existing transaction with hook', { + id, + params: existingParams, + }); + + return { + id, + params: existingParams, + }; + } + + const { transactionMeta } = await addTransaction( + { + ...params, + from, + }, + { + batchId, + networkClientId, + publishHook, + requireApproval: false, + }, + ); + + const { id, txParams } = transactionMeta; + const data = txParams.data as Hex | undefined; + const to = txParams.to as Hex | undefined; + const value = txParams.value as Hex | undefined; + + const newParams: BatchTransactionParams = { + data, + to, + value, + }; + + log('Processed new transaction with hook', { id, params: newParams }); + + return { + id, + params: newParams, + }; +} diff --git a/packages/transaction-controller/src/utils/nonce.test.ts b/packages/transaction-controller/src/utils/nonce.test.ts index b15cafe3427..d4a2dc4405f 100644 --- a/packages/transaction-controller/src/utils/nonce.test.ts +++ b/packages/transaction-controller/src/utils/nonce.test.ts @@ -168,7 +168,7 @@ describe('nonce', () => { const result = getAndFormatTransactionsForNonceTracker( '0x2', fromAddress, - TransactionStatus.confirmed, + [TransactionStatus.confirmed], inputTransactions, ); diff --git a/packages/transaction-controller/src/utils/nonce.ts b/packages/transaction-controller/src/utils/nonce.ts index b95a73a1682..bda28bf22f7 100644 --- a/packages/transaction-controller/src/utils/nonce.ts +++ b/packages/transaction-controller/src/utils/nonce.ts @@ -51,14 +51,14 @@ export async function getNextNonce( * * @param currentChainId - Chain ID of the current network. * @param fromAddress - Address of the account from which the transactions to filter from are sent. - * @param transactionStatus - Status of the transactions for which to filter. + * @param transactionStatuses - Status of the transactions for which to filter. * @param transactions - Array of transactionMeta objects that have been prefiltered. * @returns Array of transactions formatted for the nonce tracker. */ export function getAndFormatTransactionsForNonceTracker( currentChainId: string, fromAddress: string, - transactionStatus: TransactionStatus, + transactionStatuses: TransactionStatus[], transactions: TransactionMeta[], ): NonceTrackerTransaction[] { return transactions @@ -67,7 +67,7 @@ export function getAndFormatTransactionsForNonceTracker( !isTransfer && !isUserOperation && chainId === currentChainId && - status === transactionStatus && + transactionStatuses.includes(status) && from.toLowerCase() === fromAddress.toLowerCase(), ) .map(({ status, txParams: { from, gas, value, nonce } }) => {