Skip to content

Commit

Permalink
Merge pull request #872 from MatrixAI/feature-agent-stop-termination
Browse files Browse the repository at this point in the history
Adding cancellation to background handlers to prevent agent from being held open
  • Loading branch information
aryanjassal authored Feb 11, 2025
2 parents 7dcd635 + 584af7f commit 5b16a53
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 80 deletions.
38 changes: 25 additions & 13 deletions src/discovery/Discovery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class Discovery {
this.dispatchEvent(
new discoveryEvents.EventDiscoveryVertexProcessed({
detail: {
vertex,
vertex: vertex,
parent: parent ?? undefined,
},
}),
Expand Down Expand Up @@ -190,7 +190,7 @@ class Discovery {
this.dispatchEvent(
new discoveryEvents.EventDiscoveryVertexFailed({
detail: {
vertex,
vertex: vertex,
parent: parent ?? undefined,
message: e.message,
code: e.code,
Expand All @@ -206,9 +206,13 @@ class Discovery {
/**
* This handler is run periodically to check if nodes are ready to be rediscovered
*/
protected checkRediscoveryHandler: TaskHandler = async () => {
protected checkRediscoveryHandler: TaskHandler = async (
ctx: ContextTimed,
) => {
await this.checkRediscovery(
Date.now() - this.rediscoverVertexThresholdTime,
undefined,
ctx,
);
await this.taskManager.scheduleTask({
handlerId: this.checkRediscoveryHandlerId,
Expand Down Expand Up @@ -407,18 +411,18 @@ class Discovery {
const [type, id] = vertexId;
switch (type) {
case 'node':
return await this.processNode(id, ctx, lastProcessedCutoffTime);
return await this.processNode(id, lastProcessedCutoffTime, ctx);
case 'identity':
return await this.processIdentity(id, ctx, lastProcessedCutoffTime);
return await this.processIdentity(id, lastProcessedCutoffTime, ctx);
default:
never(`type must be either "node" or "identity" got "${type}"`);
}
}

protected async processNode(
nodeId: NodeId,
lastProcessedCutoffTime: number | undefined,
ctx: ContextTimed,
lastProcessedCutoffTime?: number,
) {
// If the vertex we've found is our own node, we simply get our own chain
const processedTime = Date.now();
Expand Down Expand Up @@ -456,7 +460,6 @@ class Discovery {
}
// Iterate over each of the claims in the chain (already verified).
for (const signedClaim of Object.values(vertexChainData)) {
if (ctx.signal.aborted) throw ctx.signal.reason;
switch (signedClaim.payload.typ) {
case 'ClaimLinkNode':
await this.processClaimLinkNode(
Expand All @@ -469,8 +472,8 @@ class Discovery {
await this.processClaimLinkIdentity(
signedClaim as SignedClaim<ClaimLinkIdentity>,
nodeId,
ctx,
lastProcessedCutoffTime,
ctx,
);
break;
default:
Expand Down Expand Up @@ -553,8 +556,8 @@ class Discovery {
protected async processClaimLinkIdentity(
signedClaim: SignedClaim<ClaimLinkIdentity>,
nodeId: NodeId,
ctx: ContextTimed,
lastProcessedCutoffTime = Date.now() - this.rediscoverSkipTime,
ctx: ContextTimed,
): Promise<void> {
// Checking the claim is valid
const publicKey = keysUtils.publicKeyFromNodeId(nodeId);
Expand Down Expand Up @@ -655,8 +658,8 @@ class Discovery {

protected async processIdentity(
id: ProviderIdentityId,
ctx: ContextTimed,
lastProcessedCutoffTime = Date.now() - this.rediscoverSkipTime,
ctx: ContextTimed,
) {
// If the next vertex is an identity, perform a social discovery
// Firstly get the identity info of this identity
Expand Down Expand Up @@ -789,7 +792,7 @@ class Discovery {
parent?: GestaltId,
ignoreActive: boolean = false,
tran?: DBTransaction,
) {
): Promise<void> {
if (tran == null) {
return this.db.withTransactionF((tran) =>
this.scheduleDiscoveryForVertex(
Expand Down Expand Up @@ -852,7 +855,7 @@ class Discovery {
],
lazy: true,
deadline: this.discoverVertexTimeoutTime,
delay,
delay: delay,
},
tran,
);
Expand Down Expand Up @@ -1034,10 +1037,17 @@ class Discovery {
public async checkRediscovery(
lastProcessedCutoffTime: number,
tran?: DBTransaction,
ctx?: Partial<ContextTimedInput>,
): Promise<void>;
@timedCancellable(true)
public async checkRediscovery(
lastProcessedCutoffTime: number,
tran: DBTransaction | undefined,
@context ctx: ContextTimed,
): Promise<void> {
if (tran == null) {
return this.db.withTransactionF((tran) =>
this.checkRediscovery(lastProcessedCutoffTime, tran),
this.checkRediscovery(lastProcessedCutoffTime, tran, ctx),
);
}

Expand All @@ -1055,6 +1065,7 @@ class Discovery {
},
tran,
)) {
ctx.signal.throwIfAborted();
gestaltIds.push([
gestaltsUtils.encodeGestaltId(gestaltId),
lastProcessedTime,
Expand Down Expand Up @@ -1091,6 +1102,7 @@ class Discovery {
[this.constructor.name, this.discoverVertexHandlerId, gestaltIdEncoded],
tran,
)) {
ctx.signal.throwIfAborted();
if (taskExisting == null) {
taskExisting = task;
continue;
Expand Down
100 changes: 84 additions & 16 deletions src/nodes/NodeConnectionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ import {
status,
} from '@matrixai/async-init/dist/StartStop';
import { AbstractEvent, EventAll } from '@matrixai/events';
import { context, timedCancellable } from '@matrixai/contexts/dist/decorators';
import {
context,
timed,
timedCancellable,
} from '@matrixai/contexts/dist/decorators';
import { Semaphore } from '@matrixai/async-locks';
import { PromiseCancellable } from '@matrixai/async-cancellable';
import NodeConnection from './NodeConnection';
Expand Down Expand Up @@ -768,13 +772,15 @@ class NodeConnectionManager {
* itself is such that we can pass targetNodeId as a parameter (as opposed to
* an acquire function with no parameters).
* @param targetNodeId Id of target node to communicate with
* @param ctx
* @returns ResourceAcquire Resource API for use in with contexts
*/
public acquireConnection(
targetNodeId: NodeId,
ctx: ContextTimed,
): ResourceAcquire<NodeConnection> {
return async () => {
await this.isAuthenticatedP(targetNodeId);
await this.isAuthenticatedP(targetNodeId, ctx);
return await this.acquireConnectionInternal(targetNodeId)();
};
}
Expand All @@ -785,14 +791,22 @@ class NodeConnectionManager {
* doesn't exist.
* for use with normal arrow function
* @param targetNodeId Id of target node to communicate with
* @param ctx
* @param f Function to handle communication
*/
public async withConnF<T>(
targetNodeId: NodeId,
ctx: Partial<ContextTimedInput> | undefined,
f: (conn: NodeConnection) => Promise<T>,
): Promise<T>;
@timedCancellable(true)
public async withConnF<T>(
targetNodeId: NodeId,
@context ctx: ContextTimed,
f: (conn: NodeConnection) => Promise<T>,
): Promise<T> {
return await withF(
[this.acquireConnection(targetNodeId)],
[this.acquireConnection(targetNodeId, ctx)],
async ([conn]) => {
return await f(conn);
},
Expand All @@ -805,14 +819,22 @@ class NodeConnectionManager {
* doesn't exist.
* for use with a generator function
* @param targetNodeId Id of target node to communicate with
* @param ctx
* @param g Generator function to handle communication
*/
public withConnG<T, TReturn, TNext>(
targetNodeId: NodeId,
ctx: Partial<ContextTimedInput> | undefined,
g: (conn: NodeConnection) => AsyncGenerator<T, TReturn, TNext>,
): AsyncGenerator<T, TReturn, TNext>;
@ready(new nodesErrors.ErrorNodeConnectionManagerNotRunning())
@timed()
public async *withConnG<T, TReturn, TNext>(
targetNodeId: NodeId,
@context ctx: ContextTimed,
g: (conn: NodeConnection) => AsyncGenerator<T, TReturn, TNext>,
): AsyncGenerator<T, TReturn, TNext> {
const acquire = this.acquireConnection(targetNodeId);
const acquire = this.acquireConnection(targetNodeId, ctx);
const [release, conn] = await acquire();
let caughtError: Error | undefined;
try {
Expand Down Expand Up @@ -975,6 +997,7 @@ class NodeConnectionManager {
}
const { host, port } = await this.withConnF(
nodeIdSignaller,
ctx,
async (conn) => {
const client = conn.getClient();
const nodeIdSource = this.keyRing.getNodeId();
Expand Down Expand Up @@ -1440,8 +1463,27 @@ class NodeConnectionManager {
* @param targetNodeId - NodeId of the node that needs to initiate hole punching.
* @param address - Address the target needs to punch to.
* @param requestSignature - `base64url` encoded signature
* @param ctx
*/
public async handleNodesConnectionSignalInitial(
sourceNodeId: NodeId,
targetNodeId: NodeId,
address: {
host: Host;
port: Port;
},
requestSignature: string,
ctx?: Partial<ContextTimedInput>,
): Promise<{
host: Host;
port: Port;
}>;
@ready(new nodesErrors.ErrorNodeManagerNotRunning())
@timedCancellable(
true,
(nodeConnectionManager: NodeConnectionManager) =>
nodeConnectionManager.connectionConnectTimeoutTime,
)
public async handleNodesConnectionSignalInitial(
sourceNodeId: NodeId,
targetNodeId: NodeId,
Expand All @@ -1450,6 +1492,7 @@ class NodeConnectionManager {
port: Port;
},
requestSignature: string,
@context ctx: ContextTimed,
): Promise<{
host: Host;
port: Port;
Expand Down Expand Up @@ -1479,16 +1522,20 @@ class NodeConnectionManager {
this.keyRing.keyPair,
data,
);
const connectionSignalP = this.withConnF(targetNodeId, async (conn) => {
const client = conn.getClient();
await client.methods.nodesConnectionSignalFinal({
sourceNodeIdEncoded: nodesUtils.encodeNodeId(sourceNodeId),
targetNodeIdEncoded: nodesUtils.encodeNodeId(targetNodeId),
address,
requestSignature: requestSignature,
relaySignature: relaySignature.toString('base64url'),
});
})
const connectionSignalP = this.withConnF(
targetNodeId,
ctx,
async (conn) => {
const client = conn.getClient();
await client.methods.nodesConnectionSignalFinal({
sourceNodeIdEncoded: nodesUtils.encodeNodeId(sourceNodeId),
targetNodeIdEncoded: nodesUtils.encodeNodeId(targetNodeId),
address: address,
requestSignature: requestSignature,
relaySignature: relaySignature.toString('base64url'),
});
},
)
// Ignore results and failures, then are expected to happen and are allowed
.then(
() => {},
Expand Down Expand Up @@ -1745,19 +1792,40 @@ class NodeConnectionManager {
* Returns a promise that resolves once the connection has authenticated,
* otherwise it rejects with the authentication failure
* @param nodeId
* @param ctx
*/
public async isAuthenticatedP(nodeId: NodeId): Promise<void> {
public async isAuthenticatedP(
nodeId: NodeId,
ctx?: Partial<ContextTimedInput>,
): Promise<void>;
@timedCancellable(
true,
(nodeConnectionManager: NodeConnectionManager) =>
nodeConnectionManager.connectionConnectTimeoutTime,
)
public async isAuthenticatedP(
nodeId: NodeId,
@context ctx: ContextTimed,
): Promise<void> {
ctx.signal.throwIfAborted();
const targetNodeIdString = nodeId.toString() as NodeIdString;
const connectionsEntry = this.connections.get(targetNodeIdString);
if (connectionsEntry == null) {
throw new nodesErrors.ErrorNodeConnectionManagerConnectionNotFound();
}
const { p: abortP, rejectP: rejectAbortP } = utils.promise<never>();
const abortHandler = () => {
rejectAbortP(ctx.signal.reason);
};
ctx.signal.addEventListener('abort', abortHandler, { once: true });
try {
return await connectionsEntry.authenticatedP;
return await Promise.race([connectionsEntry.authenticatedP, abortP]);
} catch (e) {
// Capture the stacktrace here since knowing where we're waiting for authentication is more useful
Error.captureStackTrace(e);
throw e;
} finally {
ctx.signal.removeEventListener('abort', abortHandler);
}
}

Expand Down
Loading

0 comments on commit 5b16a53

Please sign in to comment.