forked from OSchip/llvm-project
Refactor the DialectConversion process to clone each function and then operate in-place, as opposed to incrementally constructing a new function. This is crucial to allowing the use of non type-conversion patterns(normal RewritePatterns) as part of the conversion process.
The converter now works by inserting fake producer operations when replacing the results of an existing operation with values of a different, now legal, type. These fake operations are guaranteed to never escape the converter. -- PiperOrigin-RevId: 248969130
This commit is contained in:
parent
a23b728034
commit
6241cf132e
|
@ -38,7 +38,7 @@ class Value;
|
|||
|
||||
// Private implementation class.
|
||||
namespace impl {
|
||||
class FunctionConversion;
|
||||
class FunctionConverter;
|
||||
}
|
||||
|
||||
/// Base class for the dialect op conversion patterns. Specific conversions
|
||||
|
@ -126,7 +126,7 @@ private:
|
|||
///
|
||||
/// If the conversion fails, the module is not modified.
|
||||
class DialectConversion {
|
||||
friend class impl::FunctionConversion;
|
||||
friend class impl::FunctionConverter;
|
||||
|
||||
public:
|
||||
virtual ~DialectConversion() = default;
|
||||
|
|
|
@ -324,7 +324,7 @@ ViewOp mlir::linalg::SliceOp::getBaseViewOp() {
|
|||
}
|
||||
|
||||
ViewType mlir::linalg::SliceOp::getBaseViewType() {
|
||||
return getBaseViewOp().getType().cast<ViewType>();
|
||||
return getOperand(0)->getType().cast<ViewType>();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 8> mlir::linalg::SliceOp::getRanges() {
|
||||
|
|
|
@ -14,10 +14,6 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a generic pass for converting between MLIR dialects.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
@ -27,20 +23,89 @@
|
|||
#include "mlir/Transforms/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::impl;
|
||||
|
||||
namespace {
|
||||
/// This class provides a simple interface for generating fake producers during
|
||||
/// the conversion process. These fake producers are used when replacing the
|
||||
/// results of an operation with values of a new, legal, type. The producer
|
||||
/// provides a definition for the remaining uses of the old value while they
|
||||
/// await conversion.
|
||||
struct ProducerGenerator {
|
||||
ProducerGenerator(MLIRContext *ctx)
|
||||
: producerOpName(kProducerName, ctx), loc(UnknownLoc::get(ctx)) {}
|
||||
|
||||
/// Cleanup any generated conversion values. Returns failure if there are any
|
||||
/// dangling references to a producer operation, success otherwise.
|
||||
LogicalResult cleanupGeneratedOps() {
|
||||
for (auto *op : producerOps) {
|
||||
if (!op->use_empty()) {
|
||||
auto diag = op->getContext()->emitError(loc)
|
||||
<< "Converter did not convert all uses of replaced value "
|
||||
"with illegal type";
|
||||
for (auto *user : op->getResult(0)->getUsers())
|
||||
diag.attachNote(user->getLoc())
|
||||
<< "user was not converted : " << *user;
|
||||
return diag;
|
||||
}
|
||||
op->destroy();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Generate a producer value for 'oldValue'. These new producers replace all
|
||||
/// of the current uses of the original value, and record a mapping between
|
||||
/// for replacement with the 'newValue'.
|
||||
void generateAndReplace(Value *oldValue, Value *newValue,
|
||||
BlockAndValueMapping &mapping) {
|
||||
if (oldValue->use_empty())
|
||||
return;
|
||||
|
||||
// Otherwise, generate a new producer operation for the given value type.
|
||||
auto *producer = Operation::create(
|
||||
loc, producerOpName, llvm::None, oldValue->getType(), llvm::None,
|
||||
llvm::None, 0, false, oldValue->getContext());
|
||||
|
||||
// Replace the uses of the old value and record the mapping.
|
||||
oldValue->replaceAllUsesWith(producer->getResult(0));
|
||||
mapping.map(producer->getResult(0), newValue);
|
||||
producerOps.push_back(producer);
|
||||
}
|
||||
|
||||
/// This is an operation name for a fake operation that is inserted during the
|
||||
/// conversion process. Operations of this type are guaranteed to never escape
|
||||
/// the converter.
|
||||
static constexpr StringLiteral kProducerName = "__mlir_conversion.producer";
|
||||
OperationName producerOpName;
|
||||
|
||||
/// This is a collection of producer values that were generated during the
|
||||
/// conversion process.
|
||||
std::vector<Operation *> producerOps;
|
||||
|
||||
/// An instance of the unknown location that is used when generating
|
||||
/// producers.
|
||||
UnknownLoc loc;
|
||||
};
|
||||
|
||||
/// This class implements a pattern rewriter for DialectOpConversion patterns.
|
||||
/// It automatically performs remapping of replaced operation values.
|
||||
struct DialectConversionRewriter final : public PatternRewriter {
|
||||
DialectConversionRewriter(Function *fn) : PatternRewriter(fn) {}
|
||||
DialectConversionRewriter(Function *fn)
|
||||
: PatternRewriter(fn), tempGenerator(fn->getContext()) {}
|
||||
~DialectConversionRewriter() = default;
|
||||
|
||||
// Implement the hook for replacing an operation with new values.
|
||||
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
|
||||
ArrayRef<Value *> valuesToRemoveIfDead) override {
|
||||
assert(newValues.size() == op->getNumResults());
|
||||
for (unsigned i = 0, e = newValues.size(); i < e; ++i)
|
||||
mapping.map(op->getResult(i), newValues[i]);
|
||||
for (unsigned i = 0, e = newValues.size(); i < e; ++i) {
|
||||
Value *result = op->getResult(i);
|
||||
if (result->getType() != newValues[i]->getType())
|
||||
tempGenerator.generateAndReplace(result, newValues[i], mapping);
|
||||
else
|
||||
result->replaceAllUsesWith(newValues[i]);
|
||||
}
|
||||
op->erase();
|
||||
}
|
||||
|
||||
// Implement the hook for creating operations, and make sure that newly
|
||||
|
@ -52,16 +117,16 @@ struct DialectConversionRewriter final : public PatternRewriter {
|
|||
void lookupValues(Operation::operand_range operands,
|
||||
SmallVectorImpl<Value *> &remapped) {
|
||||
remapped.reserve(llvm::size(operands));
|
||||
for (Value *operand : operands) {
|
||||
Value *value = mapping.lookupOrNull(operand);
|
||||
assert(value && "converting op before ops defining its operands");
|
||||
remapped.push_back(value);
|
||||
}
|
||||
for (Value *operand : operands)
|
||||
remapped.push_back(mapping.lookupOrDefault(operand));
|
||||
}
|
||||
|
||||
// Mapping between values(blocks) in the original function and in the new
|
||||
// function.
|
||||
BlockAndValueMapping mapping;
|
||||
|
||||
/// Utility used to create temporary producers operations.
|
||||
ProducerGenerator tempGenerator;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -87,10 +152,7 @@ void DialectOpConversion::rewrite(Operation *op,
|
|||
SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
|
||||
unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0);
|
||||
for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) {
|
||||
// Lookup the successor.
|
||||
auto *successor = dialectRewriter.mapping.lookupOrNull(op->getSuccessor(i));
|
||||
assert(successor && "block was not remapped");
|
||||
destinations.push_back(successor);
|
||||
destinations.push_back(op->getSuccessor(i));
|
||||
|
||||
// Lookup the successors operands.
|
||||
unsigned n = op->getNumSuccessorOperands(i);
|
||||
|
@ -108,151 +170,160 @@ void DialectOpConversion::rewrite(Operation *op,
|
|||
|
||||
namespace mlir {
|
||||
namespace impl {
|
||||
// Implementation detail class of the DialectConversion pass. Performs
|
||||
// Implementation detail class of the DialectConversion utility. Performs
|
||||
// function-by-function conversions by creating new functions, filling them in
|
||||
// with converted blocks, updating the function attributes, and replacing the
|
||||
// old functions with the new ones in the module.
|
||||
class FunctionConversion {
|
||||
class FunctionConverter {
|
||||
public:
|
||||
// Constructs a FunctionConversion by storing the hooks.
|
||||
explicit FunctionConversion(DialectConversion *conversion, Function *func,
|
||||
RewritePatternMatcher &matcher)
|
||||
: dialectConversion(conversion), rewriter(func), matcher(matcher) {}
|
||||
// Constructs a FunctionConverter.
|
||||
explicit FunctionConverter(MLIRContext *ctx, DialectConversion *conversion,
|
||||
RewritePatternMatcher &matcher)
|
||||
: dialectConversion(conversion), matcher(matcher) {}
|
||||
|
||||
// Converts the current function to the dialect using hooks defined in
|
||||
// Converts the given function to the dialect using hooks defined in
|
||||
// `dialectConversion`. Returns the converted function or `nullptr` on error.
|
||||
Function *convertFunction();
|
||||
Function *convertFunction(Function *f);
|
||||
|
||||
// Converts the given region starting from the entry block and following the
|
||||
// block successors. Returns the converted region or `nullptr` on error.
|
||||
// block successors. Returns failure on error, success otherwise.
|
||||
template <typename RegionParent>
|
||||
std::unique_ptr<Region> convertRegion(MLIRContext *context, Region *region,
|
||||
RegionParent *parent);
|
||||
LogicalResult convertRegion(DialectConversionRewriter &rewriter,
|
||||
Region ®ion, RegionParent *parent);
|
||||
|
||||
// Converts a block by traversing its operations sequentially, looking for
|
||||
// the first pattern match and dispatching the operation conversion to
|
||||
// either `convertOp` or `convertOpWithSuccessors` depending on the presence
|
||||
// of successors. If there is no match, clones the operation.
|
||||
// Converts a block by traversing its operations sequentially, attempting to
|
||||
// match a pattern. If there is no match, recurses the operations regions if
|
||||
// it has any.
|
||||
//
|
||||
// After converting operations, traverses the successor blocks unless they
|
||||
// have been visited already as indicated in `visitedBlocks`.
|
||||
LogicalResult convertBlock(Block *block,
|
||||
llvm::DenseSet<Block *> &visitedBlocks);
|
||||
LogicalResult convertBlock(DialectConversionRewriter &rewriter, Block *block,
|
||||
DenseSet<Block *> &visitedBlocks);
|
||||
|
||||
// Pointer to a specific dialect pass.
|
||||
// Converts the type of the given block argument. Returns success if the
|
||||
// argument type could be successfully converted, failure otherwise.
|
||||
LogicalResult convertArgument(DialectConversionRewriter &rewriter,
|
||||
BlockArgument *arg, Location loc);
|
||||
|
||||
// Pointer to a specific dialect conversion info.
|
||||
DialectConversion *dialectConversion;
|
||||
|
||||
/// The writer used when rewriting operations.
|
||||
DialectConversionRewriter rewriter;
|
||||
|
||||
/// The matcher use when converting operations.
|
||||
/// The matcher to use when converting operations.
|
||||
RewritePatternMatcher &matcher;
|
||||
};
|
||||
} // end namespace impl
|
||||
} // end namespace mlir
|
||||
|
||||
LogicalResult
|
||||
impl::FunctionConversion::convertBlock(Block *block,
|
||||
llvm::DenseSet<Block *> &visitedBlocks) {
|
||||
FunctionConverter::convertArgument(DialectConversionRewriter &rewriter,
|
||||
BlockArgument *arg, Location loc) {
|
||||
auto convertedType = dialectConversion->convertType(arg->getType());
|
||||
if (!convertedType)
|
||||
return arg->getContext()->emitError(loc)
|
||||
<< "could not convert block argument of type : " << arg->getType();
|
||||
|
||||
// Generate a replacement value, with the new type, for this argument.
|
||||
if (convertedType != arg->getType()) {
|
||||
rewriter.tempGenerator.generateAndReplace(arg, arg, rewriter.mapping);
|
||||
arg->setType(convertedType);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
FunctionConverter::convertBlock(DialectConversionRewriter &rewriter,
|
||||
Block *block,
|
||||
DenseSet<Block *> &visitedBlocks) {
|
||||
// First, add the current block to the list of visited blocks.
|
||||
visitedBlocks.insert(block);
|
||||
// Setup the builder to the insert to the converted block.
|
||||
rewriter.setInsertionPointToStart(rewriter.mapping.lookupOrNull(block));
|
||||
|
||||
// Preserve the successors before rewriting the operations.
|
||||
SmallVector<Block *, 4> successors(block->getSuccessors());
|
||||
|
||||
// Iterate over ops and convert them.
|
||||
for (Operation &op : *block) {
|
||||
for (Operation &op : llvm::make_early_inc_range(*block)) {
|
||||
rewriter.setInsertionPoint(&op);
|
||||
if (matcher.matchAndRewrite(&op, rewriter))
|
||||
continue;
|
||||
|
||||
// If there is no conversion provided for the op, clone the op and convert
|
||||
// its regions, if any.
|
||||
auto *newOp = rewriter.cloneWithoutRegions(op, rewriter.mapping);
|
||||
for (int i = 0, e = op.getNumRegions(); i < e; ++i) {
|
||||
auto newRegion = convertRegion(op.getContext(), &op.getRegion(i), &op);
|
||||
newOp->getRegion(i).takeBody(*newRegion);
|
||||
}
|
||||
// If a rewrite wasn't matched, update any mapped operands in place.
|
||||
for (auto &operand : op.getOpOperands())
|
||||
if (auto *newOperand = rewriter.mapping.lookupOrNull(operand.get()))
|
||||
operand.set(newOperand);
|
||||
|
||||
// Traverse any held regions.
|
||||
for (auto ®ion : op.getRegions())
|
||||
if (!region.empty() && failed(convertRegion(rewriter, region, &op)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Recurse to children unless they have been already visited.
|
||||
for (Block *succ : block->getSuccessors()) {
|
||||
if (visitedBlocks.count(succ) != 0)
|
||||
// Recurse to children that haven't been visited.
|
||||
for (Block *succ : successors) {
|
||||
if (visitedBlocks.count(succ))
|
||||
continue;
|
||||
if (failed(convertBlock(succ, visitedBlocks)))
|
||||
if (failed(convertBlock(rewriter, succ, visitedBlocks)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename RegionParent>
|
||||
std::unique_ptr<Region>
|
||||
impl::FunctionConversion::convertRegion(MLIRContext *context, Region *region,
|
||||
RegionParent *parent) {
|
||||
assert(region && "expected a region");
|
||||
auto newRegion = llvm::make_unique<Region>(parent);
|
||||
if (region->empty())
|
||||
return newRegion;
|
||||
LogicalResult
|
||||
FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
|
||||
Region ®ion, RegionParent *parent) {
|
||||
assert(!region.empty() && "expected non-empty region");
|
||||
|
||||
auto emitError = [context](llvm::Twine f) -> std::unique_ptr<Region> {
|
||||
context->emitError(UnknownLoc::get(context), f.str());
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
// Create new blocks and convert their arguments.
|
||||
for (Block &block : *region) {
|
||||
auto *newBlock = new Block;
|
||||
newRegion->push_back(newBlock);
|
||||
rewriter.mapping.map(&block, newBlock);
|
||||
for (auto *arg : block.getArguments()) {
|
||||
auto convertedType = dialectConversion->convertType(arg->getType());
|
||||
if (!convertedType)
|
||||
return emitError("could not convert block argument type");
|
||||
newBlock->addArgument(convertedType);
|
||||
rewriter.mapping.map(arg, *newBlock->args_rbegin());
|
||||
}
|
||||
}
|
||||
// Create the arguments of each of the blocks in the region.
|
||||
for (Block &block : region)
|
||||
for (auto *arg : block.getArguments())
|
||||
if (failed(convertArgument(rewriter, arg, parent->getLoc())))
|
||||
return failure();
|
||||
|
||||
// Start a DFS-order traversal of the CFG to make sure defs are converted
|
||||
// before uses in dominated blocks.
|
||||
llvm::DenseSet<Block *> visitedBlocks;
|
||||
if (failed(convertBlock(®ion->front(), visitedBlocks)))
|
||||
return nullptr;
|
||||
if (failed(convertBlock(rewriter, ®ion.front(), visitedBlocks)))
|
||||
return failure();
|
||||
|
||||
// If some blocks are not reachable through successor chains, they should have
|
||||
// been removed by the DCE before this.
|
||||
if (visitedBlocks.size() != std::distance(region->begin(), region->end()))
|
||||
return emitError("unreachable blocks were not converted");
|
||||
return newRegion;
|
||||
if (visitedBlocks.size() != std::distance(region.begin(), region.end()))
|
||||
return parent->emitError("unreachable blocks were not converted");
|
||||
return success();
|
||||
}
|
||||
|
||||
Function *impl::FunctionConversion::convertFunction() {
|
||||
Function *f = rewriter.getFunction();
|
||||
MLIRContext *context = f->getContext();
|
||||
auto emitError = [context](llvm::Twine f) -> Function * {
|
||||
context->emitError(UnknownLoc::get(context), f.str());
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
// Create a new function with argument types and result types converted. Wrap
|
||||
// it into a unique_ptr to make sure it is cleaned up in case of error.
|
||||
Function *FunctionConverter::convertFunction(Function *f) {
|
||||
// Convert the function type using the dialect converter.
|
||||
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
|
||||
Type newFunctionType = dialectConversion->convertFunctionSignatureType(
|
||||
f->getType(), f->getAllArgAttrs(), newFunctionArgAttrs);
|
||||
if (!newFunctionType)
|
||||
return emitError("could not convert function type");
|
||||
auto newFunction = llvm::make_unique<Function>(
|
||||
f->getLoc(), f->getName().strref(), newFunctionType.cast<FunctionType>(),
|
||||
f->getAttrs(), newFunctionArgAttrs);
|
||||
return f->emitError("could not convert function type"), nullptr;
|
||||
|
||||
// Return early if the function is external.
|
||||
if (f->isExternal())
|
||||
return newFunction.release();
|
||||
// Create a new function using the mapped function type and arg attributes.
|
||||
auto *newFunc = new Function(f->getLoc(), f->getName().strref(),
|
||||
newFunctionType.cast<FunctionType>(),
|
||||
f->getAttrs(), newFunctionArgAttrs);
|
||||
f->getModule()->getFunctions().push_back(newFunc);
|
||||
|
||||
auto newBody = convertRegion(context, &f->getBody(), f);
|
||||
if (!newBody)
|
||||
return emitError("could not convert function body");
|
||||
newFunction->getBody().takeBody(*newBody);
|
||||
// If this is not an external function, we need to convert the body.
|
||||
if (!f->isExternal()) {
|
||||
DialectConversionRewriter rewriter(f);
|
||||
f->getBody().cloneInto(&newFunc->getBody(), rewriter.mapping,
|
||||
f->getContext());
|
||||
rewriter.mapping.clear();
|
||||
if (failed(convertRegion(rewriter, newFunc->getBody(), &*newFunc))) {
|
||||
f->getModule()->getFunctions().pop_back();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return newFunction.release();
|
||||
// Cleanup any temp producer operations that were generated by the rewriter.
|
||||
if (failed(rewriter.tempGenerator.cleanupGeneratedOps())) {
|
||||
f->getModule()->getFunctions().pop_back();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return newFunc;
|
||||
}
|
||||
|
||||
// Create a function type with arguments and results converted, and argument
|
||||
|
@ -297,38 +368,41 @@ LogicalResult DialectConversion::convert(Module *module) {
|
|||
initConverters(patterns, context);
|
||||
RewritePatternMatcher matcher(std::move(patterns));
|
||||
|
||||
// Convert the functions but don't add them to the module yet to avoid
|
||||
// converted functions to be converted again.
|
||||
SmallVector<Function *, 0> originalFuncs, convertedFuncs;
|
||||
DenseMap<Attribute, FunctionAttr> functionAttrRemapping;
|
||||
originalFuncs.reserve(module->getFunctions().size());
|
||||
for (auto &func : *module)
|
||||
originalFuncs.push_back(&func);
|
||||
convertedFuncs.reserve(module->getFunctions().size());
|
||||
for (auto *func : originalFuncs) {
|
||||
impl::FunctionConversion converter(this, func, matcher);
|
||||
Function *converted = converter.convertFunction();
|
||||
if (!converted)
|
||||
return failure();
|
||||
convertedFuncs.reserve(originalFuncs.size());
|
||||
|
||||
// Convert each function.
|
||||
FunctionConverter converter(context, this, matcher);
|
||||
for (auto *func : originalFuncs) {
|
||||
Function *converted = converter.convertFunction(func);
|
||||
if (!converted) {
|
||||
// Make sure to erase any previously converted functions.
|
||||
while (!convertedFuncs.empty())
|
||||
convertedFuncs.pop_back_val()->erase();
|
||||
return failure();
|
||||
}
|
||||
|
||||
convertedFuncs.push_back(converted);
|
||||
auto origFuncAttr = FunctionAttr::get(func);
|
||||
auto convertedFuncAttr = FunctionAttr::get(converted);
|
||||
convertedFuncs.push_back(converted);
|
||||
functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
|
||||
}
|
||||
|
||||
// Remap function attributes in the converted functions (they are not yet in
|
||||
// the module). Original functions will disappear anyway so there is no
|
||||
// need to remap attributes in them.
|
||||
// Remap function attributes in the converted functions. Original functions
|
||||
// will disappear anyway so there is no need to remap attributes in them.
|
||||
for (const auto &funcPair : functionAttrRemapping)
|
||||
remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping);
|
||||
|
||||
// Remove original functions from the module, then insert converted
|
||||
// functions. The order is important to avoid name collisions.
|
||||
for (auto &func : originalFuncs)
|
||||
func->erase();
|
||||
for (auto *func : convertedFuncs)
|
||||
module->getFunctions().push_back(func);
|
||||
// Remove the original functions from the module and update the names of the
|
||||
// converted functions.
|
||||
for (unsigned i = 0, e = originalFuncs.size(); i != e; ++i) {
|
||||
convertedFuncs[i]->takeName(*originalFuncs[i]);
|
||||
originalFuncs[i]->erase();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue