[mlir] Add a missing pattern to bufferize tensor.rank.

Differential Revision: https://reviews.llvm.org/D115745
This commit is contained in:
Alexander Belyaev 2021-12-14 19:58:40 +01:00
parent 74d1fc742a
commit a82a19c137
2 changed files with 50 additions and 46 deletions

View File

@ -24,8 +24,7 @@
using namespace mlir;
namespace {
class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
public:
struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
@ -36,11 +35,8 @@ public:
return success();
}
};
} // namespace
namespace {
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
public:
struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
@ -50,11 +46,8 @@ public:
return success();
}
};
} // namespace
namespace {
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
public:
struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
@ -64,10 +57,8 @@ public:
return success();
}
};
} // namespace
namespace {
class BufferizeFromElementsOp
struct BufferizeFromElementsOp
: public OpConversionPattern<tensor::FromElementsOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -88,11 +79,8 @@ public:
return success();
}
};
} // namespace
namespace {
class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
public:
struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@ -150,44 +138,51 @@ public:
return success();
}
};
struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
adaptor.tensor());
return success();
}
};
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
bufferization::BufferizeTypeConverter typeConverter;
ConversionTarget target(*context);
target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
StandardOpsDialect>(
[&](Operation *op) { return typeConverter.isLegal(op); });
target.addLegalOp<CallOp, ReturnOp>();
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
tensor::FromElementsOp, tensor::GenerateOp>();
bufferization::populateBufferizeMaterializationLegality(target);
RewritePatternSet patterns(context);
populateTensorBufferizePatterns(typeConverter, patterns);
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::populateTensorBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
BufferizeFromElementsOp, BufferizeGenerateOp>(
BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
typeConverter, patterns.getContext());
}
namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
bufferization::BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
bufferization::populateBufferizeMaterializationLegality(target);
populateTensorBufferizePatterns(typeConverter, patterns);
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
tensor::FromElementsOp, tensor::GenerateOp>();
target.addLegalDialect<memref::MemRefDialect>();
target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
StandardOpsDialect>(
[&](Operation *op) { return typeConverter.isLegal(op); });
target.addLegalOp<CallOp>();
target.addLegalOp<ReturnOp>();
target.addLegalDialect<scf::SCFDialect>();
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
return std::make_unique<TensorBufferizePass>();
}

View File

@ -11,6 +11,15 @@ func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
return %0 : index
}
// CHECK-LABEL: func @rank(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> index {
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
// CHECK: %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32>
func @rank(%arg0: tensor<*xf32>) -> index {
%0 = tensor.rank %arg0 : tensor<*xf32>
return %0 : index
}
// CHECK-LABEL: func @tensor.cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]