diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 4026d061c62e..b52ae2830e0a 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -924,6 +924,32 @@ class IntElementsAttr : ElementsAttrBase< def I32ElementsAttr : IntElementsAttr<32>; def I64ElementsAttr : IntElementsAttr<64>; +// A `width`-bit floating point elements attribute. The attribute should be +// ranked and has a shape as specified in `dims`. +class RankedFloatElementsAttr dims> : ElementsAttrBase< + CPred<"$_self.isa() &&" + "$_self.cast().getType()." + "getElementType().isF" # width # "() && " + // Check that this is ranked and has the specified shape. + "$_self.cast().getType().hasRank() && " + "$_self.cast().getType().getShape() == " + "llvm::ArrayRef({" # StrJoinInt.result # "})">, + width # "-bit float elements attribute of shape [" # + StrJoinInt.result # "]"> { + + let storageType = [{ DenseFPElementsAttr }]; + let returnType = [{ DenseFPElementsAttr }]; + + let constBuilderCall = "DenseElementsAttr::get(" + "$_builder.getTensorType({" # StrJoinInt.result # + "}, $_builder.getF" # width # "Type()), " + "llvm::makeArrayRef($0)).cast()"; + let convertFromStorage = "$_self"; +} + +class RankedF32ElementsAttr dims> : RankedFloatElementsAttr<32, dims>; +class RankedF64ElementsAttr dims> : RankedFloatElementsAttr<64, dims>; + // Base class for array attributes. class ArrayAttrBase : Attr { diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index a72a34155ffb..6482f434f784 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -189,3 +189,41 @@ func @disallowed_case7_fail() { %0 = "test.i64_enum_attr"() {attr = 5: i32} : () -> i32 return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test FloatElementsAttr +//===----------------------------------------------------------------------===// + +func @correct_type_pass() { + "test.float_elements_attr"() { + // CHECK: scalar_f32_attr = dense<5.000000e+00> : tensor<2xf32> + // CHECK: tensor_f64_attr = dense<6.000000e+00> : tensor<4x8xf64> + scalar_f32_attr = dense<5.0> : tensor<2xf32>, + tensor_f64_attr = dense<6.0> : tensor<4x8xf64> + } : () -> () + return +} + +// ----- + +func @wrong_element_type_pass() { + // expected-error @+1 {{failed to satisfy constraint: 32-bit float elements attribute of shape [2]}} + "test.float_elements_attr"() { + scalar_f32_attr = dense<5.0> : tensor<2xf64>, + tensor_f64_attr = dense<6.0> : tensor<4x8xf64> + } : () -> () + return +} + +// ----- + +func @correct_type_pass() { + // expected-error @+1 {{failed to satisfy constraint: 64-bit float elements attribute of shape [4, 8]}} + "test.float_elements_attr"() { + scalar_f32_attr = dense<5.0> : tensor<2xf32>, + tensor_f64_attr = dense<6.0> : tensor<4xf64> + } : () -> () + return +} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 70a5f41b6f35..9b23962be40d 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -162,6 +162,23 @@ def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> { let results = (outs I32:$val); } +def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> { + let arguments = (ins + RankedF32ElementsAttr<[2]>:$scalar_f32_attr, + RankedF64ElementsAttr<[4, 8]>:$tensor_f64_attr + ); +} + +// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>. +// This tests both matching and generating float elements attributes. +def UpdateFloatElementsAttr : Pat< + (FloatElementsAttrOp + ConstantAttr, "{3.0f, 4.0f}">:$f32attr, + $f64attr), + (FloatElementsAttrOp + ConstantAttr, "{5.0f, 6.0f}">:$f32attr, + $f64attr)>; + //===----------------------------------------------------------------------===// // Test Regions //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index ee5acf92a461..6a5253ab371d 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -163,6 +163,24 @@ func @rewrite_i32elementsattr() -> () { return } +// CHECK-LABEL: rewrite_f64elementsattr +func @rewrite_f64elementsattr() -> () { + "test.float_elements_attr"() { + // Should match + // CHECK: scalar_f32_attr = dense<[5.000000e+00, 6.000000e+00]> : tensor<2xf32> + scalar_f32_attr = dense<[3.0, 4.0]> : tensor<2xf32>, + tensor_f64_attr = dense<6.0> : tensor<4x8xf64> + } : () -> () + + "test.float_elements_attr"() { + // Should not match + // CHECK: scalar_f32_attr = dense<7.000000e+00> : tensor<2xf32> + scalar_f32_attr = dense<7.0> : tensor<2xf32>, + tensor_f64_attr = dense<3.0> : tensor<4x8xf64> + } : () -> () + return +} + //===----------------------------------------------------------------------===// // Test Multi-result Ops //===----------------------------------------------------------------------===//