forked from OSchip/llvm-project
Add a trait to set the result type by attribute
Before this CL, the result type of the pattern match results need to be as same as the first operand type, operand broadcast type or a generic tensor type. This CL adds a new trait to set the result type by attribute. For example, the TFL_ConstOp can use this to set the output type to its value attribute. PiperOrigin-RevId: 240441249
This commit is contained in:
parent
9ffdc930c0
commit
c489f50e6f
|
@ -524,6 +524,11 @@ class NativeOpTrait<string prop> : OpTrait {
|
|||
string trait = prop;
|
||||
}
|
||||
|
||||
// Specify a trait to control op definition generator internals.
|
||||
class OpGenInternalTrait<string prop> : OpTrait {
|
||||
string trait = prop;
|
||||
}
|
||||
|
||||
// Specify a trait by way of a predicate on the operation.
|
||||
class PredOpTrait<string d, Pred p> : OpTrait {
|
||||
string desc = d;
|
||||
|
@ -545,6 +550,12 @@ def SameValueType : NativeOpTrait<"SameOperandsAndResultType">;
|
|||
// op is a terminator
|
||||
def Terminator : NativeOpTrait<"IsTerminator">;
|
||||
|
||||
// op result type is derived from the first attribute. If the attribute is an
|
||||
// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
|
||||
// attribute content is used.
|
||||
def FirstAttrDerivedResultType :
|
||||
OpGenInternalTrait<"FirstAttrDerivedResultType">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -53,9 +53,13 @@ public:
|
|||
explicit Attribute(const llvm::DefInit *init);
|
||||
|
||||
// Returns true if this attribute is a derived attribute (i.e., a subclass
|
||||
// of `DrivedAttr`).
|
||||
// of `DerivedAttr`).
|
||||
bool isDerivedAttr() const;
|
||||
|
||||
// Returns true if this attribute is a type attribute (i.e., a subclass
|
||||
// of `TypeAttrBase`).
|
||||
bool isTypeAttr() const;
|
||||
|
||||
// Returns true if this attribute has storage type set.
|
||||
bool hasStorageType() const;
|
||||
|
||||
|
|
|
@ -43,6 +43,8 @@ public:
|
|||
Native,
|
||||
// OpTrait corresponding to predicate on operation.
|
||||
Pred,
|
||||
// OpTrait controlling op definition generator internals.
|
||||
Internal
|
||||
};
|
||||
|
||||
explicit OpTrait(Kind kind, const llvm::Record *def);
|
||||
|
@ -79,6 +81,17 @@ public:
|
|||
static bool classof(const OpTrait *t) { return t->getKind() == Kind::Pred; }
|
||||
};
|
||||
|
||||
// OpTrait controlling op definition generator internals.
|
||||
class InternalOpTrait : public OpTrait {
|
||||
public:
|
||||
// Returns the trait controlling op definition generator internals.
|
||||
StringRef getTrait() const;
|
||||
|
||||
static bool classof(const OpTrait *t) {
|
||||
return t->getKind() == Kind::Internal;
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -55,6 +55,10 @@ bool tblgen::Attribute::isDerivedAttr() const {
|
|||
return def->isSubClassOf("DerivedAttr");
|
||||
}
|
||||
|
||||
bool tblgen::Attribute::isTypeAttr() const {
|
||||
return def->isSubClassOf("TypeAttrBase");
|
||||
}
|
||||
|
||||
bool tblgen::Attribute::hasStorageType() const {
|
||||
const auto *init = def->getValueInit("storageType");
|
||||
return !getValueAsString(init).empty();
|
||||
|
|
|
@ -32,6 +32,8 @@ mlir::tblgen::OpTrait mlir::tblgen::OpTrait::create(const llvm::Init *init) {
|
|||
auto def = cast<llvm::DefInit>(init)->getDef();
|
||||
if (def->isSubClassOf("PredOpTrait"))
|
||||
return OpTrait(Kind::Pred, def);
|
||||
if (def->isSubClassOf("OpGenInternalTrait"))
|
||||
return OpTrait(Kind::Internal, def);
|
||||
assert(def->isSubClassOf("NativeOpTrait"));
|
||||
return OpTrait(Kind::Native, def);
|
||||
}
|
||||
|
@ -43,6 +45,10 @@ llvm::StringRef mlir::tblgen::NativeOpTrait::getTrait() const {
|
|||
return def->getValueAsString("trait");
|
||||
}
|
||||
|
||||
llvm::StringRef mlir::tblgen::InternalOpTrait::getTrait() const {
|
||||
return def->getValueAsString("trait");
|
||||
}
|
||||
|
||||
std::string mlir::tblgen::PredOpTrait::getPredTemplate() const {
|
||||
auto pred = tblgen::Pred(def->getValueInit("pred"));
|
||||
return pred.getCondition();
|
||||
|
|
|
@ -101,9 +101,13 @@ StringRef tblgen::Operator::getArgName(int index) const {
|
|||
|
||||
bool tblgen::Operator::hasTrait(StringRef trait) const {
|
||||
for (auto t : getTraits()) {
|
||||
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t))
|
||||
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
|
||||
if (opTrait->getTrait() == trait)
|
||||
return true;
|
||||
} else if (auto opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) {
|
||||
if (opTrait->getTrait() == trait)
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -31,6 +31,25 @@ def ThreeResultOp : Op<"three_result_op", []> {
|
|||
// CHECK: void ThreeResultOp::build(Builder *builder, OperationState *result, Type x, Type resultType1, Type z)
|
||||
// CHECK: result->addTypes({x, resultType1, z});
|
||||
|
||||
def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
|
||||
def TypeAttrResultTypeOp : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
|
||||
let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32);
|
||||
let results = (outs Tensor:$y);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: TypeAttrResultTypeOp definitions
|
||||
// CHECK: void TypeAttrResultTypeOp::build(Builder *builder, OperationState *result, Value *x, TypeAttr attr, FloatAttr f32)
|
||||
// CHECK: result->addTypes({attr.getValue()});
|
||||
|
||||
def ValueAttrResultTypeOp : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
|
||||
let arguments = (ins I32:$x, F32Attr:$attr);
|
||||
let results = (outs Tensor:$y);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ValueAttrResultTypeOp definitions
|
||||
// CHECK: void ValueAttrResultTypeOp::build(Builder *builder, OperationState *result, Value *x, FloatAttr attr)
|
||||
// CHECK: result->addTypes({attr.getType()});
|
||||
|
||||
def VariadicResultOp : Op<"variadic_op", []> {
|
||||
let results = (outs I32:$x, Variadic<I32>:$y);
|
||||
}
|
||||
|
|
|
@ -379,8 +379,10 @@ private:
|
|||
|
||||
// Generates the build() method that takes each result-type/operand/attribute
|
||||
// as a stand-alone parameter. Using the first operand's type as all result
|
||||
// types if `isAllSameType` is true.
|
||||
void genStandaloneParamBuilder(bool isAllSameType);
|
||||
// types if `useOperandType` is true. Using the first attribute's type as all
|
||||
// result types if `useAttrType` true. Don't set `useOperandType` and
|
||||
// `useAttrType` at the same time.
|
||||
void genStandaloneParamBuilder(bool useOperandType, bool useAttrType);
|
||||
|
||||
void genOpNameGetter();
|
||||
|
||||
|
@ -501,7 +503,14 @@ void OpEmitter::genNamedResultGetters() {
|
|||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) {
|
||||
void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
|
||||
bool useAttrType) {
|
||||
if (useOperandType && useAttrType) {
|
||||
PrintFatalError(def.getLoc(),
|
||||
"Op definition has both 'SameOperandsAndResultType' and "
|
||||
"'FirstAttrIsResultType' trait specified.");
|
||||
}
|
||||
|
||||
auto numResults = op.getNumResults();
|
||||
llvm::SmallVector<std::string, 4> resultNames;
|
||||
resultNames.reserve(numResults);
|
||||
|
@ -509,7 +518,7 @@ void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) {
|
|||
std::string paramList = "Builder *builder, OperationState *result";
|
||||
|
||||
// Emit parameters for all return types
|
||||
if (!isAllSameType) {
|
||||
if (!useOperandType && !useAttrType) {
|
||||
for (unsigned i = 0; i != numResults; ++i) {
|
||||
std::string resultName = op.getResultName(i);
|
||||
if (resultName.empty())
|
||||
|
@ -556,7 +565,7 @@ void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) {
|
|||
|
||||
// Push all result types to the result
|
||||
if (numResults > 0) {
|
||||
if (!isAllSameType) {
|
||||
if (!useOperandType && !useAttrType) {
|
||||
bool hasVariadicResult = op.hasVariadicResult();
|
||||
int numNonVariadicResults =
|
||||
numResults - static_cast<int>(hasVariadicResult);
|
||||
|
@ -573,10 +582,20 @@ void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) {
|
|||
method.body() << " result->addTypes(" << resultNames.back() << ");\n";
|
||||
}
|
||||
} else {
|
||||
auto resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str();
|
||||
std::string resultType;
|
||||
if (useAttrType) {
|
||||
const auto &namedAttr = op.getAttribute(0);
|
||||
if (namedAttr.attr.isTypeAttr()) {
|
||||
resultType = formatv("{0}.getValue()", namedAttr.name);
|
||||
} else {
|
||||
resultType = formatv("{0}.getType()", namedAttr.name);
|
||||
}
|
||||
} else {
|
||||
resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str();
|
||||
}
|
||||
method.body() << " result->addTypes({" << resultType;
|
||||
for (unsigned i = 1; i != numResults; ++i)
|
||||
method.body() << resultType;
|
||||
method.body() << ", " << resultType;
|
||||
method.body() << "});\n\n";
|
||||
}
|
||||
}
|
||||
|
@ -657,7 +676,7 @@ void OpEmitter::genBuilder() {
|
|||
|
||||
// 1. Stand-alone parameters
|
||||
|
||||
genStandaloneParamBuilder(/*isAllSameType=*/false);
|
||||
genStandaloneParamBuilder(/*useOperandType=*/false, /*useAttrType=*/false);
|
||||
|
||||
// 2. Aggregated parameters
|
||||
|
||||
|
@ -695,8 +714,10 @@ void OpEmitter::genBuilder() {
|
|||
|
||||
// 3. Deduced result types
|
||||
|
||||
if (!op.hasVariadicResult() && op.hasTrait("SameOperandsAndResultType"))
|
||||
genStandaloneParamBuilder(/*isAllSameType=*/true);
|
||||
bool useOperandType = op.hasTrait("SameOperandsAndResultType");
|
||||
bool useAttrType = op.hasTrait("FirstAttrDerivedResultType");
|
||||
if (!op.hasVariadicResult() && (useOperandType || useAttrType))
|
||||
genStandaloneParamBuilder(useOperandType, useAttrType);
|
||||
}
|
||||
|
||||
void OpEmitter::genCanonicalizerDecls() {
|
||||
|
|
|
@ -550,8 +550,9 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
|
|||
|
||||
bool isSameValueType = resultOp.hasTrait("SameOperandsAndResultType");
|
||||
bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
|
||||
bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
|
||||
|
||||
if (isConstOp || isSameValueType || isBroadcastable) {
|
||||
if (isConstOp || isSameValueType || isBroadcastable || useFirstAttr) {
|
||||
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue,
|
||||
resultOp.getQualCppClassName());
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue