[mlir][Linalg] Add support for vector.transfer ops to comprehensive bufferization (2/n).

Differential revision: https://reviews.llvm.org/D102395
This commit is contained in:
Nicolas Vasilache 2021-05-13 20:57:57 +00:00
parent 1e01a8919f
commit bebf5d56bf
2 changed files with 101 additions and 5 deletions

View File

@ -82,8 +82,8 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/BufferUtils.h"
@ -128,16 +128,25 @@ OpResult getMatchingOpResult(LinalgOp linalgOp, OpOperand &opOperand) {
return linalgOp->getResult(outputOperandIndex - numOutputBuffers);
}
OpResult getMatchingOpResult(VectorTransferOpInterface op,
OpOperand &opOperand) {
if (opOperand.get() != op.source() ||
!op.source().getType().isa<TensorType>())
return OpResult();
return op->getResult(0);
}
/// Determine which results may be reused inplace by the bufferization
/// patterns of `bufferizeFuncOpInternals`.
/// The inplace analysis uses this information along with interfering read
/// analysis to determine which op results reuse the same buffer as some
/// operand.
OpResult getMatchingOpResult(OpOperand &opOperand) {
OpResult res =
llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
.Case([&](LinalgOp op) { return getMatchingOpResult(op, opOperand); })
.Default([&](Operation *op) { return OpResult(); });
OpResult res = llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
.Case<LinalgOp, VectorTransferOpInterface>([&](auto op) {
return getMatchingOpResult(op, opOperand);
})
.Default([&](Operation *op) { return OpResult(); });
return res;
}
@ -708,6 +717,54 @@ static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp,
return success();
}
static LogicalResult convertTransferOp(OpBuilder &b,
VectorTransferOpInterface op,
BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
Location loc = op.getLoc();
if (op.getShapedType().isa<MemRefType>())
return failure();
LLVM_DEBUG(DBGS() << "convert: " << *op << "\n");
/// transfer_read from buffer
if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
readOp.sourceMutable().assign(lookup(bvm, op.source()));
return success();
}
auto inPlace = getInPlace(op->getResult(0));
auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
// If transfer_write is not inPlace, allocate a new buffer.
Value newInputBuffer;
if (inPlace != InPlaceSpec::True) {
newInputBuffer =
createNewAllocDeallocPairForShapedValue(b, loc, writeOp.result());
b.setInsertionPointAfter(newInputBuffer.getDefiningOp());
map(bvm, writeOp.result(), newInputBuffer);
} else {
// InPlace write will result in memref.tensor_load(x) which must
// canonicalize away with one of it uses.
newInputBuffer = lookup(bvm, writeOp.source());
}
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
b.create<vector::TransferWriteOp>(
loc, writeOp.vector(), newInputBuffer, writeOp.indices(),
writeOp.permutation_map(),
writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
map(bvm, op->getResult(0), newInputBuffer);
return success();
}
static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp,
const DominanceInfo &domInfo) {
assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() &&
@ -733,6 +790,9 @@ static LogicalResult bufferizeFuncOpInternals(
.Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); })
.Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); })
.Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); })
.Case([&](VectorTransferOpInterface op) {
return convertTransferOp(b, op, bvm);
})
.Default([&](Operation *op) {
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
if (llvm::any_of(op->getOperandTypes(), isaTensor) ||

View File

@ -81,3 +81,39 @@ func @not_inplace(%A : tensor<?x?xf32> {linalg.inplaceable = true}) -> tensor<?x
-> tensor<?x?xf32>
return %r: tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @vec_inplace
func @vec_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vector<4xf32>)
-> tensor<?xf32>
{
%c0 = constant 0 : index
// CHECK-NOT: alloc
%r = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor<?xf32>
return %r: tensor<?xf32>
}
// -----
// CHECK-LABEL: func @vec_not_inplace
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor<?xf32> {linalg.inplaceable = true}
func @vec_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vector<4xf32>)
-> (tensor<?xf32>, tensor<?xf32>)
{
%c0 = constant 0 : index
%c1 = constant 1 : index
// CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast %[[A]] : memref<?xf32, #[[$map_2d_dyn]]>
/// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc.
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-NEXT: vector.transfer_write {{.*}}, %[[ALLOC]]
%r0 = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor<?xf32>
/// The second vector.transfer has no interfering reads and can reuse the buffer.
// CHECK-NOT: alloc
// CHECK-NEXT: vector.transfer_write {{.*}}, %[[BUFFER_CAST]]
%r1 = vector.transfer_write %vec, %A[%c1] : vector<4xf32>, tensor<?xf32>
return %r0, %r1: tensor<?xf32>, tensor<?xf32>
}