forked from OSchip/llvm-project
Switch rewriters for relu, relu6, placeholder_input, softmax to patterns.
Add a constant F32 attribute for use with softmax legalization. PiperOrigin-RevId: 227241643
This commit is contained in:
parent
8ef2552df7
commit
bbe3f4d9f5
|
@ -108,6 +108,13 @@ class DerivedAttr<code ReturnType, code Body> : Attr<DerivedAttrBody> {
|
|||
code body = Body;
|
||||
}
|
||||
|
||||
// The values for const F32 attributes are set as strings as floating point
|
||||
// values can't be provided directly in TableGen.
|
||||
class ConstF32Attr<string val> {
|
||||
Type type = F32;
|
||||
string value = val;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op Properties
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -68,7 +68,7 @@ private:
|
|||
void emitMatcher(DagInit *tree);
|
||||
|
||||
// Emits the value of constant attribute to `os`.
|
||||
void emitAttributeValue(RecordVal *value);
|
||||
void emitAttributeValue(Record *constAttr);
|
||||
|
||||
// Collect bound arguments.
|
||||
void collectBoundArguments(DagInit *tree);
|
||||
|
@ -84,7 +84,28 @@ private:
|
|||
};
|
||||
} // end namespace
|
||||
|
||||
void Pattern::emitAttributeValue(RecordVal *value) {
|
||||
void Pattern::emitAttributeValue(Record *constAttr) {
|
||||
Record *type = constAttr->getValueAsDef("type");
|
||||
auto value = constAttr->getValue("value");
|
||||
|
||||
// Construct the attribute based on `type`.
|
||||
// TODO(jpienaar): Generalize this to avoid hardcoding here.
|
||||
if (type->isSubClassOf("F")) {
|
||||
string mlirType;
|
||||
switch (type->getValueAsInt("bitwidth")) {
|
||||
case 32:
|
||||
mlirType = "Type::getF32(context)";
|
||||
break;
|
||||
default:
|
||||
PrintFatalError("unsupported floating point width");
|
||||
}
|
||||
// TODO(jpienaar): Verify the floating point constant here.
|
||||
os << formatv("FloatAttr::get({0}, {1})", mlirType,
|
||||
value->getValue()->getAsUnquotedString());
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback to the type of value.
|
||||
switch (value->getType()->getRecTyKind()) {
|
||||
case RecTy::IntRecTyKind:
|
||||
// TODO(jpienaar): This is using 64-bits for all the bitwidth of the
|
||||
|
@ -260,7 +281,6 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
Twine("attribute '") + name +
|
||||
"' needs to be constant initialized");
|
||||
|
||||
// TODO: verify that it is an arg.
|
||||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(i);
|
||||
if (!argument.is<mlir::Operator::Attribute *>())
|
||||
|
@ -269,7 +289,7 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
|
||||
if (!name.empty())
|
||||
os << "/*" << name << "=*/";
|
||||
emitAttributeValue(value);
|
||||
emitAttributeValue(defInit->getDef());
|
||||
// TODO(jpienaar): verify types
|
||||
}
|
||||
os << "\n );\n }\n};\n";
|
||||
|
|
Loading…
Reference in New Issue