forked from OSchip/llvm-project
Support OptionalAttr inside a StructAttr
Differential revision: https://reviews.llvm.org/D74768
This commit is contained in:
parent
33aa5dfe9c
commit
066a76a234
|
@ -135,8 +135,18 @@ static void emitFactoryDef(llvm::StringRef structName,
|
|||
fields.emplace_back({0}_id, {0});
|
||||
)";
|
||||
|
||||
const char *getFieldInfoOptional = R"(
|
||||
if ({0}) {
|
||||
auto {0}_id = mlir::Identifier::get("{0}", context);
|
||||
fields.emplace_back({0}_id, {0});
|
||||
}
|
||||
)";
|
||||
|
||||
for (auto field : fields) {
|
||||
os << llvm::formatv(getFieldInfo, field.getName());
|
||||
if (field.getType().isOptional())
|
||||
os << llvm::formatv(getFieldInfoOptional, field.getName());
|
||||
else
|
||||
os << llvm::formatv(getFieldInfo, field.getName());
|
||||
}
|
||||
|
||||
const char *getEndInfo = R"(
|
||||
|
@ -154,35 +164,46 @@ static void emitClassofDef(llvm::StringRef structName,
|
|||
bool {0}::classof(mlir::Attribute attr))";
|
||||
|
||||
const char *classofInfoHeader = R"(
|
||||
auto derived = attr.dyn_cast<mlir::DictionaryAttr>();
|
||||
if (!derived)
|
||||
return false;
|
||||
if (derived.size() != {0})
|
||||
return false;
|
||||
if (!attr)
|
||||
return false;
|
||||
auto derived = attr.dyn_cast<mlir::DictionaryAttr>();
|
||||
if (!derived)
|
||||
return false;
|
||||
int empty_optionals = 0;
|
||||
)";
|
||||
|
||||
os << llvm::formatv(classofInfo, structName) << " {";
|
||||
os << llvm::formatv(classofInfoHeader, fields.size());
|
||||
os << llvm::formatv(classofInfoHeader);
|
||||
|
||||
FmtContext fctx;
|
||||
const char *classofArgInfo = R"(
|
||||
auto {0} = derived.get("{0}");
|
||||
if (!{0} || !({1}))
|
||||
return false;
|
||||
)";
|
||||
const char *classofArgInfoOptional = R"(
|
||||
auto {0} = derived.get("{0}");
|
||||
if (!{0})
|
||||
++empty_optionals;
|
||||
else if (!({1}))
|
||||
return false;
|
||||
)";
|
||||
for (auto field : fields) {
|
||||
auto name = field.getName();
|
||||
auto type = field.getType();
|
||||
std::string condition =
|
||||
std::string(tgfmt(type.getConditionTemplate(), &fctx.withSelf(name)));
|
||||
os << llvm::formatv(classofArgInfo, name, condition);
|
||||
if (type.isOptional())
|
||||
os << llvm::formatv(classofArgInfoOptional, name, condition);
|
||||
else
|
||||
os << llvm::formatv(classofArgInfo, name, condition);
|
||||
}
|
||||
|
||||
const char *classofEndInfo = R"(
|
||||
return true;
|
||||
return derived.size() + empty_optionals == {0};
|
||||
}
|
||||
)";
|
||||
os << classofEndInfo;
|
||||
os << llvm::formatv(classofEndInfo, fields.size());
|
||||
}
|
||||
|
||||
static void
|
||||
|
@ -197,12 +218,25 @@ emitAccessorDef(llvm::StringRef structName,
|
|||
assert({1}.isa<{0}>() && "incorrect Attribute type found.");
|
||||
return {1}.cast<{0}>();
|
||||
}
|
||||
)";
|
||||
const char *fieldInfoOptional = R"(
|
||||
{0} {2}::{1}() const {
|
||||
auto derived = this->cast<mlir::DictionaryAttr>();
|
||||
auto {1} = derived.get("{1}");
|
||||
if (!{1})
|
||||
return nullptr;
|
||||
assert({1}.isa<{0}>() && "incorrect Attribute type found.");
|
||||
return {1}.cast<{0}>();
|
||||
}
|
||||
)";
|
||||
for (auto field : fields) {
|
||||
auto name = field.getName();
|
||||
auto type = field.getType();
|
||||
auto storage = type.getStorageType();
|
||||
os << llvm::formatv(fieldInfo, storage, name, structName);
|
||||
if (type.isOptional())
|
||||
os << llvm::formatv(fieldInfoOptional, storage, name, structName);
|
||||
else
|
||||
os << llvm::formatv(fieldInfo, storage, name, structName);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -33,8 +33,10 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
|
|||
auto elementsType = mlir::RankedTensorType::get({2, 3}, integerType);
|
||||
auto elementsAttr =
|
||||
mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
|
||||
auto optionalAttr = nullptr;
|
||||
|
||||
return test::TestStruct::get(integerAttr, floatAttr, elementsAttr, context);
|
||||
return test::TestStruct::get(integerAttr, floatAttr, elementsAttr,
|
||||
optionalAttr, context);
|
||||
}
|
||||
|
||||
// Validates that test::TestStruct::classof correctly identifies a valid
|
||||
|
@ -159,4 +161,10 @@ TEST(StructsGenTest, GetElements) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(StructsGenTest, EmptyOptional) {
|
||||
mlir::MLIRContext context;
|
||||
auto structAttr = getTestStruct(&context);
|
||||
EXPECT_EQ(structAttr.sample_optional_integer(), nullptr);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -15,6 +15,8 @@ def Test_Dialect : Dialect {
|
|||
def Test_Struct : StructAttr<"TestStruct", Test_Dialect, [
|
||||
StructFieldAttr<"sample_integer", I32Attr>,
|
||||
StructFieldAttr<"sample_float", F32Attr>,
|
||||
StructFieldAttr<"sample_elements", I32ElementsAttr>] > {
|
||||
StructFieldAttr<"sample_elements", I32ElementsAttr>,
|
||||
StructFieldAttr<"sample_optional_integer",
|
||||
OptionalAttr<I32Attr>>] > {
|
||||
let description = "Structure for test data";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue