diff --git a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts index 03d4bb312..77a6feba7 100644 --- a/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts +++ b/packages/safe-ds-lang/src/language/typing/safe-ds-type-computer.ts @@ -125,11 +125,11 @@ export class SafeDsTypeComputer { private readonly typeChecker: SafeDsTypeChecker; /** - * Contains all lambda parameters that are currently being computed. When computing the types of lambda parameters, - * they must only access the type of the containing lambda, if they are not contained in this set themselves. - * Otherwise, this would cause endless recursion. + * Contains all calls for which we currently compute substitutions. This prevents endless recursion, since the + * substitutions of a call depend on the inferred types of their arguments, which may be lambdas. The inferred type + * of a lambda in turn depends on the substitutions of the call it is passed to. */ - private readonly incompleteLambdaParameters = new Set(); + private readonly incompleteCalls = new Set(); private readonly nodeTypeCache: WorkspaceCache; constructor(services: SafeDsServices) { @@ -301,18 +301,21 @@ export class SafeDsTypeComputer { // Lambda passed as argument if (isSdsArgument(containerOfLambda)) { - // Lookup parameter type in lambda unless the lambda is being computed. These contain the correct - // substitutions for type parameters. - if (!this.incompleteLambdaParameters.has(node)) { - return this.computeType(containingCallable); - } - const parameter = this.nodeMapper.argumentToParameter(containerOfLambda); if (!parameter) { return UnknownType; } - return this.computeType(parameter); + let result = this.computeType(parameter); + + // Substitute type parameters + const call = AstUtils.getContainerOfType(containerOfLambda, isSdsCall); + if (call) { + const substitutions = this.computeSubstitutionsForCall(call, containerOfLambda.$containerIndex); + result = result.substituteTypeParameters(substitutions); + } + + return result; } // Lambda passed as default value @@ -569,29 +572,16 @@ export class SafeDsTypeComputer { } private computeTypeOfLambda(node: SdsLambda): Type { - // Remember lambda parameters const parameters = getParameters(node); - parameters.forEach((it) => { - this.incompleteLambdaParameters.add(it); - }); - const parameterEntries = parameters.map((it) => new NamedTupleEntry(it, it.name, this.computeType(it))); const resultEntries = this.buildLambdaResultEntries(node); - const unsubstitutedType = this.factory.createCallableType( + return this.factory.createCallableType( node, undefined, this.factory.createNamedTupleType(...parameterEntries), this.factory.createNamedTupleType(...resultEntries), ); - const substitutions = this.computeSubstitutionsForLambda(node, unsubstitutedType); - - // Forget lambda parameters - parameters.forEach((it) => { - this.incompleteLambdaParameters.delete(it); - }); - - return unsubstitutedType.substituteTypeParameters(substitutions); } private buildLambdaResultEntries(node: SdsLambda): NamedTupleEntry[] { @@ -843,16 +833,17 @@ export class SafeDsTypeComputer { /** * Computes substitutions for the type parameters of a callable in the context of a call. * - * @param node The call to compute substitutions for. + * @param node + * The call to compute substitutions for. + * @param argumentEndIndex + * The index of the first argument that should not be considered for the computation. If not specified, all + * arguments are considered. + * * @returns The computed substitutions for the type parameters of the callable. */ - computeSubstitutionsForCall(node: SdsAbstractCall): TypeParameterSubstitutions { - return this.doComputeSubstitutionsForCall(node); - } - - private doComputeSubstitutionsForCall( + computeSubstitutionsForCall( node: SdsAbstractCall, - precomputedArgumentTypes?: Map, + argumentEndIndex: number | undefined = undefined, ): TypeParameterSubstitutions { // Compute substitutions for member access const substitutionsFromReceiver = @@ -860,6 +851,14 @@ export class SafeDsTypeComputer { ? this.computeSubstitutionsForMemberAccess(node.receiver) : NO_SUBSTITUTIONS; + // Check if the call is already being computed + if (this.incompleteCalls.has(node)) { + return substitutionsFromReceiver; + } + + // Remember call + this.incompleteCalls.add(node); + // Compute substitutions for arguments const callable = this.nodeMapper.callToCallable(node); const typeParameters = getTypeParameters(callable); @@ -868,17 +867,12 @@ export class SafeDsTypeComputer { } const parameters = getParameters(callable); - const args = getArguments(node); + const args = getArguments(node).slice(0, argumentEndIndex); const parametersToArguments = this.nodeMapper.parametersToArguments(parameters, args); const parameterTypesToArgumentTypes: [Type, Type][] = parameters.map((parameter) => { const argument = parametersToArguments.get(parameter); - return [ - this.computeType(parameter.type), - // Use precomputed argument types (lambdas) if available. This prevents infinite recursion. - precomputedArgumentTypes?.get(argument?.value) ?? - this.computeType(argument?.value ?? parameter.defaultValue), - ]; + return [this.computeType(parameter.type), this.computeType(argument?.value ?? parameter.defaultValue)]; }); const substitutionsFromArguments = this.computeSubstitutionsForArguments( @@ -886,6 +880,9 @@ export class SafeDsTypeComputer { parameterTypesToArgumentTypes, ); + // Forget call + this.incompleteCalls.delete(node); + return new Map([...substitutionsFromReceiver, ...substitutionsFromArguments]); } @@ -918,22 +915,6 @@ export class SafeDsTypeComputer { return this.computeSubstitutionsForArguments(ownTypeParameters, ownTypesToOverriddenTypes); } - private computeSubstitutionsForLambda(node: SdsLambda, unsubstitutedType: Type): TypeParameterSubstitutions { - const containerOfLambda = node.$container; - if (!isSdsArgument(containerOfLambda)) { - return NO_SUBSTITUTIONS; - } - - const containingCall = AstUtils.getContainerOfType(containerOfLambda, isSdsCall); - if (!containingCall) { - /* c8 ignore next 2 */ - return NO_SUBSTITUTIONS; - } - - const precomputedArgumentTypes = new Map([[node, unsubstitutedType]]); - return this.doComputeSubstitutionsForCall(containingCall, precomputedArgumentTypes); - } - private computeSubstitutionsForMemberAccess(node: SdsMemberAccess): TypeParameterSubstitutions { const receiverType = this.computeType(node.receiver); if (receiverType instanceof ClassType) { diff --git a/packages/safe-ds-lang/src/language/validation/types.ts b/packages/safe-ds-lang/src/language/validation/types.ts index caf9ee2a0..8aa31ed21 100644 --- a/packages/safe-ds-lang/src/language/validation/types.ts +++ b/packages/safe-ds-lang/src/language/validation/types.ts @@ -51,7 +51,7 @@ export const argumentTypesMustMatchParameterTypes = (services: SafeDsServices) = return; } - const argumentType = typeComputer.computeType(argument).substituteTypeParameters(substitutions); + const argumentType = typeComputer.computeType(argument); const parameterType = typeComputer.computeType(parameter).substituteTypeParameters(substitutions); if (!typeChecker.isSubtypeOf(argumentType, parameterType, { ignoreParameterNames: true })) { diff --git a/packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts b/packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts index 6b646a6ea..6208ae3b1 100644 --- a/packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts +++ b/packages/safe-ds-lang/tests/helpers/nodeFinder.test.ts @@ -65,14 +65,14 @@ describe('getNodeOfType', async () => { it('should throw if no node is found', async () => { const code = ''; - expect(async () => { + await expect(async () => { await getNodeOfType(services, code, isSdsClass); }).rejects.toThrowErrorMatchingSnapshot(); }); it('should throw if not enough nodes are found', async () => { const code = `class C`; - expect(async () => { + await expect(async () => { await getNodeOfType(services, code, isSdsClass, 1); }).rejects.toThrowErrorMatchingSnapshot(); }); diff --git a/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of block lambdas/that are passed as arguments/with type parameters.sdsdev b/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of block lambdas/that are passed as arguments/with type parameters.sdsdev index 337e11215..4a0511290 100644 --- a/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of block lambdas/that are passed as arguments/with type parameters.sdsdev +++ b/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of block lambdas/that are passed as arguments/with type parameters.sdsdev @@ -22,6 +22,6 @@ segment mySegment() { // $TEST$ serialization literal<1> myFunction(1, (»p«) {}); - // $TEST$ serialization literal<""> + // $TEST$ serialization Nothing myFunction2((»p«) -> ""); } diff --git a/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of expression lambdas/that are passed as arguments/with type parameters.sdsdev b/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of expression lambdas/that are passed as arguments/with type parameters.sdsdev index e1cd35145..6c87f561b 100644 --- a/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of expression lambdas/that are passed as arguments/with type parameters.sdsdev +++ b/packages/safe-ds-lang/tests/resources/typing/declarations/parameters/of expression lambdas/that are passed as arguments/with type parameters.sdsdev @@ -22,6 +22,6 @@ segment mySegment() { // $TEST$ serialization literal<1> myFunction(1, (»p«) -> ""); - // $TEST$ serialization literal<""> + // $TEST$ serialization Nothing myFunction2((»p«) -> ""); } diff --git a/packages/safe-ds-lang/tests/resources/typing/expressions/block lambdas/that are passed as arguments/with type parameters.sdsdev b/packages/safe-ds-lang/tests/resources/typing/expressions/block lambdas/that are passed as arguments/with type parameters.sdsdev index f3e8b27b2..631b02920 100644 --- a/packages/safe-ds-lang/tests/resources/typing/expressions/block lambdas/that are passed as arguments/with type parameters.sdsdev +++ b/packages/safe-ds-lang/tests/resources/typing/expressions/block lambdas/that are passed as arguments/with type parameters.sdsdev @@ -8,7 +8,9 @@ class MyClass(param: T) sub MySuperclass { @Pure fun myMethod(callback: (p: T) -> ()) } -@Pure fun myFunction(p: T, id: (p: T) -> (r: T)) +@Pure fun myFunction1(p: T, id: (p: T) -> (r: T)) +@Pure fun myFunction2(id: (p: T) -> (r: T)) +@Pure fun myFunction3(producer: () -> (r: T), consumer: (p: T) -> ()) segment mySegment() { // $TEST$ serialization (p: literal<1>) -> (r: literal<1>) @@ -22,7 +24,19 @@ segment mySegment() { }«); // $TEST$ serialization (p: literal<1>) -> (r: literal<1>) - myFunction(1, »(p) { + myFunction1(1, »(p) { yield r = p; }«); + + // $TEST$ serialization (p: Nothing) -> (r: literal<1>) + myFunction2(»(p) { + yield r = 1; + }«); + + // $TEST$ serialization () -> (r: literal<1>) + // $TEST$ serialization (p: literal<1>) -> () + myFunction3( + »() { yield r = 1; }«, + »(p) {}«, + ); } diff --git a/packages/safe-ds-lang/tests/resources/typing/expressions/expression lambdas/that are passed as arguments/with type parameters.sdsdev b/packages/safe-ds-lang/tests/resources/typing/expressions/expression lambdas/that are passed as arguments/with type parameters.sdsdev index 4cebf017e..24e7148fb 100644 --- a/packages/safe-ds-lang/tests/resources/typing/expressions/expression lambdas/that are passed as arguments/with type parameters.sdsdev +++ b/packages/safe-ds-lang/tests/resources/typing/expressions/expression lambdas/that are passed as arguments/with type parameters.sdsdev @@ -8,15 +8,16 @@ class MyClass(param: T) sub MySuperclass { @Pure fun myMethod(callback: (p: T) -> ()) } -@Pure fun myFunction(p: T, id: (p: T) -> (r: T)) -segment mySegment() { - // $TEST$ serialization (p: literal<1>) -> (result: literal<1>) - MyClass(1).myMethod(»(p) -> p«); - - // $TEST$ serialization (p: literal<1>) -> (result: literal<1>) - MyClass(1).myInheritedMethod(»(p) -> p«); +@Pure fun myFunction1(p: T, id: (p: T) -> (r: T)) +@Pure fun myFunction2(id: (p: T) -> (r: T)) +@Pure fun myFunction3(producer: () -> (r: T), consumer: (p: T) -> ()) +segment mySegment() { + // $TEST$ serialization () -> (result: literal<1>) // $TEST$ serialization (p: literal<1>) -> (result: literal<1>) - myFunction(1, »(p) -> p«); + myFunction3( + »() -> 1«, + »(p) -> 1«, + ); } diff --git a/packages/safe-ds-lang/tests/resources/validation/types/checking/arguments/with type parameters.sdsdev b/packages/safe-ds-lang/tests/resources/validation/types/checking/arguments/with type parameters.sdsdev index c1bb33da0..38b4d211b 100644 --- a/packages/safe-ds-lang/tests/resources/validation/types/checking/arguments/with type parameters.sdsdev +++ b/packages/safe-ds-lang/tests/resources/validation/types/checking/arguments/with type parameters.sdsdev @@ -64,7 +64,7 @@ segment mySegment( // $TEST$ no error r"Expected type .* but got .*\." f(»(p) -> p«); - // $TEST$ no error r"Expected type .* but got .*\." + // $TEST$ error "Expected type '(p: literal<1>) -> (r: literal<1>)' but got '(p: Nothing) -> (result: literal<1>)'." f(»(p) -> 1«); // $TEST$ no error r"Expected type .* but got .*\."