forked from OSchip/llvm-project
Add builderCall to Type and add constant attr class.
With the builder to construct the type on the Type, the appropriate mlir::Type can be constructed where needed. Also add a constant attr class that has the attribute and value as members. PiperOrigin-RevId: 227564789
This commit is contained in:
parent
fa710c17f4
commit
3633becf8a
|
@ -24,11 +24,15 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Base class for all types.
|
||||
class Type;
|
||||
class Type {
|
||||
// The builder call to invoke (if specified) to construct the Type.
|
||||
code builderCall = ?;
|
||||
}
|
||||
|
||||
// Integer types.
|
||||
class I<int width> : Type {
|
||||
int bitwidth = width;
|
||||
let builderCall = "getIntegerType(" # bitwidth # ")";
|
||||
}
|
||||
def I1 : I<1>;
|
||||
def I32 : I<32>;
|
||||
|
@ -37,7 +41,9 @@ def I32 : I<32>;
|
|||
class F<int width> : Type {
|
||||
int bitwidth = width;
|
||||
}
|
||||
def F32 : F<32>;
|
||||
def F32 : F<32> {
|
||||
let builderCall = "getF32Type()";
|
||||
}
|
||||
|
||||
// Vector types.
|
||||
class Vector<Type t, list<int> dims> : Type {
|
||||
|
@ -108,10 +114,18 @@ class DerivedAttr<code ReturnType, code Body> : Attr<DerivedAttrBody> {
|
|||
code body = Body;
|
||||
}
|
||||
|
||||
// Derived attribute that returns a mlir::Type.
|
||||
class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
|
||||
|
||||
// Represents a constant attribute of specific Attr type. The leaf class that
|
||||
// derives from this should additionally include a `value` member.
|
||||
class ConstantAttr<Attr attribute> {
|
||||
Attr attr = attribute;
|
||||
}
|
||||
|
||||
// The values for const F32 attributes are set as strings as floating point
|
||||
// values can't be provided directly in TableGen.
|
||||
class ConstF32Attr<string val> {
|
||||
Type type = F32;
|
||||
class ConstF32Attr<string val> : ConstantAttr<F32Attr> {
|
||||
string value = val;
|
||||
}
|
||||
|
||||
|
|
|
@ -85,45 +85,32 @@ private:
|
|||
} // end namespace
|
||||
|
||||
void Pattern::emitAttributeValue(Record *constAttr) {
|
||||
Record *type = constAttr->getValueAsDef("type");
|
||||
Record *attr = constAttr->getValueAsDef("attr");
|
||||
auto value = constAttr->getValue("value");
|
||||
Record *type = attr->getValueAsDef("type");
|
||||
auto storageType = attr->getValueAsString("storageType").trim();
|
||||
|
||||
// Construct the attribute based on `type`.
|
||||
// TODO(jpienaar): Generalize this to avoid hardcoding here.
|
||||
if (type->isSubClassOf("F")) {
|
||||
string mlirType;
|
||||
switch (type->getValueAsInt("bitwidth")) {
|
||||
case 32:
|
||||
mlirType = "Type::getF32(context)";
|
||||
break;
|
||||
default:
|
||||
PrintFatalError("unsupported floating point width");
|
||||
}
|
||||
// TODO(jpienaar): Verify the floating point constant here.
|
||||
os << formatv("FloatAttr::get({0}, {1})", mlirType,
|
||||
// For attributes stored as strings we do not need to query builder etc.
|
||||
if (storageType == "StringAttr") {
|
||||
os << formatv("rewriter.getStringAttr({0})",
|
||||
value->getValue()->getAsString());
|
||||
return;
|
||||
}
|
||||
|
||||
// Construct the attribute based on storage type and builder.
|
||||
if (auto b = type->getValue("builderCall")) {
|
||||
if (isa<UnsetInit>(b->getValue()))
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
"no builder specified for " + type->getName());
|
||||
CodeInit *builder = cast<CodeInit>(b->getValue());
|
||||
// TODO(jpienaar): Verify the constants here
|
||||
os << formatv("{0}::get(rewriter.{1}, {2})", storageType,
|
||||
builder->getValue(),
|
||||
value->getValue()->getAsUnquotedString());
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback to the type of value.
|
||||
switch (value->getType()->getRecTyKind()) {
|
||||
case RecTy::IntRecTyKind:
|
||||
// TODO(jpienaar): This is using 64-bits for all the bitwidth of the
|
||||
// type could instead be queried. These are expected to be mostly used
|
||||
// for enums or constant indices and so no arithmetic operations are
|
||||
// expected on these.
|
||||
os << formatv("IntegerAttr::get(Type::getInteger(64, context), {0})",
|
||||
value->getValue()->getAsString());
|
||||
break;
|
||||
case RecTy::StringRecTyKind:
|
||||
os << formatv("StringAttr::get({0}, context)",
|
||||
value->getValue()->getAsString());
|
||||
break;
|
||||
default:
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("unsupported/unimplemented value type for ") +
|
||||
value->getName());
|
||||
}
|
||||
PrintFatalError(pattern->getLoc(), "unable to emit attribute");
|
||||
}
|
||||
|
||||
void Pattern::collectBoundArguments(DagInit *tree) {
|
||||
|
@ -237,7 +224,6 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
os << formatv(R"(
|
||||
void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto* context = op->getContext(); (void)context;
|
||||
auto& s = *static_cast<MatchedState *>(state.get());
|
||||
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
|
||||
resultOp.cppClassName());
|
||||
|
@ -273,7 +259,7 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
(os << ",\n").indent(6);
|
||||
|
||||
// The argument in the result DAG pattern.
|
||||
std::string name = resultTree->getArgNameStr(i);
|
||||
auto name = resultOp.getArgName(i);
|
||||
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
|
||||
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
|
||||
if (!value)
|
||||
|
|
Loading…
Reference in New Issue