Skip to content

Commit

Permalink
[PYTHON] Improve equal sugar (#564)
Browse files Browse the repository at this point in the history
* [PYTHON] Improve equal sugar

* fix comment
  • Loading branch information
tqchen authored Oct 17, 2017
1 parent 60510a4 commit 9a2f01a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 13 deletions.
77 changes: 64 additions & 13 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from ._ffi.node import NodeBase, NodeGeneric, register_node
from . import make as _make
from . import _api_internal

Expand Down Expand Up @@ -89,10 +89,10 @@ def __le__(self, other):
return _make.LE(self, other)

def __eq__(self, other):
return self.equal(other)
return EqualOp(self, other)

def __ne__(self, other):
return _make.NE(self, other)
return NotEqualOp(self, other)

def __gt__(self, other):
return _make.GT(self, other)
Expand Down Expand Up @@ -138,12 +138,71 @@ def astype(self, dtype):
return _make.static_cast(dtype, self)


class EqualOp(NodeGeneric, ExprOp):
"""Deferred equal operator.
This is used to support sugar that a == b can either
mean NodeBase.same_as or NodeBase.equal.
Parameters
----------
a : Expr
Left operand.
b : Expr
Right operand.
"""
def __init__(self, a, b):
self.a = a
self.b = b

def __nonzero__(self):
return self.a.same_as(self.b)

def __bool__(self):
return self.__nonzero__()

def asnode(self):
"""Convert node."""
return _make.EQ(self.a, self.b)


class NotEqualOp(NodeGeneric, ExprOp):
"""Deferred NE operator.
This is used to support sugar that a != b can either
mean not NodeBase.same_as or make.NE.
Parameters
----------
a : Expr
Left operand.
b : Expr
Right operand.
"""
def __init__(self, a, b):
self.a = a
self.b = b

def __nonzero__(self):
return not self.a.same_as(self.b)

def __bool__(self):
return self.__nonzero__()

def asnode(self):
"""Convert node."""
return _make.NE(self.a, self.b)


class Expr(ExprOp, NodeBase):
"""Base class of all tvm Expressions"""
# In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__
# https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
__hash__ = NodeBase.__hash__


class ConstExpr(Expr):
pass

Expand Down Expand Up @@ -215,19 +274,11 @@ class Max(BinaryOpExpr):

@register_node
class EQ(CmpExpr):
def __nonzero__(self):
return self.a.same_as(self.b)

def __bool__(self):
return self.__nonzero__()
pass

@register_node
class NE(CmpExpr):
def __nonzero__(self):
return not self.a.same_as(self.b)

def __bool__(self):
return self.__nonzero__()
pass

@register_node
class LT(CmpExpr):
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_if():
A[0] = A[i] + 2

body = ib.get()
assert A == A
assert isinstance(body, tvm.stmt.For)
body = body.body
assert isinstance(body, tvm.stmt.IfThenElse)
Expand All @@ -42,6 +43,7 @@ def test_prefetch():
A = tvm.placeholder((10, 20), name="A")
ib = tvm.ir_builder.create()
n = tvm.var("n")

with ib.for_range(0, n, name="i") as i:
ib.emit(
tvm.make.Prefetch(
Expand Down

0 comments on commit 9a2f01a

Please sign in to comment.