forked from OSchip/llvm-project
[mlir] Move BufferizeDimOp to Tensor/Transforms/Bufferize.cpp
Differential Revision: https://reviews.llvm.org/D105256
This commit is contained in:
parent
06ac83fcac
commit
e895a670f8
|
@ -23,19 +23,6 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
tensor::DimOp::Adaptor adaptor(operands);
|
|
||||||
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
|
|
||||||
adaptor.index());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
|
class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
@ -70,8 +57,8 @@ public:
|
||||||
|
|
||||||
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
|
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<BufferizeDimOp, BufferizeSelectOp, BufferizeIndexCastOp>(
|
patterns.add<BufferizeSelectOp, BufferizeIndexCastOp>(typeConverter,
|
||||||
typeConverter, patterns.getContext());
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -35,6 +35,21 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
tensor::DimOp::Adaptor adaptor(operands);
|
||||||
|
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
|
||||||
|
adaptor.index());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
|
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -139,8 +154,9 @@ public:
|
||||||
|
|
||||||
void mlir::populateTensorBufferizePatterns(
|
void mlir::populateTensorBufferizePatterns(
|
||||||
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||||
patterns.add<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp,
|
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
|
||||||
BufferizeGenerateOp>(typeConverter, patterns.getContext());
|
BufferizeFromElementsOp, BufferizeGenerateOp>(
|
||||||
|
typeConverter, patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -1,16 +1,5 @@
|
||||||
// RUN: mlir-opt %s -std-bufferize | FileCheck %s
|
// RUN: mlir-opt %s -std-bufferize | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @dim(
|
|
||||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
|
|
||||||
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
|
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
|
|
||||||
// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
|
|
||||||
// CHECK: return %[[EXTENT]] : index
|
|
||||||
func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
|
|
||||||
%0 = tensor.dim %arg0, %arg1 : tensor<f32>
|
|
||||||
return %0 : index
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @select(
|
// CHECK-LABEL: func @select(
|
||||||
// CHECK-SAME: %[[PRED:.*]]: i1,
|
// CHECK-SAME: %[[PRED:.*]]: i1,
|
||||||
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
|
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
|
||||||
|
|
|
@ -1,5 +1,16 @@
|
||||||
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
|
// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dim(
|
||||||
|
// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
|
||||||
|
// CHECK-SAME: %[[INDEX:.*]]: index) -> index {
|
||||||
|
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
|
||||||
|
// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
|
||||||
|
// CHECK: return %[[EXTENT]] : index
|
||||||
|
func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
|
||||||
|
%0 = tensor.dim %arg0, %arg1 : tensor<f32>
|
||||||
|
return %0 : index
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor.cast(
|
// CHECK-LABEL: func @tensor.cast(
|
||||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
|
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]]
|
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]]
|
||||||
|
@ -67,7 +78,8 @@ func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
|
||||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||||
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
|
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
|
||||||
// CHECK: %[[ELEM:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32>
|
// CHECK: %[[CASTED:.*]] = memref.buffer_cast %[[ARG]] : memref<*xf32>
|
||||||
|
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
|
||||||
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
|
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
|
||||||
// CHECK: scf.yield
|
// CHECK: scf.yield
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
Loading…
Reference in New Issue