[TableGen] Use deduced result types for build() of suitable ops

For ops with the SameOperandsAndResultType trait, we know that all result types
should be the same as the first operand's type. So we can generate a build()
method without requiring result types as parameters and also invoke this method
when constructing such ops during expanding rewrite patterns.

Similarly for ops have broadcast behavior, we can define build() method to use
the deduced type as the result type. So we can also calling into this build()
method when constructing ops in RewriterGen.

PiperOrigin-RevId: 233988307
This commit is contained in:
Lei Zhang 2019-02-14 10:54:50 -08:00 committed by jpienaar
parent f2c93f0995
commit eb3f8dcb93
5 changed files with 75 additions and 32 deletions

View File

@ -39,13 +39,6 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
let opName = mnemonic;
let builder = [{
static void build(Builder *builder, OperationState *result, Value *lhs,
Value *rhs) {
impl::buildBinaryOp(builder, result, lhs, rhs);
}
}];
let parser = [{
return impl::parseBinaryOp(parser, result);
}];

View File

@ -106,6 +106,11 @@ public:
// Returns the total number of arguments.
int getNumArgs() const { return getNumOperands() + getNumNativeAttributes(); }
// Returns true if this op has the given MLIR C++ `trait`.
// TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
// requiring the raw MLIR trait here.
bool hasTrait(llvm::StringRef trait) const;
ArrayRef<llvm::SMLoc> getLoc() const;
// Query functions for the documentation of the operator.

View File

@ -93,6 +93,13 @@ StringRef tblgen::Operator::getArgName(int index) const {
return argumentValues->getArgName(index)->getValue();
}
bool tblgen::Operator::hasTrait(StringRef trait) const {
auto traits = def.getValueAsListOfStrings("traits");
if (std::find(traits.begin(), traits.end(), trait) != traits.end())
return true;
return false;
}
auto tblgen::Operator::attribute_begin() const -> attribute_iterator {
return attributes.begin();
}

View File

@ -125,6 +125,11 @@ private:
// Invokes the given function over all the namespaces of the class.
void mapOverClassNamespaces(function_ref<void(StringRef)> fn);
// Emits the build() method that takes each result-type/operand/attribute as
// a stand-alone parameter. Using the first operand's type as all result
// types if `hasResultType` is false.
void emitStandaloneParamBuilder(bool isAllSameType);
// The record corresponding to the op.
const Record &def;
@ -233,29 +238,13 @@ void OpEmitter::emitNamedOperands() {
}
}
void OpEmitter::emitBuilder() {
if (hasStringAttribute(def, "builder")) {
// If a custom builder is given then print that out instead.
auto builder = def.getValueAsString("builder");
if (!builder.empty())
os << builder << '\n';
}
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
// We generate two builders here, one having a stand-alone parameter for
// each result type / operand / attribute, the other having an aggregated
// parameter for all result types / operands / attributes, to facilitate
// different call patterns.
// 1. Stand-alone parameters
void OpEmitter::emitStandaloneParamBuilder(bool isAllSameType) {
OUT(2) << "static void build(Builder* builder, OperationState* result";
auto numResults = op.getNumResults();
// Emit parameters for all return types
if (isAllSameType)
for (unsigned i = 0, e = numResults; i != e; ++i)
os << ", Type returnType" << i;
@ -279,9 +268,17 @@ void OpEmitter::emitBuilder() {
// Push all result types to the result
if (numResults > 0) {
OUT(4) << "result->addTypes({returnType0";
OUT(4) << "result->addTypes({";
if (isAllSameType) {
os << "returnType0";
for (unsigned i = 1; i != numResults; ++i)
os << ", returnType" << i;
} else {
auto resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str();
os << resultType;
for (unsigned i = 1; i != numResults; ++i)
os << resultType;
}
os << "});\n\n";
}
@ -307,6 +304,37 @@ void OpEmitter::emitBuilder() {
OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
namedAttr.getName());
OUT(2) << "}\n";
}
void OpEmitter::emitBuilder() {
if (hasStringAttribute(def, "builder")) {
// If a custom builder is given then print that out.
auto builder = def.getValueAsString("builder");
if (!builder.empty())
os << builder << '\n';
}
auto numResults = op.getNumResults();
auto numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand();
int numNonVariadicOperands = numOperands - int(hasVariadicOperand);
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
// We generate three builders here:
// 1. one having a stand-alone parameter for each result type / operand /
// attribute, and
// 2. one having an aggregated parameter for all result types / operands /
// attributes, and
// 3. one having a stand-alone prameter for each operand and attribute,
// use the first operand's type as all result types
// to facilitate different call patterns.
// 1. Stand-alone parameters
emitStandaloneParamBuilder(/*isAllSameType=*/true);
// 2. Aggregated parameters
@ -336,6 +364,11 @@ void OpEmitter::emitBuilder() {
<< " result->addAttribute(pair.first, pair.second);\n"
<< " }\n";
}
// 3. Deduced result types
if (op.hasTrait("SameOperandsAndResultType"))
emitStandaloneParamBuilder(/*isAllSameType=*/false);
}
void OpEmitter::emitCanonicalizationPatterns() {

View File

@ -453,8 +453,13 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
// TODO: this is a hack to support various constant ops. We are assuming
// all of them have no operands and one attribute here. Figure out a better
// way to do this.
if (resultOp.getNumOperands() == 0 &&
resultOp.getNumNativeAttributes() == 1) {
bool isConstOp =
resultOp.getNumOperands() == 0 && resultOp.getNumNativeAttributes() == 1;
bool isSameValueType = resultOp.hasTrait("SameOperandsAndResultType");
bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
if (isConstOp || isSameValueType || isBroadcastable) {
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue,
resultOp.getQualCppClassName());
} else {