Skip to content

Commit

Permalink
feat: Allow explicit application of type arguments (#821)
Browse files Browse the repository at this point in the history
Closes #770
  • Loading branch information
mark-koch authored Feb 26, 2025
1 parent 3632ec6 commit 8f90c04
Show file tree
Hide file tree
Showing 22 changed files with 396 additions and 6 deletions.
31 changes: 26 additions & 5 deletions guppylang/checker/errors/type_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ class MissingReturnValueError(Error):
ty: Type


@dataclass(frozen=True)
class TypeApplyNotGenericError(Error):
title: ClassVar[str] = "Not generic"
span_label: ClassVar[str] = (
"{thing} is not generic, so no type parameters can be provided"
)
func_name: str | None

@property
def thing(self) -> str:
return f"`{self.func_name}`" if self.func_name else "This function"


@dataclass(frozen=True)
class NotCallableError(Error):
title: ClassVar[str] = "Not callable"
Expand All @@ -166,28 +179,36 @@ class NotCallableError(Error):
@dataclass(frozen=True)
class WrongNumberOfArgsError(Error):
title: ClassVar[str] = "" # Custom implementation in `rendered_title`
span_label: ClassVar[str] = "Expected {expected} function arguments, got `{actual}`"
expected: int
actual: int
detailed: bool = True
is_type_apply: bool = False

@property
def rendered_title(self) -> str:
return (
"Not enough arguments"
f"Not enough {self.argument_kind}s"
if self.expected > self.actual
else "Too many arguments"
else f"Too many {self.argument_kind}s"
)

@property
def argument_kind(self) -> str:
return "type argument" if self.is_type_apply else "argument"

@property
def rendered_span_label(self) -> str:
if not self.detailed:
return f"Expected {self.expected}, got {self.actual}"
diff = self.expected - self.actual
if diff < 0:
msg = "Unexpected arguments" if diff < -1 else "Unexpected argument"
msg = f"Unexpected {self.argument_kind}"
if diff < -1:
msg += "s"
else:
msg = "Missing arguments" if diff > 1 else "Missing argument"
msg = f"Missing {self.argument_kind}"
if diff > 1:
msg += "s"
return f"{msg} (expected {self.expected}, got {self.actual})"

@dataclass(frozen=True)
Expand Down
39 changes: 39 additions & 0 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
ModuleMemberNotFoundError,
NonLinearInstantiateError,
NotCallableError,
TypeApplyNotGenericError,
TypeInferenceError,
TypeMismatchError,
UnaryOperatorNotDefinedError,
Expand Down Expand Up @@ -119,6 +120,7 @@
string_type,
)
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.parsing import arg_from_ast
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
ExistentialTypeVar,
Expand Down Expand Up @@ -603,6 +605,11 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]:

def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]:
node.value, ty = self.synthesize(node.value)
# Special case for subscripts on functions: Those are type applications
if isinstance(ty, FunctionType):
inst = check_type_apply(ty, node, self.ctx)
return instantiate_poly(node.value, ty, inst), ty.instantiate(inst)
# Otherwise, it's a regular __getitem__ subscript
item_expr, item_ty = self.synthesize(node.slice)
# Give the item a unique name so we can refer to it later in case we also want
# to compile a call to `__setitem__`
Expand Down Expand Up @@ -850,6 +857,38 @@ def try_coerce_to(
return None


def check_type_apply(ty: FunctionType, node: ast.Subscript, ctx: Context) -> Inst:
"""Checks a `f[T1, T2, ...]` type application of a generic function."""
func = node.value
arg_exprs = (
node.slice.elts
if isinstance(node.slice, ast.Tuple) and len(node.slice.elts) > 0
else [node.slice]
)
globals = ctx.globals

if not ty.parametrized:
func_name = globals[func.def_id].name if isinstance(func, GlobalName) else None
raise GuppyError(TypeApplyNotGenericError(node, func_name))

exp, act = len(ty.params), len(arg_exprs)
assert exp > 0
assert act > 0
if exp != act:
if exp < act:
span = Span(to_span(arg_exprs[exp]).start, to_span(arg_exprs[-1]).end)
else:
span = Span(to_span(arg_exprs[-1]).end, to_span(node).end)
err = WrongNumberOfArgsError(span, exp, act, detailed=True, is_type_apply=True)
err.add_sub_diagnostic(WrongNumberOfArgsError.SignatureHint(None, ty))
raise GuppyError(err)

return [
param.check_arg(arg_from_ast(arg_expr, globals, ctx.generic_params), arg_expr)
for arg_expr, param in zip(arg_exprs, ty.params, strict=True)
]


def check_num_args(
exp: int, act: int, node: AstNode, sig: FunctionType | None = None
) -> None:
Expand Down
3 changes: 2 additions & 1 deletion guppylang/tys/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from guppylang.cfg.builder import is_py_expression
from guppylang.checker.core import Context, Globals, Locals
from guppylang.checker.errors.generic import ExpectedError
from guppylang.checker.expr_checker import eval_py_expr
from guppylang.definition.common import Definition
from guppylang.definition.module import ModuleDef
from guppylang.definition.parameter import ParamDef
Expand Down Expand Up @@ -92,6 +91,8 @@ def arg_from_ast(

# Py-expressions can also be used to specify static numbers
if py_expr := is_py_expression(node):
from guppylang.checker.expr_checker import eval_py_expr

v = eval_py_expr(py_expr, Context(globals, Locals({}), {}))
if isinstance(v, int):
nat_ty = NumericType(NumericType.Kind.Nat)
Expand Down
8 changes: 8 additions & 0 deletions tests/error/poly_errors/arg_mismatch6.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Type mismatch (at $FILE:17:13)
|
15 | @guppy(module)
16 | def main(x: float) -> None:
17 | foo[int](x)
| ^ Expected argument of type `int`, got `float`

Guppy compilation failed due to 1 previous error
20 changes: 20 additions & 0 deletions tests/error/poly_errors/arg_mismatch6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")

T = guppy.type_var("T", module=module)


@guppy.declare(module)
def foo(x: T) -> None:
...


@guppy(module)
def main(x: float) -> None:
foo[int](x)


module.compile()
9 changes: 9 additions & 0 deletions tests/error/poly_errors/arg_mismatch7.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Error: Type mismatch (at $FILE:17:12)
|
15 | @guppy(module)
16 | def main(xs: array[int, 42]) -> None:
17 | foo[43](xs)
| ^^ Expected argument of type `array[int, 43]`, got `array[int,
| 42]`

Guppy compilation failed due to 1 previous error
20 changes: 20 additions & 0 deletions tests/error/poly_errors/arg_mismatch7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import array

module = GuppyModule("test")

n = guppy.nat_var("n", module=module)


@guppy.declare(module)
def foo(x: array[int, n]) -> None:
...


@guppy(module)
def main(xs: array[int, 42]) -> None:
foo[43](xs)


module.compile()
8 changes: 8 additions & 0 deletions tests/error/poly_errors/arg_mismatch8.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Type mismatch (at $FILE:18:6)
|
16 | def main(x: float) -> None:
17 | f = foo[int]
18 | f(x)
| ^ Expected argument of type `int`, got `float`

Guppy compilation failed due to 1 previous error
21 changes: 21 additions & 0 deletions tests/error/poly_errors/arg_mismatch8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")

T = guppy.type_var("T", module=module)


@guppy.declare(module)
def foo(x: T) -> None:
...


@guppy(module)
def main(x: float) -> None:
f = foo[int]
f(x)


module.compile()
8 changes: 8 additions & 0 deletions tests/error/poly_errors/arg_mismatch9.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Type mismatch (at $FILE:18:23)
|
16 | @guppy(module)
17 | def main(x: int, y: float) -> None:
18 | foo[float, int](x, y)
| ^ Expected argument of type `int`, got `float`

Guppy compilation failed due to 1 previous error
21 changes: 21 additions & 0 deletions tests/error/poly_errors/arg_mismatch9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import array

module = GuppyModule("test")

S = guppy.type_var("S", module=module)
T = guppy.type_var("T", module=module)


@guppy.declare(module)
def foo(x: S, y: T) -> None:
...


@guppy(module)
def main(x: int, y: float) -> None:
foo[float, int](x, y)


module.compile()
Empty file removed tests/error/poly_errors/define.err
Empty file.
10 changes: 10 additions & 0 deletions tests/error/poly_errors/type_apply_not_enough.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Error: Not enough type arguments (at $FILE:19:16)
|
17 | @guppy(module)
18 | def main() -> None:
19 | foo[int, int]
| ^ Missing type argument (expected 3, got 2)

Note: Function signature is `forall S, T, U. (S, T, U) -> None`

Guppy compilation failed due to 1 previous error
22 changes: 22 additions & 0 deletions tests/error/poly_errors/type_apply_not_enough.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import array

module = GuppyModule("test")

S = guppy.type_var("S", module=module)
T = guppy.type_var("T", module=module)
U = guppy.type_var("U", module=module)


@guppy.declare(module)
def foo(x: S, y: T, z: U) -> None:
...


@guppy(module)
def main() -> None:
foo[int, int]


module.compile()
8 changes: 8 additions & 0 deletions tests/error/poly_errors/type_apply_not_generic1.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Error: Not generic (at $FILE:19:4)
|
17 | @guppy(module)
18 | def main() -> None:
19 | foo[int](0)
| ^^^^^^^^ `foo` is not generic, so no type parameters can be provided

Guppy compilation failed due to 1 previous error
22 changes: 22 additions & 0 deletions tests/error/poly_errors/type_apply_not_generic1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import array

module = GuppyModule("test")

S = guppy.type_var("S", module=module)
T = guppy.type_var("T", module=module)
U = guppy.type_var("U", module=module)


@guppy.declare(module)
def foo(x: int) -> None:
...


@guppy(module)
def main() -> None:
foo[int](0)


module.compile()
9 changes: 9 additions & 0 deletions tests/error/poly_errors/type_apply_not_generic2.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Error: Not generic (at $FILE:20:4)
|
18 | def main() -> None:
19 | f = foo
20 | f[int](0)
| ^^^^^^ This function is not generic, so no type parameters can be
| provided

Guppy compilation failed due to 1 previous error
23 changes: 23 additions & 0 deletions tests/error/poly_errors/type_apply_not_generic2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import array

module = GuppyModule("test")

S = guppy.type_var("S", module=module)
T = guppy.type_var("T", module=module)
U = guppy.type_var("U", module=module)


@guppy.declare(module)
def foo(x: int) -> None:
...


@guppy(module)
def main() -> None:
f = foo
f[int](0)


module.compile()
10 changes: 10 additions & 0 deletions tests/error/poly_errors/type_apply_too_many.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Error: Too many type arguments (at $FILE:17:13)
|
15 | @guppy(module)
16 | def main() -> None:
17 | foo[int, float, bool]
| ^^^^^^^^^^^ Unexpected type arguments (expected 1, got 3)

Note: Function signature is `forall T. T -> None`

Guppy compilation failed due to 1 previous error
Loading

0 comments on commit 8f90c04

Please sign in to comment.