diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 56bf708f5e06..77060d3f053b 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -35,17 +35,16 @@ Array GetCallArgs(const Call& call) { return args; } -void CheckNumArguments(const Call& call, const BlockBuilder& ctx) { +void CheckNumArguments(const Call& call) { Op op = Downcast(call->op); int expected_input = op->arguments.size(); if (static_cast(call->args.size()) != expected_input) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Operator " << op << " expects " << expected_input << " arguments" - << ", but was called with " << call->args.size() << " arguments"); + LOG(FATAL) << "Operator " << op << " expects " << expected_input << " arguments" + << ", but was called with " << call->args.size() << " arguments"; } } -TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx) { +TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg) { Op op = Downcast(call->op); ICHECK_EQ(op->arguments.size(), call->args.size()) @@ -59,24 +58,19 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const if (auto tensor_sinfo = sinfo.as()) { return tensor_sinfo.value(); } else { - ctx->ReportFatal(Diagnostic::Error(call) - << "Operator " << op << " requires argument " << i_arg << " (" - << op->arguments[i_arg]->name << ") to be a tensor. " - << "However, the argument " << arg << " is instead of type " << sinfo); - // Unreachable, but [[noreturn]] attribute on virtual function - // `ReportFatal` is insufficient to silence -Wreturn-type, as - // child class might not be [[noreturn]]. - return TensorStructInfo(); + LOG(FATAL) << "Operator " << op << " requires argument " << i_arg << " (" + << op->arguments[i_arg]->name << ") to be a tensor. " + << "However, the argument " << arg << " is instead of type " << sinfo; } } -Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { - CheckNumArguments(call, ctx); +Array GetInputTensorStructInfo(const Call& call) { + CheckNumArguments(call); Op op = Downcast(call->op); Array input_tensor_sinfo; for (size_t i = 0; i < call->args.size(); ++i) { - input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i, ctx)); + input_tensor_sinfo.push_back(GetInputTensorStructInfo(call, i)); } return input_tensor_sinfo; } diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index ed6725e27012..395d122bd3e3 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -44,6 +44,16 @@ namespace relax { /************ Op input struct info getter ************/ +/*! + * \brief Check that the operator has the correct number of arguments + * + * Verify that the number of arguments matches the expected number for + * the operator. + * + * \param call The context Call to the operator. + */ +void CheckNumArguments(const Call& call); + /*! * \brief Check that the operator has * @@ -54,7 +64,17 @@ namespace relax { * * \param ctx The error reporting context. */ -void CheckNumArguments(const Call& call, const BlockBuilder& ctx); +inline void CheckNumArguments(const Call& call, const BlockBuilder& ctx) { + CheckNumArguments(call); +} + +/*! + * \brief Get the tensor struct info of the operator input. + * \param call The context Call to the operator. + * \param i_arg The index of the argument to check + * \return The tensor struct info of the argument + */ +TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg); /*! * \brief Get the tensor struct info of the operator input. @@ -63,7 +83,19 @@ void CheckNumArguments(const Call& call, const BlockBuilder& ctx); * \param ctx The error reporting context. * \return The tensor struct info of the argument */ -TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx); +inline TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, + const BlockBuilder& ctx) { + return GetInputTensorStructInfo(call, i_arg); +} + +/*! + * \brief Get the tensor struct info of the operator input. + * \param call The context Call to the operator. + * \return The tensor struct info of each input. + * \note This function require every input to be Tensor. The number of call arguments is required + * to match the number of inputs of the op being called. + */ +Array GetInputTensorStructInfo(const Call& call); /*! * \brief Get the tensor struct info of the operator input. @@ -73,7 +105,20 @@ TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const * \note This function require every input to be Tensor. The number of call arguments is required * to match the number of inputs of the op being called. */ -Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx); +inline Array GetInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { + return GetInputTensorStructInfo(call); +} + +/*! + * \brief Get the tensor struct info of the unary operator input. + * \param call The context Call to the operator. + * \return The tensor struct info of the unary operator input. + * \throw Throw exception if the number of input is not one, or the struct info of the input is not + * a tensor struct info. + */ +inline TensorStructInfo GetUnaryInputTensorStructInfo(const Call& call) { + return GetInputTensorStructInfo(call)[0]; +} /*! * \brief Get the tensor struct info of the unary operator input. @@ -84,7 +129,7 @@ Array GetInputTensorStructInfo(const Call& call, const BlockBu * a tensor struct info. */ inline TensorStructInfo GetUnaryInputTensorStructInfo(const Call& call, const BlockBuilder& ctx) { - return GetInputTensorStructInfo(call, ctx)[0]; + return GetUnaryInputTensorStructInfo(call); } /*! @@ -101,22 +146,19 @@ Array GetTensorStructInfoFromTuple(const Call& call, const Blo namespace detail { /*! \brief Implementation helper for GetArgStructInfo */ template -ArgType GetArgStructInfoByIndex(const Call& call, const Op& op, const BlockBuilder& ctx, - size_t index) { +ArgType GetArgStructInfoByIndex(const Call& call, const Op& op, size_t index) { if (!call->args[index]->struct_info_.defined()) { - ctx->ReportFatal(Diagnostic::Error(call) - << op << " op should have arguments with defined StructInfo. " - << "However, args[" << index << "] has undefined struct info."); + LOG(FATAL) << "Operator " << op << " should have arguments with defined StructInfo. " + << "However, args[" << index << "] has undefined struct info."; } auto sinfo = GetStructInfo(call->args[index]); auto typed_sinfo = sinfo.as(); if (!typed_sinfo.defined()) { - ctx->ReportFatal(Diagnostic::Error(call) - << op << " requires that args[" << index << "] be a " - << ArgType::ContainerType::_type_key << ", but was instead " << sinfo - << " of type " << sinfo->GetTypeKey()); + LOG(FATAL) << "Operator " << op << " requires that args[" << index << "] be a " + << ArgType::ContainerType::_type_key << ", but was instead " << sinfo << " of type " + << sinfo->GetTypeKey(); } return typed_sinfo.value(); @@ -125,9 +167,8 @@ ArgType GetArgStructInfoByIndex(const Call& call, const Op& op, const BlockBuild /*! \brief Implementation helper for GetArgStructInfo */ template std::tuple GetArgStructInfoHelper(const Call& call, const Op& op, - const BlockBuilder& ctx, std::index_sequence) { - return std::tuple{GetArgStructInfoByIndex(call, op, ctx, Indices)...}; + return std::tuple{GetArgStructInfoByIndex(call, op, Indices)...}; } } // namespace detail @@ -136,12 +177,11 @@ std::tuple GetArgStructInfoHelper(const Call& call, const Op& op, * * \tparam ArgTypes The expected types of arguments, in the order they appear. * \param call The context Call to the operator. - * \param ctx The error reporting context. * \return The tensor struct infos of tuple input. * \throw Throw exception if input expression is not a tuple. */ template -std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& ctx) { +std::tuple GetArgStructInfo(const Call& call) { Op op = Downcast(call->op); size_t n_input = op->arguments.size(); @@ -154,7 +194,21 @@ std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& c << "but GetArgStructInfo was given " << sizeof...(ArgTypes) << " template arguments."; return detail::GetArgStructInfoHelper( - call, op, ctx, std::make_index_sequence()); + call, op, std::make_index_sequence()); +} + +/*! + * \brief Get all arg struct infos as expected types + * + * \tparam ArgTypes The expected types of arguments, in the order they appear. + * \param call The context Call to the operator. + * \param ctx The error reporting context. + * \return The tensor struct infos of tuple input. + * \throw Throw exception if input expression is not a tuple. + */ +template +std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& ctx) { + return GetArgStructInfo(call); } /************ Op registration macro ************/ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 7aca1470aee4..59493f849fc1 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -36,43 +36,20 @@ namespace relax { TVM_REGISTER_NODE_TYPE(InitAttrs); /* relax.full */ -Expr full(Variant> shape, Expr fill_value, DataType dtype) { - Expr shape_in_expr{nullptr}; - if (const auto* expr = shape.as()) { - shape_in_expr = GetRef(expr); - } else if (const auto* _array = shape.as()) { - shape_in_expr = ShapeExpr(GetRef>(_array)); - } else { - LOG(FATAL) << "Full only expects the input shape to be either an Expr or an Array of PrimExpr. " - "However, the given one is " - << shape->GetTypeKey(); - } - - ObjectPtr attrs = make_object(); - attrs->dtype = dtype; - - static const Op& op = Op::Get("relax.full"); - return Call(op, {std::move(shape_in_expr), std::move(fill_value)}, Attrs(attrs), {}); -} - -TVM_REGISTER_GLOBAL("relax.op.full").set_body_typed(full); - -StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) { +StructInfo InferStructInfoFull(const Call& call) { if (call->args.size() != 2) { - ctx->ReportFatal(Diagnostic::Error(call) << "Full op should have 2 arguments"); + LOG(FATAL) << "Full op should have 2 arguments"; } const auto* shape_sinfo = GetStructInfoAs(call->args[0]); const auto* fill_value_sinfo = GetStructInfoAs(call->args[1]); if (shape_sinfo == nullptr) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Full requires the input shape to be a Shape. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey()); + LOG(FATAL) << "Full requires the input shape to be a Shape. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey(); } if (fill_value_sinfo == nullptr || fill_value_sinfo->ndim != 0) { - ctx->ReportFatal( - Diagnostic::Error(call) + LOG(FATAL) << "Full requires the input fill value to be zero rank Tensor. However, the given one is " - << call->args[1]->struct_info_); + << call->args[1]->struct_info_; } const auto* attrs = call->attrs.as(); @@ -80,33 +57,56 @@ StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(/*shape=*/call->args[0], out_dtype, fill_value_sinfo->vdevice); } +Expr full(Variant> shape, Expr fill_value, DataType dtype) { + Expr shape_in_expr = [&]() -> Expr { + if (const auto* expr = shape.as()) { + return GetRef(expr); + } else if (const auto* _array = shape.as()) { + return ShapeExpr(GetRef>(_array)); + } else { + LOG(FATAL) + << "Full only expects the input shape to be either an Expr or an Array of PrimExpr. " + "However, the given one is " + << shape->GetTypeKey(); + } + }(); + + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.full"); + auto call = Call(op, {shape_in_expr, fill_value}, Attrs(attrs), {}); + + if (fill_value->struct_info_.defined()) { + UpdateStructInfo(call, InferStructInfoFull(call)); + } + + return call; +} + +TVM_REGISTER_GLOBAL("relax.op.full").set_body_typed(full); + TVM_REGISTER_OP("relax.full") .set_attrs_type() .set_num_inputs(2) .add_argument("shape", "Shape", "The shape of the created tensor.") .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.") - .set_attr("FInferStructInfo", InferStructInfoFull) + .set_attr("FInferStructInfo", + [](const Call& call, const BlockBuilder&) -> StructInfo { + return InferStructInfoFull(call); + }) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); /* relax.full_like */ -Expr full_like(Expr x, Expr fill_value, DataType dtype) { - ObjectPtr attrs = make_object(); - attrs->dtype = dtype; - static const Op& op = Op::Get("relax.full_like"); - return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); -} - -TVM_REGISTER_GLOBAL("relax.op.full_like").set_body_typed(full_like); - -StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { - Array input_sinfo = GetInputTensorStructInfo(call, ctx); +StructInfo InferStructInfoFullLike(const Call& call) { + Array input_sinfo = GetInputTensorStructInfo(call); TensorStructInfo data_sinfo = input_sinfo[0]; TensorStructInfo fill_value_sinfo = input_sinfo[1]; if (fill_value_sinfo->ndim != 0) { - ctx->ReportFatal(Diagnostic::Error(call) << "FullLike requires the input fill value to be zero " - "rank Tensor. However, the given one has ndim" - << fill_value_sinfo->ndim); + LOG(FATAL) << "FullLike requires the input fill value to be zero " + "rank Tensor. However, the given one has ndim" + << fill_value_sinfo->ndim; } const auto* attrs = call->attrs.as(); @@ -119,35 +119,50 @@ StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { } } +Expr full_like(Expr x, Expr fill_value, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.full_like"); + auto call = Call(op, {x, fill_value}, Attrs(attrs), {}); + + if (x->struct_info_.defined() && fill_value->struct_info_.defined()) { + UpdateStructInfo(call, InferStructInfoFullLike(call)); + } + + return call; +} + +TVM_REGISTER_GLOBAL("relax.op.full_like").set_body_typed(full_like); + TVM_REGISTER_OP("relax.full_like") .set_attrs_type() .set_num_inputs(2) .add_argument("x", "Tensor", "The input tensor.") .add_argument("fill_value", "Tensor", "The scalar value to fill.") - .set_attr("FInferStructInfo", InferStructInfoFullLike) + .set_attr("FInferStructInfo", + [](const Call& call, const BlockBuilder&) -> StructInfo { + return InferStructInfoFullLike(call); + }) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); // Structure info inference for ones and zeros -StructInfo InferStructInfoOnesZeros(const Call& call, const BlockBuilder& ctx) { - if (call->args.size() != 1) { - ctx->ReportFatal(Diagnostic::Error(call) << "Ones/Zeros should have 1 argument"); - } +StructInfo InferStructInfoOnesZeros(const Call& call) { + CheckNumArguments(call); const auto* shape_sinfo = GetStructInfoAs(call->args[0]); if (shape_sinfo == nullptr) { - ctx->ReportFatal( - Diagnostic::Error(call) - << "Ones/Zeros requires the input shape to be a Shape. However, the given one is " - << call->args[0]->struct_info_->GetTypeKey()); + LOG(FATAL) << "Operator " << call->op << " requires the input shape to be a Shape. " + << "However, the argument " << call->args[0] << " is of type " + << call->args[0]->struct_info_; } const auto* attrs = call->attrs.as(); return TensorStructInfo(/*shape=*/call->args[0], attrs->dtype); } // Structure info inference for ones_like and zeros_like -StructInfo InferStructInfoOnesLikeZerosLike(const Call& call, const BlockBuilder& ctx) { - TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); +StructInfo InferStructInfoOnesLikeZerosLike(const Call& call) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call); const auto* attrs = call->attrs.as(); if (attrs->dtype.is_void()) { return data_sinfo; @@ -165,14 +180,26 @@ Expr ones(Expr shape, DataType dtype) { attrs->dtype = dtype; static const Op& op = Op::Get("relax.ones"); - return Call(op, {std::move(shape)}, Attrs(attrs), {}); + Call call(op, {shape}, Attrs(attrs), {}); + + if (shape->struct_info_.defined()) { + UpdateStructInfo(call, InferStructInfoOnesZeros(call)); + } + + return call; } Expr ones_like(Expr x, DataType dtype) { ObjectPtr attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.ones_like"); - return Call(op, {std::move(x)}, Attrs(attrs), {}); + Call call(op, {x}, Attrs(attrs), {}); + + if (x->struct_info_.defined()) { + UpdateStructInfo(call, InferStructInfoOnesLikeZerosLike(call)); + } + + return call; } TVM_REGISTER_GLOBAL("relax.op.ones").set_body_typed(ones); @@ -182,7 +209,10 @@ TVM_REGISTER_OP("relax.ones") .set_attrs_type() .set_num_inputs(1) .add_argument("shape", "Shape", "The shape of the created tensor.") - .set_attr("FInferStructInfo", InferStructInfoOnesZeros) + .set_attr("FInferStructInfo", + [](const Call& call, const BlockBuilder&) -> StructInfo { + return InferStructInfoOnesZeros(call); + }) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); @@ -190,7 +220,10 @@ TVM_REGISTER_OP("relax.ones_like") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) + .set_attr("FInferStructInfo", + [](const Call& call, const BlockBuilder&) -> StructInfo { + return InferStructInfoOnesLikeZerosLike(call); + }) .set_attr("FPurity", Bool(true)); /* relax.zeros & relax.zeros_like */ @@ -200,14 +233,26 @@ Expr zeros(Expr shape, DataType dtype) { attrs->dtype = dtype; static const Op& op = Op::Get("relax.zeros"); - return Call(op, {std::move(shape)}, Attrs(attrs), {}); + Call call(op, {shape}, Attrs(attrs), {}); + + if (shape->struct_info_.defined()) { + UpdateStructInfo(call, InferStructInfoOnesZeros(call)); + } + + return call; } Expr zeros_like(Expr x, DataType dtype) { ObjectPtr attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.zeros_like"); - return Call(op, {std::move(x)}, Attrs(attrs), {}); + Call call(op, {x}, Attrs(attrs), {}); + + if (x->struct_info_.defined()) { + UpdateStructInfo(call, InferStructInfoOnesLikeZerosLike(call)); + } + + return call; } TVM_REGISTER_GLOBAL("relax.op.zeros").set_body_typed(zeros); @@ -217,7 +262,10 @@ TVM_REGISTER_OP("relax.zeros") .set_attrs_type() .set_num_inputs(1) .add_argument("shape", "Shape", "The shape of the created tensor.") - .set_attr("FInferStructInfo", InferStructInfoOnesZeros) + .set_attr("FInferStructInfo", + [](const Call& call, const BlockBuilder&) -> StructInfo { + return InferStructInfoOnesZeros(call); + }) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); @@ -225,32 +273,25 @@ TVM_REGISTER_OP("relax.zeros_like") .set_attrs_type() .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") - .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) + .set_attr("FInferStructInfo", + [](const Call& call, const BlockBuilder&) -> StructInfo { + return InferStructInfoOnesLikeZerosLike(call); + }) .set_attr("FPurity", Bool(true)); /* relax.arange */ -Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { - ObjectPtr attrs = make_object(); - attrs->dtype = dtype; - static const Op& op = Op::Get("relax.arange"); - return Call(op, {std::move(start), std::move(stop), std::move(step)}, Attrs(attrs), {}); -} - -TVM_REGISTER_GLOBAL("relax.op.arange").set_body_typed(arange); - -StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { +StructInfo InferStructInfoArange(const Call& call) { if (call->args.size() != 3) { - ctx->ReportFatal( - Diagnostic::Error(call) - << "Arange should have 3 arguments, which are `start`, `end` and `step`, but got " - << call->args.size() << " arguments"); + LOG(FATAL) << "Operator " << call->op + << " expects 3 arguments, which are `start`, `end` and `step`, " + << "but received " << call->args.size() << " arguments"; } // TODO(Siyuan): Support indirect prim_values - auto get_prim_value = [&ctx](const Expr& expr, std::string key) { + auto get_prim_value = [&](const Expr& expr, std::string key) { if (!expr->IsInstance()) { - ctx->ReportFatal(Diagnostic::Error(expr) - << "Arange expects the `" << key << "` to be a PrimValue, but got " - << expr->GetTypeKey()); + LOG(FATAL) << "Operator" << call->op << " expects the `" << key + << "` parameter to be a PrimValue, " + << "but argument " << expr << " was of type " << expr->GetTypeKey(); } return expr.as()->value; }; @@ -270,29 +311,69 @@ StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(ShapeExpr({num_elem}), dtype); } +Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.arange"); + Call call(op, {std::move(start), std::move(stop), std::move(step)}, Attrs(attrs), {}); + + UpdateStructInfo(call, InferStructInfoArange(call)); + + return call; +} + +TVM_REGISTER_GLOBAL("relax.op.arange").set_body_typed(arange); + TVM_REGISTER_OP("relax.arange") .set_attrs_type() .set_num_inputs(3) .add_argument("start", "PrimValue", "The starting value for the set of points.") .add_argument("end", "PrimValue", "The ending value for the set of points.") .add_argument("step", "PrimValue", "The gap between each pair of adjacent points.") - .set_attr("FInferStructInfo", InferStructInfoArange) + .set_attr("FInferStructInfo", + [](const Call& call, const BlockBuilder&) -> StructInfo { + return InferStructInfoArange(call); + }) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); /* relax.tril & relax.triu */ TVM_REGISTER_NODE_TYPE(TriluAttrs); +StructInfo InferStructInfoTrilTriu(const Call& call) { + auto [data_sinfo, offset] = GetArgStructInfo(call); + + if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim < 2) { + LOG(FATAL) << "Operator " << call->op + << " expects an input tensor with at least two dimensions. " + << "However, the argument " << call->args[0] << " has type " << data_sinfo + << " with " << data_sinfo->ndim << " dimension(s)."; + } + return data_sinfo; +} + Expr tril(Expr x, Expr k) { static const Op& op = Op::Get("relax.tril"); - return Call(op, {x, k}); + Call call(op, {x, k}); + + if (x->struct_info_.defined() && k->struct_info_.defined()) { + UpdateStructInfo(call, InferStructInfoTrilTriu(call)); + } + + return call; } Expr tril(Expr x, int k) { return tril(x, relax::PrimValue::Int64(k)); } Expr triu(Expr x, Expr k) { static const Op& op = Op::Get("relax.triu"); - return Call(op, {x, k}); + Call call(op, {x, k}); + + if (x->struct_info_.defined() && k->struct_info_.defined()) { + UpdateStructInfo(call, InferStructInfoTrilTriu(call)); + } + + return call; } Expr triu(Expr x, int k) { return triu(x, relax::PrimValue::Int64(k)); } @@ -300,30 +381,24 @@ Expr triu(Expr x, int k) { return triu(x, relax::PrimValue::Int64(k)); } TVM_REGISTER_GLOBAL("relax.op.tril").set_body_typed(static_cast(tril)); TVM_REGISTER_GLOBAL("relax.op.triu").set_body_typed(static_cast(triu)); -StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { - auto [data_sinfo, offset] = GetArgStructInfo(call, ctx); - - if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim < 2) { - ctx->ReportFatal(Diagnostic::Error(call) << call->op - << " requires the input tensor to have at least two " - "dimensions. However, the given input has " - << data_sinfo->ndim << " dimension(s)."); - } - return data_sinfo; -} - TVM_REGISTER_OP("relax.tril") .set_num_inputs(2) .add_argument("x", "Tensor", "The input tensor.") .add_argument("k", "PrimValue", "The offset of the diagonal.") - .set_attr("FInferStructInfo", InferStructInfoTrilTriu) + .set_attr("FInferStructInfo", + [](const Call& call, const BlockBuilder&) -> StructInfo { + return InferStructInfoTrilTriu(call); + }) .set_attr("FPurity", Bool(true)); TVM_REGISTER_OP("relax.triu") .set_num_inputs(2) .add_argument("x", "Tensor", "The input tensor.") .add_argument("k", "PrimValue", "The offset of the diagonal.") - .set_attr("FInferStructInfo", InferStructInfoTrilTriu) + .set_attr("FInferStructInfo", + [](const Call& call, const BlockBuilder&) -> StructInfo { + return InferStructInfoTrilTriu(call); + }) .set_attr("FPurity", Bool(true)); } // namespace relax diff --git a/tests/python/relax/test_bind_params.py b/tests/python/relax/test_bind_params.py index bed44c4a6ac2..36510e53c394 100644 --- a/tests/python/relax/test_bind_params.py +++ b/tests/python/relax/test_bind_params.py @@ -157,5 +157,31 @@ def before(A: R.Tensor([16], dtype="float32")): before.bind_params({"unknown_var_name": np.arange(16).astype("float32")}) +def test_bind_computed_value(): + @R.function(private=True) + def before( + state: R.Tensor(["batch_size", 16], "float16"), + weights: R.Tensor([16, 16], "float16"), + bias: R.Tensor([16], "float16"), + ): + state = R.matmul(state, weights) + state = R.add(state, bias) + return state + + @R.function(private=True) + def expected( + state: R.Tensor(["batch_size", 16], "float16"), + weights: R.Tensor([16, 16], "float16"), + ): + state = R.matmul(state, weights) + bias = R.zeros([16], "float16") + state = R.add(state, bias) + return state + + after = before.bind_params({"bias": R.zeros([16], "float16")}) + + tvm.ir.assert_structural_equal(expected, after) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py index 1e895169f620..244b9fbb2c8f 100644 --- a/tests/python/relax/test_op_create.py +++ b/tests/python/relax/test_op_create.py @@ -16,6 +16,8 @@ # under the License. import pytest +from typing import Union + import tvm import tvm.testing from tvm import TVMError, relax, tir @@ -23,28 +25,43 @@ from tvm.script import relax as R from tvm.script import tir as T +from tvm.script.parser.relax.entry import StructInfoProxy + def test_op_correctness(): x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) fill_value = relax.Var("fill_value", R.Tensor((), "float32")) - assert relax.op.full((2, 3), fill_value).op == Op.get("relax.full") - assert relax.op.full_like(x, fill_value).op == Op.get("relax.full_like") - assert relax.op.ones((2, 3), "float32").op == Op.get("relax.ones") - assert relax.op.ones_like(x).op == Op.get("relax.ones_like") - assert relax.op.zeros((2, 3), "float32").op == Op.get("relax.zeros") - assert relax.op.zeros_like(x).op == Op.get("relax.zeros_like") - assert relax.op.arange(3, 4, 1, "float32").op == Op.get("relax.arange") - assert relax.op.tril(x).op == Op.get("relax.tril") - assert relax.op.triu(x).op == Op.get("relax.triu") + assert R.full((2, 3), fill_value).op == Op.get("relax.full") + assert R.full_like(x, fill_value).op == Op.get("relax.full_like") + assert R.ones((2, 3), "float32").op == Op.get("relax.ones") + assert R.ones_like(x).op == Op.get("relax.ones_like") + assert R.zeros((2, 3), "float32").op == Op.get("relax.zeros") + assert R.zeros_like(x).op == Op.get("relax.zeros_like") + assert R.arange(3, 4, 1, "float32").op == Op.get("relax.arange") + assert R.tril(x).op == Op.get("relax.tril") + assert R.triu(x).op == Op.get("relax.triu") + + +def _get_inference_checker(bb: relax.BlockBuilder, normalize_before_check: bool = True): + def _check(call: relax.Call, expected_sinfo: Union[relax.StructInfo, StructInfoProxy]): + if isinstance(expected_sinfo, StructInfoProxy): + expected_sinfo = expected_sinfo.as_struct_info() + + if normalize_before_check: + call = bb.normalize(call) + + tvm.ir.assert_structural_equal(call.struct_info, expected_sinfo) + + return _check -def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): - ret = bb.normalize(call) - tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) +normalize_before_check = tvm.testing.parameter(True, False) -def test_full_infer_struct_info(): +def test_full_infer_struct_info(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + vdev0 = VDevice("llvm") v0 = relax.Var("v", R.Tensor((), "float32")) v1 = relax.Var("v", R.Tensor("float32", ndim=0)) @@ -56,153 +73,133 @@ def test_full_infer_struct_info(): s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) s3 = relax.Var("s", relax.ShapeStructInfo()) - _check_inference( - bb, relax.op.full((2, 3), v0, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.full(s0, v0, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full(s0, v0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full(s0, v4), relax.TensorStructInfo((2, 3), "float32", vdev0)) - _check_inference(bb, relax.op.full(s1, v0, "float16"), relax.TensorStructInfo(s1, "float16")) - _check_inference(bb, relax.op.full(s1, v0), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.full(s2, v0, "float16"), relax.TensorStructInfo(s2, "float16")) - _check_inference(bb, relax.op.full(s2, v0), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.full(s3, v0, "float16"), relax.TensorStructInfo(s3, "float16")) - _check_inference(bb, relax.op.full(s3, v0), relax.TensorStructInfo(s3, "float32")) - _check_inference( - bb, relax.op.full((2, 3), v1, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.full(s0, v1, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full(s0, v1), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full(s1, v1, "float16"), relax.TensorStructInfo(s1, "float16")) - _check_inference(bb, relax.op.full(s1, v1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.full(s2, v1, "float16"), relax.TensorStructInfo(s2, "float16")) - _check_inference(bb, relax.op.full(s2, v1), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.full(s3, v1, "float16"), relax.TensorStructInfo(s3, "float16")) - _check_inference(bb, relax.op.full(s3, v1), relax.TensorStructInfo(s3, "float32")) - _check_inference( - bb, relax.op.full((2, 3), v2, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference( - bb, relax.op.full(s0, v2, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full(s0, v2), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full(s1, v2, "float16"), relax.TensorStructInfo(s1, "float16")) - _check_inference(bb, relax.op.full(s1, v2), relax.TensorStructInfo(s1, dtype="")) - _check_inference(bb, relax.op.full(s2, v2, "float16"), relax.TensorStructInfo(s2, "float16")) - _check_inference(bb, relax.op.full(s2, v2), relax.TensorStructInfo(s2, dtype="")) - _check_inference(bb, relax.op.full(s3, v2, "float16"), relax.TensorStructInfo(s3, "float16")) - _check_inference(bb, relax.op.full(s3, v2), relax.TensorStructInfo(s3, dtype="")) - _check_inference( - bb, relax.op.full((2, 3), v3, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full((2, 3), v3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference( - bb, relax.op.full(s0, v3, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference(bb, relax.op.full(s0, v3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full(s1, v3, "float16"), relax.TensorStructInfo(s1, "float16")) - _check_inference( - bb, - relax.op.full( + inference_checker(R.full((2, 3), v0, "float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full((2, 3), v0), R.Tensor((2, 3), "float32")) + inference_checker(R.full(s0, v0, "float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full(s0, v0), R.Tensor((2, 3), "float32")) + inference_checker(R.full(s0, v4), R.Tensor((2, 3), "float32", vdev0)) + inference_checker(R.full(s1, v0, "float16"), R.Tensor(s1, "float16")) + inference_checker(R.full(s1, v0), R.Tensor(s1, "float32")) + inference_checker(R.full(s2, v0, "float16"), R.Tensor(s2, "float16")) + inference_checker(R.full(s2, v0), R.Tensor(s2, "float32")) + inference_checker(R.full(s3, v0, "float16"), R.Tensor(s3, "float16")) + inference_checker(R.full(s3, v0), R.Tensor(s3, "float32")) + inference_checker(R.full((2, 3), v1, "float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full((2, 3), v1), R.Tensor((2, 3), "float32")) + inference_checker(R.full(s0, v1, "float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full(s0, v1), R.Tensor((2, 3), "float32")) + inference_checker(R.full(s1, v1, "float16"), R.Tensor(s1, "float16")) + inference_checker(R.full(s1, v1), R.Tensor(s1, "float32")) + inference_checker(R.full(s2, v1, "float16"), R.Tensor(s2, "float16")) + inference_checker(R.full(s2, v1), R.Tensor(s2, "float32")) + inference_checker(R.full(s3, v1, "float16"), R.Tensor(s3, "float16")) + inference_checker(R.full(s3, v1), R.Tensor(s3, "float32")) + inference_checker(R.full((2, 3), v2, "float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full((2, 3), v2), R.Tensor((2, 3), dtype="")) + inference_checker(R.full(s0, v2, "float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full(s0, v2), R.Tensor((2, 3), dtype="")) + inference_checker(R.full(s1, v2, "float16"), R.Tensor(s1, "float16")) + inference_checker(R.full(s1, v2), R.Tensor(s1, dtype="")) + inference_checker(R.full(s2, v2, "float16"), R.Tensor(s2, "float16")) + inference_checker(R.full(s2, v2), R.Tensor(s2, dtype="")) + inference_checker(R.full(s3, v2, "float16"), R.Tensor(s3, "float16")) + inference_checker(R.full(s3, v2), R.Tensor(s3, dtype="")) + inference_checker(R.full((2, 3), v3, "float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full((2, 3), v3), R.Tensor((2, 3), dtype="")) + inference_checker(R.full(s0, v3, "float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full(s0, v3), R.Tensor((2, 3), dtype="")) + inference_checker(R.full(s1, v3, "float16"), R.Tensor(s1, "float16")) + inference_checker( + R.full( s1, v3, ), - relax.TensorStructInfo(s1, dtype=""), + R.Tensor(s1, dtype=""), ) - _check_inference(bb, relax.op.full(s2, v3, "float16"), relax.TensorStructInfo(s2, "float16")) - _check_inference( - bb, - relax.op.full( + inference_checker(R.full(s2, v3, "float16"), R.Tensor(s2, "float16")) + inference_checker( + R.full( s2, v3, ), - relax.TensorStructInfo(s2, dtype=""), + R.Tensor(s2, dtype=""), ) - _check_inference(bb, relax.op.full(s3, v3, "float16"), relax.TensorStructInfo(s3, "float16")) - _check_inference(bb, relax.op.full(s3, v3), relax.TensorStructInfo(s3, dtype="")) + inference_checker(R.full(s3, v3, "float16"), R.Tensor(s3, "float16")) + inference_checker(R.full(s3, v3), R.Tensor(s3, dtype="")) -def test_full_infer_struct_info_shape_symbolic(): +def test_full_infer_struct_info_shape_symbolic(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + a = tir.Var("a", "int64") v = relax.Var("v", R.Tensor((), "float32")) s0 = relax.ShapeExpr((a, 3)) s1 = relax.Var("s", relax.ShapeStructInfo((a, 3))) - _check_inference( - bb, relax.op.full((a, 3), v, "float16"), relax.TensorStructInfo((a, 3), "float16") - ) - _check_inference(bb, relax.op.full((a, 3), v), relax.TensorStructInfo((a, 3), "float32")) - _check_inference(bb, relax.op.full(s0, v, "float16"), relax.TensorStructInfo((a, 3), "float16")) - _check_inference(bb, relax.op.full(s0, v), relax.TensorStructInfo((a, 3), "float32")) - _check_inference(bb, relax.op.full(s1, v, "float16"), relax.TensorStructInfo(s1, "float16")) - _check_inference(bb, relax.op.full(s1, v), relax.TensorStructInfo(s1, "float32")) + inference_checker(R.full((a, 3), v, "float16"), R.Tensor((a, 3), "float16")) + inference_checker(R.full((a, 3), v), R.Tensor((a, 3), "float32")) + inference_checker(R.full(s0, v, "float16"), R.Tensor((a, 3), "float16")) + inference_checker(R.full(s0, v), R.Tensor((a, 3), "float32")) + inference_checker(R.full(s1, v, "float16"), R.Tensor(s1, "float16")) + inference_checker(R.full(s1, v), R.Tensor(s1, "float32")) -def test_full_infer_struct_info_shape_var(): +def test_full_infer_struct_info_shape_var(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + s0 = relax.Var("s", relax.ShapeStructInfo(())) s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) - v0 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) - v1 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + v0 = relax.Var("v", R.Tensor(s0, "float32")) + v1 = relax.Var("v", R.Tensor(s1, "float32")) - _check_inference( - bb, relax.op.full((2, 3), v0, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference( - bb, relax.op.full((2, 3), v1, "float16"), relax.TensorStructInfo((2, 3), "float16") - ) + inference_checker(R.full((2, 3), v0, "float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full((2, 3), v1, "float16"), R.Tensor((2, 3), "float16")) -def test_full_infer_struct_info_more_input_dtype(): +def test_full_infer_struct_info_more_input_dtype(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + v0 = relax.Var("v", R.Tensor((), "float16")) v1 = relax.Var("v", R.Tensor((), "int8")) v2 = relax.Var("v", R.Tensor((), "int32")) - _check_inference( - bb, relax.op.full((2, 3), v0, "float32"), relax.TensorStructInfo((2, 3), "float32") - ) - _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorStructInfo((2, 3), "float16")) - _check_inference( - bb, relax.op.full((2, 3), v1, "int32"), relax.TensorStructInfo((2, 3), "int32") - ) - _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, relax.op.full((2, 3), v2, "int8"), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorStructInfo((2, 3), "int32")) + inference_checker(R.full((2, 3), v0, "float32"), R.Tensor((2, 3), "float32")) + inference_checker(R.full((2, 3), v0), R.Tensor((2, 3), "float16")) + inference_checker(R.full((2, 3), v1, "int32"), R.Tensor((2, 3), "int32")) + inference_checker(R.full((2, 3), v1), R.Tensor((2, 3), "int8")) + inference_checker(R.full((2, 3), v2, "int8"), R.Tensor((2, 3), "int8")) + inference_checker(R.full((2, 3), v2), R.Tensor((2, 3), "int32")) -def test_full_infer_struct_info_fill_value_not_scalar_tensor(): +def test_full_infer_struct_info_fill_value_not_scalar_tensor(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + s0 = relax.Var("s", relax.ShapeStructInfo((1,))) s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) s2 = relax.Var("s", relax.ShapeStructInfo()) v0 = relax.Var("v", R.Tensor((1,), "float32")) v1 = relax.Var("v", R.Tensor("float32", ndim=1)) v2 = relax.Var("v", R.Tensor("float32")) - v3 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) - v4 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) - v5 = relax.Var("v", relax.TensorStructInfo(s2, "float32")) + v3 = relax.Var("v", R.Tensor(s0, "float32")) + v4 = relax.Var("v", R.Tensor(s1, "float32")) + v5 = relax.Var("v", R.Tensor(s2, "float32")) with pytest.raises(TVMError): - bb.normalize(relax.op.full((2, 3), v0)) + bb.normalize(R.full((2, 3), v0)) with pytest.raises(TVMError): - bb.normalize(relax.op.full((2, 3), v1)) + bb.normalize(R.full((2, 3), v1)) with pytest.raises(TVMError): - bb.normalize(relax.op.full((2, 3), v2)) + bb.normalize(R.full((2, 3), v2)) with pytest.raises(TVMError): - bb.normalize(relax.op.full((2, 3), v3)) + bb.normalize(R.full((2, 3), v3)) with pytest.raises(TVMError): - bb.normalize(relax.op.full((2, 3), v4)) + bb.normalize(R.full((2, 3), v4)) with pytest.raises(TVMError): - bb.normalize(relax.op.full((2, 3), v5)) + bb.normalize(R.full((2, 3), v5)) def test_full_shape_not_tuple(): @@ -210,9 +207,9 @@ def test_full_shape_not_tuple(): v = relax.Var("v", R.Tensor((), "float32")) with pytest.raises(TVMError): - relax.op.full(4, v) + R.full(4, v) with pytest.raises(TVMError): - relax.op.full(m, v) + R.full(m, v) def test_full_infer_struct_info_wrong_input_type(): @@ -223,15 +220,17 @@ def test_full_infer_struct_info_wrong_input_type(): s = relax.Var("s", R.Tensor((2, 3))) with pytest.raises(TVMError): - bb.normalize(relax.op.full(s, v0)) + bb.normalize(R.full(s, v0)) with pytest.raises(TVMError): - bb.normalize(relax.op.full((2, 3), v1)) + bb.normalize(R.full((2, 3), v1)) with pytest.raises(TVMError): - bb.normalize(relax.op.full((2, 3), v2)) + bb.normalize(R.full((2, 3), v2)) -def test_full_like_infer_struct_info(): +def test_full_like_infer_struct_info(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) x2 = relax.Var("x", R.Tensor("float32")) @@ -243,113 +242,103 @@ def test_full_like_infer_struct_info(): v2 = relax.Var("v", R.Tensor(())) v3 = relax.Var("v", R.Tensor(ndim=0)) - _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full_like(x0, v3), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.full_like(x1, v0), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.full_like(x1, v1), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.full_like(x1, v2), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference( - bb, relax.op.full_like(x1, v3), relax.TensorStructInfo(dtype="float32", ndim=2) - ) - _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.full_like(x2, v3), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full_like(x3, v2), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full_like(x3, v3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.full_like(x4, v0), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.full_like(x4, v1), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.full_like(x4, v2), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.full_like(x4, v3), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.full_like(x5, v0), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.full_like(x5, v1), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.full_like(x5, v2), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.full_like(x5, v3), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.full_like(x0, v0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference( - bb, relax.op.full_like(x0, v2, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference( - bb, relax.op.full_like(x3, v0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference( - bb, relax.op.full_like(x3, v2, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") - ) - - -def test_full_like_infer_struct_info_shape_symbolic(): + inference_checker(R.full_like(x0, v0), R.Tensor((2, 3), "float32")) + inference_checker(R.full_like(x0, v1), R.Tensor((2, 3), "float32")) + inference_checker(R.full_like(x0, v2), R.Tensor((2, 3), "float32")) + inference_checker(R.full_like(x0, v3), R.Tensor((2, 3), "float32")) + inference_checker(R.full_like(x1, v0), R.Tensor(dtype="float32", ndim=2)) + inference_checker(R.full_like(x1, v1), R.Tensor(dtype="float32", ndim=2)) + inference_checker(R.full_like(x1, v2), R.Tensor(dtype="float32", ndim=2)) + inference_checker(R.full_like(x1, v3), R.Tensor(dtype="float32", ndim=2)) + inference_checker(R.full_like(x2, v0), R.Tensor(dtype="float32")) + inference_checker(R.full_like(x2, v1), R.Tensor(dtype="float32")) + inference_checker(R.full_like(x2, v2), R.Tensor(dtype="float32")) + inference_checker(R.full_like(x2, v3), R.Tensor(dtype="float32")) + inference_checker(R.full_like(x3, v0), R.Tensor((2, 3), dtype="")) + inference_checker(R.full_like(x3, v1), R.Tensor((2, 3), dtype="")) + inference_checker(R.full_like(x3, v2), R.Tensor((2, 3), dtype="")) + inference_checker(R.full_like(x3, v3), R.Tensor((2, 3), dtype="")) + inference_checker(R.full_like(x4, v0), R.Tensor(dtype="", ndim=2)) + inference_checker(R.full_like(x4, v1), R.Tensor(dtype="", ndim=2)) + inference_checker(R.full_like(x4, v2), R.Tensor(dtype="", ndim=2)) + inference_checker(R.full_like(x4, v3), R.Tensor(dtype="", ndim=2)) + inference_checker(R.full_like(x5, v0), R.Tensor(dtype="")) + inference_checker(R.full_like(x5, v1), R.Tensor(dtype="")) + inference_checker(R.full_like(x5, v2), R.Tensor(dtype="")) + inference_checker(R.full_like(x5, v3), R.Tensor(dtype="")) + inference_checker(R.full_like(x0, v0, dtype="float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full_like(x0, v2, dtype="float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full_like(x3, v0, dtype="float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.full_like(x3, v2, dtype="float16"), R.Tensor((2, 3), "float16")) + + +def test_full_like_infer_struct_info_shape_symbolic(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + m = tir.Var("m", "int64") n = tir.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((m, n))) v = relax.Var("v", R.Tensor((), "float16")) - _check_inference(bb, relax.op.full_like(x0, v), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.full_like(x1, v), relax.TensorStructInfo((m, n), dtype="")) + inference_checker(R.full_like(x0, v), R.Tensor((m, n), "float32")) + inference_checker(R.full_like(x1, v), R.Tensor((m, n), dtype="")) -def test_full_like_infer_struct_info_shape_var(): +def test_full_like_infer_struct_info_shape_var(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + vdev0 = VDevice("llvm") s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x0 = relax.Var("x", R.Tensor(s0, "float32")) + x1 = relax.Var("x", R.Tensor(s1, "float32")) + x2 = relax.Var("x", R.Tensor(s2, "float32")) x3 = relax.Var("x", R.Tensor((2, 3), "float32")) x4 = relax.Var("x", R.Tensor((2, 3), "float32", vdev0)) sv0 = relax.Var("sv", relax.ShapeStructInfo(())) sv1 = relax.Var("sv", relax.ShapeStructInfo(ndim=0)) - v0 = relax.Var("v", relax.TensorStructInfo(sv0, "float16")) - v1 = relax.Var("v", relax.TensorStructInfo(sv1, "float16")) + v0 = relax.Var("v", R.Tensor(sv0, "float16")) + v1 = relax.Var("v", R.Tensor(sv1, "float16")) v2 = relax.Var("v", R.Tensor((), "float16")) - v3 = relax.Var("v", relax.TensorStructInfo(sv1, "float16", vdev0)) - - _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.full_like(x1, v2), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), "float32")) - _check_inference( - bb, relax.op.full_like(x4, v3), relax.TensorStructInfo((2, 3), "float32", vdev0) - ) - - -def test_full_like_infer_struct_info_more_input_dtype(): + v3 = relax.Var("v", R.Tensor(sv1, "float16", vdev0)) + + inference_checker(R.full_like(x0, v0), R.Tensor(s0, "float32")) + inference_checker(R.full_like(x0, v1), R.Tensor(s0, "float32")) + inference_checker(R.full_like(x0, v2), R.Tensor(s0, "float32")) + inference_checker(R.full_like(x1, v0), R.Tensor(s1, "float32")) + inference_checker(R.full_like(x1, v1), R.Tensor(s1, "float32")) + inference_checker(R.full_like(x1, v2), R.Tensor(s1, "float32")) + inference_checker(R.full_like(x2, v0), R.Tensor(s2, "float32")) + inference_checker(R.full_like(x2, v1), R.Tensor(s2, "float32")) + inference_checker(R.full_like(x2, v2), R.Tensor(s2, "float32")) + inference_checker(R.full_like(x3, v0), R.Tensor((2, 3), "float32")) + inference_checker(R.full_like(x3, v1), R.Tensor((2, 3), "float32")) + inference_checker(R.full_like(x4, v3), R.Tensor((2, 3), "float32", vdev0)) + + +def test_full_like_infer_struct_info_more_input_dtype(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) v0 = relax.Var("v", R.Tensor((), "int32")) v1 = relax.Var("v", R.Tensor((), "float64")) - _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo((2, 3), "float16")) - _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo((2, 3), "float16")) - _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorStructInfo((2, 3), "int8")) - _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorStructInfo((2, 3), "int8")) + inference_checker(R.full_like(x0, v0), R.Tensor((2, 3), "float16")) + inference_checker(R.full_like(x0, v1), R.Tensor((2, 3), "float16")) + inference_checker(R.full_like(x1, v0), R.Tensor((2, 3), "int8")) + inference_checker(R.full_like(x1, v1), R.Tensor((2, 3), "int8")) -def test_full_like_infer_struct_info_fill_value_not_scalar_tensor(): +def test_full_like_infer_struct_info_fill_value_not_scalar_tensor(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + x = relax.Var("x", R.Tensor((2, 3), "float32")) s0 = relax.Var("s", relax.ShapeStructInfo((1,))) s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) @@ -357,26 +346,28 @@ def test_full_like_infer_struct_info_fill_value_not_scalar_tensor(): v0 = relax.Var("v", R.Tensor((1,), "float32")) v1 = relax.Var("v", R.Tensor("float32", ndim=1)) v2 = relax.Var("v", R.Tensor("float32")) - v3 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) - v4 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) - v5 = relax.Var("v", relax.TensorStructInfo(s2, "float32")) + v3 = relax.Var("v", R.Tensor(s0, "float32")) + v4 = relax.Var("v", R.Tensor(s1, "float32")) + v5 = relax.Var("v", R.Tensor(s2, "float32")) with pytest.raises(TVMError): - bb.normalize(relax.op.full_like(x, v0)) + bb.normalize(R.full_like(x, v0)) with pytest.raises(TVMError): - bb.normalize(relax.op.full_like(x, v1)) + bb.normalize(R.full_like(x, v1)) with pytest.raises(TVMError): - bb.normalize(relax.op.full_like(x, v2)) + bb.normalize(R.full_like(x, v2)) with pytest.raises(TVMError): - bb.normalize(relax.op.full_like(x, v3)) + bb.normalize(R.full_like(x, v3)) with pytest.raises(TVMError): - bb.normalize(relax.op.full_like(x, v4)) + bb.normalize(R.full_like(x, v4)) with pytest.raises(TVMError): - bb.normalize(relax.op.full_like(x, v5)) + bb.normalize(R.full_like(x, v5)) -def test_full_like_infer_struct_info_wrong_input_type(): +def test_full_like_infer_struct_info_wrong_input_type(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((), "float32"))) x2 = relax.Var("x", R.Tensor((2, 3))) @@ -384,86 +375,82 @@ def test_full_like_infer_struct_info_wrong_input_type(): v1 = relax.Var("v", relax.ShapeStructInfo(())) with pytest.raises(TVMError): - bb.normalize(relax.op.full_like(x0, v0)) + bb.normalize(R.full_like(x0, v0)) with pytest.raises(TVMError): - bb.normalize(relax.op.full_like(x1, v0)) + bb.normalize(R.full_like(x1, v0)) with pytest.raises(TVMError): - bb.normalize(relax.op.full_like(x2, v1)) + bb.normalize(R.full_like(x2, v1)) -def test_ones_zeros_infer_struct_info(): +def test_ones_zeros_infer_struct_info(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) + s0 = relax.ShapeExpr((2, 3)) s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) s3 = relax.Var("s", relax.ShapeStructInfo()) - _check_inference( - bb, relax.op.ones((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32") - ) - _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.ones(s2, "float32"), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.ones(s3, "float32"), relax.TensorStructInfo(s3, "float32")) - _check_inference( - bb, relax.op.zeros((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32") - ) - _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.zeros(s2, "float32"), relax.TensorStructInfo(s2, "float32")) - _check_inference(bb, relax.op.zeros(s3, "float32"), relax.TensorStructInfo(s3, "float32")) + inference_checker(R.ones((2, 3), "float32"), R.Tensor((2, 3), "float32")) + inference_checker(R.ones(s0, "float32"), R.Tensor((2, 3), "float32")) + inference_checker(R.ones(s1, "float32"), R.Tensor(s1, "float32")) + inference_checker(R.ones(s2, "float32"), R.Tensor(s2, "float32")) + inference_checker(R.ones(s3, "float32"), R.Tensor(s3, "float32")) + inference_checker(R.zeros((2, 3), "float32"), R.Tensor((2, 3), "float32")) + inference_checker(R.zeros(s0, "float32"), R.Tensor((2, 3), "float32")) + inference_checker(R.zeros(s1, "float32"), R.Tensor(s1, "float32")) + inference_checker(R.zeros(s2, "float32"), R.Tensor(s2, "float32")) + inference_checker(R.zeros(s3, "float32"), R.Tensor(s3, "float32")) -def test_ones_zeros_infer_struct_info_shape_symbolic(): +def test_ones_zeros_infer_struct_info_shape_symbolic(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) m = tir.Var("m", "int64") n = tir.Var("n", "int64") s0 = relax.ShapeExpr((m, n)) s1 = relax.Var("s", relax.ShapeStructInfo((m, n))) - _check_inference( - bb, relax.op.ones((m, n), "float32"), relax.TensorStructInfo((m, n), "float32") - ) - _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorStructInfo(s1, "float32")) - _check_inference( - bb, relax.op.zeros((m, n), "float32"), relax.TensorStructInfo((m, n), "float32") - ) - _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + inference_checker(R.ones((m, n), "float32"), R.Tensor((m, n), "float32")) + inference_checker(R.ones(s0, "float32"), R.Tensor((m, n), "float32")) + inference_checker(R.ones(s1, "float32"), R.Tensor(s1, "float32")) + inference_checker(R.zeros((m, n), "float32"), R.Tensor((m, n), "float32")) + inference_checker(R.zeros(s0, "float32"), R.Tensor((m, n), "float32")) + inference_checker(R.zeros(s1, "float32"), R.Tensor(s1, "float32")) -def test_ones_zeros_infer_struct_info_more_input_dtype(): +def test_ones_zeros_infer_struct_info_more_input_dtype(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) s0 = relax.ShapeExpr((2, 3)) s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) s3 = relax.Var("s", relax.ShapeStructInfo()) - _check_inference(bb, relax.op.ones(s0, "float16"), relax.TensorStructInfo((2, 3), "float16")) - _check_inference(bb, relax.op.ones(s1, "int8"), relax.TensorStructInfo(s1, "int8")) - _check_inference(bb, relax.op.zeros(s2, "int32"), relax.TensorStructInfo(s2, "int32")) - _check_inference(bb, relax.op.zeros(s3, "float64"), relax.TensorStructInfo(s3, "float64")) + inference_checker(R.ones(s0, "float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.ones(s1, "int8"), R.Tensor(s1, "int8")) + inference_checker(R.zeros(s2, "int32"), R.Tensor(s2, "int32")) + inference_checker(R.zeros(s3, "float64"), R.Tensor(s3, "float64")) def test_ones_zeros_shape_not_tuple(): m = tir.Var("m", "int64") with pytest.raises(TVMError): - relax.op.ones(10, "float32") + R.ones(10, "float32") with pytest.raises(TVMError): - relax.op.zeros(m, "float32") + R.zeros(m, "float32") def test_ones_zeros_wrong_dtype(): with pytest.raises(TypeError): - relax.op.ones((2, 3)) + R.ones((2, 3)) with pytest.raises(TVMError): - relax.op.ones((2, 3), "") + R.ones((2, 3), "") with pytest.raises(TypeError): - relax.op.zeros((2, 3)) + R.zeros((2, 3)) with pytest.raises(TVMError): - relax.op.zeros((2, 3), "") + R.zeros((2, 3), "") def test_ones_zeros_infer_struct_info_wrong_input_type(): @@ -472,13 +459,14 @@ def test_ones_zeros_infer_struct_info_wrong_input_type(): s1 = relax.Var("s", relax.FuncStructInfo([], R.Tensor((2, 3)))) with pytest.raises(TVMError): - bb.normalize(relax.op.ones(s0, "float32")) + bb.normalize(R.ones(s0, "float32")) with pytest.raises(TVMError): - bb.normalize(relax.op.zeros(s1, "float32")) + bb.normalize(R.zeros(s1, "float32")) -def test_ones_like_zeros_like_infer_struct_info(): +def test_ones_like_zeros_like_infer_struct_info(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) x0 = relax.Var("x", R.Tensor((2, 3), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=2)) x2 = relax.Var("x", R.Tensor("float32")) @@ -486,52 +474,51 @@ def test_ones_like_zeros_like_infer_struct_info(): x4 = relax.Var("x", R.Tensor(ndim=2)) x5 = relax.Var("x", R.Tensor()) - _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((2, 3), "float32")) - _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo(dtype="float32", ndim=2)) - _check_inference(bb, relax.op.ones_like(x2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.zeros_like(x3), relax.TensorStructInfo((2, 3), dtype="")) - _check_inference(bb, relax.op.ones_like(x4), relax.TensorStructInfo(dtype="", ndim=2)) - _check_inference(bb, relax.op.zeros_like(x5), relax.TensorStructInfo(dtype="")) - _check_inference( - bb, relax.op.ones_like(x0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") - ) - _check_inference( - bb, relax.op.zeros_like(x3, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") - ) + inference_checker(R.ones_like(x0), R.Tensor((2, 3), "float32")) + inference_checker(R.zeros_like(x1), R.Tensor(dtype="float32", ndim=2)) + inference_checker(R.ones_like(x2), R.Tensor(dtype="float32")) + inference_checker(R.zeros_like(x3), R.Tensor((2, 3), dtype="")) + inference_checker(R.ones_like(x4), R.Tensor(dtype="", ndim=2)) + inference_checker(R.zeros_like(x5), R.Tensor(dtype="")) + inference_checker(R.ones_like(x0, dtype="float16"), R.Tensor((2, 3), "float16")) + inference_checker(R.zeros_like(x3, dtype="float16"), R.Tensor((2, 3), "float16")) -def test_ones_like_zeros_like_infer_struct_info_shape_symbolic(): +def test_ones_like_zeros_like_infer_struct_info_shape_symbolic(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) m = tir.Var("m", "int64") n = tir.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((m, n))) - _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((m, n), "float32")) - _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo((m, n), dtype="")) + inference_checker(R.ones_like(x0), R.Tensor((m, n), "float32")) + inference_checker(R.zeros_like(x1), R.Tensor((m, n), dtype="")) -def test_ones_like_zeros_like_infer_struct_info_shape_var(): +def test_ones_like_zeros_like_infer_struct_info_shape_var(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x0 = relax.Var("x", R.Tensor(s0, "float32")) + x1 = relax.Var("x", R.Tensor(s1, "float32")) + x2 = relax.Var("x", R.Tensor(s2, "float32")) - _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.zeros_like(x2), relax.TensorStructInfo(s2, "float32")) + inference_checker(R.ones_like(x0), R.Tensor(s0, "float32")) + inference_checker(R.zeros_like(x1), R.Tensor(s1, "float32")) + inference_checker(R.zeros_like(x2), R.Tensor(s2, "float32")) -def test_ones_like_zeros_like_infer_struct_info_more_input_dtype(): +def test_ones_like_zeros_like_infer_struct_info_more_input_dtype(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) x0 = relax.Var("x", R.Tensor((2, 3), "float64")) x1 = relax.Var("x", R.Tensor((2, 3), "int8")) - _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((2, 3), "float64")) - _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo((2, 3), "int8")) + inference_checker(R.ones_like(x0), R.Tensor((2, 3), "float64")) + inference_checker(R.zeros_like(x1), R.Tensor((2, 3), "int8")) def test_ones_like_zeros_like_infer_struct_info_wrong_input_type(): @@ -540,80 +527,74 @@ def test_ones_like_zeros_like_infer_struct_info_wrong_input_type(): x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) with pytest.raises(TVMError): - bb.normalize(relax.op.ones_like(x0)) + bb.normalize(R.ones_like(x0)) with pytest.raises(TVMError): - bb.normalize(relax.op.zeros_like(x1)) + bb.normalize(R.zeros_like(x1)) -def test_arange_infer_struct_info(): +def test_arange_infer_struct_info(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) - _check_inference(bb, relax.op.arange(10), relax.TensorStructInfo((10,), "int64")) - _check_inference(bb, relax.op.arange(1, 10), relax.TensorStructInfo((9,), "int64")) - _check_inference(bb, relax.op.arange(0, 10, 2), relax.TensorStructInfo((5,), "int64")) - _check_inference(bb, relax.op.arange(1, 10, 2), relax.TensorStructInfo((5,), "int64")) + inference_checker(R.arange(10), R.Tensor((10,), "int64")) + inference_checker(R.arange(1, 10), R.Tensor((9,), "int64")) + inference_checker(R.arange(0, 10, 2), R.Tensor((5,), "int64")) + inference_checker(R.arange(1, 10, 2), R.Tensor((5,), "int64")) - _check_inference(bb, relax.op.arange(10.0), relax.TensorStructInfo((10,), "float32")) - _check_inference(bb, relax.op.arange(1.0, 10), relax.TensorStructInfo((9,), "float32")) - _check_inference(bb, relax.op.arange(0, 20, 2.5), relax.TensorStructInfo((8,), "float32")) - _check_inference(bb, relax.op.arange(1, 10, 2.3), relax.TensorStructInfo((4,), "float32")) + inference_checker(R.arange(10.0), R.Tensor((10,), "float32")) + inference_checker(R.arange(1.0, 10), R.Tensor((9,), "float32")) + inference_checker(R.arange(0, 20, 2.5), R.Tensor((8,), "float32")) + inference_checker(R.arange(1, 10, 2.3), R.Tensor((4,), "float32")) -def test_arange_infer_struct_info_shape_var(): +def test_arange_infer_struct_info_shape_var(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) start = tir.Var("start", "int64") stop = tir.Var("stop", "int64") step = tir.Var("step", "int64") - _check_inference(bb, relax.op.arange(stop), relax.TensorStructInfo((stop,), "int64")) - _check_inference(bb, relax.op.arange(1, stop), relax.TensorStructInfo((stop - 1,), "int64")) - _check_inference( - bb, relax.op.arange(start, stop), relax.TensorStructInfo((stop - start,), "int64") - ) - _check_inference( - bb, - relax.op.arange(start, stop, 2), - relax.TensorStructInfo(((stop + 1 - start) // 2,), "int64"), + inference_checker(R.arange(stop), R.Tensor((stop,), "int64")) + inference_checker(R.arange(1, stop), R.Tensor((stop - 1,), "int64")) + inference_checker(R.arange(start, stop), R.Tensor((stop - start,), "int64")) + inference_checker( + R.arange(start, stop, 2), + R.Tensor(((stop + 1 - start) // 2,), "int64"), ) - _check_inference( - bb, - relax.op.arange(start, stop, step), - relax.TensorStructInfo(((stop + step - start - 1) // step,), "int64"), + inference_checker( + R.arange(start, stop, step), + R.Tensor(((stop + step - start - 1) // step,), "int64"), ) start = tir.Var("start", "float32") stop = tir.Var("stop", "float32") step = tir.Var("step", "float32") - _check_inference( - bb, - relax.op.arange(stop), - relax.TensorStructInfo((T.cast(T.ceil(stop), "int64"),), "float32"), + inference_checker( + R.arange(stop), + R.Tensor((T.cast(T.ceil(stop), "int64"),), "float32"), ) - _check_inference( - bb, - relax.op.arange(1, stop), - relax.TensorStructInfo((T.cast(T.ceil(stop - 1.0), "int64"),), "float32"), + inference_checker( + R.arange(1, stop), + R.Tensor((T.cast(T.ceil(stop - 1.0), "int64"),), "float32"), ) - _check_inference( - bb, - relax.op.arange(start, stop), - relax.TensorStructInfo((T.cast(T.ceil(stop - start), "int64"),), "float32"), + inference_checker( + R.arange(start, stop), + R.Tensor((T.cast(T.ceil(stop - start), "int64"),), "float32"), ) - _check_inference( - bb, - relax.op.arange(start, stop, 2), - relax.TensorStructInfo((T.cast(T.ceil((stop - start) * 0.5), "int64"),), "float32"), + inference_checker( + R.arange(start, stop, 2), + R.Tensor((T.cast(T.ceil((stop - start) * 0.5), "int64"),), "float32"), ) - _check_inference( - bb, - relax.op.arange(start, stop, step), - relax.TensorStructInfo((T.cast(T.ceil((stop - start) / step), "int64"),), "float32"), + inference_checker( + R.arange(start, stop, step), + R.Tensor((T.cast(T.ceil((stop - start) / step), "int64"),), "float32"), ) -def test_tril_triu_infer_struct_info(): +def test_tril_triu_infer_struct_info(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) vdev0 = VDevice("llvm") x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=3)) @@ -623,19 +604,20 @@ def test_tril_triu_infer_struct_info(): x5 = relax.Var("x", R.Tensor()) x6 = relax.Var("x", R.Tensor((2, 3, 4), "float32", vdev0)) - _check_inference(bb, relax.op.tril(x0, k=1), relax.TensorStructInfo((2, 3, 4), "float32")) - _check_inference(bb, relax.op.triu(x0, k=0), relax.TensorStructInfo((2, 3, 4), "float32")) - _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((2, 3, 4), "float32")) - _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) - _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo(dtype="float32")) - _check_inference(bb, relax.op.triu(x3), relax.TensorStructInfo((2, 3, 4), dtype="")) - _check_inference(bb, relax.op.tril(x4), relax.TensorStructInfo(dtype="", ndim=3)) - _check_inference(bb, relax.op.triu(x5), relax.TensorStructInfo(dtype="")) - _check_inference(bb, relax.op.tril(x6), relax.TensorStructInfo((2, 3, 4), "float32", vdev0)) + inference_checker(R.tril(x0, k=1), R.Tensor((2, 3, 4), "float32")) + inference_checker(R.triu(x0, k=0), R.Tensor((2, 3, 4), "float32")) + inference_checker(R.tril(x0), R.Tensor((2, 3, 4), "float32")) + inference_checker(R.triu(x1), R.Tensor(dtype="float32", ndim=3)) + inference_checker(R.tril(x2), R.Tensor(dtype="float32")) + inference_checker(R.triu(x3), R.Tensor((2, 3, 4), dtype="")) + inference_checker(R.tril(x4), R.Tensor(dtype="", ndim=3)) + inference_checker(R.triu(x5), R.Tensor(dtype="")) + inference_checker(R.tril(x6), R.Tensor((2, 3, 4), "float32", vdev0)) -def test_tril_triu_infer_struct_info_shape_symbolic(): +def test_tril_triu_infer_struct_info_shape_symbolic(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) vdev0 = VDevice("llvm") a = tir.Var("a", "int64") b = tir.Var("b", "int64") @@ -646,40 +628,42 @@ def test_tril_triu_infer_struct_info_shape_symbolic(): x3 = relax.Var("x", R.Tensor((16, 32, 64))) # Dynamic tensor, static offset - _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((a, b, c), "float32")) - _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo((a, b, c), dtype="")) - _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo((a, b, c), "float32", vdev0)) + inference_checker(R.tril(x0), R.Tensor((a, b, c), "float32")) + inference_checker(R.triu(x1), R.Tensor((a, b, c), dtype="")) + inference_checker(R.tril(x2), R.Tensor((a, b, c), "float32", vdev0)) # Static tensor, dynamic offset - _check_inference(bb, relax.op.tril(x3, a), relax.TensorStructInfo((16, 32, 64), dtype="")) + inference_checker(R.tril(x3, a), R.Tensor((16, 32, 64), dtype="")) # Dynamic tensor, dynamic offset - _check_inference(bb, relax.op.tril(x0, a), relax.TensorStructInfo((a, b, c), "float32")) + inference_checker(R.tril(x0, a), R.Tensor((a, b, c), "float32")) -def test_tril_triu_infer_struct_info_shape_var(): +def test_tril_triu_infer_struct_info_shape_var(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) s2 = relax.Var("s", relax.ShapeStructInfo()) - x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x0 = relax.Var("x", R.Tensor(s0, "float32")) + x1 = relax.Var("x", R.Tensor(s1, "float32")) + x2 = relax.Var("x", R.Tensor(s2, "float32")) - _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo(s1, "float32")) - _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo(s2, "float32")) + inference_checker(R.tril(x0), R.Tensor(s0, "float32")) + inference_checker(R.triu(x1), R.Tensor(s1, "float32")) + inference_checker(R.tril(x2), R.Tensor(s2, "float32")) -def test_tril_triu_infer_struct_info_more_input_dtype(): +def test_tril_triu_infer_struct_info_more_input_dtype(normalize_before_check): bb = relax.BlockBuilder() + inference_checker = _get_inference_checker(bb, normalize_before_check=normalize_before_check) x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) - _check_inference(bb, relax.op.triu(x0), relax.TensorStructInfo((2, 3, 4), "float16")) - _check_inference(bb, relax.op.tril(x1), relax.TensorStructInfo((2, 3, 4), "int8")) - _check_inference(bb, relax.op.triu(x2), relax.TensorStructInfo((2, 3, 4), "int32")) + inference_checker(R.triu(x0), R.Tensor((2, 3, 4), "float16")) + inference_checker(R.tril(x1), R.Tensor((2, 3, 4), "int8")) + inference_checker(R.triu(x2), R.Tensor((2, 3, 4), "int32")) def test_tril_triu_infer_struct_info_less_than_two_ndim(): @@ -692,27 +676,27 @@ def test_tril_triu_infer_struct_info_less_than_two_ndim(): x1 = relax.Var("x", R.Tensor((), "float32")) x2 = relax.Var("x", R.Tensor("float32", ndim=1)) x3 = relax.Var("x", R.Tensor("float32", ndim=0)) - x4 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) - x5 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - x6 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) - x7 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x4 = relax.Var("x", R.Tensor(s0, "float32")) + x5 = relax.Var("x", R.Tensor(s1, "float32")) + x6 = relax.Var("x", R.Tensor(s2, "float32")) + x7 = relax.Var("x", R.Tensor(s3, "float32")) with pytest.raises(TVMError): - bb.normalize(relax.op.tril(x0)) + bb.normalize(R.tril(x0)) with pytest.raises(TVMError): - bb.normalize(relax.op.triu(x1)) + bb.normalize(R.triu(x1)) with pytest.raises(TVMError): - bb.normalize(relax.op.tril(x2)) + bb.normalize(R.tril(x2)) with pytest.raises(TVMError): - bb.normalize(relax.op.triu(x3)) + bb.normalize(R.triu(x3)) with pytest.raises(TVMError): - bb.normalize(relax.op.tril(x4)) + bb.normalize(R.tril(x4)) with pytest.raises(TVMError): - bb.normalize(relax.op.triu(x5)) + bb.normalize(R.triu(x5)) with pytest.raises(TVMError): - bb.normalize(relax.op.tril(x6)) + bb.normalize(R.tril(x6)) with pytest.raises(TVMError): - bb.normalize(relax.op.triu(x7)) + bb.normalize(R.triu(x7)) def test_tril_triu_infer_struct_info_wrong_input_type(): @@ -721,9 +705,9 @@ def test_tril_triu_infer_struct_info_wrong_input_type(): x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) with pytest.raises(TVMError): - bb.normalize(relax.op.tril(x0)) + bb.normalize(R.tril(x0)) with pytest.raises(TVMError): - bb.normalize(relax.op.triu(x1)) + bb.normalize(R.triu(x1)) if __name__ == "__main__": diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py b/tests/python/relax/test_transform_legalize_ops_create_datatype.py index 7b2b2d2e7644..0180ff9c62f1 100644 --- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py +++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py @@ -160,14 +160,14 @@ def test_full_like(): @tvm.script.ir_module class FullLike: @R.function - def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): + def main(x: R.Tensor((2, 3), "float32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): gv: R.Tensor((2, 3), "float32") = R.full_like(x, v) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): + def main(x: R.Tensor((2, 3), "float32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"): gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="float32")) return gv @@ -191,14 +191,14 @@ def test_full_like_constant_scalar_fill_value(): @tvm.script.ir_module class FullLike: @R.function - def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"): + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): gv: R.Tensor((2, 3), "float32") = R.full_like(x, R.const(-5, "float32")) return gv @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"): + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="float32")) return gv @@ -217,7 +217,7 @@ def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): tvm.ir.assert_structural_equal(mod, Expected) -def test_full_like_different_dtype(): +def test_full_like_different_explicit_dtype(): # fmt: off @tvm.script.ir_module class FullLike: @@ -248,12 +248,43 @@ def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T tvm.ir.assert_structural_equal(mod, Expected) +def test_full_like_different_inferred_dtype(): + # fmt: off + @tvm.script.ir_module + class FullLike: + @R.function + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")): + gv = R.full_like(x, v) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "int32"): + gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="int32")) + return gv + + @T.prim_func(private=True) + def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.Cast("int32", rxplaceholder[()]) + # fmt: on + + mod = LegalizeOps()(FullLike) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_full_like_symbolic(): # fmt: off @tvm.script.ir_module class FullLike: @R.function - def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): + def main(x: R.Tensor(("m", "n"), "float32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): m = T.int64() n = T.int64() gv: R.Tensor((m, n), "float32") = R.full_like(x, v) @@ -262,7 +293,7 @@ def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tens @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): + def main(x: R.Tensor(("m", "n"), "float32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"): m = T.int64() n = T.int64() gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="float32"))