Use Status instead of bool in DialectConversion.

PiperOrigin-RevId: 237339277
This commit is contained in:
River Riddle 2019-03-07 15:33:48 -08:00 committed by jpienaar
parent f427bddd06
commit 10ddae6d88
3 changed files with 41 additions and 43 deletions

View File

@ -23,8 +23,8 @@
#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_ #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/Support/Status.h"
namespace mlir { namespace mlir {
@ -152,10 +152,9 @@ class DialectConversion {
public: public:
virtual ~DialectConversion() = default; virtual ~DialectConversion() = default;
/// Run the converter on the provided module. This function returns /// Run the converter on the provided module.
/// true if the module was unsuccessfully converted. Otherwise, it returns LLVM_NODISCARD
/// false for success. Status convert(Module *m);
bool convert(Module *m);
protected: protected:
/// Derived classes must implement this hook to produce a set of conversion /// Derived classes must implement this hook to produce a set of conversion
@ -168,7 +167,7 @@ protected:
/// block or function argument types or function result types. If the target /// block or function argument types or function result types. If the target
/// dialect has support for custom first-class function types, convertType /// dialect has support for custom first-class function types, convertType
/// should create those types for arguments of MLIR function type. It can be /// should create those types for arguments of MLIR function type. It can be
/// used for values (constant, operands, resutls) of function type but not for /// used for values (constant, operands, results) of function type but not for
/// the function signatures. For the latter, convertFunctionSignatureType is /// the function signatures. For the latter, convertFunctionSignatureType is
/// used instead. /// used instead.
/// ///

View File

@ -1078,7 +1078,7 @@ public:
void runOnModule() override { void runOnModule() override {
Module *m = &getModule(); Module *m = &getModule();
uniqueSuccessorsWithArguments(m); uniqueSuccessorsWithArguments(m);
if (DialectConversion::convert(m)) if (failed(DialectConversion::convert(m)))
signalPassFailure(); signalPassFailure();
} }

View File

@ -40,7 +40,7 @@ public:
// conversion patterns and to convert function and block argument types. // conversion patterns and to convert function and block argument types.
// Converts the `module` in-place by replacing all existing functions with the // Converts the `module` in-place by replacing all existing functions with the
// converted ones. // converted ones.
static bool convert(DialectConversion *conversion, Module *module); static Status convert(DialectConversion *conversion, Module *module);
private: private:
// Constructs a FunctionConversion by storing the hooks. // Constructs a FunctionConversion by storing the hooks.
@ -61,14 +61,14 @@ private:
// from `valueRemapping` and the converted blocks from `blockRemapping`, and // from `valueRemapping` and the converted blocks from `blockRemapping`, and
// passes them to `converter->rewriteTerminator` function defined in the // passes them to `converter->rewriteTerminator` function defined in the
// pattern, together with `builder`. // pattern, together with `builder`.
bool convertOpWithSuccessors(DialectOpConversion *converter, Instruction *op, Status convertOpWithSuccessors(DialectOpConversion *converter,
FuncBuilder &builder); Instruction *op, FuncBuilder &builder);
// Converts an operation without successors. Extracts the converted operands // Converts an operation without successors. Extracts the converted operands
// from `valueRemapping` and passes them to the `converter->rewrite` function // from `valueRemapping` and passes them to the `converter->rewrite` function
// defined in the pattern, together with `builder`. // defined in the pattern, together with `builder`.
bool convertOp(DialectOpConversion *converter, Instruction *op, Status convertOp(DialectOpConversion *converter, Instruction *op,
FuncBuilder &builder); FuncBuilder &builder);
// Converts a block by traversing its instructions sequentially, looking for // Converts a block by traversing its instructions sequentially, looking for
// the first pattern match and dispatching the instruction conversion to // the first pattern match and dispatching the instruction conversion to
@ -77,10 +77,8 @@ private:
// //
// After converting operations, traverses the successor blocks unless they // After converting operations, traverses the successor blocks unless they
// have been visited already as indicated in `visitedBlocks`. // have been visited already as indicated in `visitedBlocks`.
// Status convertBlock(Block *block, FuncBuilder &builder,
// Return `true` on error. llvm::DenseSet<Block *> &visitedBlocks);
bool convertBlock(Block *block, FuncBuilder &builder,
llvm::DenseSet<Block *> &visitedBlocks);
// Converts the module as follows. // Converts the module as follows.
// 1. Call `convertFunction` on each function of the module and collect the // 1. Call `convertFunction` on each function of the module and collect the
@ -88,7 +86,7 @@ private:
// 2. Remap all function attributes in the new functions to point to the new // 2. Remap all function attributes in the new functions to point to the new
// functions instead of the old ones. // functions instead of the old ones.
// 3. Replace old functions with the new in the module. // 3. Replace old functions with the new in the module.
bool run(Module *m); Status run(Module *m);
// Pointer to a specific dialect pass. // Pointer to a specific dialect pass.
DialectConversion *dialectConversion; DialectConversion *dialectConversion;
@ -116,7 +114,7 @@ SmallVector<Value *, 4> impl::FunctionConversion::lookupValues(
return remapped; return remapped;
} }
bool impl::FunctionConversion::convertOpWithSuccessors( Status impl::FunctionConversion::convertOpWithSuccessors(
DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) { DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) {
SmallVector<Block *, 2> destinations; SmallVector<Block *, 2> destinations;
destinations.reserve(op->getNumSuccessors()); destinations.reserve(op->getNumSuccessors());
@ -144,28 +142,29 @@ bool impl::FunctionConversion::convertOpWithSuccessors(
llvm::makeArrayRef(operands.data(), llvm::makeArrayRef(operands.data(),
operands.data() + firstSuccessorOperand), operands.data() + firstSuccessorOperand),
destinations, operandsPerDestination, builder); destinations, operandsPerDestination, builder);
return false; return Status::success();
} }
bool impl::FunctionConversion::convertOp(DialectOpConversion *converter, Status impl::FunctionConversion::convertOp(DialectOpConversion *converter,
Instruction *op, Instruction *op,
FuncBuilder &builder) { FuncBuilder &builder) {
auto operands = lookupValues(op->getOperands()); auto operands = lookupValues(op->getOperands());
assert((!operands.empty() || op->getNumOperands() == 0) && assert((!operands.empty() || op->getNumOperands() == 0) &&
"converting op before ops defining its operands"); "converting op before ops defining its operands");
auto results = converter->rewrite(op, operands, builder); auto results = converter->rewrite(op, operands, builder);
if (results.size() != op->getNumResults()) if (results.size() != op->getNumResults())
return op->emitError("rewriting produced a different number of results"); return (op->emitError("rewriting produced a different number of results"),
Status::failure());
for (unsigned i = 0, e = results.size(); i < e; ++i) for (unsigned i = 0, e = results.size(); i < e; ++i)
mapping.map(op->getResult(i), results[i]); mapping.map(op->getResult(i), results[i]);
return false; return Status::success();
} }
bool impl::FunctionConversion::convertBlock( Status
Block *block, FuncBuilder &builder, impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
llvm::DenseSet<Block *> &visitedBlocks) { llvm::DenseSet<Block *> &visitedBlocks) {
// First, add the current block to the list of visited blocks. // First, add the current block to the list of visited blocks.
visitedBlocks.insert(block); visitedBlocks.insert(block);
// Setup the builder to the insert to the converted block. // Setup the builder to the insert to the converted block.
@ -175,7 +174,7 @@ bool impl::FunctionConversion::convertBlock(
for (Instruction &inst : *block) { for (Instruction &inst : *block) {
if (inst.getNumBlockLists() != 0) { if (inst.getNumBlockLists() != 0) {
inst.emitError("unsupported region instruction"); inst.emitError("unsupported region instruction");
return true; return Status::failure();
} }
// Find the first matching conversion and apply it. // Find the first matching conversion and apply it.
@ -185,10 +184,10 @@ bool impl::FunctionConversion::convertBlock(
continue; continue;
if (inst.getNumSuccessors() != 0) { if (inst.getNumSuccessors() != 0) {
if (convertOpWithSuccessors(conversion, &inst, builder)) if (failed(convertOpWithSuccessors(conversion, &inst, builder)))
return true; return Status::failure();
} else if (convertOp(conversion, &inst, builder)) { } else if (failed(convertOp(conversion, &inst, builder))) {
return true; return Status::failure();
} }
converted = true; converted = true;
break; break;
@ -202,10 +201,10 @@ bool impl::FunctionConversion::convertBlock(
for (Block *succ : block->getSuccessors()) { for (Block *succ : block->getSuccessors()) {
if (visitedBlocks.count(succ) != 0) if (visitedBlocks.count(succ) != 0)
continue; continue;
if (convertBlock(succ, builder, visitedBlocks)) if (failed(convertBlock(succ, builder, visitedBlocks)))
return true; return Status::failure();
} }
return false; return Status::success();
} }
Function *impl::FunctionConversion::convertFunction(Function *f) { Function *impl::FunctionConversion::convertFunction(Function *f) {
@ -250,7 +249,7 @@ Function *impl::FunctionConversion::convertFunction(Function *f) {
// Start a DFS-order traversal of the CFG to make sure defs are converted // Start a DFS-order traversal of the CFG to make sure defs are converted
// before uses in dominated blocks. // before uses in dominated blocks.
llvm::DenseSet<Block *> visitedBlocks; llvm::DenseSet<Block *> visitedBlocks;
if (convertBlock(&f->front(), builder, visitedBlocks)) if (failed(convertBlock(&f->front(), builder, visitedBlocks)))
return nullptr; return nullptr;
// If some blocks are not reachable through successor chains, they should have // If some blocks are not reachable through successor chains, they should have
@ -261,14 +260,14 @@ Function *impl::FunctionConversion::convertFunction(Function *f) {
return newFunction.release(); return newFunction.release();
} }
bool impl::FunctionConversion::convert(DialectConversion *conversion, Status impl::FunctionConversion::convert(DialectConversion *conversion,
Module *module) { Module *module) {
return impl::FunctionConversion(conversion).run(module); return impl::FunctionConversion(conversion).run(module);
} }
bool impl::FunctionConversion::run(Module *module) { Status impl::FunctionConversion::run(Module *module) {
if (!module) if (!module)
return true; return Status::failure();
MLIRContext *context = module->getContext(); MLIRContext *context = module->getContext();
conversions = dialectConversion->initConverters(context); conversions = dialectConversion->initConverters(context);
@ -284,7 +283,7 @@ bool impl::FunctionConversion::run(Module *module) {
for (auto *func : originalFuncs) { for (auto *func : originalFuncs) {
Function *converted = convertFunction(func); Function *converted = convertFunction(func);
if (!converted) if (!converted)
return true; return Status::failure();
auto origFuncAttr = FunctionAttr::get(func, context); auto origFuncAttr = FunctionAttr::get(func, context);
auto convertedFuncAttr = FunctionAttr::get(converted, context); auto convertedFuncAttr = FunctionAttr::get(converted, context);
@ -306,7 +305,7 @@ bool impl::FunctionConversion::run(Module *module) {
for (auto *func : convertedFuncs) for (auto *func : convertedFuncs)
module->getFunctions().push_back(func); module->getFunctions().push_back(func);
return false; return Status::success();
} }
// Create a function type with arguments and results converted, and argument // Create a function type with arguments and results converted, and argument
@ -329,6 +328,6 @@ DialectConversion::convertFunctionSignatureType(
FunctionType::get(arguments, results, type.getContext()), argAttrs.vec()); FunctionType::get(arguments, results, type.getContext()), argAttrs.vec());
} }
bool DialectConversion::convert(Module *m) { Status DialectConversion::convert(Module *m) {
return impl::FunctionConversion::convert(this, m); return impl::FunctionConversion::convert(this, m);
} }