Skip to content

Commit

Permalink
fix min_inputs in num_inputs_check
Browse files Browse the repository at this point in the history
  • Loading branch information
awayzjj committed Mar 20, 2024
1 parent 0789b36 commit 4d7cc13
Showing 1 changed file with 49 additions and 49 deletions.
98 changes: 49 additions & 49 deletions src/frontends/pytorch/src/op/bucketize.cpp
Original file line number Diff line number Diff line change
@@ -1,50 +1,50 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/bucketize.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/multiply.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_bucketize(const NodeContext& context) {
num_inputs_check(context, 4, 5);
auto input = context.get_input(0);
auto boundaries = context.get_input(1);

element::Type output_type = ov::element::i64;
if (!context.input_is_none(2) && context.const_input<bool>(2)) {
output_type = ov::element::i32;
}

bool with_right_bound = true;
if (!context.input_is_none(3)) {
with_right_bound = !context.const_input<bool>(3);
}

auto bucketize =
context.mark_node(std::make_shared<v3::Bucketize>(input, boundaries, output_type, with_right_bound));

if (!context.input_is_none(4)) {
context.mutate_input(4, bucketize);
}

return {bucketize};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/bucketize.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/multiply.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_bucketize(const NodeContext& context) {
num_inputs_check(context, 2, 5);
auto input = context.get_input(0);
auto boundaries = context.get_input(1);

element::Type output_type = ov::element::i64;
if (!context.input_is_none(2) && context.const_input<bool>(2)) {
output_type = ov::element::i32;
}

bool with_right_bound = true;
if (!context.input_is_none(3)) {
with_right_bound = !context.const_input<bool>(3);
}

auto bucketize =
context.mark_node(std::make_shared<v3::Bucketize>(input, boundaries, output_type, with_right_bound));

if (!context.input_is_none(4)) {
context.mutate_input(4, bucketize);
}

return {bucketize};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

0 comments on commit 4d7cc13

Please sign in to comment.