Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh committed Oct 27, 2020
1 parent f76e2fa commit f8d2248
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
equal,
shape_of,
log,
concatenate
concatenate,
)
from .transform import (
broadcast_to_like,
Expand All @@ -62,7 +62,7 @@
split,
squeeze,
strided_set,
arange
arange,
)


Expand Down Expand Up @@ -677,20 +677,24 @@ def cross_entropy_with_logits_grad(orig, grad):

@register_gradient("take")
def take_grad(orig, grad):
"""
Returns the gradient of take.
"""

def make_scalar_tensor(v):
if isinstance(v, int):
v = const(v, dtype='int32')
v = const(v, dtype="int32")
return reshape(v, (1,))

# TODO(@altanh): we currently assume indices are in range
data, indices = orig.args
axis, mode = orig.attrs.axis, orig.attrs.mode
axis = orig.attrs.axis
zero, one = map(make_scalar_tensor, [0, 1])
data_grad = zeros_like(data)
try:
data_shape = data.checked_type.concrete_shape
except TypeError:
raise OpError('currently take_grad only supports data with concrete shape')
except TypeError as ty_err:
raise OpError("currently take_grad only supports data with concrete shape") from ty_err
if axis is None:
axis = 0
data_grad = reshape(data_grad, (-1,))
Expand All @@ -710,7 +714,7 @@ def make_scalar_tensor(v):
elif len(indices.checked_type.shape) == 1:
num_indices = take(shape_of(indices), zero, axis=0)
else:
raise OpError('take_grad only supports scalar or 1D indices')
raise OpError("take_grad only supports scalar or 1D indices")

def loop_cond(data_grad, i):
return squeeze(less(i, num_indices))
Expand All @@ -731,8 +735,8 @@ def loop_body(data_grad, i):
return (next_data_grad, i + one)

loop_vars = [
Var('data_grad', type_annotation=TensorType(data_shape, data.checked_type.dtype)),
Var('i', type_annotation=TensorType((1,), 'int32')),
Var("data_grad", type_annotation=TensorType(data_shape, data.checked_type.dtype)),
Var("i", type_annotation=TensorType((1,), "int32")),
]

loop = while_loop(loop_cond, loop_vars, loop_body)
Expand All @@ -747,11 +751,17 @@ def loop_body(data_grad, i):

@register_gradient("contrib_reverse_reshape")
def reverse_reshape_grad(orig, grad):
"""
Returns the gradient of reverse_reshape (same as reshape).
"""
return [reshape_like(grad, orig.args[0])]


@register_gradient("stack")
def stack_grad(orig, grad):
"""
Returns grad split across stacked inputs.
"""
stack_axis = int(orig.attrs.axis)
sections = len(orig.args[0].checked_type.fields)
splits = split(grad, sections, stack_axis)
Expand All @@ -761,13 +771,19 @@ def stack_grad(orig, grad):

@register_gradient("squeeze")
def squeeze_grad(orig, grad):
"""
Returns grad expanded to input size.
"""
# this should work, can't use expand_dims since we lose
# squeeze information when axis=None
return [reshape_like(grad, orig.args[0])]


@register_gradient("expand_dims")
def expand_dims_grad(orig, grad):
"""
Returns grad squeezed on expanded dims.
"""
axis = int(orig.attrs.axis)
for _ in range(orig.attrs.num_newaxis):
grad = squeeze(grad, axis=[axis])
Expand All @@ -776,12 +792,15 @@ def expand_dims_grad(orig, grad):

@register_gradient("arange")
def arange_grad(orig, grad):
"""
Returns the gradient of arange.
"""
start, stop, step = orig.args
length = take(shape_of(orig), const(0, dtype='int32'), axis=0)
length = take(shape_of(orig), const(0, dtype="int32"), axis=0)

grad_start = cast_like(_sum(grad), start)
grad_stop = zeros_like(stop)
grad_step = cast_like(arange(length, dtype='int32'), grad) * grad
grad_step = cast_like(arange(length, dtype="int32"), grad) * grad
grad_step = cast_like(_sum(grad_step), step)

return [grad_start, grad_stop, grad_step]

0 comments on commit f8d2248

Please sign in to comment.