forked from OSchip/llvm-project
[MLIR] Add argument related API to Region
- Arguments of the first block of a region are considered region arguments. - Add API on Region class to deal with these arguments directly instead of using the front() block. - Changed several instances of existing code that can use this API - Fixes https://bugs.llvm.org/show_bug.cgi?id=46535 Differential Revision: https://reviews.llvm.org/D83599
This commit is contained in:
parent
85bed2f381
commit
e2b716105b
|
@ -237,7 +237,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
|
|||
/// the workgroup memory
|
||||
ArrayRef<BlockArgument> getWorkgroupAttributions() {
|
||||
auto begin =
|
||||
std::next(getBody().front().args_begin(), getType().getNumInputs());
|
||||
std::next(getBody().args_begin(), getType().getNumInputs());
|
||||
auto end = std::next(begin, getNumWorkgroupAttributions());
|
||||
return {begin, end};
|
||||
}
|
||||
|
@ -248,7 +248,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
|
|||
|
||||
/// Returns the number of buffers located in the private memory.
|
||||
unsigned getNumPrivateAttributions() {
|
||||
return getBody().front().getNumArguments() - getType().getNumInputs() -
|
||||
return getBody().getNumArguments() - getType().getNumInputs() -
|
||||
getNumWorkgroupAttributions();
|
||||
}
|
||||
|
||||
|
@ -258,9 +258,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
|
|||
// Buffers on the private memory always come after buffers on the workgroup
|
||||
// memory.
|
||||
auto begin =
|
||||
std::next(getBody().front().args_begin(),
|
||||
std::next(getBody().args_begin(),
|
||||
getType().getNumInputs() + getNumWorkgroupAttributions());
|
||||
return {begin, getBody().front().args_end()};
|
||||
return {begin, getBody().args_end()};
|
||||
}
|
||||
|
||||
/// Adds a new block argument that corresponds to buffers located in
|
||||
|
|
|
@ -583,7 +583,7 @@ def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
|
|||
let extraClassDeclaration = [{
|
||||
// The value stored in memref[ivs].
|
||||
Value getCurrentValue() {
|
||||
return body().front().getArgument(0);
|
||||
return body().getArgument(0);
|
||||
}
|
||||
MemRefType getMemRefType() {
|
||||
return memref().getType().cast<MemRefType>();
|
||||
|
|
|
@ -216,15 +216,13 @@ public:
|
|||
}
|
||||
|
||||
/// Gets argument.
|
||||
BlockArgument getArgument(unsigned idx) {
|
||||
return getBlocks().front().getArgument(idx);
|
||||
}
|
||||
BlockArgument getArgument(unsigned idx) { return getBody().getArgument(idx); }
|
||||
|
||||
/// Support argument iteration.
|
||||
using args_iterator = Block::args_iterator;
|
||||
args_iterator args_begin() { return front().args_begin(); }
|
||||
args_iterator args_end() { return front().args_end(); }
|
||||
Block::BlockArgListType getArguments() { return front().getArguments(); }
|
||||
using args_iterator = Region::args_iterator;
|
||||
args_iterator args_begin() { return getBody().args_begin(); }
|
||||
args_iterator args_end() { return getBody().args_end(); }
|
||||
Block::BlockArgListType getArguments() { return getBody().getArguments(); }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Argument Attributes
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
#include "mlir/IR/Block.h"
|
||||
|
||||
namespace mlir {
|
||||
class TypeRange;
|
||||
template <typename ValueRangeT>
|
||||
class ValueTypeRange;
|
||||
class BlockAndValueMapping;
|
||||
|
||||
/// This class contains a list of basic blocks and a link to the parent
|
||||
|
@ -62,6 +65,48 @@ public:
|
|||
return &Region::blocks;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Argument Handling
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// This is the list of arguments to the block.
|
||||
using BlockArgListType = MutableArrayRef<BlockArgument>;
|
||||
BlockArgListType getArguments() {
|
||||
return empty() ? BlockArgListType() : front().getArguments();
|
||||
}
|
||||
using args_iterator = BlockArgListType::iterator;
|
||||
using reverse_args_iterator = BlockArgListType::reverse_iterator;
|
||||
args_iterator args_begin() { return getArguments().begin(); }
|
||||
args_iterator args_end() { return getArguments().end(); }
|
||||
reverse_args_iterator args_rbegin() { return getArguments().rbegin(); }
|
||||
reverse_args_iterator args_rend() { return getArguments().rend(); }
|
||||
|
||||
bool args_empty() { return getArguments().empty(); }
|
||||
|
||||
/// Add one value to the argument list.
|
||||
BlockArgument addArgument(Type type) { return front().addArgument(type); }
|
||||
|
||||
/// Insert one value to the position in the argument list indicated by the
|
||||
/// given iterator. The existing arguments are shifted. The block is expected
|
||||
/// not to have predecessors.
|
||||
BlockArgument insertArgument(args_iterator it, Type type) {
|
||||
return front().insertArgument(it, type);
|
||||
}
|
||||
|
||||
/// Add one argument to the argument list for each type specified in the list.
|
||||
iterator_range<args_iterator> addArguments(TypeRange types);
|
||||
|
||||
/// Add one value to the argument list at the specified position.
|
||||
BlockArgument insertArgument(unsigned index, Type type) {
|
||||
return front().insertArgument(index, type);
|
||||
}
|
||||
|
||||
/// Erase the argument at 'index' and remove it from the argument list.
|
||||
void eraseArgument(unsigned index) { front().eraseArgument(index); }
|
||||
|
||||
unsigned getNumArguments() { return getArguments().size(); }
|
||||
BlockArgument getArgument(unsigned i) { return getArguments()[i]; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operation list utilities
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -417,8 +417,8 @@ static LogicalResult processParallelLoop(
|
|||
|
||||
if (isMappedToProcessor(processor)) {
|
||||
// Use the corresponding thread/grid index as replacement for the loop iv.
|
||||
Value operand = launchOp.body().front().getArgument(
|
||||
getLaunchOpArgumentNum(processor));
|
||||
Value operand =
|
||||
launchOp.body().getArgument(getLaunchOpArgumentNum(processor));
|
||||
// Take the indexmap and add the lower bound and step computations in.
|
||||
// This computes operand * step + lowerBound.
|
||||
// Use an affine map here so that it composes nicely with the provided
|
||||
|
|
|
@ -127,9 +127,9 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
|
|||
return allReduce.emitError(
|
||||
"expected either an op attribute or a non-empty body");
|
||||
if (!allReduce.body().empty()) {
|
||||
if (allReduce.body().front().getNumArguments() != 2)
|
||||
if (allReduce.body().getNumArguments() != 2)
|
||||
return allReduce.emitError("expected two region arguments");
|
||||
for (auto argument : allReduce.body().front().getArguments()) {
|
||||
for (auto argument : allReduce.body().getArguments()) {
|
||||
if (argument.getType() != allReduce.getType())
|
||||
return allReduce.emitError("incorrect region argument type");
|
||||
}
|
||||
|
@ -219,25 +219,25 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
|
|||
|
||||
KernelDim3 LaunchOp::getBlockIds() {
|
||||
assert(!body().empty() && "LaunchOp body must not be empty.");
|
||||
auto args = body().front().getArguments();
|
||||
auto args = body().getArguments();
|
||||
return KernelDim3{args[0], args[1], args[2]};
|
||||
}
|
||||
|
||||
KernelDim3 LaunchOp::getThreadIds() {
|
||||
assert(!body().empty() && "LaunchOp body must not be empty.");
|
||||
auto args = body().front().getArguments();
|
||||
auto args = body().getArguments();
|
||||
return KernelDim3{args[3], args[4], args[5]};
|
||||
}
|
||||
|
||||
KernelDim3 LaunchOp::getGridSize() {
|
||||
assert(!body().empty() && "LaunchOp body must not be empty.");
|
||||
auto args = body().front().getArguments();
|
||||
auto args = body().getArguments();
|
||||
return KernelDim3{args[6], args[7], args[8]};
|
||||
}
|
||||
|
||||
KernelDim3 LaunchOp::getBlockSize() {
|
||||
assert(!body().empty() && "LaunchOp body must not be empty.");
|
||||
auto args = body().getBlocks().front().getArguments();
|
||||
auto args = body().getArguments();
|
||||
return KernelDim3{args[9], args[10], args[11]};
|
||||
}
|
||||
|
||||
|
@ -254,8 +254,7 @@ static LogicalResult verify(LaunchOp op) {
|
|||
// sizes and transforms them into kNumConfigRegionAttributes region arguments
|
||||
// for block/thread identifiers and grid/block sizes.
|
||||
if (!op.body().empty()) {
|
||||
Block &entryBlock = op.body().front();
|
||||
if (entryBlock.getNumArguments() !=
|
||||
if (op.body().getNumArguments() !=
|
||||
LaunchOp::kNumConfigOperands + op.getNumOperands())
|
||||
return op.emitOpError("unexpected number of region arguments");
|
||||
}
|
||||
|
@ -463,8 +462,8 @@ BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) {
|
|||
auto attrName = getNumWorkgroupAttributionsAttrName();
|
||||
auto attr = getAttrOfType<IntegerAttr>(attrName);
|
||||
setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1));
|
||||
return getBody().front().insertArgument(
|
||||
getType().getNumInputs() + attr.getInt(), type);
|
||||
return getBody().insertArgument(getType().getNumInputs() + attr.getInt(),
|
||||
type);
|
||||
}
|
||||
|
||||
/// Adds a new block argument that corresponds to buffers located in
|
||||
|
@ -472,7 +471,7 @@ BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) {
|
|||
BlockArgument GPUFuncOp::addPrivateAttribution(Type type) {
|
||||
// Buffers on the private memory always come after buffers on the workgroup
|
||||
// memory.
|
||||
return getBody().front().addArgument(type);
|
||||
return getBody().addArgument(type);
|
||||
}
|
||||
|
||||
void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
|
||||
|
|
|
@ -181,8 +181,8 @@ private:
|
|||
|
||||
// Insert accumulator body between split block.
|
||||
BlockAndValueMapping mapping;
|
||||
mapping.map(body.front().getArgument(0), lhs);
|
||||
mapping.map(body.front().getArgument(1), rhs);
|
||||
mapping.map(body.getArgument(0), lhs);
|
||||
mapping.map(body.getArgument(1), rhs);
|
||||
rewriter.cloneRegionBefore(body, *split->getParent(),
|
||||
split->getIterator(), mapping);
|
||||
|
||||
|
|
|
@ -1102,8 +1102,7 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
|
|||
unsigned argIndex,
|
||||
NamedAttribute attribute) {
|
||||
return verifyRegionAttribute(
|
||||
op->getLoc(),
|
||||
op->getRegion(regionIndex).front().getArgument(argIndex).getType(),
|
||||
op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
|
||||
attribute);
|
||||
}
|
||||
|
||||
|
|
|
@ -525,22 +525,21 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
|
|||
|
||||
Region *bodyRegion = result.addRegion();
|
||||
bodyRegion->push_back(new Block());
|
||||
bodyRegion->front().addArgument(elementType);
|
||||
bodyRegion->addArgument(elementType);
|
||||
}
|
||||
}
|
||||
|
||||
static LogicalResult verify(GenericAtomicRMWOp op) {
|
||||
auto &block = op.body().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
auto &body = op.body();
|
||||
if (body.getNumArguments() != 1)
|
||||
return op.emitOpError("expected single number of entry block arguments");
|
||||
|
||||
if (op.getResult().getType() != block.getArgument(0).getType())
|
||||
if (op.getResult().getType() != body.getArgument(0).getType())
|
||||
return op.emitOpError(
|
||||
"expected block argument of the same type result type");
|
||||
|
||||
bool hasSideEffects =
|
||||
op.body()
|
||||
.walk([&](Operation *nestedOp) {
|
||||
body.walk([&](Operation *nestedOp) {
|
||||
if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
|
||||
return WalkResult::advance();
|
||||
nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
|
||||
|
|
|
@ -619,7 +619,7 @@ unsigned SSANameState::getBlockID(Block *block) {
|
|||
|
||||
void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
|
||||
assert(!region.empty() && "cannot shadow arguments of an empty region");
|
||||
assert(region.front().getNumArguments() == namesToUse.size() &&
|
||||
assert(region.getNumArguments() == namesToUse.size() &&
|
||||
"incorrect number of names passed in");
|
||||
assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
|
||||
"only KnownIsolatedFromAbove ops can shadow names");
|
||||
|
@ -629,7 +629,7 @@ void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
|
|||
auto nameToUse = namesToUse[i];
|
||||
if (nameToUse == nullptr)
|
||||
continue;
|
||||
auto nameToReplace = region.front().getArgument(i);
|
||||
auto nameToReplace = region.getArgument(i);
|
||||
|
||||
nameStr.clear();
|
||||
llvm::raw_svector_ostream nameStream(nameStr);
|
||||
|
|
|
@ -238,7 +238,7 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
|
|||
p << ", ";
|
||||
|
||||
if (!isExternal) {
|
||||
p.printOperand(body.front().getArgument(i));
|
||||
p.printOperand(body.getArgument(i));
|
||||
p << ": ";
|
||||
}
|
||||
|
||||
|
|
|
@ -1022,7 +1022,7 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
|
|||
if (region.empty())
|
||||
continue;
|
||||
|
||||
if (region.front().getNumArguments() != 0) {
|
||||
if (region.getNumArguments() != 0) {
|
||||
if (op->getNumRegions() > 1)
|
||||
return op->emitOpError("region #")
|
||||
<< region.getRegionNumber() << " should have no arguments";
|
||||
|
|
|
@ -33,6 +33,11 @@ Location Region::getLoc() {
|
|||
return container->getLoc();
|
||||
}
|
||||
|
||||
/// Add one argument to the argument list for each type specified in the list.
|
||||
iterator_range<Region::args_iterator> Region::addArguments(TypeRange types) {
|
||||
return front().addArguments(types);
|
||||
}
|
||||
|
||||
Region *Region::getParentRegion() {
|
||||
assert(container && "region is not attached to a container");
|
||||
return container->getParentRegion();
|
||||
|
|
|
@ -123,7 +123,7 @@ public:
|
|||
/// Build a lattice state with a given callable region, and a specified number
|
||||
/// of results to be initialized to the default lattice value (Unknown).
|
||||
CallableLatticeState(Region *callableRegion, unsigned numResults)
|
||||
: callableArguments(callableRegion->front().getArguments()),
|
||||
: callableArguments(callableRegion->getArguments()),
|
||||
resultLatticeValues(numResults) {}
|
||||
|
||||
/// Returns the arguments to the callable region.
|
||||
|
@ -403,7 +403,7 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
|
|||
// If not all of the uses of this symbol are visible, we can't track the
|
||||
// state of the arguments.
|
||||
if (symbol.isPublic() || (!allUsesVisible && symbol.isNested()))
|
||||
markAllOverdefined(callableRegion->front().getArguments());
|
||||
markAllOverdefined(callableRegion->getArguments());
|
||||
}
|
||||
if (callableLatticeState.empty())
|
||||
return;
|
||||
|
|
|
@ -284,7 +284,7 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
|
|||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto illegalOp =
|
||||
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
|
||||
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
|
||||
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
|
||||
illegalOp);
|
||||
rewriter.updateRootInPlace(op, [] {});
|
||||
return success();
|
||||
|
|
Loading…
Reference in New Issue