forked from OSchip/llvm-project
[MLIR][Shape] Fix lowering of `shape.get_extent`
The declarative conversion patterns caused crashes in the asan configuration. The non-declarative implementation circumvents this. Differential Revision: https://reviews.llvm.org/D82797
This commit is contained in:
parent
fe08ab542b
commit
8577a090f5
|
@ -90,6 +90,29 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
|
||||
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
GetExtentOp::Adaptor transformed(operands);
|
||||
|
||||
// Derive shape extent directly from shape origin if possible.
|
||||
// This circumvents the necessity to materialize the shape in memory.
|
||||
if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
|
||||
rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
|
||||
transformed.dim());
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<ExtractElementOp>(
|
||||
op, rewriter.getIndexType(), transformed.shape(),
|
||||
ValueRange{transformed.dim()});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class RankOpConverter : public OpConversionPattern<shape::RankOp> {
|
||||
public:
|
||||
using OpConversionPattern<shape::RankOp>::OpConversionPattern;
|
||||
|
@ -161,6 +184,7 @@ void mlir::populateShapeToStandardConversionPatterns(
|
|||
BinaryOpConversion<AddOp, AddIOp>,
|
||||
BinaryOpConversion<MulOp, MulIOp>,
|
||||
ConstSizeOpConverter,
|
||||
GetExtentOpConverter,
|
||||
RankOpConverter,
|
||||
ShapeOfOpConversion>(ctx);
|
||||
// clang-format on
|
||||
|
|
|
@ -19,20 +19,3 @@ def SizeToIndexOpConversion : Pat<
|
|||
(Shape_SizeToIndexOp $arg),
|
||||
(replaceWithValue $arg)>;
|
||||
|
||||
// Derive shape extent directly from shape origin if possible.
|
||||
// This circumvents the necessity to materialize the shape in memory.
|
||||
def GetExtentShapeOfConversion : Pat<
|
||||
(Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx),
|
||||
(Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
|
||||
[],
|
||||
(addBenefit 10)>;
|
||||
def GetExtentFromExtentTensorConversion : Pattern<
|
||||
(Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx),
|
||||
[
|
||||
(Shape_SizeToIndexOp:$std_idx $idx),
|
||||
(ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)),
|
||||
(Shape_IndexToSizeOp $std_result)
|
||||
],
|
||||
[],
|
||||
(addBenefit 10)>;
|
||||
|
||||
|
|
Loading…
Reference in New Issue