[MLIR] Add argument related API to Region

- Arguments of the first block of a region are considered region arguments.
- Add API on Region class to deal with these arguments directly instead of
  using the front() block.
- Changed several instances of existing code that can use this API
- Fixes https://bugs.llvm.org/show_bug.cgi?id=46535

Differential Revision: https://reviews.llvm.org/D83599
This commit is contained in:
Rahul Joshi 2020-07-10 17:07:29 -07:00
parent 85bed2f381
commit e2b716105b
15 changed files with 87 additions and 42 deletions

View File

@ -237,7 +237,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
/// the workgroup memory
ArrayRef<BlockArgument> getWorkgroupAttributions() {
auto begin =
std::next(getBody().front().args_begin(), getType().getNumInputs());
std::next(getBody().args_begin(), getType().getNumInputs());
auto end = std::next(begin, getNumWorkgroupAttributions());
return {begin, end};
}
@ -248,7 +248,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
/// Returns the number of buffers located in the private memory.
unsigned getNumPrivateAttributions() {
return getBody().front().getNumArguments() - getType().getNumInputs() -
return getBody().getNumArguments() - getType().getNumInputs() -
getNumWorkgroupAttributions();
}
@ -258,9 +258,9 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
// Buffers on the private memory always come after buffers on the workgroup
// memory.
auto begin =
std::next(getBody().front().args_begin(),
std::next(getBody().args_begin(),
getType().getNumInputs() + getNumWorkgroupAttributions());
return {begin, getBody().front().args_end()};
return {begin, getBody().args_end()};
}
/// Adds a new block argument that corresponds to buffers located in

View File

@ -583,7 +583,7 @@ def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
let extraClassDeclaration = [{
// The value stored in memref[ivs].
Value getCurrentValue() {
return body().front().getArgument(0);
return body().getArgument(0);
}
MemRefType getMemRefType() {
return memref().getType().cast<MemRefType>();

View File

@ -216,15 +216,13 @@ public:
}
/// Gets argument.
BlockArgument getArgument(unsigned idx) {
return getBlocks().front().getArgument(idx);
}
BlockArgument getArgument(unsigned idx) { return getBody().getArgument(idx); }
/// Support argument iteration.
using args_iterator = Block::args_iterator;
args_iterator args_begin() { return front().args_begin(); }
args_iterator args_end() { return front().args_end(); }
Block::BlockArgListType getArguments() { return front().getArguments(); }
using args_iterator = Region::args_iterator;
args_iterator args_begin() { return getBody().args_begin(); }
args_iterator args_end() { return getBody().args_end(); }
Block::BlockArgListType getArguments() { return getBody().getArguments(); }
//===--------------------------------------------------------------------===//
// Argument Attributes

View File

@ -16,6 +16,9 @@
#include "mlir/IR/Block.h"
namespace mlir {
class TypeRange;
template <typename ValueRangeT>
class ValueTypeRange;
class BlockAndValueMapping;
/// This class contains a list of basic blocks and a link to the parent
@ -62,6 +65,48 @@ public:
return &Region::blocks;
}
//===--------------------------------------------------------------------===//
// Argument Handling
//===--------------------------------------------------------------------===//
// This is the list of arguments to the block.
using BlockArgListType = MutableArrayRef<BlockArgument>;
BlockArgListType getArguments() {
return empty() ? BlockArgListType() : front().getArguments();
}
using args_iterator = BlockArgListType::iterator;
using reverse_args_iterator = BlockArgListType::reverse_iterator;
args_iterator args_begin() { return getArguments().begin(); }
args_iterator args_end() { return getArguments().end(); }
reverse_args_iterator args_rbegin() { return getArguments().rbegin(); }
reverse_args_iterator args_rend() { return getArguments().rend(); }
bool args_empty() { return getArguments().empty(); }
/// Add one value to the argument list.
BlockArgument addArgument(Type type) { return front().addArgument(type); }
/// Insert one value to the position in the argument list indicated by the
/// given iterator. The existing arguments are shifted. The block is expected
/// not to have predecessors.
BlockArgument insertArgument(args_iterator it, Type type) {
return front().insertArgument(it, type);
}
/// Add one argument to the argument list for each type specified in the list.
iterator_range<args_iterator> addArguments(TypeRange types);
/// Add one value to the argument list at the specified position.
BlockArgument insertArgument(unsigned index, Type type) {
return front().insertArgument(index, type);
}
/// Erase the argument at 'index' and remove it from the argument list.
void eraseArgument(unsigned index) { front().eraseArgument(index); }
unsigned getNumArguments() { return getArguments().size(); }
BlockArgument getArgument(unsigned i) { return getArguments()[i]; }
//===--------------------------------------------------------------------===//
// Operation list utilities
//===--------------------------------------------------------------------===//

View File

@ -417,8 +417,8 @@ static LogicalResult processParallelLoop(
if (isMappedToProcessor(processor)) {
// Use the corresponding thread/grid index as replacement for the loop iv.
Value operand = launchOp.body().front().getArgument(
getLaunchOpArgumentNum(processor));
Value operand =
launchOp.body().getArgument(getLaunchOpArgumentNum(processor));
// Take the indexmap and add the lower bound and step computations in.
// This computes operand * step + lowerBound.
// Use an affine map here so that it composes nicely with the provided

View File

@ -127,9 +127,9 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
return allReduce.emitError(
"expected either an op attribute or a non-empty body");
if (!allReduce.body().empty()) {
if (allReduce.body().front().getNumArguments() != 2)
if (allReduce.body().getNumArguments() != 2)
return allReduce.emitError("expected two region arguments");
for (auto argument : allReduce.body().front().getArguments()) {
for (auto argument : allReduce.body().getArguments()) {
if (argument.getType() != allReduce.getType())
return allReduce.emitError("incorrect region argument type");
}
@ -219,25 +219,25 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
KernelDim3 LaunchOp::getBlockIds() {
assert(!body().empty() && "LaunchOp body must not be empty.");
auto args = body().front().getArguments();
auto args = body().getArguments();
return KernelDim3{args[0], args[1], args[2]};
}
KernelDim3 LaunchOp::getThreadIds() {
assert(!body().empty() && "LaunchOp body must not be empty.");
auto args = body().front().getArguments();
auto args = body().getArguments();
return KernelDim3{args[3], args[4], args[5]};
}
KernelDim3 LaunchOp::getGridSize() {
assert(!body().empty() && "LaunchOp body must not be empty.");
auto args = body().front().getArguments();
auto args = body().getArguments();
return KernelDim3{args[6], args[7], args[8]};
}
KernelDim3 LaunchOp::getBlockSize() {
assert(!body().empty() && "LaunchOp body must not be empty.");
auto args = body().getBlocks().front().getArguments();
auto args = body().getArguments();
return KernelDim3{args[9], args[10], args[11]};
}
@ -254,8 +254,7 @@ static LogicalResult verify(LaunchOp op) {
// sizes and transforms them into kNumConfigRegionAttributes region arguments
// for block/thread identifiers and grid/block sizes.
if (!op.body().empty()) {
Block &entryBlock = op.body().front();
if (entryBlock.getNumArguments() !=
if (op.body().getNumArguments() !=
LaunchOp::kNumConfigOperands + op.getNumOperands())
return op.emitOpError("unexpected number of region arguments");
}
@ -463,8 +462,8 @@ BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) {
auto attrName = getNumWorkgroupAttributionsAttrName();
auto attr = getAttrOfType<IntegerAttr>(attrName);
setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1));
return getBody().front().insertArgument(
getType().getNumInputs() + attr.getInt(), type);
return getBody().insertArgument(getType().getNumInputs() + attr.getInt(),
type);
}
/// Adds a new block argument that corresponds to buffers located in
@ -472,7 +471,7 @@ BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) {
BlockArgument GPUFuncOp::addPrivateAttribution(Type type) {
// Buffers on the private memory always come after buffers on the workgroup
// memory.
return getBody().front().addArgument(type);
return getBody().addArgument(type);
}
void GPUFuncOp::build(OpBuilder &builder, OperationState &result,

View File

@ -181,8 +181,8 @@ private:
// Insert accumulator body between split block.
BlockAndValueMapping mapping;
mapping.map(body.front().getArgument(0), lhs);
mapping.map(body.front().getArgument(1), rhs);
mapping.map(body.getArgument(0), lhs);
mapping.map(body.getArgument(1), rhs);
rewriter.cloneRegionBefore(body, *split->getParent(),
split->getIterator(), mapping);

View File

@ -1102,8 +1102,7 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
unsigned argIndex,
NamedAttribute attribute) {
return verifyRegionAttribute(
op->getLoc(),
op->getRegion(regionIndex).front().getArgument(argIndex).getType(),
op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
attribute);
}

View File

@ -525,22 +525,21 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block());
bodyRegion->front().addArgument(elementType);
bodyRegion->addArgument(elementType);
}
}
static LogicalResult verify(GenericAtomicRMWOp op) {
auto &block = op.body().front();
if (block.getNumArguments() != 1)
auto &body = op.body();
if (body.getNumArguments() != 1)
return op.emitOpError("expected single number of entry block arguments");
if (op.getResult().getType() != block.getArgument(0).getType())
if (op.getResult().getType() != body.getArgument(0).getType())
return op.emitOpError(
"expected block argument of the same type result type");
bool hasSideEffects =
op.body()
.walk([&](Operation *nestedOp) {
body.walk([&](Operation *nestedOp) {
if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
return WalkResult::advance();
nestedOp->emitError("body of 'generic_atomic_rmw' should contain "

View File

@ -619,7 +619,7 @@ unsigned SSANameState::getBlockID(Block *block) {
void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
assert(!region.empty() && "cannot shadow arguments of an empty region");
assert(region.front().getNumArguments() == namesToUse.size() &&
assert(region.getNumArguments() == namesToUse.size() &&
"incorrect number of names passed in");
assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
"only KnownIsolatedFromAbove ops can shadow names");
@ -629,7 +629,7 @@ void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
auto nameToUse = namesToUse[i];
if (nameToUse == nullptr)
continue;
auto nameToReplace = region.front().getArgument(i);
auto nameToReplace = region.getArgument(i);
nameStr.clear();
llvm::raw_svector_ostream nameStream(nameStr);

View File

@ -238,7 +238,7 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
p << ", ";
if (!isExternal) {
p.printOperand(body.front().getArgument(i));
p.printOperand(body.getArgument(i));
p << ": ";
}

View File

@ -1022,7 +1022,7 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
if (region.empty())
continue;
if (region.front().getNumArguments() != 0) {
if (region.getNumArguments() != 0) {
if (op->getNumRegions() > 1)
return op->emitOpError("region #")
<< region.getRegionNumber() << " should have no arguments";

View File

@ -33,6 +33,11 @@ Location Region::getLoc() {
return container->getLoc();
}
/// Add one argument to the argument list for each type specified in the list.
iterator_range<Region::args_iterator> Region::addArguments(TypeRange types) {
return front().addArguments(types);
}
Region *Region::getParentRegion() {
assert(container && "region is not attached to a container");
return container->getParentRegion();

View File

@ -123,7 +123,7 @@ public:
/// Build a lattice state with a given callable region, and a specified number
/// of results to be initialized to the default lattice value (Unknown).
CallableLatticeState(Region *callableRegion, unsigned numResults)
: callableArguments(callableRegion->front().getArguments()),
: callableArguments(callableRegion->getArguments()),
resultLatticeValues(numResults) {}
/// Returns the arguments to the callable region.
@ -403,7 +403,7 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
// If not all of the uses of this symbol are visible, we can't track the
// state of the arguments.
if (symbol.isPublic() || (!allUsesVisible && symbol.isNested()))
markAllOverdefined(callableRegion->front().getArguments());
markAllOverdefined(callableRegion->getArguments());
}
if (callableLatticeState.empty())
return;

View File

@ -284,7 +284,7 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
auto illegalOp =
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
illegalOp);
rewriter.updateRootInPlace(op, [] {});
return success();