forked from OSchip/llvm-project
[mlir] Add a missing pattern to bufferize tensor.rank.
Differential Revision: https://reviews.llvm.org/D115745
This commit is contained in:
parent
74d1fc742a
commit
a82a19c137
|
@ -24,8 +24,7 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
|
||||
public:
|
||||
struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
|
||||
|
@ -36,11 +35,8 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
|
||||
public:
|
||||
struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
|
||||
|
@ -50,11 +46,8 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
|
||||
public:
|
||||
struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
|
||||
|
@ -64,10 +57,8 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeFromElementsOp
|
||||
struct BufferizeFromElementsOp
|
||||
: public OpConversionPattern<tensor::FromElementsOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
@ -88,11 +79,8 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
|
||||
public:
|
||||
struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
|
@ -150,44 +138,51 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
|
||||
adaptor.tensor());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
|
||||
void runOnFunction() override {
|
||||
auto *context = &getContext();
|
||||
bufferization::BufferizeTypeConverter typeConverter;
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
|
||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
||||
StandardOpsDialect>(
|
||||
[&](Operation *op) { return typeConverter.isLegal(op); });
|
||||
target.addLegalOp<CallOp, ReturnOp>();
|
||||
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
|
||||
tensor::FromElementsOp, tensor::GenerateOp>();
|
||||
bufferization::populateBufferizeMaterializationLegality(target);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
populateTensorBufferizePatterns(typeConverter, patterns);
|
||||
if (failed(
|
||||
applyPartialConversion(getFunction(), target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateTensorBufferizePatterns(
|
||||
bufferization::BufferizeTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
|
||||
BufferizeFromElementsOp, BufferizeGenerateOp>(
|
||||
BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
|
||||
void runOnFunction() override {
|
||||
auto *context = &getContext();
|
||||
bufferization::BufferizeTypeConverter typeConverter;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
|
||||
bufferization::populateBufferizeMaterializationLegality(target);
|
||||
|
||||
populateTensorBufferizePatterns(typeConverter, patterns);
|
||||
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
|
||||
tensor::FromElementsOp, tensor::GenerateOp>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
||||
StandardOpsDialect>(
|
||||
[&](Operation *op) { return typeConverter.isLegal(op); });
|
||||
target.addLegalOp<CallOp>();
|
||||
target.addLegalOp<ReturnOp>();
|
||||
target.addLegalDialect<scf::SCFDialect>();
|
||||
|
||||
if (failed(
|
||||
applyPartialConversion(getFunction(), target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
|
||||
return std::make_unique<TensorBufferizePass>();
|
||||
}
|
||||
|
|
|
@ -11,6 +11,15 @@ func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
|
|||
return %0 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @rank(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> index {
|
||||
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
|
||||
// CHECK: %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32>
|
||||
func @rank(%arg0: tensor<*xf32>) -> index {
|
||||
%0 = tensor.rank %arg0 : tensor<*xf32>
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor.cast(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
|
||||
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]
|
||||
|
|
Loading…
Reference in New Issue