forked from OSchip/llvm-project
401 lines
16 KiB
C++
401 lines
16 KiB
C++
//===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <utility>
|
|
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ControlFlowInterfaces
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
|
|
|
|
SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
|
|
: producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
|
|
}
|
|
|
|
SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
|
|
MutableOperandRange forwardedOperands)
|
|
: producedOperandCount(producedOperandCount),
|
|
forwardedOperands(std::move(forwardedOperands)) {}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BranchOpInterface
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
|
|
/// successor if 'operandIndex' is within the range of 'operands', or None if
|
|
/// `operandIndex` isn't a successor operand index.
|
|
Optional<BlockArgument>
|
|
detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
|
|
unsigned operandIndex, Block *successor) {
|
|
OperandRange forwardedOperands = operands.getForwardedOperands();
|
|
// Check that the operands are valid.
|
|
if (forwardedOperands.empty())
|
|
return llvm::None;
|
|
|
|
// Check to ensure that this operand is within the range.
|
|
unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
|
|
if (operandIndex < operandsStart ||
|
|
operandIndex >= (operandsStart + forwardedOperands.size()))
|
|
return llvm::None;
|
|
|
|
// Index the successor.
|
|
unsigned argIndex =
|
|
operands.getProducedOperandCount() + operandIndex - operandsStart;
|
|
return successor->getArgument(argIndex);
|
|
}
|
|
|
|
/// Verify that the given operands match those of the given successor block.
|
|
LogicalResult
|
|
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
|
const SuccessorOperands &operands) {
|
|
// Check the count.
|
|
unsigned operandCount = operands.size();
|
|
Block *destBB = op->getSuccessor(succNo);
|
|
if (operandCount != destBB->getNumArguments())
|
|
return op->emitError() << "branch has " << operandCount
|
|
<< " operands for successor #" << succNo
|
|
<< ", but target block has "
|
|
<< destBB->getNumArguments();
|
|
|
|
// Check the types.
|
|
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
|
|
++i) {
|
|
if (!cast<BranchOpInterface>(op).areTypesCompatible(
|
|
operands[i].getType(), destBB->getArgument(i).getType()))
|
|
return op->emitError() << "type mismatch for bb argument #" << i
|
|
<< " of successor #" << succNo;
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RegionBranchOpInterface
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Verify that types match along all region control flow edges originating from
|
|
/// `sourceNo` (region # if source is a region, llvm::None if source is parent
|
|
/// op). `getInputsTypesForRegion` is a function that returns the types of the
|
|
/// inputs that flow from `sourceIndex' to the given region, or llvm::None if
|
|
/// the exact type match verification is not necessary (e.g., if the Op verifies
|
|
/// the match itself).
|
|
static LogicalResult
|
|
verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
|
|
function_ref<Optional<TypeRange>(Optional<unsigned>)>
|
|
getInputsTypesForRegion) {
|
|
auto regionInterface = cast<RegionBranchOpInterface>(op);
|
|
|
|
SmallVector<RegionSuccessor, 2> successors;
|
|
regionInterface.getSuccessorRegions(sourceNo, successors);
|
|
|
|
for (RegionSuccessor &succ : successors) {
|
|
Optional<unsigned> succRegionNo;
|
|
if (!succ.isParent())
|
|
succRegionNo = succ.getSuccessor()->getRegionNumber();
|
|
|
|
auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
|
|
diag << "from ";
|
|
if (sourceNo)
|
|
diag << "Region #" << sourceNo.value();
|
|
else
|
|
diag << "parent operands";
|
|
|
|
diag << " to ";
|
|
if (succRegionNo)
|
|
diag << "Region #" << succRegionNo.value();
|
|
else
|
|
diag << "parent results";
|
|
return diag;
|
|
};
|
|
|
|
Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
|
|
if (!sourceTypes.has_value())
|
|
continue;
|
|
|
|
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
|
|
if (sourceTypes->size() != succInputsTypes.size()) {
|
|
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
|
|
return printEdgeName(diag) << ": source has " << sourceTypes->size()
|
|
<< " operands, but target successor needs "
|
|
<< succInputsTypes.size();
|
|
}
|
|
|
|
for (const auto &typesIdx :
|
|
llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
|
|
Type sourceType = std::get<0>(typesIdx.value());
|
|
Type inputType = std::get<1>(typesIdx.value());
|
|
if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
|
|
InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
|
|
return printEdgeName(diag)
|
|
<< ": source type #" << typesIdx.index() << " " << sourceType
|
|
<< " should match input type #" << typesIdx.index() << " "
|
|
<< inputType;
|
|
}
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Verify that types match along control flow edges described the given op.
|
|
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
|
|
auto regionInterface = cast<RegionBranchOpInterface>(op);
|
|
|
|
auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
|
|
return regionInterface.getSuccessorEntryOperands(regionNo).getTypes();
|
|
};
|
|
|
|
// Verify types along control flow edges originating from the parent.
|
|
if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
|
|
return failure();
|
|
|
|
// RegionBranchOpInterface should not be implemented by Ops that do not have
|
|
// attached regions.
|
|
assert(op->getNumRegions() != 0);
|
|
|
|
auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
|
|
if (lhs.size() != rhs.size())
|
|
return false;
|
|
for (auto types : llvm::zip(lhs, rhs)) {
|
|
if (!regionInterface.areTypesCompatible(std::get<0>(types),
|
|
std::get<1>(types))) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
};
|
|
|
|
// Verify types along control flow edges originating from each region.
|
|
for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
|
|
Region ®ion = op->getRegion(regionNo);
|
|
|
|
// Since there can be multiple `ReturnLike` terminators or others
|
|
// implementing the `RegionBranchTerminatorOpInterface`, all should have the
|
|
// same operand types when passing them to the same region.
|
|
|
|
Optional<OperandRange> regionReturnOperands;
|
|
for (Block &block : region) {
|
|
Operation *terminator = block.getTerminator();
|
|
auto terminatorOperands =
|
|
getRegionBranchSuccessorOperands(terminator, regionNo);
|
|
if (!terminatorOperands)
|
|
continue;
|
|
|
|
if (!regionReturnOperands) {
|
|
regionReturnOperands = terminatorOperands;
|
|
continue;
|
|
}
|
|
|
|
// Found more than one ReturnLike terminator. Make sure the operand types
|
|
// match with the first one.
|
|
if (!areTypesCompatible(regionReturnOperands->getTypes(),
|
|
terminatorOperands->getTypes()))
|
|
return op->emitOpError("Region #")
|
|
<< regionNo
|
|
<< " operands mismatch between return-like terminators";
|
|
}
|
|
|
|
auto inputTypesFromRegion =
|
|
[&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
|
|
// If there is no return-like terminator, the op itself should verify
|
|
// type consistency.
|
|
if (!regionReturnOperands)
|
|
return llvm::None;
|
|
|
|
// All successors get the same set of operand types.
|
|
return TypeRange(regionReturnOperands->getTypes());
|
|
};
|
|
|
|
if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Return `true` if region `r` is reachable from region `begin` according to
|
|
/// the RegionBranchOpInterface (by taking a branch).
|
|
static bool isRegionReachable(Region *begin, Region *r) {
|
|
assert(begin->getParentOp() == r->getParentOp() &&
|
|
"expected that both regions belong to the same op");
|
|
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
|
|
SmallVector<bool> visited(op->getNumRegions(), false);
|
|
visited[begin->getRegionNumber()] = true;
|
|
|
|
// Retrieve all successors of the region and enqueue them in the worklist.
|
|
SmallVector<unsigned> worklist;
|
|
auto enqueueAllSuccessors = [&](unsigned index) {
|
|
SmallVector<RegionSuccessor> successors;
|
|
op.getSuccessorRegions(index, successors);
|
|
for (RegionSuccessor successor : successors)
|
|
if (!successor.isParent())
|
|
worklist.push_back(successor.getSuccessor()->getRegionNumber());
|
|
};
|
|
enqueueAllSuccessors(begin->getRegionNumber());
|
|
|
|
// Process all regions in the worklist via DFS.
|
|
while (!worklist.empty()) {
|
|
unsigned nextRegion = worklist.pop_back_val();
|
|
if (nextRegion == r->getRegionNumber())
|
|
return true;
|
|
if (visited[nextRegion])
|
|
continue;
|
|
visited[nextRegion] = true;
|
|
enqueueAllSuccessors(nextRegion);
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
/// Return `true` if `a` and `b` are in mutually exclusive regions.
|
|
///
|
|
/// 1. Find the first common of `a` and `b` (ancestor) that implements
|
|
/// RegionBranchOpInterface.
|
|
/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
|
|
/// contained.
|
|
/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
|
|
/// mutually exclusive if they are not reachable from each other as per
|
|
/// RegionBranchOpInterface::getSuccessorRegions.
|
|
bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
|
|
assert(a && "expected non-empty operation");
|
|
assert(b && "expected non-empty operation");
|
|
|
|
auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
|
|
while (branchOp) {
|
|
// Check if b is inside branchOp. (We already know that a is.)
|
|
if (!branchOp->isProperAncestor(b)) {
|
|
// Check next enclosing RegionBranchOpInterface.
|
|
branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
|
|
continue;
|
|
}
|
|
|
|
// b is contained in branchOp. Retrieve the regions in which `a` and `b`
|
|
// are contained.
|
|
Region *regionA = nullptr, *regionB = nullptr;
|
|
for (Region &r : branchOp->getRegions()) {
|
|
if (r.findAncestorOpInRegion(*a)) {
|
|
assert(!regionA && "already found a region for a");
|
|
regionA = &r;
|
|
}
|
|
if (r.findAncestorOpInRegion(*b)) {
|
|
assert(!regionB && "already found a region for b");
|
|
regionB = &r;
|
|
}
|
|
}
|
|
assert(regionA && regionB && "could not find region of op");
|
|
|
|
// `a` and `b` are in mutually exclusive regions if both regions are
|
|
// distinct and neither region is reachable from the other region.
|
|
return regionA != regionB && !isRegionReachable(regionA, regionB) &&
|
|
!isRegionReachable(regionB, regionA);
|
|
}
|
|
|
|
// Could not find a common RegionBranchOpInterface among a's and b's
|
|
// ancestors.
|
|
return false;
|
|
}
|
|
|
|
bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
|
|
Region *region = &getOperation()->getRegion(index);
|
|
return isRegionReachable(region, region);
|
|
}
|
|
|
|
void RegionBranchOpInterface::getSuccessorRegions(
|
|
Optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
|
unsigned numInputs = 0;
|
|
if (index) {
|
|
// If the predecessor is a region, get the number of operands from an
|
|
// exiting terminator in the region.
|
|
for (Block &block : getOperation()->getRegion(*index)) {
|
|
Operation *terminator = block.getTerminator();
|
|
if (getRegionBranchSuccessorOperands(terminator, *index)) {
|
|
numInputs = terminator->getNumOperands();
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
// Otherwise, use the number of parent operation operands.
|
|
numInputs = getOperation()->getNumOperands();
|
|
}
|
|
SmallVector<Attribute, 2> operands(numInputs, nullptr);
|
|
getSuccessorRegions(index, operands, regions);
|
|
}
|
|
|
|
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
|
|
while (Region *region = op->getParentRegion()) {
|
|
op = region->getParentOp();
|
|
if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
|
|
if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
|
|
return region;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
Region *mlir::getEnclosingRepetitiveRegion(Value value) {
|
|
Region *region = value.getParentRegion();
|
|
while (region) {
|
|
Operation *op = region->getParentOp();
|
|
if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
|
|
if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
|
|
return region;
|
|
region = op->getParentRegion();
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RegionBranchTerminatorOpInterface
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns true if the given operation is either annotated with the
|
|
/// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
|
|
bool mlir::isRegionReturnLike(Operation *operation) {
|
|
return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
|
|
operation->hasTrait<OpTrait::ReturnLike>();
|
|
}
|
|
|
|
/// Returns the mutable operands that are passed to the region with the given
|
|
/// `regionIndex`. If the operation does not implement the
|
|
/// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
|
|
/// result will be `llvm::None`. In all other cases, the resulting
|
|
/// `OperandRange` represents all operands that are passed to the specified
|
|
/// successor region. If `regionIndex` is `llvm::None`, all operands that are
|
|
/// passed to the parent operation will be returned.
|
|
Optional<MutableOperandRange>
|
|
mlir::getMutableRegionBranchSuccessorOperands(Operation *operation,
|
|
Optional<unsigned> regionIndex) {
|
|
// Try to query a RegionBranchTerminatorOpInterface to determine
|
|
// all successor operands that will be passed to the successor
|
|
// input arguments.
|
|
if (auto regionTerminatorInterface =
|
|
dyn_cast<RegionBranchTerminatorOpInterface>(operation))
|
|
return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
|
|
|
|
// TODO: The ReturnLike trait should imply a default implementation of the
|
|
// RegionBranchTerminatorOpInterface. This would make this code significantly
|
|
// easier. Furthermore, this may even make this function obsolete.
|
|
if (operation->hasTrait<OpTrait::ReturnLike>())
|
|
return MutableOperandRange(operation);
|
|
return llvm::None;
|
|
}
|
|
|
|
/// Returns the read only operands that are passed to the region with the given
|
|
/// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
|
|
/// information.
|
|
Optional<OperandRange>
|
|
mlir::getRegionBranchSuccessorOperands(Operation *operation,
|
|
Optional<unsigned> regionIndex) {
|
|
auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
|
|
return range ? Optional<OperandRange>(*range) : llvm::None;
|
|
}
|