forked from OSchip/llvm-project
[ODS] Add support for FloatElementsAttr
This CL adds a new FloatElementsAttr definition to ODS for float elements attributes of a certain type. Tests are added to show both verification and how to use it in patterns. PiperOrigin-RevId: 270455487
This commit is contained in:
parent
33a3a91ba2
commit
8e4906362e
|
@ -924,6 +924,32 @@ class IntElementsAttr<int width> : ElementsAttrBase<
|
|||
def I32ElementsAttr : IntElementsAttr<32>;
|
||||
def I64ElementsAttr : IntElementsAttr<64>;
|
||||
|
||||
// A `width`-bit floating point elements attribute. The attribute should be
|
||||
// ranked and has a shape as specified in `dims`.
|
||||
class RankedFloatElementsAttr<int width, list<int> dims> : ElementsAttrBase<
|
||||
CPred<"$_self.isa<DenseFPElementsAttr>() &&"
|
||||
"$_self.cast<DenseFPElementsAttr>().getType()."
|
||||
"getElementType().isF" # width # "() && "
|
||||
// Check that this is ranked and has the specified shape.
|
||||
"$_self.cast<DenseFPElementsAttr>().getType().hasRank() && "
|
||||
"$_self.cast<DenseFPElementsAttr>().getType().getShape() == "
|
||||
"llvm::ArrayRef<int64_t>({" # StrJoinInt<dims>.result # "})">,
|
||||
width # "-bit float elements attribute of shape [" #
|
||||
StrJoinInt<dims>.result # "]"> {
|
||||
|
||||
let storageType = [{ DenseFPElementsAttr }];
|
||||
let returnType = [{ DenseFPElementsAttr }];
|
||||
|
||||
let constBuilderCall = "DenseElementsAttr::get("
|
||||
"$_builder.getTensorType({" # StrJoinInt<dims>.result #
|
||||
"}, $_builder.getF" # width # "Type()), "
|
||||
"llvm::makeArrayRef($0)).cast<DenseFPElementsAttr>()";
|
||||
let convertFromStorage = "$_self";
|
||||
}
|
||||
|
||||
class RankedF32ElementsAttr<list<int> dims> : RankedFloatElementsAttr<32, dims>;
|
||||
class RankedF64ElementsAttr<list<int> dims> : RankedFloatElementsAttr<64, dims>;
|
||||
|
||||
// Base class for array attributes.
|
||||
class ArrayAttrBase<Pred condition, string description> :
|
||||
Attr<condition, description> {
|
||||
|
|
|
@ -189,3 +189,41 @@ func @disallowed_case7_fail() {
|
|||
%0 = "test.i64_enum_attr"() {attr = 5: i32} : () -> i32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test FloatElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @correct_type_pass() {
|
||||
"test.float_elements_attr"() {
|
||||
// CHECK: scalar_f32_attr = dense<5.000000e+00> : tensor<2xf32>
|
||||
// CHECK: tensor_f64_attr = dense<6.000000e+00> : tensor<4x8xf64>
|
||||
scalar_f32_attr = dense<5.0> : tensor<2xf32>,
|
||||
tensor_f64_attr = dense<6.0> : tensor<4x8xf64>
|
||||
} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @wrong_element_type_pass() {
|
||||
// expected-error @+1 {{failed to satisfy constraint: 32-bit float elements attribute of shape [2]}}
|
||||
"test.float_elements_attr"() {
|
||||
scalar_f32_attr = dense<5.0> : tensor<2xf64>,
|
||||
tensor_f64_attr = dense<6.0> : tensor<4x8xf64>
|
||||
} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @correct_type_pass() {
|
||||
// expected-error @+1 {{failed to satisfy constraint: 64-bit float elements attribute of shape [4, 8]}}
|
||||
"test.float_elements_attr"() {
|
||||
scalar_f32_attr = dense<5.0> : tensor<2xf32>,
|
||||
tensor_f64_attr = dense<6.0> : tensor<4xf64>
|
||||
} : () -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -162,6 +162,23 @@ def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> {
|
|||
let results = (outs I32:$val);
|
||||
}
|
||||
|
||||
def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> {
|
||||
let arguments = (ins
|
||||
RankedF32ElementsAttr<[2]>:$scalar_f32_attr,
|
||||
RankedF64ElementsAttr<[4, 8]>:$tensor_f64_attr
|
||||
);
|
||||
}
|
||||
|
||||
// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
|
||||
// This tests both matching and generating float elements attributes.
|
||||
def UpdateFloatElementsAttr : Pat<
|
||||
(FloatElementsAttrOp
|
||||
ConstantAttr<RankedF32ElementsAttr<[2]>, "{3.0f, 4.0f}">:$f32attr,
|
||||
$f64attr),
|
||||
(FloatElementsAttrOp
|
||||
ConstantAttr<RankedF32ElementsAttr<[2]>, "{5.0f, 6.0f}">:$f32attr,
|
||||
$f64attr)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Regions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -163,6 +163,24 @@ func @rewrite_i32elementsattr() -> () {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: rewrite_f64elementsattr
|
||||
func @rewrite_f64elementsattr() -> () {
|
||||
"test.float_elements_attr"() {
|
||||
// Should match
|
||||
// CHECK: scalar_f32_attr = dense<[5.000000e+00, 6.000000e+00]> : tensor<2xf32>
|
||||
scalar_f32_attr = dense<[3.0, 4.0]> : tensor<2xf32>,
|
||||
tensor_f64_attr = dense<6.0> : tensor<4x8xf64>
|
||||
} : () -> ()
|
||||
|
||||
"test.float_elements_attr"() {
|
||||
// Should not match
|
||||
// CHECK: scalar_f32_attr = dense<7.000000e+00> : tensor<2xf32>
|
||||
scalar_f32_attr = dense<7.0> : tensor<2xf32>,
|
||||
tensor_f64_attr = dense<3.0> : tensor<4x8xf64>
|
||||
} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Multi-result Ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue