forked from OSchip/llvm-project
[mlir][bufferize][NFC] Change signature of allocateTensorForShapedValue
Add a failure return value and bufferization options argument. This is to keep a subsequent change smaller. Differential Revision: https://reviews.llvm.org/D128278
This commit is contained in:
parent
f5d781d627
commit
45b995cda4
|
@ -472,9 +472,10 @@ private:
|
|||
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
|
||||
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
|
||||
/// undefined contents is allocated.
|
||||
Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
|
||||
Value shapedValue, bool escape,
|
||||
bool copy = true);
|
||||
FailureOr<Value>
|
||||
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
|
||||
bool escape, const BufferizationOptions &options,
|
||||
bool copy = true);
|
||||
|
||||
/// Lookup the buffer for the given value. If the value was not bufferized
|
||||
/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
|
||||
|
|
|
@ -46,9 +46,9 @@ constexpr const ::llvm::StringLiteral
|
|||
/// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
|
||||
/// shaped value is copied. Otherwise, a tensor with undefined contents is
|
||||
/// allocated.
|
||||
Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
|
||||
Value shapedValue,
|
||||
bool escape, bool copy) {
|
||||
FailureOr<Value> bufferization::allocateTensorForShapedValue(
|
||||
OpBuilder &b, Location loc, Value shapedValue, bool escape,
|
||||
const BufferizationOptions &options, bool copy) {
|
||||
Value tensor;
|
||||
if (shapedValue.getType().isa<RankedTensorType>()) {
|
||||
tensor = shapedValue;
|
||||
|
@ -88,7 +88,7 @@ Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
|
|||
copy ? tensor : Value());
|
||||
allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
|
||||
b.getBoolArrayAttr({escape}));
|
||||
return allocTensorOp;
|
||||
return allocTensorOp.getResult();
|
||||
}
|
||||
|
||||
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
||||
|
@ -147,26 +147,30 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
|||
// Insert copies of OpOperands.
|
||||
rewriter.setInsertionPoint(op);
|
||||
for (OpOperand *opOperand : outOfPlaceOpOperands) {
|
||||
Value copy = allocateTensorForShapedValue(
|
||||
FailureOr<Value> copy = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), opOperand->get(),
|
||||
escapingOpOperandCopies.contains(opOperand),
|
||||
escapingOpOperandCopies.contains(opOperand), state.getOptions(),
|
||||
copiedOpOperands.contains(opOperand));
|
||||
rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
|
||||
if (failed(copy))
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
|
||||
}
|
||||
|
||||
// Insert copies of OpResults.
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
for (OpResult opResult : outOfPlaceOpResults) {
|
||||
Value copy =
|
||||
allocateTensorForShapedValue(rewriter, op->getLoc(), opResult,
|
||||
escapingOpResultCopies.contains(opResult),
|
||||
copiedOpResults.count(opResult));
|
||||
FailureOr<Value> copy = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), opResult,
|
||||
escapingOpResultCopies.contains(opResult), state.getOptions(),
|
||||
copiedOpResults.count(opResult));
|
||||
if (failed(copy))
|
||||
return failure();
|
||||
SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
|
||||
opResult.getUses(), [](OpOperand &use) { return &use; }));
|
||||
for (OpOperand *use : uses) {
|
||||
// Do not update the alloc_tensor op that we just created.
|
||||
if (use->getOwner() != copy.getDefiningOp())
|
||||
rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
|
||||
if (use->getOwner() != copy->getDefiningOp())
|
||||
rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -458,9 +458,12 @@ struct ForOpInterface
|
|||
yieldValues.push_back(value);
|
||||
continue;
|
||||
}
|
||||
Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
|
||||
value, /*escape=*/true);
|
||||
yieldValues.push_back(alloc);
|
||||
FailureOr<Value> alloc =
|
||||
allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
|
||||
/*escape=*/true, state.getOptions());
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
yieldValues.push_back(*alloc);
|
||||
}
|
||||
|
||||
rewriter.updateRootInPlace(
|
||||
|
@ -669,9 +672,12 @@ struct WhileOpInterface
|
|||
beforeYieldValues.push_back(value);
|
||||
continue;
|
||||
}
|
||||
Value alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(),
|
||||
value, /*escape=*/true);
|
||||
beforeYieldValues.push_back(alloc);
|
||||
FailureOr<Value> alloc =
|
||||
allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
|
||||
/*escape=*/true, state.getOptions());
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
beforeYieldValues.push_back(*alloc);
|
||||
}
|
||||
rewriter.updateRootInPlace(conditionOp, [&]() {
|
||||
conditionOp.getArgsMutable().assign(beforeYieldValues);
|
||||
|
@ -687,9 +693,12 @@ struct WhileOpInterface
|
|||
afterYieldValues.push_back(value);
|
||||
continue;
|
||||
}
|
||||
Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
|
||||
value, /*escape=*/true);
|
||||
afterYieldValues.push_back(alloc);
|
||||
FailureOr<Value> alloc =
|
||||
allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
|
||||
/*escape=*/true, state.getOptions());
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
afterYieldValues.push_back(*alloc);
|
||||
}
|
||||
rewriter.updateRootInPlace(yieldOp, [&]() {
|
||||
yieldOp.getResultsMutable().assign(afterYieldValues);
|
||||
|
@ -972,13 +981,15 @@ struct ForeachThreadOpInterface
|
|||
|
||||
// Insert tensor allocation.
|
||||
bool isYielded = state.isTensorYielded(opResult);
|
||||
Value alloc = allocateTensorForShapedValue(rewriter, op->getLoc(),
|
||||
destOperands.front()->get(),
|
||||
/*escape=*/isYielded);
|
||||
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), destOperands.front()->get(),
|
||||
/*escape=*/isYielded, state.getOptions());
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
|
||||
// Update terminator operand.
|
||||
rewriter.updateRootInPlace(destOperands.front()->getOwner(),
|
||||
[&]() { destOperands.front()->set(alloc); });
|
||||
[&]() { destOperands.front()->set(*alloc); });
|
||||
}
|
||||
|
||||
return success();
|
||||
|
|
|
@ -154,15 +154,17 @@ struct CollapseShapeOpInterface
|
|||
if (!canBeCollapsed) {
|
||||
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
|
||||
AnalysisState analysisState(options);
|
||||
Value tensorAlloc = allocateTensorForShapedValue(
|
||||
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), collapseShapeOp.getSrc(),
|
||||
analysisState.isTensorYielded(collapseShapeOp.getResult()));
|
||||
analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
auto memrefType =
|
||||
MemRefType::get(collapseShapeOp.getSrcType().getShape(),
|
||||
collapseShapeOp.getSrcType().getElementType(),
|
||||
AffineMap(), bufferType.getMemorySpaceAsInt());
|
||||
buffer = rewriter.create<bufferization::ToMemrefOp>(
|
||||
op->getLoc(), memrefType, tensorAlloc);
|
||||
op->getLoc(), memrefType, *tensorAlloc);
|
||||
}
|
||||
|
||||
// Result type is inferred by the builder.
|
||||
|
@ -383,14 +385,16 @@ struct FromElementsOpInterface
|
|||
auto shape = tensorType.getShape();
|
||||
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
|
||||
AnalysisState analysisState(options);
|
||||
Value tensorAlloc = allocateTensorForShapedValue(
|
||||
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, loc, fromElementsOp.getResult(),
|
||||
analysisState.isTensorYielded(fromElementsOp.getResult()),
|
||||
analysisState.isTensorYielded(fromElementsOp.getResult()), options,
|
||||
/*copy=*/false);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
auto memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
|
||||
op->getLoc(), memrefType, tensorAlloc);
|
||||
op->getLoc(), memrefType, *tensorAlloc);
|
||||
|
||||
// Case: tensor<0xelem_type>.
|
||||
if (fromElementsOp.getElements().empty()) {
|
||||
|
@ -436,14 +440,16 @@ struct GenerateOpInterface
|
|||
Location loc = op->getLoc();
|
||||
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
|
||||
AnalysisState analysisState(options);
|
||||
Value tensorAlloc = allocateTensorForShapedValue(
|
||||
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, loc, generateOp.getResult(),
|
||||
analysisState.isTensorYielded(generateOp.getResult()),
|
||||
analysisState.isTensorYielded(generateOp.getResult()), options,
|
||||
/*copy=*/false);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
auto memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
|
||||
op->getLoc(), memrefType, tensorAlloc);
|
||||
op->getLoc(), memrefType, *tensorAlloc);
|
||||
|
||||
// Collect loop bounds.
|
||||
int64_t rank = memrefType.getRank();
|
||||
|
|
Loading…
Reference in New Issue