forked from OSchip/llvm-project
[mlir] Add a pattern to bufferize linalg.tensor_reshape.
Differential Revision: https://reviews.llvm.org/D102089
This commit is contained in:
parent
21db1e3b01
commit
a3f22d020b
|
@ -149,6 +149,23 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern that replaces `linalg.tensor_reshape` with
|
||||
/// `linalg.reshape`.
|
||||
class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
|
||||
public:
|
||||
using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
linalg::TensorReshapeOpAdaptor adaptor(operands, op->getAttrDictionary());
|
||||
rewriter.replaceOpWithNewOp<linalg::ReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
|
||||
adaptor.src(), adaptor.reassociation());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern that bufferizes `linalg.fill` operation.
|
||||
class BufferizeFillOp : public OpConversionPattern<FillOp> {
|
||||
public:
|
||||
|
@ -336,6 +353,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
|
|||
BufferizeAnyLinalgOp,
|
||||
BufferizeFillOp,
|
||||
BufferizeInitTensorOp,
|
||||
BufferizeTensorReshapeOp,
|
||||
SubTensorOpConverter,
|
||||
SubTensorInsertOpConverter
|
||||
>(typeConverter, patterns.getContext());
|
||||
|
|
|
@ -278,3 +278,18 @@ func @bufferize_fill(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|||
%0 = linalg.fill(%arg0, %c0) : tensor<?xf32>, f32 -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @bufferize_tensor_reshape(
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<4x5xf32>
|
||||
func @bufferize_tensor_reshape(%arg0: tensor<4x5xf32>) -> tensor<20xf32> {
|
||||
%out = linalg.tensor_reshape %arg0 [[0, 1]] :
|
||||
tensor<4x5xf32> into tensor<20xf32>
|
||||
return %out : tensor<20xf32>
|
||||
}
|
||||
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x5xf32>
|
||||
// CHECK: %[[RESHAPE:.*]] = linalg.reshape %[[MEMREF]] {{\[}}[0, 1]]
|
||||
// CHECK-SAME: : memref<4x5xf32> into memref<20xf32>
|
||||
// CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32>
|
||||
// CHECK: return %[[TENSOR]]
|
||||
|
|
Loading…
Reference in New Issue