Skip to content

Commit

Permalink
Implement user-defined type guard method
Browse files Browse the repository at this point in the history
This implements the minimal version of user-defined type guard method.
A user-defined type guard methos is a method defined by user that is
able to guarantee the type of the objects (receiver or arguments).

At present, Steep supports type guard methods provided by ruby-core (ex.
`#is_a?`, `#nil?`, and so on).  But we also have many kinds of
user-defined methods that are able to check the type of the objects.

Therefore user-defined type guard will help checking the type of these
applications by narrowing types.

This implementation uses an annotation to declare user-defined type
guard method.

```
class Example < Integer
  %a{guard:self is Integer}
  def integer?: () -> bool
end
```

For example, the above method `Example#integer?` is a user-defined
type guard method that narrows the Example object itself to an Integer
if the conditional branch passed.

```
example = Example.new
if example.integer?
  example  #=> Integer
end
```

In this PR, the predicate of type guards only supports "self is TYPE"
statement.  I have a plan to extend it:

* `%a{guard:self is arg}`
* `%a{guard:self is_a arg}`
* `%a{guard:self is TYPE_PARAM}`
* `%a{guard:arg is TYPE}`

Note: The compatibility of RBS syntax is the large reason of using
annotations.  I'm afraid that adding a new syntax to define it will
bring breaking change to the RBS, and difficult to use it on common
repository or generators (ex. gem_rbs_collection and rbs_rails).
  • Loading branch information
tk0miya committed Feb 23, 2025
1 parent 7a7fcc4 commit 6821365
Show file tree
Hide file tree
Showing 9 changed files with 325 additions and 8 deletions.
24 changes: 23 additions & 1 deletion lib/steep/ast/types/logic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Types
module Logic
class Base
extend SharedInstance

def subst(s)
self
end
Expand Down Expand Up @@ -53,6 +53,28 @@ class ArgEqualsReceiver < Base
class ArgIsAncestor < Base
end

class Guard < Base
PATTERN = /\Aguard:\s*(self)\s+(is)\s+(\S+)\s*\Z/

attr_reader :subject
attr_reader :operator
attr_reader :type

def initialize(subject:, operator:, type:)
@subject = subject
@operator = operator
@type = type
end

def ==(other)
super && subject == other.subject && operator == other.operator && type == other.type
end

def hash
self.class.hash ^ subject.hash ^ operator.hash ^ type.hash
end
end

class Env < Base
attr_reader :truthy, :falsy, :type

Expand Down
36 changes: 36 additions & 0 deletions lib/steep/interface/builder.rb
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def singleton_shape(type_name)
overloads = method.defs.map do |type_def|
method_name = method_name_for(type_def, name)
method_type = factory.method_type(type_def.type)
method_type = replace_guard_method(definition, type_def, method_type)
method_type = replace_primitive_method(method_name, type_def, method_type)
method_type = replace_kernel_class(method_name, type_def, method_type) { AST::Builtin::Class.instance_type }
method_type = add_implicitly_returns_nil(type_def.annotations, method_type)
Expand Down Expand Up @@ -313,6 +314,7 @@ def object_shape(type_name)
overloads = method.defs.map do |type_def|
method_name = method_name_for(type_def, name)
method_type = factory.method_type(type_def.type)
method_type = replace_guard_method(definition, type_def, method_type)
method_type = replace_primitive_method(method_name, type_def, method_type)
if type_name.class?
method_type = replace_kernel_class(method_name, type_def, method_type) { AST::Types::Name::Singleton.new(name: type_name) }
Expand Down Expand Up @@ -725,6 +727,40 @@ def proc_shape(proc, proc_shape)
shape
end

def replace_guard_method(definition, method_def, method_type)
match = method_def.annotations.filter_map { AST::Types::Logic::Guard::PATTERN.match(_1.string) }.first
if match
subject = match[1] or raise
operator = match[2] or raise
type_name = match[3] or raise

type = RBS::Parser.parse_type(type_name)
raise "Unknown type: #{type_name}" unless type

context = context_from(definition.type_name)
type = type.map_type_name { factory.absolute_type_name(_1, context: context) or raise "Unknown type: #{_1}" }
guard = AST::Types::Logic::Guard.new(subject: subject, operator: operator, type: type)
definition.type_name
method_type.with(
type: method_type.type.with(return_type: guard)
)
else
method_type
end
rescue => exn
Steep.logger.error { exn.message }
method_type
end

def context_from(type_name)
if type_name.namespace == RBS::Namespace.root
[nil, type_name]
else
parent = context_from(type_name.namespace.to_type_name)
[parent, type_name]
end
end

def replace_primitive_method(method_name, method_def, method_type)
defined_in = method_def.defined_in
member = method_def.member
Expand Down
54 changes: 49 additions & 5 deletions lib/steep/type_inference/logic_type_interpreter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,42 @@ def evaluate_method_call(env:, type:, receiver:, arguments:)
truthy_result.update_type { FALSE }
]
end

when AST::Types::Logic::Guard
if receiver
receiver_type = factory.deep_expand_alias(typing.type_of(node: receiver)) || raise

# TODO: Expand the type params to the actual types
# TODO: Support argument types (ex. `self is arg1`)
# TODO: Support class' type param types (ex. `self is T`)
# TODO: Support method's type param types (ex. `self is T`)
# TODO: Support is_a operator
# TODO: Ensure the type exists

sub_type = factory.type(type.type)
if no_subtyping?(sub_type: sub_type, super_type: receiver_type)
typing.add_error(
Diagnostic::Ruby::UnexpectedError.new(node: receiver, error: Exception.new("#{receiver_type} is not a subtype of #{type.type}"))
)
return nil
end

truthy_type, falsy_type = type_guard_type_case_select(receiver_type, sub_type)
truthy_env, falsy_env = refine_node_type(
env: env,
node: receiver,
truthy_type: truthy_type || receiver_type,
falsy_type: falsy_type || UNTYPED
)

truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
truthy_result.unreachable! unless truthy_type

falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
falsy_result.unreachable! unless falsy_type

[truthy_result, falsy_result]
end
end
end

Expand Down Expand Up @@ -494,25 +530,33 @@ def literal_var_type_case_select(value_node, arg_type)
end
end

def type_case_select(type, klass)
truth_types, false_types = type_case_select0(type, klass)
def type_guard_type_case_select(type, guard_type)
truth_types, false_types = type_case_select0(type, guard_type)

[
truth_types.empty? ? nil : AST::Types::Union.build(types: truth_types),
false_types.empty? ? nil : AST::Types::Union.build(types: false_types)
]
end

def type_case_select0(type, klass)
def type_case_select(type, klass)
instance_type = factory.instance_type(klass)
truth_types, false_types = type_case_select0(type, instance_type)

[
truth_types.empty? ? nil : AST::Types::Union.build(types: truth_types),
false_types.empty? ? nil : AST::Types::Union.build(types: false_types)
]
end

def type_case_select0(type, instance_type)
case type
when AST::Types::Union
truthy_types = [] # :Array[AST::Types::t]
falsy_types = [] #: Array[AST::Types::t]

type.types.each do |ty|
truths, falses = type_case_select0(ty, klass)
truths, falses = type_case_select0(ty, instance_type)

if truths.empty?
falsy_types.push(ty)
Expand All @@ -529,7 +573,7 @@ def type_case_select0(type, klass)
if ty == type
[[type], [type]]
else
type_case_select0(ty, klass)
type_case_select0(ty, instance_type)
end

when AST::Types::Any, AST::Types::Top, AST::Types::Var
Expand Down
2 changes: 1 addition & 1 deletion sig/steep/ast/types.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module Steep
| Intersection | Record | Tuple | Union
| Name::Alias | Name::Instance | Name::Interface | Name::Singleton
| Proc | Var
| Logic::Not | Logic::ReceiverIsNil | Logic::ReceiverIsNotNil | Logic::ReceiverIsArg | Logic::ArgIsReceiver | Logic::ArgEqualsReceiver | Logic::ArgIsAncestor | Logic::Env
| Logic::Not | Logic::ReceiverIsNil | Logic::ReceiverIsNotNil | Logic::ReceiverIsArg | Logic::ArgIsReceiver | Logic::ArgEqualsReceiver | Logic::ArgIsAncestor | Logic::Guard | Logic::Env

# Variables and special types that is subject for substitution
#
Expand Down
13 changes: 13 additions & 0 deletions sig/steep/ast/types/logic.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ module Steep
class ArgIsAncestor < Base
end

# A type for type guard.
class Guard < Base
PATTERN: Regexp

attr_reader subject: String
attr_reader operator: String
attr_reader type: RBS::Types::t

def self.new: (subject: String, operator: String, type: RBS::Types::t) -> Guard

def initialize: (subject: String, operator: String, type: RBS::Types::t) -> void
end

# A type with truthy/falsy type environment.
class Env < Base
attr_reader truthy: TypeInference::TypeEnv
Expand Down
4 changes: 4 additions & 0 deletions sig/steep/interface/builder.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ module Steep

def method_name_for: (RBS::Definition::Method::TypeDef, Symbol name) -> method_name

def replace_guard_method: (RBS::Definition, RBS::Definition::Method::TypeDef, MethodType) -> MethodType

def context_from: (RBS::TypeName) -> RBS::Resolver::context

def replace_primitive_method: (method_name, RBS::Definition::Method::TypeDef, MethodType) -> MethodType

def replace_kernel_class: (method_name, RBS::Definition::Method::TypeDef, MethodType) { () -> AST::Types::t } -> MethodType
Expand Down
4 changes: 3 additions & 1 deletion sig/steep/type_inference/logic_type_interpreter.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ module Steep
#
def literal_var_type_case_select: (Parser::AST::Node value_node, AST::Types::t arg_type) -> [Array[AST::Types::t], Array[AST::Types::t]]?

def type_guard_type_case_select: (AST::Types::t `type`, AST::Types::t guard_type) -> [AST::Types::t?, AST::Types::t?]

def type_case_select: (AST::Types::t `type`, RBS::TypeName klass) -> [AST::Types::t?, AST::Types::t?]

def type_case_select0: (AST::Types::t `type`, RBS::TypeName klass) -> [Array[AST::Types::t], Array[AST::Types::t]]
def type_case_select0: (AST::Types::t `type`, AST::Types::t instance_type) -> [Array[AST::Types::t], Array[AST::Types::t]]

def try_convert: (AST::Types::t, Symbol) -> AST::Types::t?

Expand Down
8 changes: 8 additions & 0 deletions sig/test/type_check_test.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ class TypeCheckTest < Minitest::Test

def test_type_narrowing__local_variable_safe_navigation_operator: () -> untyped

def test_type_guard__self_is_TYPE: () -> untyped

def test_type_guard__self_is_TYPE_no_subtyping: () -> untyped

def test_type_guard__self_is_TYPE_unknown: () -> untyped

def test_type_guard__self_is_TYPE_singleton: () -> untyped

def test_argument_error__unexpected_unexpected_positional_argument: () -> untyped

def test_type_assertion__type_error: () -> untyped
Expand Down
Loading

0 comments on commit 6821365

Please sign in to comment.