forked from OSchip/llvm-project
ConvertToCFG: properly remap nested function attributes.
Array attributes can nested and function attributes can appear anywhere at that level. They should be remapped to point to the generated CFGFunction after ML-to-CFG conversion, similarly to plain function attributes. Extract the nested attribute remapping functionality from the Parser to Utils. Extract out the remapping function for individual Functions from the module remapping function. Use these new functions in the ML-to-CFG conversion pass and in the parser. PiperOrigin-RevId: 221510997
This commit is contained in:
parent
cb40633969
commit
d030433443
|
@ -20,11 +20,12 @@
|
|||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
class Function;
|
||||
class FunctionAttr;
|
||||
class FunctionType;
|
||||
class IntegerSet;
|
||||
class MLIRContext;
|
||||
|
@ -110,6 +111,14 @@ public:
|
|||
/// Return true if this field is, or contains, a function attribute.
|
||||
bool isOrContainsFunction() const;
|
||||
|
||||
/// Replace a function attribute or function attributes nested in an array
|
||||
/// attribute with another function attribute as defined by the provided
|
||||
/// remapping table. Return the original attribute if it (or any of nested
|
||||
/// attributes) is not present in the table.
|
||||
Attribute remapFunctionAttrs(
|
||||
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable,
|
||||
MLIRContext *context) const;
|
||||
|
||||
/// Print the attribute.
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
|
|
@ -28,13 +28,16 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class CFGFunction;
|
||||
class ForStmt;
|
||||
class FuncBuilder;
|
||||
class Location;
|
||||
class MLValue;
|
||||
class Module;
|
||||
class OperationStmt;
|
||||
class SSAValue;
|
||||
|
||||
|
@ -101,6 +104,21 @@ void forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp);
|
|||
/// Returns false if the folding happens for at least one bound, true otherwise.
|
||||
bool constantFoldBounds(ForStmt *forStmt);
|
||||
|
||||
/// Replaces (potentially nested) function attributes in the operation "op"
|
||||
/// with those specified in "remappingTable".
|
||||
void remapFunctionAttrs(
|
||||
Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable);
|
||||
|
||||
/// Replaces (potentially nested) function attributes all operations of the
|
||||
/// Function "fn" with those specified in "remappingTable".
|
||||
void remapFunctionAttrs(
|
||||
Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable);
|
||||
|
||||
/// Replaces (potentially nested) function attributes in the entire module
|
||||
/// with those specified in "remappingTable". Ignores external functions.
|
||||
void remapFunctionAttrs(
|
||||
Module &module, const DenseMap<Attribute, FunctionAttr> &remappingTable);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_UTILS_H
|
||||
|
|
|
@ -31,6 +31,39 @@ bool Attribute::isOrContainsFunction() const {
|
|||
return attr->isOrContainsFunctionCache;
|
||||
}
|
||||
|
||||
// Given an attribute that could refer to a function attribute in the remapping
|
||||
// table, walk it and rewrite it to use the mapped function. If it doesn't
|
||||
// refer to anything in the table, then it is returned unmodified.
|
||||
Attribute Attribute::remapFunctionAttrs(
|
||||
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable,
|
||||
MLIRContext *context) const {
|
||||
// Most attributes are trivially unrelated to function attributes, skip them
|
||||
// rapidly.
|
||||
if (!isOrContainsFunction())
|
||||
return *this;
|
||||
|
||||
// If we have a function attribute, remap it.
|
||||
if (auto fnAttr = this->dyn_cast<FunctionAttr>()) {
|
||||
auto it = remappingTable.find(fnAttr);
|
||||
return it != remappingTable.end() ? it->second : *this;
|
||||
}
|
||||
|
||||
// Otherwise, we must have an array attribute, remap the elements.
|
||||
auto arrayAttr = this->cast<ArrayAttr>();
|
||||
SmallVector<Attribute, 8> remappedElts;
|
||||
bool anyChange = false;
|
||||
for (auto elt : arrayAttr.getValue()) {
|
||||
auto newElt = elt.remapFunctionAttrs(remappingTable, context);
|
||||
remappedElts.push_back(newElt);
|
||||
anyChange |= (elt != newElt);
|
||||
}
|
||||
|
||||
if (!anyChange)
|
||||
return *this;
|
||||
|
||||
return ArrayAttr::get(remappedElts, context);
|
||||
}
|
||||
|
||||
BoolAttr::BoolAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
|
@ -3411,59 +3412,6 @@ ParseResult ModuleParser::parseMLFunc() {
|
|||
return parser.parseFunctionBody();
|
||||
}
|
||||
|
||||
/// Given an attribute that could refer to a function attribute in the
|
||||
/// remapping table, walk it and rewrite it to use the mapped function. If it
|
||||
/// doesn't refer to anything in the table, then it is returned unmodified.
|
||||
static Attribute
|
||||
remapFunctionAttrs(Attribute input,
|
||||
DenseMap<Attribute, FunctionAttr> &remappingTable,
|
||||
MLIRContext *context) {
|
||||
// Most attributes are trivially unrelated to function attributes, skip them
|
||||
// rapidly.
|
||||
if (!input.isOrContainsFunction())
|
||||
return input;
|
||||
|
||||
// If we have a function attribute, remap it.
|
||||
if (auto fnAttr = input.dyn_cast<FunctionAttr>()) {
|
||||
auto it = remappingTable.find(fnAttr);
|
||||
return it != remappingTable.end() ? it->second : input;
|
||||
}
|
||||
|
||||
// Otherwise, we must have an array attribute, remap the elements.
|
||||
auto arrayAttr = input.cast<ArrayAttr>();
|
||||
SmallVector<Attribute, 8> remappedElts;
|
||||
bool anyChange = false;
|
||||
for (auto elt : arrayAttr.getValue()) {
|
||||
auto newElt = remapFunctionAttrs(elt, remappingTable, context);
|
||||
remappedElts.push_back(newElt);
|
||||
anyChange |= (elt != newElt);
|
||||
}
|
||||
|
||||
if (!anyChange)
|
||||
return input;
|
||||
|
||||
return ArrayAttr::get(remappedElts, context);
|
||||
}
|
||||
|
||||
/// Remap function attributes to resolve forward references to their actual
|
||||
/// definition.
|
||||
static void remapFunctionAttrsInOperation(
|
||||
Operation *op, DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
for (auto attr : op->getAttrs()) {
|
||||
// Do the remapping, if we got the same thing back, then it must contain
|
||||
// functions that aren't getting remapped.
|
||||
auto newVal =
|
||||
remapFunctionAttrs(attr.second, remappingTable, op->getContext());
|
||||
if (newVal == attr.second)
|
||||
continue;
|
||||
|
||||
// Otherwise, replace the existing attribute with the new one. It is safe
|
||||
// to mutate the attribute list while we walk it because underlying
|
||||
// attribute lists are uniqued and immortal.
|
||||
op->setAttr(attr.first, newVal);
|
||||
}
|
||||
}
|
||||
|
||||
/// Finish the end of module parsing - when the result is valid, do final
|
||||
/// checking.
|
||||
ParseResult ModuleParser::finalizeModule() {
|
||||
|
@ -3491,32 +3439,7 @@ ParseResult ModuleParser::finalizeModule() {
|
|||
|
||||
// Otherwise, walk the entire module replacing uses of one attribute set
|
||||
// with the correct ones.
|
||||
for (auto &fn : *getModule()) {
|
||||
if (auto *cfgFn = dyn_cast<CFGFunction>(&fn)) {
|
||||
for (auto &bb : *cfgFn) {
|
||||
for (auto &inst : bb) {
|
||||
remapFunctionAttrsInOperation(&inst, remappingTable);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, look at MLFunctions. We ignore ExtFunctions.
|
||||
auto *mlFn = dyn_cast<MLFunction>(&fn);
|
||||
if (!mlFn)
|
||||
continue;
|
||||
|
||||
struct MLFnWalker : public StmtWalker<MLFnWalker> {
|
||||
MLFnWalker(DenseMap<Attribute, FunctionAttr> &remappingTable)
|
||||
: remappingTable(remappingTable) {}
|
||||
void visitOperationStmt(OperationStmt *opStmt) {
|
||||
remapFunctionAttrsInOperation(opStmt, remappingTable);
|
||||
}
|
||||
|
||||
DenseMap<Attribute, FunctionAttr> &remappingTable;
|
||||
};
|
||||
|
||||
MLFnWalker(remappingTable).walk(mlFn);
|
||||
}
|
||||
remapFunctionAttrs(*getModule(), remappingTable);
|
||||
|
||||
// Now that all references to the forward definition placeholders are
|
||||
// resolved, we can deallocate the placeholders.
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
using namespace mlir;
|
||||
|
@ -380,21 +381,29 @@ CFGFunction *ModuleConverter::convert(MLFunction *mlFunc) {
|
|||
// removed anyway. However, it is necessary to replace the references in the
|
||||
// converted CFGFunctions that have not been added to the module yet.
|
||||
void ModuleConverter::replaceReferences() {
|
||||
for (Function &fn : *module) {
|
||||
switch (fn.getKind()) {
|
||||
case Function::Kind::CFGFunc:
|
||||
replaceReferences(&cast<CFGFunction>(fn));
|
||||
break;
|
||||
case Function::Kind::MLFunc:
|
||||
// ML functions must have been converted already and will be removed.
|
||||
break;
|
||||
case Function::Kind::ExtFunc:
|
||||
// nothing to do for external functions
|
||||
break;
|
||||
}
|
||||
// Build the remapping between function attributes pointing to ML functions
|
||||
// and the newly created function attributes pointing to the converted CFG
|
||||
// functions.
|
||||
llvm::DenseMap<Attribute, FunctionAttr> remappingTable;
|
||||
for (const Function &fn : *module) {
|
||||
const auto *mlFunc = dyn_cast<MLFunction>(&fn);
|
||||
if (!mlFunc)
|
||||
continue;
|
||||
CFGFunction *convertedFunc = generatedFuncs.lookup(mlFunc);
|
||||
assert(convertedFunc && "ML function was not converted");
|
||||
|
||||
MLIRContext *context = module->getContext();
|
||||
auto mlFuncAttr = FunctionAttr::get(mlFunc, context);
|
||||
auto cfgFuncAttr = FunctionAttr::get(convertedFunc, module->getContext());
|
||||
remappingTable.insert({mlFuncAttr, cfgFuncAttr});
|
||||
}
|
||||
|
||||
// Remap in existing functions.
|
||||
remapFunctionAttrs(*module, remappingTable);
|
||||
|
||||
// Remap in generated functions.
|
||||
for (auto pair : generatedFuncs) {
|
||||
replaceReferences(pair.second);
|
||||
remapFunctionAttrs(*pair.second, remappingTable);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -413,27 +422,6 @@ static inline void replaceMLFunctionAttr(
|
|||
op.setAttr(name, b.getFunctionAttr(cfgFunc));
|
||||
}
|
||||
|
||||
// Replace references to MLFunctions with the references to the converted
|
||||
// CFGFunctions. References to MLFunctions can potentially appear in any
|
||||
// function attribute (in particular, they are known to appear in the "callee"
|
||||
// attribute of a direct call and the "value" attribute of a constant). Replace
|
||||
// the values of these attributes to point to the converted functions.
|
||||
void ModuleConverter::replaceReferences(CFGFunction *func) {
|
||||
for (auto &bb : *func) {
|
||||
for (auto &inst : bb) {
|
||||
for (auto &attr : inst.getAttrs()) {
|
||||
// TODO(zinenko): handle nested attributes, e.g. array attributes
|
||||
// containing functions.
|
||||
auto functionAttr = attr.second.dyn_cast<FunctionAttr>();
|
||||
if (!functionAttr)
|
||||
continue;
|
||||
replaceMLFunctionAttr(inst, attr.first, functionAttr.getValue(),
|
||||
generatedFuncs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The CFG and ML functions have the same name. First, erase the MLFunction.
|
||||
// Then insert the CFGFunction at the same place.
|
||||
void ModuleConverter::replaceFunctions() {
|
||||
|
|
|
@ -25,6 +25,8 @@
|
|||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -394,3 +396,57 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
|
|||
ret &= foldLowerOrUpperBound(/*lower=*/false);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void mlir::remapFunctionAttrs(
|
||||
Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
for (auto attr : op.getAttrs()) {
|
||||
// Do the remapping, if we got the same thing back, then it must contain
|
||||
// functions that aren't getting remapped.
|
||||
auto newVal =
|
||||
attr.second.remapFunctionAttrs(remappingTable, op.getContext());
|
||||
if (newVal == attr.second)
|
||||
continue;
|
||||
|
||||
// Otherwise, replace the existing attribute with the new one. It is safe
|
||||
// to mutate the attribute list while we walk it because underlying
|
||||
// attribute lists are uniqued and immortal.
|
||||
op.setAttr(attr.first, newVal);
|
||||
}
|
||||
}
|
||||
|
||||
void mlir::remapFunctionAttrs(
|
||||
Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
// Look at all instructions in a CFGFunction.
|
||||
if (auto *cfgFn = dyn_cast<CFGFunction>(&fn)) {
|
||||
for (auto &bb : *cfgFn) {
|
||||
for (auto &inst : bb) {
|
||||
remapFunctionAttrs(inst, remappingTable);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, look at MLFunctions. We ignore ExtFunctions.
|
||||
auto *mlFn = dyn_cast<MLFunction>(&fn);
|
||||
if (!mlFn)
|
||||
return;
|
||||
|
||||
struct MLFnWalker : public StmtWalker<MLFnWalker> {
|
||||
MLFnWalker(const DenseMap<Attribute, FunctionAttr> &remappingTable)
|
||||
: remappingTable(remappingTable) {}
|
||||
void visitOperationStmt(OperationStmt *opStmt) {
|
||||
remapFunctionAttrs(*opStmt, remappingTable);
|
||||
}
|
||||
|
||||
const DenseMap<Attribute, FunctionAttr> &remappingTable;
|
||||
};
|
||||
|
||||
MLFnWalker(remappingTable).walk(mlFn);
|
||||
}
|
||||
|
||||
void mlir::remapFunctionAttrs(
|
||||
Module &module, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
for (auto &fn : module) {
|
||||
remapFunctionAttrs(fn, remappingTable);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,6 +53,18 @@ bb0:
|
|||
return
|
||||
}
|
||||
|
||||
cfgfunc @nested_attributes() {
|
||||
bb0:
|
||||
%0 = constant 0 : index
|
||||
// CHECK: call @body(%c0) {attr1: [@simple_loop : () -> (), @simple_loop : () -> ()]} : (index) -> ()
|
||||
call @body(%0) {attr1: [@simple_loop : () -> (), @simple_loop : () -> ()]} : (index) -> ()
|
||||
// Note: the {{\[}} construct is necessary to prevent FileCheck from
|
||||
// interpreting [[ as the start of its variable in the pattern below.
|
||||
// CHECK: call @body(%c0) {attr2: {{\[}}{{\[}}{{\[}}@simple_loop : () -> ()]], [@simple_loop : () -> ()]]} : (index) -> ()
|
||||
call @body(%0) {attr2: [[[@simple_loop : () -> ()]], [@simple_loop : () -> ()]]} : (index) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cfgfunc @ml_caller() {
|
||||
mlfunc @ml_caller() {
|
||||
// Direct calls inside ML functions are renamed if asked (given that the
|
||||
|
|
Loading…
Reference in New Issue