From 88f07a736bbc3f0062d7d8f4032f0b54aff5c018 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 5 Oct 2022 10:40:58 -0700 Subject: [PATCH] [mlir] Make UnitAttr's default val in unwrapped builder UnitAttr is optional but unwrapped builders require it. Make Change onstructing from bool as required for when not set at moment (for UnitAttr nothing needs to be constructed, this is true for others here too and can be addressed together). Differential Revision: https://reviews.llvm.org/D135058 --- .../SparseTensor/IR/SparseTensorOps.td | 3 --- mlir/include/mlir/IR/OpBase.td | 5 ++-- .../SparseTensor/IR/SparseTensorDialect.cpp | 5 ---- mlir/test/mlir-tblgen/op-attribute.td | 5 ++++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 26 +++++++++++-------- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 4d1f23719ee0..549f6c83441a 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -460,9 +460,6 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>, let assemblyFormat = "(`stable` $stable^)? $n" "`,`$xs (`jointly` $ys^)? attr-dict" "`:` type($xs) (`jointly` type($ys)^)?"; - let builders = [ - OpBuilder<(ins "Value":$n, "ValueRange":$xs, "ValueRange":$ys)> - ]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index f35cd3afe7b4..27e55ad07960 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1246,9 +1246,10 @@ class TypeAttrOf // "true" if the attribute is present and "false" otherwise. def UnitAttr : Attr()">, "unit attribute"> { let storageType = [{ ::mlir::UnitAttr }]; - let constBuilderCall = "$_builder.getUnitAttr()"; + let constBuilderCall = "(($0) ? $_builder.getUnitAttr() : nullptr)"; let convertFromStorage = "$_self != nullptr"; let returnType = "bool"; + let defaultValue = "false"; let valueType = NoneType; let isOptional = 1; } @@ -1575,7 +1576,7 @@ class ConstantAttr : AttrConstraint< class ConstF32Attr : ConstantAttr; def ConstBoolAttrFalse : ConstantAttr; def ConstBoolAttrTrue : ConstantAttr; -def ConstUnitAttr : ConstantAttr; +def ConstUnitAttr : ConstantAttr; // Constant string-based attribute. Wraps the desired string in escaped quotes. class ConstantStrAttr diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index fd1eb93d2e88..3f5b9c663774 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -706,11 +706,6 @@ LogicalResult SelectOp::verify() { return success(); } -void SortOp::build(OpBuilder &odsBuilder, OperationState &odsState, Value n, - ValueRange xs, ValueRange ys) { - build(odsBuilder, odsState, n, xs, ys, /*stable=*/false); -} - LogicalResult SortOp::verify() { if (getXs().empty()) return emitError("need at least one xs buffer."); diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index b25adfb4c426..e6cc49dfb049 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -488,6 +488,11 @@ def UnitAttrOp : NS_Op<"unit_attr_op", []> { // DEF-NEXT: (*this)->removeAttr(getAttrAttrName()); // DEF: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::UnitAttr attr) +// DEF: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/bool attr) + +// DECL-LABEL: UnitAttrOp declarations +// DECL-NOT: declarations +// DECL: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/bool attr = false) // Test elementAttr field of TypedArrayAttr. diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 084aac2ee294..546c71e99052 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1635,9 +1635,9 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { } void OpEmitter::genPopulateDefaultAttributes() { - // All done if no attributes have default values. + // All done if no attributes, except optional ones, have default values. if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) { - return !named.attr.hasDefaultValue(); + return !named.attr.hasDefaultValue() || named.attr.isOptional(); })) return; @@ -1667,8 +1667,8 @@ void OpEmitter::genPopulateDefaultAttributes() { fctx.withBuilder(odsBuilder); std::string defaultValue = std::string( tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); - body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n", - index, defaultValue); + body.indent() << formatv("attributes.append(attrNames[{0}], {1});\n", index, + defaultValue); body.unindent() << "}\n"; } } @@ -2143,12 +2143,16 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder( if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name)) continue; - // TODO(jpienaar): The wrapping of optional is different for default or not, - // so don't unwrap for default ones that would fail below. - bool emitNotNullCheck = (attr.isOptional() && !attr.hasDefaultValue()) || - (attr.hasDefaultValue() && !isRawValueAttr); + // TODO: The wrapping of optional is different for default or not, so don't + // unwrap for default ones that would fail below. + bool emitNotNullCheck = + (attr.isOptional() && !attr.hasDefaultValue()) || + (attr.hasDefaultValue() && !isRawValueAttr) || + // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as + // the constant materialization is only for true case. + (isRawValueAttr && attr.getAttrDefName() == "UnitAttr"); if (emitNotNullCheck) - body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; + body.indent() << formatv("if ({0}) ", namedAttr.name) << "{\n"; if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { // If this is a raw value, then we need to wrap it in an Attribute @@ -2175,7 +2179,7 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder( namedAttr.name); } if (emitNotNullCheck) - body << " }\n"; + body.unindent() << " }\n"; } // Create the correct number of regions. @@ -2966,7 +2970,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( // call. This should be set instead. std::string defaultValue = std::string( tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); - body << " if (!attr)\n attr = " << defaultValue << ";\n"; + body << "if (!attr)\n attr = " << defaultValue << ";\n"; } body << "return attr;\n"; };