forked from OSchip/llvm-project
Use Status instead of bool in DialectConversion.
PiperOrigin-RevId: 237339277
This commit is contained in:
parent
f427bddd06
commit
10ddae6d88
|
@ -23,8 +23,8 @@
|
||||||
#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
|
#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "mlir/Support/Status.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
@ -152,10 +152,9 @@ class DialectConversion {
|
||||||
public:
|
public:
|
||||||
virtual ~DialectConversion() = default;
|
virtual ~DialectConversion() = default;
|
||||||
|
|
||||||
/// Run the converter on the provided module. This function returns
|
/// Run the converter on the provided module.
|
||||||
/// true if the module was unsuccessfully converted. Otherwise, it returns
|
LLVM_NODISCARD
|
||||||
/// false for success.
|
Status convert(Module *m);
|
||||||
bool convert(Module *m);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// Derived classes must implement this hook to produce a set of conversion
|
/// Derived classes must implement this hook to produce a set of conversion
|
||||||
|
@ -168,7 +167,7 @@ protected:
|
||||||
/// block or function argument types or function result types. If the target
|
/// block or function argument types or function result types. If the target
|
||||||
/// dialect has support for custom first-class function types, convertType
|
/// dialect has support for custom first-class function types, convertType
|
||||||
/// should create those types for arguments of MLIR function type. It can be
|
/// should create those types for arguments of MLIR function type. It can be
|
||||||
/// used for values (constant, operands, resutls) of function type but not for
|
/// used for values (constant, operands, results) of function type but not for
|
||||||
/// the function signatures. For the latter, convertFunctionSignatureType is
|
/// the function signatures. For the latter, convertFunctionSignatureType is
|
||||||
/// used instead.
|
/// used instead.
|
||||||
///
|
///
|
||||||
|
|
|
@ -1078,7 +1078,7 @@ public:
|
||||||
void runOnModule() override {
|
void runOnModule() override {
|
||||||
Module *m = &getModule();
|
Module *m = &getModule();
|
||||||
uniqueSuccessorsWithArguments(m);
|
uniqueSuccessorsWithArguments(m);
|
||||||
if (DialectConversion::convert(m))
|
if (failed(DialectConversion::convert(m)))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ public:
|
||||||
// conversion patterns and to convert function and block argument types.
|
// conversion patterns and to convert function and block argument types.
|
||||||
// Converts the `module` in-place by replacing all existing functions with the
|
// Converts the `module` in-place by replacing all existing functions with the
|
||||||
// converted ones.
|
// converted ones.
|
||||||
static bool convert(DialectConversion *conversion, Module *module);
|
static Status convert(DialectConversion *conversion, Module *module);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Constructs a FunctionConversion by storing the hooks.
|
// Constructs a FunctionConversion by storing the hooks.
|
||||||
|
@ -61,14 +61,14 @@ private:
|
||||||
// from `valueRemapping` and the converted blocks from `blockRemapping`, and
|
// from `valueRemapping` and the converted blocks from `blockRemapping`, and
|
||||||
// passes them to `converter->rewriteTerminator` function defined in the
|
// passes them to `converter->rewriteTerminator` function defined in the
|
||||||
// pattern, together with `builder`.
|
// pattern, together with `builder`.
|
||||||
bool convertOpWithSuccessors(DialectOpConversion *converter, Instruction *op,
|
Status convertOpWithSuccessors(DialectOpConversion *converter,
|
||||||
FuncBuilder &builder);
|
Instruction *op, FuncBuilder &builder);
|
||||||
|
|
||||||
// Converts an operation without successors. Extracts the converted operands
|
// Converts an operation without successors. Extracts the converted operands
|
||||||
// from `valueRemapping` and passes them to the `converter->rewrite` function
|
// from `valueRemapping` and passes them to the `converter->rewrite` function
|
||||||
// defined in the pattern, together with `builder`.
|
// defined in the pattern, together with `builder`.
|
||||||
bool convertOp(DialectOpConversion *converter, Instruction *op,
|
Status convertOp(DialectOpConversion *converter, Instruction *op,
|
||||||
FuncBuilder &builder);
|
FuncBuilder &builder);
|
||||||
|
|
||||||
// Converts a block by traversing its instructions sequentially, looking for
|
// Converts a block by traversing its instructions sequentially, looking for
|
||||||
// the first pattern match and dispatching the instruction conversion to
|
// the first pattern match and dispatching the instruction conversion to
|
||||||
|
@ -77,10 +77,8 @@ private:
|
||||||
//
|
//
|
||||||
// After converting operations, traverses the successor blocks unless they
|
// After converting operations, traverses the successor blocks unless they
|
||||||
// have been visited already as indicated in `visitedBlocks`.
|
// have been visited already as indicated in `visitedBlocks`.
|
||||||
//
|
Status convertBlock(Block *block, FuncBuilder &builder,
|
||||||
// Return `true` on error.
|
llvm::DenseSet<Block *> &visitedBlocks);
|
||||||
bool convertBlock(Block *block, FuncBuilder &builder,
|
|
||||||
llvm::DenseSet<Block *> &visitedBlocks);
|
|
||||||
|
|
||||||
// Converts the module as follows.
|
// Converts the module as follows.
|
||||||
// 1. Call `convertFunction` on each function of the module and collect the
|
// 1. Call `convertFunction` on each function of the module and collect the
|
||||||
|
@ -88,7 +86,7 @@ private:
|
||||||
// 2. Remap all function attributes in the new functions to point to the new
|
// 2. Remap all function attributes in the new functions to point to the new
|
||||||
// functions instead of the old ones.
|
// functions instead of the old ones.
|
||||||
// 3. Replace old functions with the new in the module.
|
// 3. Replace old functions with the new in the module.
|
||||||
bool run(Module *m);
|
Status run(Module *m);
|
||||||
|
|
||||||
// Pointer to a specific dialect pass.
|
// Pointer to a specific dialect pass.
|
||||||
DialectConversion *dialectConversion;
|
DialectConversion *dialectConversion;
|
||||||
|
@ -116,7 +114,7 @@ SmallVector<Value *, 4> impl::FunctionConversion::lookupValues(
|
||||||
return remapped;
|
return remapped;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool impl::FunctionConversion::convertOpWithSuccessors(
|
Status impl::FunctionConversion::convertOpWithSuccessors(
|
||||||
DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) {
|
DialectOpConversion *converter, Instruction *op, FuncBuilder &builder) {
|
||||||
SmallVector<Block *, 2> destinations;
|
SmallVector<Block *, 2> destinations;
|
||||||
destinations.reserve(op->getNumSuccessors());
|
destinations.reserve(op->getNumSuccessors());
|
||||||
|
@ -144,28 +142,29 @@ bool impl::FunctionConversion::convertOpWithSuccessors(
|
||||||
llvm::makeArrayRef(operands.data(),
|
llvm::makeArrayRef(operands.data(),
|
||||||
operands.data() + firstSuccessorOperand),
|
operands.data() + firstSuccessorOperand),
|
||||||
destinations, operandsPerDestination, builder);
|
destinations, operandsPerDestination, builder);
|
||||||
return false;
|
return Status::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool impl::FunctionConversion::convertOp(DialectOpConversion *converter,
|
Status impl::FunctionConversion::convertOp(DialectOpConversion *converter,
|
||||||
Instruction *op,
|
Instruction *op,
|
||||||
FuncBuilder &builder) {
|
FuncBuilder &builder) {
|
||||||
auto operands = lookupValues(op->getOperands());
|
auto operands = lookupValues(op->getOperands());
|
||||||
assert((!operands.empty() || op->getNumOperands() == 0) &&
|
assert((!operands.empty() || op->getNumOperands() == 0) &&
|
||||||
"converting op before ops defining its operands");
|
"converting op before ops defining its operands");
|
||||||
|
|
||||||
auto results = converter->rewrite(op, operands, builder);
|
auto results = converter->rewrite(op, operands, builder);
|
||||||
if (results.size() != op->getNumResults())
|
if (results.size() != op->getNumResults())
|
||||||
return op->emitError("rewriting produced a different number of results");
|
return (op->emitError("rewriting produced a different number of results"),
|
||||||
|
Status::failure());
|
||||||
|
|
||||||
for (unsigned i = 0, e = results.size(); i < e; ++i)
|
for (unsigned i = 0, e = results.size(); i < e; ++i)
|
||||||
mapping.map(op->getResult(i), results[i]);
|
mapping.map(op->getResult(i), results[i]);
|
||||||
return false;
|
return Status::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool impl::FunctionConversion::convertBlock(
|
Status
|
||||||
Block *block, FuncBuilder &builder,
|
impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
|
||||||
llvm::DenseSet<Block *> &visitedBlocks) {
|
llvm::DenseSet<Block *> &visitedBlocks) {
|
||||||
// First, add the current block to the list of visited blocks.
|
// First, add the current block to the list of visited blocks.
|
||||||
visitedBlocks.insert(block);
|
visitedBlocks.insert(block);
|
||||||
// Setup the builder to the insert to the converted block.
|
// Setup the builder to the insert to the converted block.
|
||||||
|
@ -175,7 +174,7 @@ bool impl::FunctionConversion::convertBlock(
|
||||||
for (Instruction &inst : *block) {
|
for (Instruction &inst : *block) {
|
||||||
if (inst.getNumBlockLists() != 0) {
|
if (inst.getNumBlockLists() != 0) {
|
||||||
inst.emitError("unsupported region instruction");
|
inst.emitError("unsupported region instruction");
|
||||||
return true;
|
return Status::failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the first matching conversion and apply it.
|
// Find the first matching conversion and apply it.
|
||||||
|
@ -185,10 +184,10 @@ bool impl::FunctionConversion::convertBlock(
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (inst.getNumSuccessors() != 0) {
|
if (inst.getNumSuccessors() != 0) {
|
||||||
if (convertOpWithSuccessors(conversion, &inst, builder))
|
if (failed(convertOpWithSuccessors(conversion, &inst, builder)))
|
||||||
return true;
|
return Status::failure();
|
||||||
} else if (convertOp(conversion, &inst, builder)) {
|
} else if (failed(convertOp(conversion, &inst, builder))) {
|
||||||
return true;
|
return Status::failure();
|
||||||
}
|
}
|
||||||
converted = true;
|
converted = true;
|
||||||
break;
|
break;
|
||||||
|
@ -202,10 +201,10 @@ bool impl::FunctionConversion::convertBlock(
|
||||||
for (Block *succ : block->getSuccessors()) {
|
for (Block *succ : block->getSuccessors()) {
|
||||||
if (visitedBlocks.count(succ) != 0)
|
if (visitedBlocks.count(succ) != 0)
|
||||||
continue;
|
continue;
|
||||||
if (convertBlock(succ, builder, visitedBlocks))
|
if (failed(convertBlock(succ, builder, visitedBlocks)))
|
||||||
return true;
|
return Status::failure();
|
||||||
}
|
}
|
||||||
return false;
|
return Status::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
Function *impl::FunctionConversion::convertFunction(Function *f) {
|
Function *impl::FunctionConversion::convertFunction(Function *f) {
|
||||||
|
@ -250,7 +249,7 @@ Function *impl::FunctionConversion::convertFunction(Function *f) {
|
||||||
// Start a DFS-order traversal of the CFG to make sure defs are converted
|
// Start a DFS-order traversal of the CFG to make sure defs are converted
|
||||||
// before uses in dominated blocks.
|
// before uses in dominated blocks.
|
||||||
llvm::DenseSet<Block *> visitedBlocks;
|
llvm::DenseSet<Block *> visitedBlocks;
|
||||||
if (convertBlock(&f->front(), builder, visitedBlocks))
|
if (failed(convertBlock(&f->front(), builder, visitedBlocks)))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
// If some blocks are not reachable through successor chains, they should have
|
// If some blocks are not reachable through successor chains, they should have
|
||||||
|
@ -261,14 +260,14 @@ Function *impl::FunctionConversion::convertFunction(Function *f) {
|
||||||
return newFunction.release();
|
return newFunction.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool impl::FunctionConversion::convert(DialectConversion *conversion,
|
Status impl::FunctionConversion::convert(DialectConversion *conversion,
|
||||||
Module *module) {
|
Module *module) {
|
||||||
return impl::FunctionConversion(conversion).run(module);
|
return impl::FunctionConversion(conversion).run(module);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool impl::FunctionConversion::run(Module *module) {
|
Status impl::FunctionConversion::run(Module *module) {
|
||||||
if (!module)
|
if (!module)
|
||||||
return true;
|
return Status::failure();
|
||||||
|
|
||||||
MLIRContext *context = module->getContext();
|
MLIRContext *context = module->getContext();
|
||||||
conversions = dialectConversion->initConverters(context);
|
conversions = dialectConversion->initConverters(context);
|
||||||
|
@ -284,7 +283,7 @@ bool impl::FunctionConversion::run(Module *module) {
|
||||||
for (auto *func : originalFuncs) {
|
for (auto *func : originalFuncs) {
|
||||||
Function *converted = convertFunction(func);
|
Function *converted = convertFunction(func);
|
||||||
if (!converted)
|
if (!converted)
|
||||||
return true;
|
return Status::failure();
|
||||||
|
|
||||||
auto origFuncAttr = FunctionAttr::get(func, context);
|
auto origFuncAttr = FunctionAttr::get(func, context);
|
||||||
auto convertedFuncAttr = FunctionAttr::get(converted, context);
|
auto convertedFuncAttr = FunctionAttr::get(converted, context);
|
||||||
|
@ -306,7 +305,7 @@ bool impl::FunctionConversion::run(Module *module) {
|
||||||
for (auto *func : convertedFuncs)
|
for (auto *func : convertedFuncs)
|
||||||
module->getFunctions().push_back(func);
|
module->getFunctions().push_back(func);
|
||||||
|
|
||||||
return false;
|
return Status::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a function type with arguments and results converted, and argument
|
// Create a function type with arguments and results converted, and argument
|
||||||
|
@ -329,6 +328,6 @@ DialectConversion::convertFunctionSignatureType(
|
||||||
FunctionType::get(arguments, results, type.getContext()), argAttrs.vec());
|
FunctionType::get(arguments, results, type.getContext()), argAttrs.vec());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool DialectConversion::convert(Module *m) {
|
Status DialectConversion::convert(Module *m) {
|
||||||
return impl::FunctionConversion::convert(this, m);
|
return impl::FunctionConversion::convert(this, m);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue