From 45725cbbd33ee90963c31c9ab1a8511d8a0d7f6a Mon Sep 17 00:00:00 2001 From: Zentrik Date: Tue, 4 Feb 2025 17:23:43 +0000 Subject: [PATCH] Add broadcasting support for Remainder op in torch mlir -> stable hlo conversion --- lib/Conversion/TorchToStablehlo/Basic.cpp | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index d6ba57a08a8f..bec3eb5d38cb 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1966,24 +1966,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -// AtenRemainderTensorOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenRemainderTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value lhs = adaptor.getSelf(); - Value rhs = adaptor.getOther(); - - auto resultType = - cast(getTypeConverter()->convertType(op.getType())); - lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, - resultType.getElementType()); - rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, - resultType.getElementType()); - rewriter.replaceOpWithNewOp(op, lhs, rhs); - return success(); -} - // AtenFmodTensorOp // torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b template <> @@ -2231,6 +2213,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarModeOp, chlo::BroadcastDivOp); + INSERT_BINARY_MULDIV_PATTERN(AtenRemainderTensorOp, chlo::BroadcastRemOp); INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp); #undef INSERT_BINARY_MULDIV_PATTERN @@ -2310,7 +2293,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFlipOp); - INSERT_ATENOP_PATTERN(AtenRemainderTensorOp); INSERT_ATENOP_PATTERN(AtenFmodTensorOp); INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp); INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);