From 6241cf132e9fb775a3d5aa7b5406f0de9beb582c Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sun, 19 May 2019 17:56:32 -0700 Subject: [PATCH] Refactor the DialectConversion process to clone each function and then operate in-place, as opposed to incrementally constructing a new function. This is crucial to allowing the use of non type-conversion patterns(normal RewritePatterns) as part of the conversion process. The converter now works by inserting fake producer operations when replacing the results of an existing operation with values of a different, now legal, type. These fake operations are guaranteed to never escape the converter. -- PiperOrigin-RevId: 248969130 --- .../mlir/Transforms/DialectConversion.h | 4 +- mlir/lib/Linalg/IR/LinalgOps.cpp | 2 +- mlir/lib/Transforms/DialectConversion.cpp | 322 +++++++++++------- 3 files changed, 201 insertions(+), 127 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 0790e45bfc07..48376e5cb4a8 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -38,7 +38,7 @@ class Value; // Private implementation class. namespace impl { -class FunctionConversion; +class FunctionConverter; } /// Base class for the dialect op conversion patterns. Specific conversions @@ -126,7 +126,7 @@ private: /// /// If the conversion fails, the module is not modified. class DialectConversion { - friend class impl::FunctionConversion; + friend class impl::FunctionConverter; public: virtual ~DialectConversion() = default; diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index c0d1856d901f..6a8f289d379f 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -324,7 +324,7 @@ ViewOp mlir::linalg::SliceOp::getBaseViewOp() { } ViewType mlir::linalg::SliceOp::getBaseViewType() { - return getBaseViewOp().getType().cast(); + return getOperand(0)->getType().cast(); } SmallVector mlir::linalg::SliceOp::getRanges() { diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 2d718e264f30..3b163230716d 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -14,10 +14,6 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -// -// This file implements a generic pass for converting between MLIR dialects. -// -//===----------------------------------------------------------------------===// #include "mlir/Transforms/DialectConversion.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -27,20 +23,89 @@ #include "mlir/Transforms/Utils.h" using namespace mlir; +using namespace mlir::impl; namespace { +/// This class provides a simple interface for generating fake producers during +/// the conversion process. These fake producers are used when replacing the +/// results of an operation with values of a new, legal, type. The producer +/// provides a definition for the remaining uses of the old value while they +/// await conversion. +struct ProducerGenerator { + ProducerGenerator(MLIRContext *ctx) + : producerOpName(kProducerName, ctx), loc(UnknownLoc::get(ctx)) {} + + /// Cleanup any generated conversion values. Returns failure if there are any + /// dangling references to a producer operation, success otherwise. + LogicalResult cleanupGeneratedOps() { + for (auto *op : producerOps) { + if (!op->use_empty()) { + auto diag = op->getContext()->emitError(loc) + << "Converter did not convert all uses of replaced value " + "with illegal type"; + for (auto *user : op->getResult(0)->getUsers()) + diag.attachNote(user->getLoc()) + << "user was not converted : " << *user; + return diag; + } + op->destroy(); + } + return success(); + } + + /// Generate a producer value for 'oldValue'. These new producers replace all + /// of the current uses of the original value, and record a mapping between + /// for replacement with the 'newValue'. + void generateAndReplace(Value *oldValue, Value *newValue, + BlockAndValueMapping &mapping) { + if (oldValue->use_empty()) + return; + + // Otherwise, generate a new producer operation for the given value type. + auto *producer = Operation::create( + loc, producerOpName, llvm::None, oldValue->getType(), llvm::None, + llvm::None, 0, false, oldValue->getContext()); + + // Replace the uses of the old value and record the mapping. + oldValue->replaceAllUsesWith(producer->getResult(0)); + mapping.map(producer->getResult(0), newValue); + producerOps.push_back(producer); + } + + /// This is an operation name for a fake operation that is inserted during the + /// conversion process. Operations of this type are guaranteed to never escape + /// the converter. + static constexpr StringLiteral kProducerName = "__mlir_conversion.producer"; + OperationName producerOpName; + + /// This is a collection of producer values that were generated during the + /// conversion process. + std::vector producerOps; + + /// An instance of the unknown location that is used when generating + /// producers. + UnknownLoc loc; +}; + /// This class implements a pattern rewriter for DialectOpConversion patterns. /// It automatically performs remapping of replaced operation values. struct DialectConversionRewriter final : public PatternRewriter { - DialectConversionRewriter(Function *fn) : PatternRewriter(fn) {} + DialectConversionRewriter(Function *fn) + : PatternRewriter(fn), tempGenerator(fn->getContext()) {} ~DialectConversionRewriter() = default; // Implement the hook for replacing an operation with new values. void replaceOp(Operation *op, ArrayRef newValues, ArrayRef valuesToRemoveIfDead) override { assert(newValues.size() == op->getNumResults()); - for (unsigned i = 0, e = newValues.size(); i < e; ++i) - mapping.map(op->getResult(i), newValues[i]); + for (unsigned i = 0, e = newValues.size(); i < e; ++i) { + Value *result = op->getResult(i); + if (result->getType() != newValues[i]->getType()) + tempGenerator.generateAndReplace(result, newValues[i], mapping); + else + result->replaceAllUsesWith(newValues[i]); + } + op->erase(); } // Implement the hook for creating operations, and make sure that newly @@ -52,16 +117,16 @@ struct DialectConversionRewriter final : public PatternRewriter { void lookupValues(Operation::operand_range operands, SmallVectorImpl &remapped) { remapped.reserve(llvm::size(operands)); - for (Value *operand : operands) { - Value *value = mapping.lookupOrNull(operand); - assert(value && "converting op before ops defining its operands"); - remapped.push_back(value); - } + for (Value *operand : operands) + remapped.push_back(mapping.lookupOrDefault(operand)); } // Mapping between values(blocks) in the original function and in the new // function. BlockAndValueMapping mapping; + + /// Utility used to create temporary producers operations. + ProducerGenerator tempGenerator; }; } // end anonymous namespace @@ -87,10 +152,7 @@ void DialectOpConversion::rewrite(Operation *op, SmallVector, 2> operandsPerDestination; unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0); for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) { - // Lookup the successor. - auto *successor = dialectRewriter.mapping.lookupOrNull(op->getSuccessor(i)); - assert(successor && "block was not remapped"); - destinations.push_back(successor); + destinations.push_back(op->getSuccessor(i)); // Lookup the successors operands. unsigned n = op->getNumSuccessorOperands(i); @@ -108,151 +170,160 @@ void DialectOpConversion::rewrite(Operation *op, namespace mlir { namespace impl { -// Implementation detail class of the DialectConversion pass. Performs +// Implementation detail class of the DialectConversion utility. Performs // function-by-function conversions by creating new functions, filling them in // with converted blocks, updating the function attributes, and replacing the // old functions with the new ones in the module. -class FunctionConversion { +class FunctionConverter { public: - // Constructs a FunctionConversion by storing the hooks. - explicit FunctionConversion(DialectConversion *conversion, Function *func, - RewritePatternMatcher &matcher) - : dialectConversion(conversion), rewriter(func), matcher(matcher) {} + // Constructs a FunctionConverter. + explicit FunctionConverter(MLIRContext *ctx, DialectConversion *conversion, + RewritePatternMatcher &matcher) + : dialectConversion(conversion), matcher(matcher) {} - // Converts the current function to the dialect using hooks defined in + // Converts the given function to the dialect using hooks defined in // `dialectConversion`. Returns the converted function or `nullptr` on error. - Function *convertFunction(); + Function *convertFunction(Function *f); // Converts the given region starting from the entry block and following the - // block successors. Returns the converted region or `nullptr` on error. + // block successors. Returns failure on error, success otherwise. template - std::unique_ptr convertRegion(MLIRContext *context, Region *region, - RegionParent *parent); + LogicalResult convertRegion(DialectConversionRewriter &rewriter, + Region ®ion, RegionParent *parent); - // Converts a block by traversing its operations sequentially, looking for - // the first pattern match and dispatching the operation conversion to - // either `convertOp` or `convertOpWithSuccessors` depending on the presence - // of successors. If there is no match, clones the operation. + // Converts a block by traversing its operations sequentially, attempting to + // match a pattern. If there is no match, recurses the operations regions if + // it has any. // // After converting operations, traverses the successor blocks unless they // have been visited already as indicated in `visitedBlocks`. - LogicalResult convertBlock(Block *block, - llvm::DenseSet &visitedBlocks); + LogicalResult convertBlock(DialectConversionRewriter &rewriter, Block *block, + DenseSet &visitedBlocks); - // Pointer to a specific dialect pass. + // Converts the type of the given block argument. Returns success if the + // argument type could be successfully converted, failure otherwise. + LogicalResult convertArgument(DialectConversionRewriter &rewriter, + BlockArgument *arg, Location loc); + + // Pointer to a specific dialect conversion info. DialectConversion *dialectConversion; - /// The writer used when rewriting operations. - DialectConversionRewriter rewriter; - - /// The matcher use when converting operations. + /// The matcher to use when converting operations. RewritePatternMatcher &matcher; }; } // end namespace impl } // end namespace mlir LogicalResult -impl::FunctionConversion::convertBlock(Block *block, - llvm::DenseSet &visitedBlocks) { +FunctionConverter::convertArgument(DialectConversionRewriter &rewriter, + BlockArgument *arg, Location loc) { + auto convertedType = dialectConversion->convertType(arg->getType()); + if (!convertedType) + return arg->getContext()->emitError(loc) + << "could not convert block argument of type : " << arg->getType(); + + // Generate a replacement value, with the new type, for this argument. + if (convertedType != arg->getType()) { + rewriter.tempGenerator.generateAndReplace(arg, arg, rewriter.mapping); + arg->setType(convertedType); + } + return success(); +} + +LogicalResult +FunctionConverter::convertBlock(DialectConversionRewriter &rewriter, + Block *block, + DenseSet &visitedBlocks) { // First, add the current block to the list of visited blocks. visitedBlocks.insert(block); - // Setup the builder to the insert to the converted block. - rewriter.setInsertionPointToStart(rewriter.mapping.lookupOrNull(block)); + + // Preserve the successors before rewriting the operations. + SmallVector successors(block->getSuccessors()); // Iterate over ops and convert them. - for (Operation &op : *block) { + for (Operation &op : llvm::make_early_inc_range(*block)) { + rewriter.setInsertionPoint(&op); if (matcher.matchAndRewrite(&op, rewriter)) continue; - // If there is no conversion provided for the op, clone the op and convert - // its regions, if any. - auto *newOp = rewriter.cloneWithoutRegions(op, rewriter.mapping); - for (int i = 0, e = op.getNumRegions(); i < e; ++i) { - auto newRegion = convertRegion(op.getContext(), &op.getRegion(i), &op); - newOp->getRegion(i).takeBody(*newRegion); - } + // If a rewrite wasn't matched, update any mapped operands in place. + for (auto &operand : op.getOpOperands()) + if (auto *newOperand = rewriter.mapping.lookupOrNull(operand.get())) + operand.set(newOperand); + + // Traverse any held regions. + for (auto ®ion : op.getRegions()) + if (!region.empty() && failed(convertRegion(rewriter, region, &op))) + return failure(); } - // Recurse to children unless they have been already visited. - for (Block *succ : block->getSuccessors()) { - if (visitedBlocks.count(succ) != 0) + // Recurse to children that haven't been visited. + for (Block *succ : successors) { + if (visitedBlocks.count(succ)) continue; - if (failed(convertBlock(succ, visitedBlocks))) + if (failed(convertBlock(rewriter, succ, visitedBlocks))) return failure(); } return success(); } template -std::unique_ptr -impl::FunctionConversion::convertRegion(MLIRContext *context, Region *region, - RegionParent *parent) { - assert(region && "expected a region"); - auto newRegion = llvm::make_unique(parent); - if (region->empty()) - return newRegion; +LogicalResult +FunctionConverter::convertRegion(DialectConversionRewriter &rewriter, + Region ®ion, RegionParent *parent) { + assert(!region.empty() && "expected non-empty region"); - auto emitError = [context](llvm::Twine f) -> std::unique_ptr { - context->emitError(UnknownLoc::get(context), f.str()); - return nullptr; - }; - - // Create new blocks and convert their arguments. - for (Block &block : *region) { - auto *newBlock = new Block; - newRegion->push_back(newBlock); - rewriter.mapping.map(&block, newBlock); - for (auto *arg : block.getArguments()) { - auto convertedType = dialectConversion->convertType(arg->getType()); - if (!convertedType) - return emitError("could not convert block argument type"); - newBlock->addArgument(convertedType); - rewriter.mapping.map(arg, *newBlock->args_rbegin()); - } - } + // Create the arguments of each of the blocks in the region. + for (Block &block : region) + for (auto *arg : block.getArguments()) + if (failed(convertArgument(rewriter, arg, parent->getLoc()))) + return failure(); // Start a DFS-order traversal of the CFG to make sure defs are converted // before uses in dominated blocks. llvm::DenseSet visitedBlocks; - if (failed(convertBlock(®ion->front(), visitedBlocks))) - return nullptr; + if (failed(convertBlock(rewriter, ®ion.front(), visitedBlocks))) + return failure(); // If some blocks are not reachable through successor chains, they should have // been removed by the DCE before this. - if (visitedBlocks.size() != std::distance(region->begin(), region->end())) - return emitError("unreachable blocks were not converted"); - return newRegion; + if (visitedBlocks.size() != std::distance(region.begin(), region.end())) + return parent->emitError("unreachable blocks were not converted"); + return success(); } -Function *impl::FunctionConversion::convertFunction() { - Function *f = rewriter.getFunction(); - MLIRContext *context = f->getContext(); - auto emitError = [context](llvm::Twine f) -> Function * { - context->emitError(UnknownLoc::get(context), f.str()); - return nullptr; - }; - - // Create a new function with argument types and result types converted. Wrap - // it into a unique_ptr to make sure it is cleaned up in case of error. +Function *FunctionConverter::convertFunction(Function *f) { + // Convert the function type using the dialect converter. SmallVector newFunctionArgAttrs; Type newFunctionType = dialectConversion->convertFunctionSignatureType( f->getType(), f->getAllArgAttrs(), newFunctionArgAttrs); if (!newFunctionType) - return emitError("could not convert function type"); - auto newFunction = llvm::make_unique( - f->getLoc(), f->getName().strref(), newFunctionType.cast(), - f->getAttrs(), newFunctionArgAttrs); + return f->emitError("could not convert function type"), nullptr; - // Return early if the function is external. - if (f->isExternal()) - return newFunction.release(); + // Create a new function using the mapped function type and arg attributes. + auto *newFunc = new Function(f->getLoc(), f->getName().strref(), + newFunctionType.cast(), + f->getAttrs(), newFunctionArgAttrs); + f->getModule()->getFunctions().push_back(newFunc); - auto newBody = convertRegion(context, &f->getBody(), f); - if (!newBody) - return emitError("could not convert function body"); - newFunction->getBody().takeBody(*newBody); + // If this is not an external function, we need to convert the body. + if (!f->isExternal()) { + DialectConversionRewriter rewriter(f); + f->getBody().cloneInto(&newFunc->getBody(), rewriter.mapping, + f->getContext()); + rewriter.mapping.clear(); + if (failed(convertRegion(rewriter, newFunc->getBody(), &*newFunc))) { + f->getModule()->getFunctions().pop_back(); + return nullptr; + } - return newFunction.release(); + // Cleanup any temp producer operations that were generated by the rewriter. + if (failed(rewriter.tempGenerator.cleanupGeneratedOps())) { + f->getModule()->getFunctions().pop_back(); + return nullptr; + } + } + return newFunc; } // Create a function type with arguments and results converted, and argument @@ -297,38 +368,41 @@ LogicalResult DialectConversion::convert(Module *module) { initConverters(patterns, context); RewritePatternMatcher matcher(std::move(patterns)); - // Convert the functions but don't add them to the module yet to avoid - // converted functions to be converted again. SmallVector originalFuncs, convertedFuncs; DenseMap functionAttrRemapping; originalFuncs.reserve(module->getFunctions().size()); for (auto &func : *module) originalFuncs.push_back(&func); - convertedFuncs.reserve(module->getFunctions().size()); - for (auto *func : originalFuncs) { - impl::FunctionConversion converter(this, func, matcher); - Function *converted = converter.convertFunction(); - if (!converted) - return failure(); + convertedFuncs.reserve(originalFuncs.size()); + // Convert each function. + FunctionConverter converter(context, this, matcher); + for (auto *func : originalFuncs) { + Function *converted = converter.convertFunction(func); + if (!converted) { + // Make sure to erase any previously converted functions. + while (!convertedFuncs.empty()) + convertedFuncs.pop_back_val()->erase(); + return failure(); + } + + convertedFuncs.push_back(converted); auto origFuncAttr = FunctionAttr::get(func); auto convertedFuncAttr = FunctionAttr::get(converted); - convertedFuncs.push_back(converted); functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr}); } - // Remap function attributes in the converted functions (they are not yet in - // the module). Original functions will disappear anyway so there is no - // need to remap attributes in them. + // Remap function attributes in the converted functions. Original functions + // will disappear anyway so there is no need to remap attributes in them. for (const auto &funcPair : functionAttrRemapping) remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping); - // Remove original functions from the module, then insert converted - // functions. The order is important to avoid name collisions. - for (auto &func : originalFuncs) - func->erase(); - for (auto *func : convertedFuncs) - module->getFunctions().push_back(func); + // Remove the original functions from the module and update the names of the + // converted functions. + for (unsigned i = 0, e = originalFuncs.size(); i != e; ++i) { + convertedFuncs[i]->takeName(*originalFuncs[i]); + originalFuncs[i]->erase(); + } return success(); }