forked from OSchip/llvm-project
[mlir][bufferize] OpOperands can have multiple aliasing OpResults
This makes getAliasingOpResult symmetric to getAliasingOpOperand. The previous implementation was confusing for users and implemented in such a way only because there are currently no bufferizable ops that have multiple aliasing OpResults. Differential Revision: https://reviews.llvm.org/D119259
This commit is contained in:
parent
22a1973dbe
commit
585a8a321c
mlir
include/mlir/Dialect/Bufferization/IR
lib/Dialect
Arithmetic/Transforms
Bufferization
Linalg/ComprehensiveBufferize
SCF/Transforms
Tensor/Transforms
Vector/Transforms
|
@ -180,9 +180,8 @@ public:
|
||||||
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
|
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
|
||||||
|
|
||||||
/// Determine which OpResult will alias with `opOperand` if the op is
|
/// Determine which OpResult will alias with `opOperand` if the op is
|
||||||
/// bufferized in place. Return an empty OpResult if the op is not
|
/// bufferized in place. Return an empty vector if the op is not bufferizable.
|
||||||
/// bufferizable.
|
SmallVector<OpResult> getAliasingOpResult(OpOperand &opOperand) const;
|
||||||
OpResult getAliasingOpResult(OpOperand &opOperand) const;
|
|
||||||
|
|
||||||
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if
|
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if
|
||||||
/// the op is not bufferizable.
|
/// the op is not bufferizable.
|
||||||
|
@ -396,9 +395,10 @@ struct AllocationHoistingBarrierOnly
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
|
|
|
@ -124,7 +124,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
bufferized in-place. This method will never be called on OpOperands
|
bufferized in-place. This method will never be called on OpOperands
|
||||||
that do not have a tensor type.
|
that do not have a tensor type.
|
||||||
}],
|
}],
|
||||||
/*retType=*/"OpResult",
|
/*retType=*/"SmallVector<OpResult>",
|
||||||
/*methodName=*/"getAliasingOpResult",
|
/*methodName=*/"getAliasingOpResult",
|
||||||
/*args=*/(ins "OpOperand &":$opOperand,
|
/*args=*/(ins "OpOperand &":$opOperand,
|
||||||
"const BufferizationState &":$state),
|
"const BufferizationState &":$state),
|
||||||
|
@ -162,8 +162,10 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) {
|
for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) {
|
||||||
if (!opOperand.get().getType().isa<TensorType>())
|
if (!opOperand.get().getType().isa<TensorType>())
|
||||||
continue;
|
continue;
|
||||||
if (bufferizableOp.getAliasingOpResult(opOperand, state) ==
|
SmallVector<OpResult> aliasingOpResults =
|
||||||
opResult)
|
bufferizableOp.getAliasingOpResult(opOperand, state);
|
||||||
|
if (llvm::find(aliasingOpResults, opResult)
|
||||||
|
!= aliasingOpResults.end())
|
||||||
result.push_back(&opOperand);
|
result.push_back(&opOperand);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
@ -304,8 +306,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
cast<BufferizableOpInterface>(getOperation());
|
cast<BufferizableOpInterface>(getOperation());
|
||||||
return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
|
return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
|
||||||
&& !bufferizableOp.bufferizesToMemoryWrite(opOperand, state)
|
&& !bufferizableOp.bufferizesToMemoryWrite(opOperand, state)
|
||||||
&& static_cast<bool>(
|
&& !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
|
||||||
bufferizableOp.getAliasingOpResult(opOperand, state));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: The following two attributes should belong to the tensor dialect.
|
// TODO: The following two attributes should belong to the tensor dialect.
|
||||||
|
|
|
@ -211,9 +211,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(OpOperand &opOperand,
|
SmallVector<OpResult> getAliasingOpResult(
|
||||||
const BufferizationState &state) const {
|
OpOperand &opOperand, const BufferizationState &state) const {
|
||||||
return OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(RewriterBase &rewriter,
|
LogicalResult bufferize(RewriterBase &rewriter,
|
||||||
|
|
|
@ -69,9 +69,10 @@ struct IndexCastOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return op->getResult(0);
|
return {op->getResult(0)};
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
|
@ -114,9 +115,10 @@ struct SelectOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return op->getOpResult(0) /*result*/;
|
return {op->getOpResult(0) /*result*/};
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpOperand *>
|
SmallVector<OpOperand *>
|
||||||
|
|
|
@ -87,12 +87,13 @@ BufferizationState::getAliasingOpOperand(OpResult result) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Determine which OpResult will alias with `opOperand` if the op is bufferized
|
/// Determine which OpResult will alias with `opOperand` if the op is bufferized
|
||||||
/// in place. Return an empty OpResult if the op is not bufferizable.
|
/// in place. Return an empty vector if the op is not bufferizable.
|
||||||
OpResult BufferizationState::getAliasingOpResult(OpOperand &opOperand) const {
|
SmallVector<OpResult>
|
||||||
|
BufferizationState::getAliasingOpResult(OpOperand &opOperand) const {
|
||||||
if (auto bufferizableOp =
|
if (auto bufferizableOp =
|
||||||
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
|
||||||
return bufferizableOp.getAliasingOpResult(opOperand, *this);
|
return bufferizableOp.getAliasingOpResult(opOperand, *this);
|
||||||
return OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
|
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
|
||||||
|
@ -144,7 +145,8 @@ bool BufferizationState::isValueRead(Value value) const {
|
||||||
OpOperand *uMaybeReading = workingSet.pop_back_val();
|
OpOperand *uMaybeReading = workingSet.pop_back_val();
|
||||||
// Skip over all ops that neither read nor write (but create an alias).
|
// Skip over all ops that neither read nor write (but create an alias).
|
||||||
if (bufferizesToAliasOnly(*uMaybeReading))
|
if (bufferizesToAliasOnly(*uMaybeReading))
|
||||||
for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses())
|
for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
|
||||||
|
for (OpOperand &use : opResult.getUses())
|
||||||
workingSet.push_back(&use);
|
workingSet.push_back(&use);
|
||||||
if (bufferizesToMemoryRead(*uMaybeReading))
|
if (bufferizesToMemoryRead(*uMaybeReading))
|
||||||
return true;
|
return true;
|
||||||
|
@ -266,9 +268,10 @@ FailureOr<Value> BufferizationState::getBuffer(
|
||||||
}))
|
}))
|
||||||
return resultBuffer;
|
return resultBuffer;
|
||||||
// Do not copy if the copied data is never read.
|
// Do not copy if the copied data is never read.
|
||||||
OpResult aliasingOpResult = getAliasingOpResult(opOperand);
|
SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
|
||||||
if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&
|
if (!aliasingOpResults.empty() && !bufferizesToMemoryRead(opOperand) &&
|
||||||
!isValueRead(aliasingOpResult))
|
llvm::none_of(aliasingOpResults,
|
||||||
|
[&](OpResult opResult) { return isValueRead(opResult); }))
|
||||||
return resultBuffer;
|
return resultBuffer;
|
||||||
// Do not copy if this op does not read the data, but writes it.
|
// Do not copy if this op does not read the data, but writes it.
|
||||||
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
|
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
|
||||||
|
|
|
@ -140,7 +140,7 @@ bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
|
||||||
void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
|
void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
|
||||||
BufferizationState &state) {
|
BufferizationState &state) {
|
||||||
markInPlace(operand);
|
markInPlace(operand);
|
||||||
if (OpResult result = state.getAliasingOpResult(operand))
|
for (OpResult result : state.getAliasingOpResult(operand))
|
||||||
aliasInfo.unionSets(result, operand.get());
|
aliasInfo.unionSets(result, operand.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,7 +196,7 @@ AnalysisBufferizationState::AnalysisBufferizationState(
|
||||||
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
|
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
|
||||||
if (opOperand.get().getType().isa<TensorType>())
|
if (opOperand.get().getType().isa<TensorType>())
|
||||||
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
|
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
|
||||||
if (OpResult opResult =
|
for (OpResult opResult :
|
||||||
bufferizableOp.getAliasingOpResult(opOperand, *this))
|
bufferizableOp.getAliasingOpResult(opOperand, *this))
|
||||||
aliasInfo.unionAliasSets(opOperand.get(), opResult);
|
aliasInfo.unionAliasSets(opOperand.get(), opResult);
|
||||||
aliasInfo.markInPlace(opOperand);
|
aliasInfo.markInPlace(opOperand);
|
||||||
|
@ -404,7 +404,9 @@ static bool hasReadAfterWriteInterference(
|
||||||
|
|
||||||
// No conflict if the conflicting write and the last write are the same
|
// No conflict if the conflicting write and the last write are the same
|
||||||
// use.
|
// use.
|
||||||
if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
|
SmallVector<OpResult> aliasingOpResult =
|
||||||
|
state.getAliasingOpResult(*uConflictingWrite);
|
||||||
|
if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// All requirements are met. Conflict found!
|
// All requirements are met. Conflict found!
|
||||||
|
@ -477,7 +479,7 @@ static bool wouldCreateReadAfterWriteInterference(
|
||||||
DenseSet<OpOperand *> usesRead, usesWrite;
|
DenseSet<OpOperand *> usesRead, usesWrite;
|
||||||
getAliasingReads(usesRead, operand.get());
|
getAliasingReads(usesRead, operand.get());
|
||||||
getAliasingInplaceWrites(usesWrite, operand.get());
|
getAliasingInplaceWrites(usesWrite, operand.get());
|
||||||
if (OpResult result = state.getAliasingOpResult(operand)) {
|
for (OpResult result : state.getAliasingOpResult(operand)) {
|
||||||
getAliasingReads(usesRead, result);
|
getAliasingReads(usesRead, result);
|
||||||
getAliasingInplaceWrites(usesWrite, result);
|
getAliasingInplaceWrites(usesWrite, result);
|
||||||
}
|
}
|
||||||
|
@ -506,7 +508,7 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
|
||||||
bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
|
bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
|
||||||
state.bufferizesToMemoryWrite(opOperand);
|
state.bufferizesToMemoryWrite(opOperand);
|
||||||
|
|
||||||
if (OpResult opResult = state.getAliasingOpResult(opOperand))
|
for (OpResult opResult : state.getAliasingOpResult(opOperand))
|
||||||
hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
|
hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
|
||||||
|
|
||||||
return hasWrite;
|
return hasWrite;
|
||||||
|
|
|
@ -168,8 +168,7 @@ struct LinalgOpInterface
|
||||||
// Operand is written to if it has an aliasing OpResult. For more details,
|
// Operand is written to if it has an aliasing OpResult. For more details,
|
||||||
// see `computeAliasingPairs`.
|
// see `computeAliasingPairs`.
|
||||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||||
return static_cast<bool>(
|
return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
|
||||||
bufferizableOp.getAliasingOpResult(opOperand, state));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpOperand *>
|
SmallVector<OpOperand *>
|
||||||
|
@ -185,13 +184,16 @@ struct LinalgOpInterface
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto genericOp = cast<linalg::LinalgOp>(op);
|
auto genericOp = cast<linalg::LinalgOp>(op);
|
||||||
|
|
||||||
// Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`.
|
// Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`.
|
||||||
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
|
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
|
||||||
return pairs[&opOperand];
|
if (!pairs.count(&opOperand))
|
||||||
|
return {};
|
||||||
|
return {pairs[&opOperand]};
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
|
@ -252,16 +254,19 @@ struct TiledLoopOpInterface
|
||||||
|
|
||||||
// Only operands with an aliasing OpResult (i.e., output operands) bufferize
|
// Only operands with an aliasing OpResult (i.e., output operands) bufferize
|
||||||
// to a memory write.
|
// to a memory write.
|
||||||
return static_cast<bool>(
|
return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
|
||||||
bufferizableOp.getAliasingOpResult(opOperand, state));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
|
||||||
|
|
||||||
// Output operands are tied to their corresponding OpResults.
|
// Output operands are tied to their corresponding OpResults.
|
||||||
return tiledLoopOp.getTiedOpResult(opOperand);
|
OpResult opResult = tiledLoopOp.getTiedOpResult(opOperand);
|
||||||
|
if (!opResult)
|
||||||
|
return {};
|
||||||
|
return {opResult};
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
|
@ -397,9 +402,10 @@ struct YieldOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
|
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
|
||||||
|
|
|
@ -723,7 +723,8 @@ struct CallOpInterface
|
||||||
funcOp.getArgument(opOperand.getOperandNumber()));
|
funcOp.getArgument(opOperand.getOperandNumber()));
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
CallOp callOp = cast<CallOp>(op);
|
CallOp callOp = cast<CallOp>(op);
|
||||||
FuncOp funcOp = getCalledFunction(callOp);
|
FuncOp funcOp = getCalledFunction(callOp);
|
||||||
|
@ -731,17 +732,15 @@ struct CallOpInterface
|
||||||
const ModuleBufferizationState &moduleState =
|
const ModuleBufferizationState &moduleState =
|
||||||
getModuleBufferizationState(state);
|
getModuleBufferizationState(state);
|
||||||
|
|
||||||
|
SmallVector<OpResult> result;
|
||||||
for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults();
|
for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults();
|
||||||
++resultIdx)
|
++resultIdx)
|
||||||
if (Optional<int64_t> maybeArgNumber =
|
if (Optional<int64_t> maybeArgNumber =
|
||||||
getEquivalentFuncArgIdx(funcOp, moduleState, resultIdx))
|
getEquivalentFuncArgIdx(funcOp, moduleState, resultIdx))
|
||||||
if (*maybeArgNumber == opOperand.getOperandNumber())
|
if (*maybeArgNumber == opOperand.getOperandNumber())
|
||||||
return callOp->getOpResult(resultIdx);
|
result.push_back(callOp->getOpResult(resultIdx));
|
||||||
|
|
||||||
// Note: Returning a non-equivalent tensor from a FuncOp is currently not
|
return result;
|
||||||
// supported an will fail bufferization. (Even if allow-return-memref, it
|
|
||||||
// will fail when the function is called.)
|
|
||||||
return OpResult();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpOperand *>
|
SmallVector<OpOperand *>
|
||||||
|
@ -916,9 +915,10 @@ struct ReturnOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
|
|
|
@ -278,12 +278,13 @@ struct ForOpInterface
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
auto forOp = cast<scf::ForOp>(op);
|
auto forOp = cast<scf::ForOp>(op);
|
||||||
if (!opOperand.get().getType().isa<RankedTensorType>())
|
if (!opOperand.get().getType().isa<RankedTensorType>())
|
||||||
return OpResult();
|
return {};
|
||||||
return forOp.getResultForOpOperand(opOperand);
|
return {forOp.getResultForOpOperand(opOperand)};
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
|
@ -401,13 +402,14 @@ struct YieldOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
if (isa<scf::IfOp>(op->getParentOp()))
|
if (isa<scf::IfOp>(op->getParentOp()))
|
||||||
return op->getParentOp()->getResult(opOperand.getOperandNumber());
|
return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
|
||||||
if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
|
if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
|
||||||
return op->getParentOp()->getResult(opOperand.getOperandNumber());
|
return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
|
||||||
return OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
|
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
|
||||||
|
|
|
@ -35,9 +35,10 @@ struct CastOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return op->getResult(0);
|
return {op->getResult(0)};
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
|
@ -93,9 +94,10 @@ struct DimOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
|
@ -121,11 +123,12 @@ struct ExtractSliceOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return &opOperand == &op->getOpOperand(0) /*source*/
|
if (&opOperand == &op->getOpOperand(0) /*source*/)
|
||||||
? op->getResult(0)
|
return {op->getOpResult(0)};
|
||||||
: OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
|
@ -207,9 +210,10 @@ struct ExtractOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
|
@ -371,11 +375,12 @@ struct InsertOpInterface
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
|
assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
|
||||||
"expected dest OpOperand");
|
"expected dest OpOperand");
|
||||||
return op->getOpResult(0);
|
return {op->getOpResult(0)};
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpOperand *>
|
SmallVector<OpOperand *>
|
||||||
|
@ -451,11 +456,12 @@ struct InsertSliceOpInterface
|
||||||
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
return &opOperand == &op->getOpOperand(1) /*dest*/;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return &opOperand == &op->getOpOperand(1) /*dest*/
|
if (&opOperand == &op->getOpOperand(1) /*dest*/)
|
||||||
? op->getResult(0)
|
return {op->getResult(0)};
|
||||||
: OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
|
@ -606,9 +612,10 @@ struct RankOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
|
|
|
@ -40,9 +40,10 @@ struct TransferReadOpInterface
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
return OpResult();
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
|
@ -81,11 +82,12 @@ struct TransferWriteOpInterface
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
SmallVector<OpResult>
|
||||||
|
getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||||
const BufferizationState &state) const {
|
const BufferizationState &state) const {
|
||||||
assert(opOperand.get().getType().isa<TensorType>() &&
|
assert(opOperand.get().getType().isa<TensorType>() &&
|
||||||
"only tensor types expected");
|
"only tensor types expected");
|
||||||
return op->getOpResult(0);
|
return {op->getOpResult(0)};
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||||
|
|
Loading…
Reference in New Issue