forked from OSchip/llvm-project
Add builderCall to Type and add constant attr class.
With the builder to construct the type on the Type, the appropriate mlir::Type can be constructed where needed. Also add a constant attr class that has the attribute and value as members. PiperOrigin-RevId: 227564789
This commit is contained in:
parent
fa710c17f4
commit
3633becf8a
|
@ -24,11 +24,15 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Base class for all types.
|
// Base class for all types.
|
||||||
class Type;
|
class Type {
|
||||||
|
// The builder call to invoke (if specified) to construct the Type.
|
||||||
|
code builderCall = ?;
|
||||||
|
}
|
||||||
|
|
||||||
// Integer types.
|
// Integer types.
|
||||||
class I<int width> : Type {
|
class I<int width> : Type {
|
||||||
int bitwidth = width;
|
int bitwidth = width;
|
||||||
|
let builderCall = "getIntegerType(" # bitwidth # ")";
|
||||||
}
|
}
|
||||||
def I1 : I<1>;
|
def I1 : I<1>;
|
||||||
def I32 : I<32>;
|
def I32 : I<32>;
|
||||||
|
@ -37,7 +41,9 @@ def I32 : I<32>;
|
||||||
class F<int width> : Type {
|
class F<int width> : Type {
|
||||||
int bitwidth = width;
|
int bitwidth = width;
|
||||||
}
|
}
|
||||||
def F32 : F<32>;
|
def F32 : F<32> {
|
||||||
|
let builderCall = "getF32Type()";
|
||||||
|
}
|
||||||
|
|
||||||
// Vector types.
|
// Vector types.
|
||||||
class Vector<Type t, list<int> dims> : Type {
|
class Vector<Type t, list<int> dims> : Type {
|
||||||
|
@ -108,10 +114,18 @@ class DerivedAttr<code ReturnType, code Body> : Attr<DerivedAttrBody> {
|
||||||
code body = Body;
|
code body = Body;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Derived attribute that returns a mlir::Type.
|
||||||
|
class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
|
||||||
|
|
||||||
|
// Represents a constant attribute of specific Attr type. The leaf class that
|
||||||
|
// derives from this should additionally include a `value` member.
|
||||||
|
class ConstantAttr<Attr attribute> {
|
||||||
|
Attr attr = attribute;
|
||||||
|
}
|
||||||
|
|
||||||
// The values for const F32 attributes are set as strings as floating point
|
// The values for const F32 attributes are set as strings as floating point
|
||||||
// values can't be provided directly in TableGen.
|
// values can't be provided directly in TableGen.
|
||||||
class ConstF32Attr<string val> {
|
class ConstF32Attr<string val> : ConstantAttr<F32Attr> {
|
||||||
Type type = F32;
|
|
||||||
string value = val;
|
string value = val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -85,45 +85,32 @@ private:
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
void Pattern::emitAttributeValue(Record *constAttr) {
|
void Pattern::emitAttributeValue(Record *constAttr) {
|
||||||
Record *type = constAttr->getValueAsDef("type");
|
Record *attr = constAttr->getValueAsDef("attr");
|
||||||
auto value = constAttr->getValue("value");
|
auto value = constAttr->getValue("value");
|
||||||
|
Record *type = attr->getValueAsDef("type");
|
||||||
|
auto storageType = attr->getValueAsString("storageType").trim();
|
||||||
|
|
||||||
// Construct the attribute based on `type`.
|
// For attributes stored as strings we do not need to query builder etc.
|
||||||
// TODO(jpienaar): Generalize this to avoid hardcoding here.
|
if (storageType == "StringAttr") {
|
||||||
if (type->isSubClassOf("F")) {
|
os << formatv("rewriter.getStringAttr({0})",
|
||||||
string mlirType;
|
value->getValue()->getAsString());
|
||||||
switch (type->getValueAsInt("bitwidth")) {
|
return;
|
||||||
case 32:
|
}
|
||||||
mlirType = "Type::getF32(context)";
|
|
||||||
break;
|
// Construct the attribute based on storage type and builder.
|
||||||
default:
|
if (auto b = type->getValue("builderCall")) {
|
||||||
PrintFatalError("unsupported floating point width");
|
if (isa<UnsetInit>(b->getValue()))
|
||||||
}
|
PrintFatalError(pattern->getLoc(),
|
||||||
// TODO(jpienaar): Verify the floating point constant here.
|
"no builder specified for " + type->getName());
|
||||||
os << formatv("FloatAttr::get({0}, {1})", mlirType,
|
CodeInit *builder = cast<CodeInit>(b->getValue());
|
||||||
|
// TODO(jpienaar): Verify the constants here
|
||||||
|
os << formatv("{0}::get(rewriter.{1}, {2})", storageType,
|
||||||
|
builder->getValue(),
|
||||||
value->getValue()->getAsUnquotedString());
|
value->getValue()->getAsUnquotedString());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to the type of value.
|
PrintFatalError(pattern->getLoc(), "unable to emit attribute");
|
||||||
switch (value->getType()->getRecTyKind()) {
|
|
||||||
case RecTy::IntRecTyKind:
|
|
||||||
// TODO(jpienaar): This is using 64-bits for all the bitwidth of the
|
|
||||||
// type could instead be queried. These are expected to be mostly used
|
|
||||||
// for enums or constant indices and so no arithmetic operations are
|
|
||||||
// expected on these.
|
|
||||||
os << formatv("IntegerAttr::get(Type::getInteger(64, context), {0})",
|
|
||||||
value->getValue()->getAsString());
|
|
||||||
break;
|
|
||||||
case RecTy::StringRecTyKind:
|
|
||||||
os << formatv("StringAttr::get({0}, context)",
|
|
||||||
value->getValue()->getAsString());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
PrintFatalError(pattern->getLoc(),
|
|
||||||
Twine("unsupported/unimplemented value type for ") +
|
|
||||||
value->getName());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Pattern::collectBoundArguments(DagInit *tree) {
|
void Pattern::collectBoundArguments(DagInit *tree) {
|
||||||
|
@ -237,7 +224,6 @@ void Pattern::emit(StringRef rewriteName) {
|
||||||
os << formatv(R"(
|
os << formatv(R"(
|
||||||
void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
|
void rewrite(OperationInst *op, std::unique_ptr<PatternState> state,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto* context = op->getContext(); (void)context;
|
|
||||||
auto& s = *static_cast<MatchedState *>(state.get());
|
auto& s = *static_cast<MatchedState *>(state.get());
|
||||||
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
|
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
|
||||||
resultOp.cppClassName());
|
resultOp.cppClassName());
|
||||||
|
@ -273,7 +259,7 @@ void Pattern::emit(StringRef rewriteName) {
|
||||||
(os << ",\n").indent(6);
|
(os << ",\n").indent(6);
|
||||||
|
|
||||||
// The argument in the result DAG pattern.
|
// The argument in the result DAG pattern.
|
||||||
std::string name = resultTree->getArgNameStr(i);
|
auto name = resultOp.getArgName(i);
|
||||||
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
|
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
|
||||||
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
|
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
|
||||||
if (!value)
|
if (!value)
|
||||||
|
|
Loading…
Reference in New Issue