Add a trait to set the result type by attribute

Before this CL, the result type of the pattern match results need to be as same
as the first operand type, operand broadcast type or a generic tensor type.
This CL adds a new trait to set the result type by attribute. For example, the
TFL_ConstOp can use this to set the output type to its value attribute.

PiperOrigin-RevId: 240441249
This commit is contained in:
Feng Liu 2019-03-26 15:31:15 -07:00 committed by jpienaar
parent 9ffdc930c0
commit c489f50e6f
9 changed files with 96 additions and 13 deletions

View File

@ -524,6 +524,11 @@ class NativeOpTrait<string prop> : OpTrait {
string trait = prop;
}
// Specify a trait to control op definition generator internals.
class OpGenInternalTrait<string prop> : OpTrait {
string trait = prop;
}
// Specify a trait by way of a predicate on the operation.
class PredOpTrait<string d, Pred p> : OpTrait {
string desc = d;
@ -545,6 +550,12 @@ def SameValueType : NativeOpTrait<"SameOperandsAndResultType">;
// op is a terminator
def Terminator : NativeOpTrait<"IsTerminator">;
// op result type is derived from the first attribute. If the attribute is an
// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
// attribute content is used.
def FirstAttrDerivedResultType :
OpGenInternalTrait<"FirstAttrDerivedResultType">;
//===----------------------------------------------------------------------===//
// Op definitions
//===----------------------------------------------------------------------===//

View File

@ -53,9 +53,13 @@ public:
explicit Attribute(const llvm::DefInit *init);
// Returns true if this attribute is a derived attribute (i.e., a subclass
// of `DrivedAttr`).
// of `DerivedAttr`).
bool isDerivedAttr() const;
// Returns true if this attribute is a type attribute (i.e., a subclass
// of `TypeAttrBase`).
bool isTypeAttr() const;
// Returns true if this attribute has storage type set.
bool hasStorageType() const;

View File

@ -43,6 +43,8 @@ public:
Native,
// OpTrait corresponding to predicate on operation.
Pred,
// OpTrait controlling op definition generator internals.
Internal
};
explicit OpTrait(Kind kind, const llvm::Record *def);
@ -79,6 +81,17 @@ public:
static bool classof(const OpTrait *t) { return t->getKind() == Kind::Pred; }
};
// OpTrait controlling op definition generator internals.
class InternalOpTrait : public OpTrait {
public:
// Returns the trait controlling op definition generator internals.
StringRef getTrait() const;
static bool classof(const OpTrait *t) {
return t->getKind() == Kind::Internal;
}
};
} // end namespace tblgen
} // end namespace mlir

View File

@ -55,6 +55,10 @@ bool tblgen::Attribute::isDerivedAttr() const {
return def->isSubClassOf("DerivedAttr");
}
bool tblgen::Attribute::isTypeAttr() const {
return def->isSubClassOf("TypeAttrBase");
}
bool tblgen::Attribute::hasStorageType() const {
const auto *init = def->getValueInit("storageType");
return !getValueAsString(init).empty();

View File

@ -32,6 +32,8 @@ mlir::tblgen::OpTrait mlir::tblgen::OpTrait::create(const llvm::Init *init) {
auto def = cast<llvm::DefInit>(init)->getDef();
if (def->isSubClassOf("PredOpTrait"))
return OpTrait(Kind::Pred, def);
if (def->isSubClassOf("OpGenInternalTrait"))
return OpTrait(Kind::Internal, def);
assert(def->isSubClassOf("NativeOpTrait"));
return OpTrait(Kind::Native, def);
}
@ -43,6 +45,10 @@ llvm::StringRef mlir::tblgen::NativeOpTrait::getTrait() const {
return def->getValueAsString("trait");
}
llvm::StringRef mlir::tblgen::InternalOpTrait::getTrait() const {
return def->getValueAsString("trait");
}
std::string mlir::tblgen::PredOpTrait::getPredTemplate() const {
auto pred = tblgen::Pred(def->getValueInit("pred"));
return pred.getCondition();

View File

@ -101,9 +101,13 @@ StringRef tblgen::Operator::getArgName(int index) const {
bool tblgen::Operator::hasTrait(StringRef trait) const {
for (auto t : getTraits()) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t))
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
if (opTrait->getTrait() == trait)
return true;
} else if (auto opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) {
if (opTrait->getTrait() == trait)
return true;
}
}
return false;
}

View File

@ -31,6 +31,25 @@ def ThreeResultOp : Op<"three_result_op", []> {
// CHECK: void ThreeResultOp::build(Builder *builder, OperationState *result, Type x, Type resultType1, Type z)
// CHECK: result->addTypes({x, resultType1, z});
def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
def TypeAttrResultTypeOp : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32);
let results = (outs Tensor:$y);
}
// CHECK-LABEL: TypeAttrResultTypeOp definitions
// CHECK: void TypeAttrResultTypeOp::build(Builder *builder, OperationState *result, Value *x, TypeAttr attr, FloatAttr f32)
// CHECK: result->addTypes({attr.getValue()});
def ValueAttrResultTypeOp : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, F32Attr:$attr);
let results = (outs Tensor:$y);
}
// CHECK-LABEL: ValueAttrResultTypeOp definitions
// CHECK: void ValueAttrResultTypeOp::build(Builder *builder, OperationState *result, Value *x, FloatAttr attr)
// CHECK: result->addTypes({attr.getType()});
def VariadicResultOp : Op<"variadic_op", []> {
let results = (outs I32:$x, Variadic<I32>:$y);
}

View File

@ -379,8 +379,10 @@ private:
// Generates the build() method that takes each result-type/operand/attribute
// as a stand-alone parameter. Using the first operand's type as all result
// types if `isAllSameType` is true.
void genStandaloneParamBuilder(bool isAllSameType);
// types if `useOperandType` is true. Using the first attribute's type as all
// result types if `useAttrType` true. Don't set `useOperandType` and
// `useAttrType` at the same time.
void genStandaloneParamBuilder(bool useOperandType, bool useAttrType);
void genOpNameGetter();
@ -501,7 +503,14 @@ void OpEmitter::genNamedResultGetters() {
}
}
void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) {
void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
bool useAttrType) {
if (useOperandType && useAttrType) {
PrintFatalError(def.getLoc(),
"Op definition has both 'SameOperandsAndResultType' and "
"'FirstAttrIsResultType' trait specified.");
}
auto numResults = op.getNumResults();
llvm::SmallVector<std::string, 4> resultNames;
resultNames.reserve(numResults);
@ -509,7 +518,7 @@ void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) {
std::string paramList = "Builder *builder, OperationState *result";
// Emit parameters for all return types
if (!isAllSameType) {
if (!useOperandType && !useAttrType) {
for (unsigned i = 0; i != numResults; ++i) {
std::string resultName = op.getResultName(i);
if (resultName.empty())
@ -556,7 +565,7 @@ void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) {
// Push all result types to the result
if (numResults > 0) {
if (!isAllSameType) {
if (!useOperandType && !useAttrType) {
bool hasVariadicResult = op.hasVariadicResult();
int numNonVariadicResults =
numResults - static_cast<int>(hasVariadicResult);
@ -573,10 +582,20 @@ void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) {
method.body() << " result->addTypes(" << resultNames.back() << ");\n";
}
} else {
auto resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str();
std::string resultType;
if (useAttrType) {
const auto &namedAttr = op.getAttribute(0);
if (namedAttr.attr.isTypeAttr()) {
resultType = formatv("{0}.getValue()", namedAttr.name);
} else {
resultType = formatv("{0}.getType()", namedAttr.name);
}
} else {
resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str();
}
method.body() << " result->addTypes({" << resultType;
for (unsigned i = 1; i != numResults; ++i)
method.body() << resultType;
method.body() << ", " << resultType;
method.body() << "});\n\n";
}
}
@ -657,7 +676,7 @@ void OpEmitter::genBuilder() {
// 1. Stand-alone parameters
genStandaloneParamBuilder(/*isAllSameType=*/false);
genStandaloneParamBuilder(/*useOperandType=*/false, /*useAttrType=*/false);
// 2. Aggregated parameters
@ -695,8 +714,10 @@ void OpEmitter::genBuilder() {
// 3. Deduced result types
if (!op.hasVariadicResult() && op.hasTrait("SameOperandsAndResultType"))
genStandaloneParamBuilder(/*isAllSameType=*/true);
bool useOperandType = op.hasTrait("SameOperandsAndResultType");
bool useAttrType = op.hasTrait("FirstAttrDerivedResultType");
if (!op.hasVariadicResult() && (useOperandType || useAttrType))
genStandaloneParamBuilder(useOperandType, useAttrType);
}
void OpEmitter::genCanonicalizerDecls() {

View File

@ -550,8 +550,9 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
bool isSameValueType = resultOp.hasTrait("SameOperandsAndResultType");
bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
if (isConstOp || isSameValueType || isBroadcastable) {
if (isConstOp || isSameValueType || isBroadcastable || useFirstAttr) {
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue,
resultOp.getQualCppClassName());
} else {