[mlir] Move BufferizeDimOp to Tensor/Transforms/Bufferize.cpp

Differential Revision: https://reviews.llvm.org/D105256
This commit is contained in:
Matthias Springer 2021-07-02 09:44:41 +09:00
parent 06ac83fcac
commit e895a670f8
4 changed files with 33 additions and 29 deletions

View File

@ -23,19 +23,6 @@
using namespace mlir; using namespace mlir;
namespace { namespace {
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tensor::DimOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
adaptor.index());
return success();
}
};
class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> { class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@ -70,8 +57,8 @@ public:
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter, void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
patterns.add<BufferizeDimOp, BufferizeSelectOp, BufferizeIndexCastOp>( patterns.add<BufferizeSelectOp, BufferizeIndexCastOp>(typeConverter,
typeConverter, patterns.getContext()); patterns.getContext());
} }
namespace { namespace {

View File

@ -35,6 +35,21 @@ public:
}; };
} // namespace } // namespace
namespace {
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tensor::DimOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
adaptor.index());
return success();
}
};
} // namespace
namespace { namespace {
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
public: public:
@ -139,8 +154,9 @@ public:
void mlir::populateTensorBufferizePatterns( void mlir::populateTensorBufferizePatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp, patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
BufferizeGenerateOp>(typeConverter, patterns.getContext()); BufferizeFromElementsOp, BufferizeGenerateOp>(
typeConverter, patterns.getContext());
} }
namespace { namespace {

View File

@ -1,16 +1,5 @@
// RUN: mlir-opt %s -std-bufferize | FileCheck %s // RUN: mlir-opt %s -std-bufferize | FileCheck %s
// CHECK-LABEL: func @dim(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
// CHECK: return %[[EXTENT]] : index
func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
%0 = tensor.dim %arg0, %arg1 : tensor<f32>
return %0 : index
}
// CHECK-LABEL: func @select( // CHECK-LABEL: func @select(
// CHECK-SAME: %[[PRED:.*]]: i1, // CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>, // CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,

View File

@ -1,5 +1,16 @@
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
// CHECK-LABEL: func @dim(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
// CHECK: return %[[EXTENT]] : index
func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
%0 = tensor.dim %arg0, %arg1 : tensor<f32>
return %0 : index
}
// CHECK-LABEL: func @tensor.cast( // CHECK-LABEL: func @tensor.cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> { // CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] // CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]]
@ -67,7 +78,8 @@ func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { // CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
// CHECK: %[[ELEM:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32> // CHECK: %[[CASTED:.*]] = memref.buffer_cast %[[ARG]] : memref<*xf32>
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex> // CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
// CHECK: scf.yield // CHECK: scf.yield
// CHECK: } // CHECK: }