Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Add copy_ operator #7513

Closed
wants to merge 1 commit into from
Closed

Conversation

apivovarov
Copy link
Contributor

Add copy_ operator to PyTorch frontend.

Operator description:
https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_

Related discussions:
https://discuss.tvm.apache.org/t/copy-opertor-in-relay/9212

@masahi
Copy link
Member

masahi commented Feb 24, 2021

This came up many times, we probably do not want to support this due to a lack of inplace op support in Relay. See #7231

@apivovarov
Copy link
Contributor Author

apivovarov commented Feb 24, 2021

@masahi I tried to add the following to def _run_jit_passes(graph)

torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx(graph)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)

Yes it replaces copy_ with expand_as + index_put. But index_put operator gets empty indices array in that case

Test Module:

import torch
import numpy as np

class MyCopy(torch.nn.Module):
    def __init__(self, shape):
        super(MyCopy, self).__init__()
        self.shape = shape
        
    def forward(self, values):
        A = torch.zeros(self.shape)
        B = A.copy_(values)
        return B


MP = MyCopy((2,4))
a = torch.tensor([0, 1, 2, 6])
MP(a)

traced_MP = torch.jit.trace(MP, (a))
traced_MP.graph
graph(%self : __torch__.MyCopy,
      %values : Long(4, strides=[1], requires_grad=0, device=cpu)):
  %4 : int = prim::Constant[value=2]() # <stdin>:7:0
  %5 : int = prim::Constant[value=4]() # <stdin>:7:0
  %6 : int[] = prim::ListConstruct(%4, %5)
  %7 : int = prim::Constant[value=6]() # <stdin>:7:0
  %8 : None = prim::Constant()
  %9 : Device = prim::Constant[value="cpu"]() # <stdin>:7:0
  %10 : bool = prim::Constant[value=0]() # <stdin>:7:0
  %A : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::zeros(%6, %7, %8, %9, %10) # <stdin>:7:0
  %12 : bool = prim::Constant[value=0]()
  %13 : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::copy_(%A, %values, %12) # <stdin>:8:0
  return (%13)

After jit_passes

graph = traced_MP.graph.copy()
torch._C._jit_pass_onnx_function_substitution(graph)
torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx(graph)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
graph
graph(%self : __torch__.MyCopy,
      %values.1 : Long(4, strides=[1], requires_grad=0, device=cpu)):
  %2 : int = prim::Constant[value=2]() # <stdin>:7:0
  %3 : int = prim::Constant[value=4]() # <stdin>:7:0
  %4 : int[] = prim::ListConstruct(%2, %3)
  %5 : int = prim::Constant[value=6]() # <stdin>:7:0
  %6 : None = prim::Constant()
  %7 : Device = prim::Constant[value="cpu"]() # <stdin>:7:0
  %8 : bool = prim::Constant[value=0]() # <stdin>:7:0
  %A : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::zeros(%4, %5, %6, %7, %8) # <stdin>:7:0
  %10 : bool = prim::Constant[value=0]()
  %values : Long(4, strides=[1], requires_grad=0, device=cpu) = aten::expand_as(%values.1, %A) # <stdin>:8:0
  %15 : Tensor?[] = prim::ListConstruct()
  %16 : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::index_put(%A, %15, %values, %10)
  return (%16)

As you can see %15 is empty list.

As a result we get TVM error because index_put indices array is empty:

import tvm
from tvm import relay
ctx = tvm.cpu(0)
target = 'llvm'

shape_list = [("input0", [4,]),]
mod, params = relay.frontend.from_pytorch(traced_MP, shape_list)

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 3186, in from_pytorch
    ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 2605, in convert_operators
    relay_out = relay_op(
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/frontend/pytorch.py", line 2059, in index_put
    index_tensor = _op.stack(indices, axis=0)
  File "/Users/pivovaa/workspace/tvm/python/tvm/relay/op/tensor.py", line 1124, in stack
    raise ValueError("relay.stack requires data to be non-empty.")
ValueError: relay.stack requires data to be non-empty.

@apivovarov
Copy link
Contributor Author

@masahi This PR uses Relay.stack to implement copy_ operator. Why it is a bad idea to do it this way? I added couple unit tests - they are passed for CPU and GPU.

@masahi
Copy link
Member

masahi commented Feb 25, 2021

The point is, no matter how we try to support copy_, there seems to be an edge case pytorch usage that prevents a correct translation to Relay. See the discussion and examples in the PR linked above.

Rather than adding copy_ support that sometime works and sometime doesn't, the current consensus seems to be that we shouldn't support this op. The best way forward is to remove that inplace assignment idiom from your models. Don't say you cannot change the model because of customer issues, that's not our problem.

@apivovarov
Copy link
Contributor Author

I think we already have 16 in-place operators ending with _ in Pytorch frontend. e.g. add_

a=torch.tensor([1,2,3,4])
b=torch.tensor([1,1,2,2])
a.add_(b)
a
tensor([2, 3, 5, 6])

Is copy_ different from them?

@masahi
Copy link
Member

masahi commented Feb 25, 2021

Yeah, this is a tricky issue. Some inplace ops are benign (like relu_), preventing all of them will likely lead to too many otherwise ok models to be rejected. I don't have a good solution that enables distinguishing which inplace ops are safe or not. This is an ongoing problem.

cc @t-vi our favorite topic again

@masahi
Copy link
Member

masahi commented Feb 25, 2021

@apivovarov If you really need to support this conversion and you are sure your conversion should work, you can use a custom convert map. See

def test_custom_conversion_map():
it basically allows you to register custom conversion from user's code. This is probably the best solution (or workaround) for now.

@apivovarov
Copy link
Contributor Author

What if we add log.warn message - in-place copy_ operator was replaced with out-place operator. Compiled model might not work as expected. It is strongly recommended to remove copy_ operator from the model graph.

This default copy_ implementation will allow users to compile the model and see if it works as expected.
if not - users really need to remove copy_ op from the graph.

@masahi
Copy link
Member

masahi commented Feb 25, 2021

I generally don't like a half baked solution, but if that allows more models to run on TVM, that doesn't sound too bad. If people agree with this, I think we can go with this @t-vi @siju-samuel @kevinthesun @codeislife99 @alexwong

@apivovarov If you want to pursue this path, please revisit the index_put -> scatter_nd solution. This is what PyTorch ONNX converter does to convert copy_. Your solution probably doesn't work for dynamic shape inputs, but copying dynamic tensors should be conceptually no harder than for static case.

@apivovarov
Copy link
Contributor Author

apivovarov commented Feb 26, 2021

I prepared a test model which uses both destination tensor and copy_ output tensor after the copy_ operator.
PyTorch and TVM outputs are the same
What other test models can we try?

import torch
import tvm
from tvm import relay
import numpy as np
        
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    def forward(self, values):
        scale = values.shape[0]
        A = torch.zeros(values.shape)
        B = torch.stack([A] * scale)
        V1 = B + 1
        C = B.copy_(values)
        V2 = B + 2
        V3 = C + 3
        D = V1 + V2 + V3
        return D
        

net = Net()
a = torch.tensor([0, 1, 2, 6])
net(a)

traced_net = torch.jit.trace(net, (a))

ctx = tvm.cpu(0)
target = 'llvm'

shape_list = [("input0", [4,]),]
mod, params = relay.frontend.from_pytorch(traced_net, shape_list)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

func=mod['main']
intrp = relay.create_executor("graph", ctx=ctx, target=target)
ff=intrp.evaluate(func)
ff([0, 1, 2, 6])

<tvm.nd.NDArray shape=(4, 4), cpu(0)>
array([[ 6.,  8., 10., 18.],
       [ 6.,  8., 10., 18.],
       [ 6.,  8., 10., 18.],
       [ 6.,  8., 10., 18.]], dtype=float32)

Relay graph:

print(func)
fn (%input0: Tensor[(4), int64]) {
  %0 = full(0, shape=[4], dtype="float32");
  %1 = (%0, %0, %0, %0);
  %2 = stack(%1);
  %3 = add(%2, 1f);
  %4 = cast(%input0, dtype="float32");
  %5 = (%4, %4, %4, %4);
  %6 = stack(%5);
  %7 = add(%6, 2f);
  %8 = add(%3, %7);
  %9 = add(%6, 3f);
  add(%8, %9)
}

Torch graph:

print(traced_net.graph)
graph(%self : __torch__.Net,
      %values : Long(4, strides=[1], requires_grad=0, device=cpu)):
  %7 : int = prim::Constant[value=0]() # <stdin>:6:0
  %8 : int = aten::size(%values, %7) # <stdin>:6:0
  %9 : Long(device=cpu) = prim::NumToTensor(%8)
  %10 : int = aten::Int(%9)
  %11 : int[] = prim::ListConstruct(%10)
  %12 : int = prim::Constant[value=6]() # <stdin>:6:0
  %13 : None = prim::Constant()
  %14 : Device = prim::Constant[value="cpu"]() # <stdin>:6:0
  %15 : bool = prim::Constant[value=0]() # <stdin>:6:0
  %A : Float(4, strides=[1], requires_grad=0, device=cpu) = aten::zeros(%11, %12, %13, %14, %15) # <stdin>:6:0
  %17 : Tensor[] = prim::ListConstruct(%A, %A, %A, %A)
  %18 : int = prim::Constant[value=0]() # <stdin>:7:0
  %B.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::stack(%17, %18) # <stdin>:7:0
  %20 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # <stdin>:8:0
  %21 : int = prim::Constant[value=1]() # <stdin>:8:0
  %V1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%B.1, %20, %21) # <stdin>:8:0
  %23 : bool = prim::Constant[value=0]()
  %B : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::copy_(%B.1, %values, %23) # <stdin>:9:0
  %25 : Long(requires_grad=0, device=cpu) = prim::Constant[value={2}]() # <stdin>:10:0
  %26 : int = prim::Constant[value=1]() # <stdin>:10:0
  %V2 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%B, %25, %26) # <stdin>:10:0
  %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={3}]() # <stdin>:11:0
  %29 : int = prim::Constant[value=1]() # <stdin>:11:0
  %V3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%B, %28, %29) # <stdin>:11:0
  %31 : int = prim::Constant[value=1]() # <stdin>:12:0
  %32 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%V1, %V2, %31) # <stdin>:12:0
  %33 : int = prim::Constant[value=1]() # <stdin>:12:0
  %34 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%32, %V3, %33) # <stdin>:12:0
  return (%34)

@masahi
Copy link
Member

masahi commented Feb 26, 2021

People never use copy_ directly. It is better to make it more realistic. Also see my point on dynamic input support.

@apivovarov
Copy link
Contributor Author

apivovarov commented Feb 26, 2021

What could be the example of more realistic Net with copy_?
https://github.com/facebookresearch/detectron2/search?q=copy_%28

@masahi
Copy link
Member

masahi commented Feb 26, 2021

@apivovarov
Copy link
Contributor Author

What Net can represent it? Can you make a test Net to demonstrate it?

@codeislife99
Copy link
Contributor

Here is another example

x = torch.tensor([1,2]) 
a = torch.tensor([2,3])
x[0] = a[0] 

@codeislife99
Copy link
Contributor

We should include such test cases in this PR.

@codeislife99
Copy link
Contributor

Can you also address the test case in #7231 ? #7231 (comment)

@apivovarov
Copy link
Contributor Author

Net

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    def forward(self, a,b):
        a[0] = b[0]
        c = a + b
        return c

Torch graph:

graph(%self : __torch__.Net,
      %a : Long(4, strides=[1], requires_grad=0, device=cpu),
      %b : Long(4, strides=[1], requires_grad=0, device=cpu)):
  %5 : int = prim::Constant[value=0]() # <stdin>:5:0
  %6 : int = prim::Constant[value=0]() # <stdin>:5:0
  %7 : Long(requires_grad=0, device=cpu) = aten::select(%b, %5, %6) # <stdin>:5:0
  %8 : int = prim::Constant[value=0]() # <stdin>:5:0
  %9 : int = prim::Constant[value=0]() # <stdin>:5:0
  %10 : Long(requires_grad=0, device=cpu) = aten::select(%a, %8, %9) # <stdin>:5:0
  %11 : bool = prim::Constant[value=0]()
  %12 : Long(requires_grad=0, device=cpu) = aten::copy_(%10, %7, %11) # <stdin>:5:0
  %13 : int = prim::Constant[value=1]() # <stdin>:6:0
  %14 : Long(4, strides=[1], requires_grad=0, device=cpu) = aten::add(%a, %b, %13) # <stdin>:6:0
  return (%14)

Relay graph has add operator only:

fn (%input0: Tensor[(4), int64], %input1: Tensor[(4), int64]) {
  add(%input0, %input1)
}

a[0] = b[0] was not included to the Relay graph

@apivovarov
Copy link
Contributor Author

apivovarov commented Feb 27, 2021

Probably it is better to rewrite the following 4 lines in the model itself.
Detectron2 - box2box_transform.apply_deltas:111

copy_ operator is the end of a dead-end branch in the graph. As we saw before TVM will remove such branches.

  %pred_boxes.1 : Float(182400, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::zeros_like(%deltas.2, %987, %988, %989, %990, %1001), scope: __module.torch_model.proposal_generator # /Users/pivovaa/workspace/detectron2/detectron2/modeling/box_regression.py:110:0
...
  %1405 : Float(182400, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::slice(%pred_boxes.1, %988, %988, %979, %992), scope: __module.torch_model.proposal_generator # /Users/pivovaa/workspace/detectron2/detectron2/modeling/box_regression.py:111:0
  %1406 : Float(182400, 1, strides=[4, 4], requires_grad=0, device=cpu) = aten::slice(%1405, %992, %988, %979, %986), scope: __module.torch_model.proposal_generator # /Users/pivovaa/workspace/detectron2/detectron2/modeling/box_regression.py:111:0
  %1407 : Float(182400, 1, strides=[4, 4], requires_grad=0, device=cpu) = aten::copy_(%1406, %1404, %990), scope: __module.torch_model.proposal_generator # /Users/pivovaa/workspace/detectron2/detectron2/modeling/box_regression.py:111:0
...
  %1423 : int[] = prim::ListConstruct(%1339, %991, %1340), scope: __module.torch_model.proposal_generator
  %proposals_i.1 : Float(1, 182400, 4, strides=[729600, 4, 1], requires_grad=0, device=cpu) = aten::view(%pred_boxes.1, %1423), scope: __module.torch_model.proposal_generator # /Users/pivovaa/workspace/detectron2/detectron2/modeling/proposal_generator/rpn.py:491:0

rpn.py:491

@apivovarov apivovarov closed this Feb 27, 2021
@apivovarov
Copy link
Contributor Author

copy_ operator was removed from Detectron2 model
facebookresearch/detectron2#2685

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants