diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 4cbf4ee37b6f..a3f8ad538750 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -132,6 +132,12 @@ public: IntegerAttr getI32IntegerAttr(int32_t value); IntegerAttr getI64IntegerAttr(int64_t value); + ArrayAttr getI32ArrayAttr(ArrayRef values); + ArrayAttr getI64ArrayAttr(ArrayRef values); + ArrayAttr getF32ArrayAttr(ArrayRef values); + ArrayAttr getF64ArrayAttr(ArrayRef values); + ArrayAttr getStrArrayAttr(ArrayRef values); + // Affine expressions and affine maps. AffineExpr getAffineDimExpr(unsigned position); AffineExpr getAffineSymbolExpr(unsigned position); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 791e99cf69c4..602d50cc4089 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -550,11 +550,22 @@ class TypedArrayAttrBase: ArrayAttrBase< } def I32ArrayAttr : TypedArrayAttrBase; + "32-bit integer array attribute"> { + let constBuilderCall = "{0}.getI32ArrayAttr({1})"; +} def I64ArrayAttr : TypedArrayAttrBase; -def F32ArrayAttr : TypedArrayAttrBase; -def StrArrayAttr : TypedArrayAttrBase; + "64-bit integer array attribute"> { + let constBuilderCall = "{0}.getI64ArrayAttr({1})"; +} +def F32ArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "{0}.getF32ArrayAttr({1})"; +} +def F64ArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "{0}.getF64ArrayAttr({1})"; +} +def StrArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "{0}.getStrArrayAttr({1})"; +} // Attributes containing functions. def FunctionAttr : Attr()">, diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index a0d9367fa5f1..962fa345c662 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Support/Functional.h" using namespace mlir; Builder::Builder(Module *module) : context(module->getContext()) {} @@ -202,6 +203,36 @@ ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, return OpaqueElementsAttr::get(dialect, type, bytes); } +ArrayAttr Builder::getI32ArrayAttr(ArrayRef values) { + auto attrs = functional::map( + [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }, values); + return getArrayAttr(attrs); +} + +ArrayAttr Builder::getI64ArrayAttr(ArrayRef values) { + auto attrs = functional::map( + [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }, values); + return getArrayAttr(attrs); +} + +ArrayAttr Builder::getF32ArrayAttr(ArrayRef values) { + auto attrs = functional::map( + [this](float v) -> Attribute { return getF32FloatAttr(v); }, values); + return getArrayAttr(attrs); +} + +ArrayAttr Builder::getF64ArrayAttr(ArrayRef values) { + auto attrs = functional::map( + [this](double v) -> Attribute { return getF64FloatAttr(v); }, values); + return getArrayAttr(attrs); +} + +ArrayAttr Builder::getStrArrayAttr(ArrayRef values) { + auto attrs = functional::map( + [this](StringRef v) -> Attribute { return getStringAttr(v); }, values); + return getArrayAttr(attrs); +} + Attribute Builder::getZeroAttr(Type type) { switch (type.getKind()) { case StandardTypes::F32: diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 26c22041f526..3791d2cb3cb7 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -103,9 +103,15 @@ bool tblgen::Attribute::isOptional() const { std::string tblgen::Attribute::getDefaultValueTemplate() const { assert(isConstBuildable() && "requiers constBuilderCall"); - const auto *init = def->getValueInit("defaultValue"); + StringRef defaultValue = getValueAsString(def->getValueInit("defaultValue")); + // TODO(antiagainst): This is a temporary hack to support array initializers + // because '{' is the special marker for placeholders for formatv. Remove this + // after switching to our own formatting utility and $-placeholders. + bool needsEscape = + defaultValue.startswith("{") && !defaultValue.startswith("{{"); + return llvm::formatv(getConstBuilderTemplate().str().c_str(), "{0}", - getValueAsString(init)); + needsEscape ? "{" + defaultValue : defaultValue); } StringRef tblgen::Attribute::getTableGenDefName() const { diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index 17c88dcecff2..4b8a8818e6c3 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -9,6 +9,9 @@ def SomeAttr : Attr, "some attribute kind"> { let constBuilderCall = "some-const-builder-call({0}, {1})"; } +// Test required, optional, default-valued attributes +// --- + def AOp : Op<"a_op", []> { let arguments = (ins SomeAttr:$aAttr, @@ -100,6 +103,28 @@ def BOp : Op<"b_op", []> { // CHECK: if (!((tblgen_array_attr.isa()))) // CHECK: if (!(((tblgen_some_attr_array.isa())) && (llvm::all_of(tblgen_some_attr_array.cast(), [](Attribute attr) { return (some-condition); })))) +// Test building constant values for array attribute kinds +// --- + +def COp : Op<"c_op", []> { + let arguments = (ins + DefaultValuedAttr:$i32_array_attr, + DefaultValuedAttr:$i64_array_attr, + DefaultValuedAttr:$f32_array_attr, + DefaultValuedAttr:$f64_array_attr, + DefaultValuedAttr:$str_array_attr + ); +} + +// CHECK-LABEL: COp definitions +// CHECK: mlir::Builder(this->getContext()).getI32ArrayAttr({1, 2}) +// CHECK: mlir::Builder(this->getContext()).getI64ArrayAttr({3, 4}) +// CHECK: mlir::Builder(this->getContext()).getF32ArrayAttr({5.f, 6.f}) +// CHECK: mlir::Builder(this->getContext()).getF64ArrayAttr({7., 8.}) +// CHECK: mlir::Builder(this->getContext()).getStrArrayAttr({"a", "b"}) + +// Test mixing operands and attributes in arbitrary order +// --- def MixOperandsAndAttrs : Op<"mix_operands_and_attrs", []> { let arguments = (ins F32Attr:$attr, F32:$operand, F32Attr:$otherAttr, F32:$otherArg);