Add builderCall to Type and add constant attr class.

With the builder to construct the type on the Type, the appropriate mlir::Type can be constructed where needed. Also add a constant attr class that has the attribute and value as members.

PiperOrigin-RevId: 227564789
This commit is contained in:
Jacques Pienaar 2019-01-02 12:43:52 -08:00 committed by jpienaar
parent fa710c17f4
commit 3633becf8a
2 changed files with 39 additions and 39 deletions

View File

@ -24,11 +24,15 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Base class for all types. // Base class for all types.
class Type; class Type {
// The builder call to invoke (if specified) to construct the Type.
code builderCall = ?;
}
// Integer types. // Integer types.
class I<int width> : Type { class I<int width> : Type {
int bitwidth = width; int bitwidth = width;
let builderCall = "getIntegerType(" # bitwidth # ")";
} }
def I1 : I<1>; def I1 : I<1>;
def I32 : I<32>; def I32 : I<32>;
@ -37,7 +41,9 @@ def I32 : I<32>;
class F<int width> : Type { class F<int width> : Type {
int bitwidth = width; int bitwidth = width;
} }
def F32 : F<32>; def F32 : F<32> {
let builderCall = "getF32Type()";
}
// Vector types. // Vector types.
class Vector<Type t, list<int> dims> : Type { class Vector<Type t, list<int> dims> : Type {
@ -108,10 +114,18 @@ class DerivedAttr<code ReturnType, code Body> : Attr<DerivedAttrBody> {
code body = Body; code body = Body;
} }
// Derived attribute that returns a mlir::Type.
class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
// Represents a constant attribute of specific Attr type. The leaf class that
// derives from this should additionally include a `value` member.
class ConstantAttr<Attr attribute> {
Attr attr = attribute;
}
// The values for const F32 attributes are set as strings as floating point // The values for const F32 attributes are set as strings as floating point
// values can't be provided directly in TableGen. // values can't be provided directly in TableGen.
class ConstF32Attr<string val> { class ConstF32Attr<string val> : ConstantAttr<F32Attr> {
Type type = F32;
string value = val; string value = val;
} }

View File

@ -85,45 +85,32 @@ private:
} // end namespace } // end namespace
void Pattern::emitAttributeValue(Record *constAttr) { void Pattern::emitAttributeValue(Record *constAttr) {
Record *type = constAttr->getValueAsDef("type"); Record *attr = constAttr->getValueAsDef("attr");
auto value = constAttr->getValue("value"); auto value = constAttr->getValue("value");
Record *type = attr->getValueAsDef("type");
auto storageType = attr->getValueAsString("storageType").trim();
// Construct the attribute based on `type`. // For attributes stored as strings we do not need to query builder etc.
// TODO(jpienaar): Generalize this to avoid hardcoding here. if (storageType == "StringAttr") {
if (type->isSubClassOf("F")) { os << formatv("rewriter.getStringAttr({0})",
string mlirType; value->getValue()->getAsString());
switch (type->getValueAsInt("bitwidth")) { return;
case 32: }
mlirType = "Type::getF32(context)";
break; // Construct the attribute based on storage type and builder.
default: if (auto b = type->getValue("builderCall")) {
PrintFatalError("unsupported floating point width"); if (isa<UnsetInit>(b->getValue()))
} PrintFatalError(pattern->getLoc(),
// TODO(jpienaar): Verify the floating point constant here. "no builder specified for " + type->getName());
os << formatv("FloatAttr::get({0}, {1})", mlirType, CodeInit *builder = cast<CodeInit>(b->getValue());
// TODO(jpienaar): Verify the constants here
os << formatv("{0}::get(rewriter.{1}, {2})", storageType,
builder->getValue(),
value->getValue()->getAsUnquotedString()); value->getValue()->getAsUnquotedString());
return; return;
} }
// Fallback to the type of value. PrintFatalError(pattern->getLoc(), "unable to emit attribute");
switch (value->getType()->getRecTyKind()) {
case RecTy::IntRecTyKind:
// TODO(jpienaar): This is using 64-bits for all the bitwidth of the
// type could instead be queried. These are expected to be mostly used
// for enums or constant indices and so no arithmetic operations are
// expected on these.
os << formatv("IntegerAttr::get(Type::getInteger(64, context), {0})",
value->getValue()->getAsString());
break;
case RecTy::StringRecTyKind:
os << formatv("StringAttr::get({0}, context)",
value->getValue()->getAsString());
break;
default:
PrintFatalError(pattern->getLoc(),
Twine("unsupported/unimplemented value type for ") +
value->getName());
}
} }
void Pattern::collectBoundArguments(DagInit *tree) { void Pattern::collectBoundArguments(DagInit *tree) {
@ -237,7 +224,6 @@ void Pattern::emit(StringRef rewriteName) {
os << formatv(R"( os << formatv(R"(
void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto* context = op->getContext(); (void)context;
auto& s = *static_cast<MatchedState *>(state.get()); auto& s = *static_cast<MatchedState *>(state.get());
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
resultOp.cppClassName()); resultOp.cppClassName());
@ -273,7 +259,7 @@ void Pattern::emit(StringRef rewriteName) {
(os << ",\n").indent(6); (os << ",\n").indent(6);
// The argument in the result DAG pattern. // The argument in the result DAG pattern.
std::string name = resultTree->getArgNameStr(i); auto name = resultOp.getArgName(i);
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i)); auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
if (!value) if (!value)