From 8577a090f5f04e18d72bb2dd387e60082e4da0ca Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Tue, 30 Jun 2020 08:33:49 +0000 Subject: [PATCH] [MLIR][Shape] Fix lowering of `shape.get_extent` The declarative conversion patterns caused crashes in the asan configuration. The non-declarative implementation circumvents this. Differential Revision: https://reviews.llvm.org/D82797 --- .../ShapeToStandard/ShapeToStandard.cpp | 24 +++++++++++++++++++ .../ShapeToStandardPatterns.td | 17 ------------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index 5fd9be0bd73a..7ebcb397349d 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -90,6 +90,29 @@ public: } }; +class GetExtentOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GetExtentOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + GetExtentOp::Adaptor transformed(operands); + + // Derive shape extent directly from shape origin if possible. + // This circumvents the necessity to materialize the shape in memory. + if (auto shapeOfOp = op.shape().getDefiningOp()) { + rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), + transformed.dim()); + return success(); + } + + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexType(), transformed.shape(), + ValueRange{transformed.dim()}); + return success(); + } +}; + class RankOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -161,6 +184,7 @@ void mlir::populateShapeToStandardConversionPatterns( BinaryOpConversion, BinaryOpConversion, ConstSizeOpConverter, + GetExtentOpConverter, RankOpConverter, ShapeOfOpConversion>(ctx); // clang-format on diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td index 154cf6a9e1f7..a1335487f5ab 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td @@ -19,20 +19,3 @@ def SizeToIndexOpConversion : Pat< (Shape_SizeToIndexOp $arg), (replaceWithValue $arg)>; -// Derive shape extent directly from shape origin if possible. -// This circumvents the necessity to materialize the shape in memory. -def GetExtentShapeOfConversion : Pat< - (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx), - (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))), - [], - (addBenefit 10)>; -def GetExtentFromExtentTensorConversion : Pattern< - (Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx), - [ - (Shape_SizeToIndexOp:$std_idx $idx), - (ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)), - (Shape_IndexToSizeOp $std_result) - ], - [], - (addBenefit 10)>; -