Switch rewriters for relu, relu6, placeholder_input, softmax to patterns.

Add a constant F32 attribute for use with softmax legalization.

PiperOrigin-RevId: 227241643
This commit is contained in:
Jacques Pienaar 2018-12-29 14:34:06 -08:00 committed by jpienaar
parent 8ef2552df7
commit bbe3f4d9f5
2 changed files with 31 additions and 4 deletions

View File

@ -108,6 +108,13 @@ class DerivedAttr<code ReturnType, code Body> : Attr<DerivedAttrBody> {
code body = Body;
}
// The values for const F32 attributes are set as strings as floating point
// values can't be provided directly in TableGen.
class ConstF32Attr<string val> {
Type type = F32;
string value = val;
}
//===----------------------------------------------------------------------===//
// Op Properties
//===----------------------------------------------------------------------===//

View File

@ -68,7 +68,7 @@ private:
void emitMatcher(DagInit *tree);
// Emits the value of constant attribute to `os`.
void emitAttributeValue(RecordVal *value);
void emitAttributeValue(Record *constAttr);
// Collect bound arguments.
void collectBoundArguments(DagInit *tree);
@ -84,7 +84,28 @@ private:
};
} // end namespace
void Pattern::emitAttributeValue(RecordVal *value) {
void Pattern::emitAttributeValue(Record *constAttr) {
Record *type = constAttr->getValueAsDef("type");
auto value = constAttr->getValue("value");
// Construct the attribute based on `type`.
// TODO(jpienaar): Generalize this to avoid hardcoding here.
if (type->isSubClassOf("F")) {
string mlirType;
switch (type->getValueAsInt("bitwidth")) {
case 32:
mlirType = "Type::getF32(context)";
break;
default:
PrintFatalError("unsupported floating point width");
}
// TODO(jpienaar): Verify the floating point constant here.
os << formatv("FloatAttr::get({0}, {1})", mlirType,
value->getValue()->getAsUnquotedString());
return;
}
// Fallback to the type of value.
switch (value->getType()->getRecTyKind()) {
case RecTy::IntRecTyKind:
// TODO(jpienaar): This is using 64-bits for all the bitwidth of the
@ -260,7 +281,6 @@ void Pattern::emit(StringRef rewriteName) {
Twine("attribute '") + name +
"' needs to be constant initialized");
// TODO: verify that it is an arg.
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
if (!argument.is<mlir::Operator::Attribute *>())
@ -269,7 +289,7 @@ void Pattern::emit(StringRef rewriteName) {
if (!name.empty())
os << "/*" << name << "=*/";
emitAttributeValue(value);
emitAttributeValue(defInit->getDef());
// TODO(jpienaar): verify types
}
os << "\n );\n }\n};\n";