forked from OSchip/llvm-project
[mlir] Move BufferizeDimOp to Tensor/Transforms/Bufferize.cpp
Differential Revision: https://reviews.llvm.org/D105256
This commit is contained in:
parent
06ac83fcac
commit
e895a670f8
|
@ -23,19 +23,6 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
tensor::DimOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
|
||||
adaptor.index());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
@ -70,8 +57,8 @@ public:
|
|||
|
||||
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<BufferizeDimOp, BufferizeSelectOp, BufferizeIndexCastOp>(
|
||||
typeConverter, patterns.getContext());
|
||||
patterns.add<BufferizeSelectOp, BufferizeIndexCastOp>(typeConverter,
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -35,6 +35,21 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
tensor::DimOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
|
||||
adaptor.index());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
|
||||
public:
|
||||
|
@ -139,8 +154,9 @@ public:
|
|||
|
||||
void mlir::populateTensorBufferizePatterns(
|
||||
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp,
|
||||
BufferizeGenerateOp>(typeConverter, patterns.getContext());
|
||||
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
|
||||
BufferizeFromElementsOp, BufferizeGenerateOp>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -1,16 +1,5 @@
|
|||
// RUN: mlir-opt %s -std-bufferize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @dim(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
|
||||
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
|
||||
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
|
||||
// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
|
||||
// CHECK: return %[[EXTENT]] : index
|
||||
func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
|
||||
%0 = tensor.dim %arg0, %arg1 : tensor<f32>
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @select(
|
||||
// CHECK-SAME: %[[PRED:.*]]: i1,
|
||||
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
|
||||
|
|
|
@ -1,5 +1,16 @@
|
|||
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @dim(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
|
||||
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
|
||||
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
|
||||
// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
|
||||
// CHECK: return %[[EXTENT]] : index
|
||||
func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
|
||||
%0 = tensor.dim %arg0, %arg1 : tensor<f32>
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor.cast(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
|
||||
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]]
|
||||
|
@ -67,7 +78,8 @@ func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
|
|||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
|
||||
// CHECK: %[[ELEM:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32>
|
||||
// CHECK: %[[CASTED:.*]] = memref.buffer_cast %[[ARG]] : memref<*xf32>
|
||||
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
|
||||
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
|
||||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
|
|
Loading…
Reference in New Issue