Skip to content

Commit

Permalink
feat: Type check lists and comprehensions (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch authored Jan 16, 2024
1 parent e1e4da3 commit e22fde7
Show file tree
Hide file tree
Showing 14 changed files with 347 additions and 27 deletions.
21 changes: 20 additions & 1 deletion guppy/cfg/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self

from guppy.ast_util import AstNode, name_nodes_in_ast
from guppy.nodes import NestedFunctionDef, PyExpr
from guppy.nodes import DesugaredListComp, NestedFunctionDef, PyExpr

if TYPE_CHECKING:
from guppy.cfg.cfg import BaseCFG
Expand Down Expand Up @@ -119,6 +119,25 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_DesugaredListComp(self, node: DesugaredListComp) -> None:
# Names bound in the comprehension are only available inside, so we shouldn't
# update `self.stats` with assignments
inner_visitor = VariableVisitor(self.bb)
inner_stats = inner_visitor.stats

# The generators are evaluated left to right
for gen in node.generators:
inner_visitor.visit(gen.iter_assign)
inner_visitor.visit(gen.hasnext_assign)
inner_visitor.visit(gen.next_assign)
for cond in gen.ifs:
inner_visitor.visit(cond)
inner_visitor.visit(node.elt)

self.stats.used |= {
x: n for x, n in inner_stats.used.items() if x not in self.stats.assigned
}

def visit_PyExpr(self, node: PyExpr) -> None:
# Don't look into `py(...)` expressions
pass
Expand Down
4 changes: 2 additions & 2 deletions guppy/checker/cfg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from guppy.ast_util import line_col
from guppy.cfg.bb import BB
from guppy.cfg.cfg import CFG, BaseCFG
from guppy.checker.core import Context, Globals, Variable
from guppy.checker.core import Context, Globals, Locals, Variable
from guppy.checker.expr_checker import ExprSynthesizer, to_bool
from guppy.checker.stmt_checker import StmtChecker
from guppy.error import GuppyError
Expand Down Expand Up @@ -127,7 +127,7 @@ def check_bb(
raise GuppyError(f"Variable `{x}` is not defined", use)

# Check the basic block
ctx = Context(globals, {v.name: v for v in inputs})
ctx = Context(globals, Locals({v.name: v for v in inputs}))
checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements)

# If we branch, we also have to check the branch predicate
Expand Down
44 changes: 42 additions & 2 deletions guppy/checker/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import ast
import copy
import itertools
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, NamedTuple

Expand Down Expand Up @@ -110,8 +113,45 @@ def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034
return self


# Local variable mapping
Locals = dict[str, Variable]
@dataclass
class Locals:
"""Scoped mapping from names to variables"""

vars: dict[str, Variable]
parent_scope: "Locals | None" = None

def __getitem__(self, item: str) -> Variable:
if item not in self.vars and self.parent_scope:
return self.parent_scope[item]

return self.vars[item]

def __setitem__(self, key: str, value: Variable) -> None:
self.vars[key] = value

def __iter__(self) -> Iterator[str]:
parent_iter = iter(self.parent_scope) if self.parent_scope else iter(())
return itertools.chain(iter(self.vars), parent_iter)

def __contains__(self, item: str) -> bool:
return (item in self.vars) or (
self.parent_scope is not None and item in self.parent_scope
)

def __copy__(self) -> "Locals":
# Make a copy of the var map so that mutating the copy doesn't
# mutate our variable mapping
return Locals(self.vars.copy(), copy.copy(self.parent_scope))

def keys(self) -> set[str]:
parent_keys = self.parent_scope.keys() if self.parent_scope else set()
return parent_keys | self.vars.keys()

def items(self) -> Iterable[tuple[str, Variable]]:
parent_items = (
iter(self.parent_scope.items()) if self.parent_scope else iter(())
)
return itertools.chain(self.vars.items(), parent_items)


class Context(NamedTuple):
Expand Down
Loading

0 comments on commit e22fde7

Please sign in to comment.