forked from OSchip/llvm-project
[mlir] Add linalg.fill bufferization conversion
`BufferizeAnyLinalgOp` fails because `FillOp` is not a `LinalgGenericOp` and it fails while reading operand sizes attribute. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D98671
This commit is contained in:
parent
1c740b29fa
commit
32a744ab20
|
@ -32,8 +32,7 @@ static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
|
|||
}
|
||||
|
||||
static LogicalResult
|
||||
allocateBuffersForResults(Location loc, LinalgOp linalgOp,
|
||||
linalg::GenericOpAdaptor &adaptor,
|
||||
allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
|
||||
SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
|
||||
// Lazily compute loopRanges.
|
||||
SmallVector<Range, 4> loopRanges;
|
||||
|
@ -52,7 +51,7 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp,
|
|||
}
|
||||
auto tensorShape = tensorType.getShape();
|
||||
auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
|
||||
Value resultTensor = adaptor.outputs()[resultIndex];
|
||||
Value resultTensor = outputs[resultIndex];
|
||||
|
||||
// Clone output buffers whose value is actually used.
|
||||
if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) {
|
||||
|
@ -138,8 +137,7 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
|
|||
|
||||
namespace {
|
||||
|
||||
/// Generic conversion pattern that matches any LinalgOp. This avoids template
|
||||
/// instantiating one pattern for each LinalgOp.
|
||||
/// Conversion pattern that replaces `linalg.init_tensor` with allocation.
|
||||
class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern<InitTensorOp>::OpConversionPattern;
|
||||
|
@ -155,6 +153,26 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern that bufferizes `linalg.fill` operation.
|
||||
class BufferizeFillOp : public OpConversionPattern<FillOp> {
|
||||
public:
|
||||
using OpConversionPattern<FillOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(FillOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary());
|
||||
if (!op.output().getType().isa<TensorType>())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"operand must be of a tensor type");
|
||||
|
||||
rewriter.create<FillOp>(op.getLoc(), adaptor.output(), adaptor.value());
|
||||
rewriter.replaceOp(op, adaptor.output());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Generic conversion pattern that matches any LinalgOp. This avoids template
|
||||
/// instantiating one pattern for each LinalgOp.
|
||||
class BufferizeAnyLinalgOp : public ConversionPattern {
|
||||
|
@ -178,7 +196,7 @@ public:
|
|||
Location loc = linalgOp.getLoc();
|
||||
SmallVector<Value, 2> newOutputBuffers;
|
||||
|
||||
if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
|
||||
if (failed(allocateBuffersForResults(loc, linalgOp, adaptor.outputs(),
|
||||
newOutputBuffers, rewriter))) {
|
||||
linalgOp.emitOpError()
|
||||
<< "Failed to allocate buffers for tensor results.";
|
||||
|
@ -325,6 +343,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
|
|||
// TODO: Drop this once tensor constants work in standard.
|
||||
// clang-format off
|
||||
patterns.insert<
|
||||
BufferizeFillOp,
|
||||
BufferizeInitTensorOp,
|
||||
SubTensorOpConverter,
|
||||
SubTensorInsertOpConverter
|
||||
|
|
|
@ -265,3 +265,16 @@ func @bufferize_subtensor_insert(%t : tensor<?x?xf32>, %st0 : tensor<2x3xf32>, %
|
|||
return %t0, %t1: tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @bufferize_fill(
|
||||
// CHECK-SAME: %[[IN:.*]]: tensor<?xf32>
|
||||
func @bufferize_fill(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%c0 = constant 0.0 : f32
|
||||
// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[IN]] : memref<?xf32>
|
||||
// CHECK: linalg.fill(%[[MEMREF]], %cst) : memref<?xf32>, f32
|
||||
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]] : memref<?xf32>
|
||||
// CHECK: return %[[TENSOR]]
|
||||
%0 = linalg.fill(%arg0, %c0) : tensor<?xf32>, f32 -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue