[ODS] Add support for FloatElementsAttr

This CL adds a new FloatElementsAttr definition to ODS for float
elements attributes of a certain type.

Tests are added to show both verification and how to use it in patterns.

PiperOrigin-RevId: 270455487
This commit is contained in:
Lei Zhang 2019-09-21 09:44:38 -07:00 committed by A. Unique TensorFlower
parent 33a3a91ba2
commit 8e4906362e
4 changed files with 99 additions and 0 deletions

View File

@ -924,6 +924,32 @@ class IntElementsAttr<int width> : ElementsAttrBase<
def I32ElementsAttr : IntElementsAttr<32>;
def I64ElementsAttr : IntElementsAttr<64>;
// A `width`-bit floating point elements attribute. The attribute should be
// ranked and has a shape as specified in `dims`.
class RankedFloatElementsAttr<int width, list<int> dims> : ElementsAttrBase<
CPred<"$_self.isa<DenseFPElementsAttr>() &&"
"$_self.cast<DenseFPElementsAttr>().getType()."
"getElementType().isF" # width # "() && "
// Check that this is ranked and has the specified shape.
"$_self.cast<DenseFPElementsAttr>().getType().hasRank() && "
"$_self.cast<DenseFPElementsAttr>().getType().getShape() == "
"llvm::ArrayRef<int64_t>({" # StrJoinInt<dims>.result # "})">,
width # "-bit float elements attribute of shape [" #
StrJoinInt<dims>.result # "]"> {
let storageType = [{ DenseFPElementsAttr }];
let returnType = [{ DenseFPElementsAttr }];
let constBuilderCall = "DenseElementsAttr::get("
"$_builder.getTensorType({" # StrJoinInt<dims>.result #
"}, $_builder.getF" # width # "Type()), "
"llvm::makeArrayRef($0)).cast<DenseFPElementsAttr>()";
let convertFromStorage = "$_self";
}
class RankedF32ElementsAttr<list<int> dims> : RankedFloatElementsAttr<32, dims>;
class RankedF64ElementsAttr<list<int> dims> : RankedFloatElementsAttr<64, dims>;
// Base class for array attributes.
class ArrayAttrBase<Pred condition, string description> :
Attr<condition, description> {

View File

@ -189,3 +189,41 @@ func @disallowed_case7_fail() {
%0 = "test.i64_enum_attr"() {attr = 5: i32} : () -> i32
return
}
// -----
//===----------------------------------------------------------------------===//
// Test FloatElementsAttr
//===----------------------------------------------------------------------===//
func @correct_type_pass() {
"test.float_elements_attr"() {
// CHECK: scalar_f32_attr = dense<5.000000e+00> : tensor<2xf32>
// CHECK: tensor_f64_attr = dense<6.000000e+00> : tensor<4x8xf64>
scalar_f32_attr = dense<5.0> : tensor<2xf32>,
tensor_f64_attr = dense<6.0> : tensor<4x8xf64>
} : () -> ()
return
}
// -----
func @wrong_element_type_pass() {
// expected-error @+1 {{failed to satisfy constraint: 32-bit float elements attribute of shape [2]}}
"test.float_elements_attr"() {
scalar_f32_attr = dense<5.0> : tensor<2xf64>,
tensor_f64_attr = dense<6.0> : tensor<4x8xf64>
} : () -> ()
return
}
// -----
func @correct_type_pass() {
// expected-error @+1 {{failed to satisfy constraint: 64-bit float elements attribute of shape [4, 8]}}
"test.float_elements_attr"() {
scalar_f32_attr = dense<5.0> : tensor<2xf32>,
tensor_f64_attr = dense<6.0> : tensor<4xf64>
} : () -> ()
return
}

View File

@ -162,6 +162,23 @@ def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> {
let results = (outs I32:$val);
}
def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> {
let arguments = (ins
RankedF32ElementsAttr<[2]>:$scalar_f32_attr,
RankedF64ElementsAttr<[4, 8]>:$tensor_f64_attr
);
}
// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
// This tests both matching and generating float elements attributes.
def UpdateFloatElementsAttr : Pat<
(FloatElementsAttrOp
ConstantAttr<RankedF32ElementsAttr<[2]>, "{3.0f, 4.0f}">:$f32attr,
$f64attr),
(FloatElementsAttrOp
ConstantAttr<RankedF32ElementsAttr<[2]>, "{5.0f, 6.0f}">:$f32attr,
$f64attr)>;
//===----------------------------------------------------------------------===//
// Test Regions
//===----------------------------------------------------------------------===//

View File

@ -163,6 +163,24 @@ func @rewrite_i32elementsattr() -> () {
return
}
// CHECK-LABEL: rewrite_f64elementsattr
func @rewrite_f64elementsattr() -> () {
"test.float_elements_attr"() {
// Should match
// CHECK: scalar_f32_attr = dense<[5.000000e+00, 6.000000e+00]> : tensor<2xf32>
scalar_f32_attr = dense<[3.0, 4.0]> : tensor<2xf32>,
tensor_f64_attr = dense<6.0> : tensor<4x8xf64>
} : () -> ()
"test.float_elements_attr"() {
// Should not match
// CHECK: scalar_f32_attr = dense<7.000000e+00> : tensor<2xf32>
scalar_f32_attr = dense<7.0> : tensor<2xf32>,
tensor_f64_attr = dense<3.0> : tensor<4x8xf64>
} : () -> ()
return
}
//===----------------------------------------------------------------------===//
// Test Multi-result Ops
//===----------------------------------------------------------------------===//