Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some higher-order behavior of TypeGuard and TypeIs #719

Merged
merged 5 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Fix some higher-order behavior of `TypeGuard` and `TypeIs` (#719)
- Add support for `TypeIs` from PEP 742 (#718)
- More PEP 695 support: generic classes and functions. Scoping rules
are not yet fully implemented. (#703)
Expand Down
5 changes: 5 additions & 0 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,11 @@ def visit_FunctionDef(self, node: FunctionDefNode) -> Value:
expected_return = info.return_annotation | KnownValue(NotImplemented)
else:
expected_return = info.return_annotation
if isinstance(expected_return, AnnotatedValue):
expected_return, _ = unannotate_value(expected_return, TypeIsExtension)
expected_return, _ = unannotate_value(
expected_return, TypeGuardExtension
)

with self.asynq_checker.set_func_name(
node.name,
Expand Down
51 changes: 16 additions & 35 deletions pyanalyze/test_typeis.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,8 @@ def main(a: int):

@assert_passes()
def testTypeIsHigherOrder(self):
import collections.abc
from typing import Callable, TypeVar, Iterable, List
from typing_extensions import TypeIs
from pyanalyze.value import assert_is_value, GenericValue, AnyValue, AnySource
from typing_extensions import TypeIs, assert_type

T = TypeVar("T")
R = TypeVar("R")
Expand All @@ -143,13 +141,7 @@ def is_float(a: object) -> TypeIs[float]:
def capybara() -> None:
a: List[object] = ["a", 0, 0.0]
b = filter(is_float, a)
# TODO should be Iterable[float]
assert_is_value(
b,
GenericValue(
collections.abc.Iterable, [AnyValue(AnySource.generic_argument)]
),
)
assert_type(b, Iterable[float])

@assert_passes()
def testTypeIsMethod(self):
Expand Down Expand Up @@ -277,10 +269,8 @@ def main1(a: object) -> None:

@assert_passes()
def testTypeIsOverload(self):
import collections.abc
from typing import Callable, Iterable, Iterator, List, Optional, TypeVar
from typing_extensions import TypeIs, overload
from pyanalyze.value import assert_is_value, GenericValue, AnyValue, AnySource
from typing_extensions import TypeIs, overload, assert_type

T = TypeVar("T")
R = TypeVar("R")
Expand All @@ -302,21 +292,13 @@ def is_int_typeguard(a: object) -> TypeIs[int]:
def is_int_bool(a: object) -> bool:
return False

iter_any = GenericValue(
collections.abc.Iterator, [AnyValue(AnySource.generic_argument)]
)

def main(a: List[Optional[int]]) -> None:
bb = filter(lambda x: x is not None, a)
# TODO Iterator[Optional[int]]
assert_is_value(bb, iter_any)
# Also, if you replace 'bool' with 'Any' in the second overload, bb is Iterator[Any]
assert_type(bb, Iterator[Optional[int]])
cc = filter(is_int_typeguard, a)
# TODO Iterator[int]
assert_is_value(cc, iter_any)
assert_type(cc, Iterator[int])
dd = filter(is_int_bool, a)
# TODO Iterator[Optional[int]]
assert_is_value(dd, iter_any)
assert_type(dd, Iterator[Optional[int]])

@assert_passes()
def testTypeIsDecorated(self):
Expand Down Expand Up @@ -345,7 +327,7 @@ def is_float(self, a: object) -> TypeIs[float]:
return False

class D(C):
def is_float(self, a: object) -> bool: # TODO: incompatible_override
def is_float(self, a: object) -> bool: # E: incompatible_override
return False

@assert_passes()
Expand Down Expand Up @@ -576,10 +558,10 @@ def with_bool(o: object) -> bool:
return False

accepts_typeguard(with_typeguard)
accepts_typeguard(with_bool) # TODO error
accepts_typeguard(with_bool) # E: incompatible_argument

different_typeguard(with_typeguard) # TODO error
different_typeguard(with_bool) # TODO error
different_typeguard(with_typeguard) # E: incompatible_argument
different_typeguard(with_bool) # E: incompatible_argument

@assert_passes()
def testTypeIsAsGenericFunctionArg(self):
Expand All @@ -602,7 +584,7 @@ def with_bool(o: object) -> bool:

accepts_typeguard(with_bool_typeguard)
accepts_typeguard(with_str_typeguard)
accepts_typeguard(with_bool) # TODO error
accepts_typeguard(with_bool) # E: incompatible_argument

@assert_passes()
def testTypeIsAsOverloadedFunctionArg(self):
Expand Down Expand Up @@ -662,9 +644,9 @@ def with_typeguard_b(o: object) -> TypeIs[B]:
def with_typeguard_c(o: object) -> TypeIs[C]:
return False

accepts_typeguard(with_typeguard_a) # TODO error
accepts_typeguard(with_typeguard_a) # E: incompatible_argument
accepts_typeguard(with_typeguard_b)
accepts_typeguard(with_typeguard_c) # TODO error
accepts_typeguard(with_typeguard_c) # E: incompatible_argument

@assert_passes()
def testTypeIsWithIdentityGeneric(self):
Expand Down Expand Up @@ -786,7 +768,7 @@ def typeguard(x: object, y: str) -> TypeIs[str]: ...
@overload
def typeguard(x: object, y: int) -> TypeIs[int]: ...

def typeguard(x: object, y: Union[int, str]) -> Union[TypeIs[int], TypeIs[str]]:
def typeguard(x: object, y: Union[int, str]) -> bool:
return False

def capybara(x: object) -> None:
Expand All @@ -805,7 +787,7 @@ def capybara(x: object) -> None:
@assert_passes()
def testGenericAliasWithTypeIs(self):
from typing import Callable, List, TypeVar
from typing_extensions import TypeIs
from typing_extensions import TypeIs, assert_type

T = TypeVar("T")
A = Callable[[object], TypeIs[List[T]]]
Expand All @@ -817,8 +799,7 @@ def test(f: A[T]) -> T:
raise NotImplementedError

def capybara() -> None:
pass
# TODO: assert_type(test(foo), List[str])
assert_type(test(foo), str)

@assert_passes()
def testNoCrashOnDunderCallTypeIs(self):
Expand Down
67 changes: 63 additions & 4 deletions pyanalyze/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,12 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "Extension":
def walk_values(self) -> Iterable[Value]:
return []

def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign:
return {}

def can_be_assigned(self, value: Value, ctx: CanAssignContext) -> CanAssign:
return {}


@dataclass(frozen=True)
class CustomCheckExtension(Extension):
Expand All @@ -1868,6 +1874,12 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "Extension":
def walk_values(self) -> Iterable[Value]:
yield from self.custom_check.walk_values()

def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign:
return self.custom_check.can_assign(value, ctx)

def can_be_assigned(self, value: Value, ctx: CanAssignContext) -> CanAssign:
return self.custom_check.can_be_assigned(value, ctx)


@dataclass(frozen=True)
class ParameterTypeGuardExtension(Extension):
Expand Down Expand Up @@ -1928,6 +1940,26 @@ def substitute_typevars(self, typevars: TypeVarMap) -> Extension:
def walk_values(self) -> Iterable[Value]:
yield from self.guarded_type.walk_values()

def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign:
can_assign_maps = []
if isinstance(value, AnnotatedValue):
for ext in value.get_metadata_of_type(Extension):
if isinstance(ext, TypeIsExtension):
return CanAssignError("TypeGuard is not compatible with TypeIs")
elif isinstance(ext, TypeGuardExtension):
# TypeGuard is covariant
left_can_assign = self.guarded_type.can_assign(
ext.guarded_type, ctx
)
if isinstance(left_can_assign, CanAssignError):
return CanAssignError(
"Incompatible types in TypeIs", children=[left_can_assign]
)
can_assign_maps.append(left_can_assign)
if not can_assign_maps:
return CanAssignError(f"{value} is not a TypeGuard")
return unify_bounds_maps(can_assign_maps)


@dataclass(frozen=True)
class TypeIsExtension(Extension):
Expand All @@ -1947,6 +1979,33 @@ def substitute_typevars(self, typevars: TypeVarMap) -> Extension:
def walk_values(self) -> Iterable[Value]:
yield from self.guarded_type.walk_values()

def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign:
can_assign_maps = []
if isinstance(value, AnnotatedValue):
for ext in value.get_metadata_of_type(Extension):
if isinstance(ext, TypeGuardExtension):
return CanAssignError("TypeGuard is not compatible with TypeIs")
elif isinstance(ext, TypeIsExtension):
# TypeIs is invariant
left_can_assign = self.guarded_type.can_assign(
ext.guarded_type, ctx
)
if isinstance(left_can_assign, CanAssignError):
return CanAssignError(
"Incompatible types in TypeIs", children=[left_can_assign]
)
right_can_assign = ext.guarded_type.can_assign(
self.guarded_type, ctx
)
if isinstance(right_can_assign, CanAssignError):
return CanAssignError(
"Incompatible types in TypeIs", children=[right_can_assign]
)
can_assign_maps += [left_can_assign, right_can_assign]
if not can_assign_maps:
return CanAssignError(f"{value} is not a TypeIs")
return unify_bounds_maps(can_assign_maps)


@dataclass(frozen=True)
class HasAttrGuardExtension(Extension):
Expand Down Expand Up @@ -2120,8 +2179,8 @@ def can_assign(self, other: Value, ctx: CanAssignContext) -> CanAssign:
if isinstance(can_assign, CanAssignError):
return can_assign
bounds_maps = [can_assign]
for custom_check in self.get_metadata_of_type(CustomCheckExtension):
custom_can_assign = custom_check.custom_check.can_assign(other, ctx)
for ext in self.get_metadata_of_type(Extension):
custom_can_assign = ext.can_assign(other, ctx)
if isinstance(custom_can_assign, CanAssignError):
return custom_can_assign
bounds_maps.append(custom_can_assign)
Expand All @@ -2132,8 +2191,8 @@ def can_be_assigned(self, other: Value, ctx: CanAssignContext) -> CanAssign:
if isinstance(can_assign, CanAssignError):
return can_assign
bounds_maps = [can_assign]
for custom_check in self.get_metadata_of_type(CustomCheckExtension):
custom_can_assign = custom_check.custom_check.can_be_assigned(other, ctx)
for ext in self.get_metadata_of_type(Extension):
custom_can_assign = ext.can_be_assigned(other, ctx)
if isinstance(custom_can_assign, CanAssignError):
return custom_can_assign
bounds_maps.append(custom_can_assign)
Expand Down
Loading