Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] [Frontend] Add support for 'aten::new_zeros' & 'aten::copy_' #9375

Closed
wants to merge 1 commit into from

Conversation

hgt312
Copy link
Contributor

@hgt312 hgt312 commented Oct 26, 2021

  • add support for aten::new_zeros & aten::copy_
  • view of a constant scalar

This pr enables convert from 'bart_base' model in transformers
@comaniac

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. cc @masahi

@masahi
Copy link
Member

masahi commented Oct 26, 2021

Please add tests.

Also, we need to discuss aten::copy_. There have been many requests to support this op and attempts to support it.
#6472
#6049
#7231
#7513

aten::copy_ is used purely for in-place mutation purpose, which we don't support at all. The naive conversion might work for some cases, but since we cannot support 100% of use cases correctly, so far we've decided not to support it. Have you verified that the output from bart-base is correct after converting to TVM?

@comaniac
Copy link
Contributor

Thanks for pointing out. @hgt312 should have verified that the output of BART-base is correct with this change. Meanwhile, as you mentioned this is not 100% semantic equivalent, maybe we could add a warning to let users check the correctness if their PyTorch models include copy_?

Also cc @yzhliu

@masahi
Copy link
Member

masahi commented Oct 27, 2021

Yes we should definitely add a warning at least.

A safer alternative, which I'm more comfortable with, is make use of the custom convert map

custom_convert_map : Dictionary of str to Relay op
. You can add a converter for copy_ and pass it to from_torch. If that's acceptable for your use cases, I think this is the most reasonable compromise.

@comaniac
Copy link
Contributor

Good suggestion. We should definitely do so for our use cases in particular, although I guess the same issue may happen periodically when someone tries to work on BART lol

@masahi
Copy link
Member

masahi commented Oct 27, 2021

Yeah, users need to acknowledge that TVM cannot represent all PT models. It is better to reject such models early than returning corrupted models.

Since this request comes up too often, how about we do the following:

  • Add a new option like allow_inplace_copy, which is False by default. Add a converter for copy_ only if this option is True
  • If allow_inplace_copy is False but we hit copy_ op, emit a helpful error message saying "You can try allow_inplace_copy option but conversion is not guranteed to be correct".

@hgt312
Copy link
Contributor Author

hgt312 commented Oct 27, 2021

@comaniac @masahi I find that the output will not be correct due to something like a[...] = b, like the previous issues.

In BART, it is from a function, the whole function is not inplace.

def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids

Also, I find that after pytorch/pytorch#52063 (torch version >= 1.9), we can use torch._C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, None) to move all the aten::copys, then the corresponding part will look like:

  %69 : Tensor = onnx::Placeholder[name="index_put_"](%62) # <ipython-input-1-662caefe3c7e>:8:0
    block0():
      %70 : Float(3, 3, strides=[3, 1], requires_grad=0, device=cpu) = aten::slice(%shifted_input_ids, %42, %43, %44, %45) # <ipython-input-1-662caefe3c7e>:8:0
      %71 : Float(3, strides=[3], requires_grad=0, device=cpu) = aten::select(%70, %47, %48) # <ipython-input-1-662caefe3c7e>:8:0
      %72 : Float(3, strides=[3], requires_grad=0, device=cpu) = aten::index_put_(%71, %66, %67, %57) # <ipython-input-1-662caefe3c7e>:8:0
      -> (%72)

and the subgraph can be convert to ONNX's index_put.

Maybe the torch->onnx path will work for these models?

@hgt312
Copy link
Contributor Author

hgt312 commented Oct 27, 2021

closed now

@hgt312 hgt312 closed this Oct 27, 2021
@masahi
Copy link
Member

masahi commented Oct 27, 2021

I've tried index_put approach before, it is available before PT 1.9. See the discussion in #7231 (comment). My impression back then is that ONNX's approach also doesn't cover 100% of cases. Our PT frontend already supports converting index_put, so you can try for your model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants