forked from OSchip/llvm-project
[TableGen] Use deduced result types for build() of suitable ops
For ops with the SameOperandsAndResultType trait, we know that all result types should be the same as the first operand's type. So we can generate a build() method without requiring result types as parameters and also invoke this method when constructing such ops during expanding rewrite patterns. Similarly for ops have broadcast behavior, we can define build() method to use the deduced type as the result type. So we can also calling into this build() method when constructing ops in RewriterGen. PiperOrigin-RevId: 233988307
This commit is contained in:
parent
f2c93f0995
commit
eb3f8dcb93
|
@ -39,13 +39,6 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
|
||||
let opName = mnemonic;
|
||||
|
||||
let builder = [{
|
||||
static void build(Builder *builder, OperationState *result, Value *lhs,
|
||||
Value *rhs) {
|
||||
impl::buildBinaryOp(builder, result, lhs, rhs);
|
||||
}
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
return impl::parseBinaryOp(parser, result);
|
||||
}];
|
||||
|
|
|
@ -106,6 +106,11 @@ public:
|
|||
// Returns the total number of arguments.
|
||||
int getNumArgs() const { return getNumOperands() + getNumNativeAttributes(); }
|
||||
|
||||
// Returns true if this op has the given MLIR C++ `trait`.
|
||||
// TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
|
||||
// requiring the raw MLIR trait here.
|
||||
bool hasTrait(llvm::StringRef trait) const;
|
||||
|
||||
ArrayRef<llvm::SMLoc> getLoc() const;
|
||||
|
||||
// Query functions for the documentation of the operator.
|
||||
|
|
|
@ -93,6 +93,13 @@ StringRef tblgen::Operator::getArgName(int index) const {
|
|||
return argumentValues->getArgName(index)->getValue();
|
||||
}
|
||||
|
||||
bool tblgen::Operator::hasTrait(StringRef trait) const {
|
||||
auto traits = def.getValueAsListOfStrings("traits");
|
||||
if (std::find(traits.begin(), traits.end(), trait) != traits.end())
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tblgen::Operator::attribute_begin() const -> attribute_iterator {
|
||||
return attributes.begin();
|
||||
}
|
||||
|
|
|
@ -125,6 +125,11 @@ private:
|
|||
// Invokes the given function over all the namespaces of the class.
|
||||
void mapOverClassNamespaces(function_ref<void(StringRef)> fn);
|
||||
|
||||
// Emits the build() method that takes each result-type/operand/attribute as
|
||||
// a stand-alone parameter. Using the first operand's type as all result
|
||||
// types if `hasResultType` is false.
|
||||
void emitStandaloneParamBuilder(bool isAllSameType);
|
||||
|
||||
// The record corresponding to the op.
|
||||
const Record &def;
|
||||
|
||||
|
@ -233,31 +238,15 @@ void OpEmitter::emitNamedOperands() {
|
|||
}
|
||||
}
|
||||
|
||||
void OpEmitter::emitBuilder() {
|
||||
if (hasStringAttribute(def, "builder")) {
|
||||
// If a custom builder is given then print that out instead.
|
||||
auto builder = def.getValueAsString("builder");
|
||||
if (!builder.empty())
|
||||
os << builder << '\n';
|
||||
}
|
||||
|
||||
// Generate default builders that requires all result type, operands, and
|
||||
// attributes as parameters.
|
||||
|
||||
// We generate two builders here, one having a stand-alone parameter for
|
||||
// each result type / operand / attribute, the other having an aggregated
|
||||
// parameter for all result types / operands / attributes, to facilitate
|
||||
// different call patterns.
|
||||
|
||||
// 1. Stand-alone parameters
|
||||
|
||||
void OpEmitter::emitStandaloneParamBuilder(bool isAllSameType) {
|
||||
OUT(2) << "static void build(Builder* builder, OperationState* result";
|
||||
|
||||
auto numResults = op.getNumResults();
|
||||
|
||||
// Emit parameters for all return types
|
||||
for (unsigned i = 0, e = numResults; i != e; ++i)
|
||||
os << ", Type returnType" << i;
|
||||
if (isAllSameType)
|
||||
for (unsigned i = 0, e = numResults; i != e; ++i)
|
||||
os << ", Type returnType" << i;
|
||||
|
||||
// Emit parameters for all operands
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
|
@ -279,9 +268,17 @@ void OpEmitter::emitBuilder() {
|
|||
|
||||
// Push all result types to the result
|
||||
if (numResults > 0) {
|
||||
OUT(4) << "result->addTypes({returnType0";
|
||||
for (unsigned i = 1; i != numResults; ++i)
|
||||
os << ", returnType" << i;
|
||||
OUT(4) << "result->addTypes({";
|
||||
if (isAllSameType) {
|
||||
os << "returnType0";
|
||||
for (unsigned i = 1; i != numResults; ++i)
|
||||
os << ", returnType" << i;
|
||||
} else {
|
||||
auto resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str();
|
||||
os << resultType;
|
||||
for (unsigned i = 1; i != numResults; ++i)
|
||||
os << resultType;
|
||||
}
|
||||
os << "});\n\n";
|
||||
}
|
||||
|
||||
|
@ -307,6 +304,37 @@ void OpEmitter::emitBuilder() {
|
|||
OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
|
||||
namedAttr.getName());
|
||||
OUT(2) << "}\n";
|
||||
}
|
||||
|
||||
void OpEmitter::emitBuilder() {
|
||||
if (hasStringAttribute(def, "builder")) {
|
||||
// If a custom builder is given then print that out.
|
||||
auto builder = def.getValueAsString("builder");
|
||||
if (!builder.empty())
|
||||
os << builder << '\n';
|
||||
}
|
||||
|
||||
auto numResults = op.getNumResults();
|
||||
|
||||
auto numOperands = op.getNumOperands();
|
||||
bool hasVariadicOperand = op.hasVariadicOperand();
|
||||
int numNonVariadicOperands = numOperands - int(hasVariadicOperand);
|
||||
|
||||
// Generate default builders that requires all result type, operands, and
|
||||
// attributes as parameters.
|
||||
|
||||
// We generate three builders here:
|
||||
// 1. one having a stand-alone parameter for each result type / operand /
|
||||
// attribute, and
|
||||
// 2. one having an aggregated parameter for all result types / operands /
|
||||
// attributes, and
|
||||
// 3. one having a stand-alone prameter for each operand and attribute,
|
||||
// use the first operand's type as all result types
|
||||
// to facilitate different call patterns.
|
||||
|
||||
// 1. Stand-alone parameters
|
||||
|
||||
emitStandaloneParamBuilder(/*isAllSameType=*/true);
|
||||
|
||||
// 2. Aggregated parameters
|
||||
|
||||
|
@ -336,6 +364,11 @@ void OpEmitter::emitBuilder() {
|
|||
<< " result->addAttribute(pair.first, pair.second);\n"
|
||||
<< " }\n";
|
||||
}
|
||||
|
||||
// 3. Deduced result types
|
||||
|
||||
if (op.hasTrait("SameOperandsAndResultType"))
|
||||
emitStandaloneParamBuilder(/*isAllSameType=*/false);
|
||||
}
|
||||
|
||||
void OpEmitter::emitCanonicalizationPatterns() {
|
||||
|
|
|
@ -453,8 +453,13 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
|
|||
// TODO: this is a hack to support various constant ops. We are assuming
|
||||
// all of them have no operands and one attribute here. Figure out a better
|
||||
// way to do this.
|
||||
if (resultOp.getNumOperands() == 0 &&
|
||||
resultOp.getNumNativeAttributes() == 1) {
|
||||
bool isConstOp =
|
||||
resultOp.getNumOperands() == 0 && resultOp.getNumNativeAttributes() == 1;
|
||||
|
||||
bool isSameValueType = resultOp.hasTrait("SameOperandsAndResultType");
|
||||
bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
|
||||
|
||||
if (isConstOp || isSameValueType || isBroadcastable) {
|
||||
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue,
|
||||
resultOp.getQualCppClassName());
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue