[mlir] Move BufferizeDimOp to Tensor/Transforms/Bufferize.cpp

Differential Revision: https://reviews.llvm.org/D105256
This commit is contained in:
Matthias Springer 2021-07-02 09:44:41 +09:00
parent 06ac83fcac
commit e895a670f8
4 changed files with 33 additions and 29 deletions

View File

@ -23,19 +23,6 @@
using namespace mlir;
namespace {
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tensor::DimOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
adaptor.index());
return success();
}
};
class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -70,8 +57,8 @@ public:
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeDimOp, BufferizeSelectOp, BufferizeIndexCastOp>(
typeConverter, patterns.getContext());
patterns.add<BufferizeSelectOp, BufferizeIndexCastOp>(typeConverter,
patterns.getContext());
}
namespace {

View File

@ -35,6 +35,21 @@ public:
};
} // namespace
namespace {
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tensor::DimOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
adaptor.index());
return success();
}
};
} // namespace
namespace {
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
public:
@ -139,8 +154,9 @@ public:
void mlir::populateTensorBufferizePatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp,
BufferizeGenerateOp>(typeConverter, patterns.getContext());
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
BufferizeFromElementsOp, BufferizeGenerateOp>(
typeConverter, patterns.getContext());
}
namespace {

View File

@ -1,16 +1,5 @@
// RUN: mlir-opt %s -std-bufferize | FileCheck %s
// CHECK-LABEL: func @dim(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
// CHECK: return %[[EXTENT]] : index
func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
%0 = tensor.dim %arg0, %arg1 : tensor<f32>
return %0 : index
}
// CHECK-LABEL: func @select(
// CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,

View File

@ -1,5 +1,16 @@
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
// CHECK-LABEL: func @dim(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
// CHECK: return %[[EXTENT]] : index
func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
%0 = tensor.dim %arg0, %arg1 : tensor<f32>
return %0 : index
}
// CHECK-LABEL: func @tensor.cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]]
@ -67,7 +78,8 @@ func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
// CHECK: %[[ELEM:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32>
// CHECK: %[[CASTED:.*]] = memref.buffer_cast %[[ARG]] : memref<*xf32>
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
// CHECK: scf.yield
// CHECK: }