forked from OSchip/llvm-project
[mlir] Add verify method to adaptor
This allows verifying op-indepent attributes (e.g., attributes that do not require the op to have been created) before constructing an operation. These include checking whether required attributes are defined or constraints on attributes (such as I32 attribute). This is not perfect (e.g., if one had a disjunctive constraint where one part relied on the op and the other doesn't, then this would not try and extract the op independent from the op dependent). The next step is to move these out to a trait that could be verified earlier than in the generated method. The first use case is for inferring the return type while constructing the op. At that point you don't have an Operation yet and that ends up in one having to duplicate the same checks, e.g., verify that attribute A is defined before querying A in shape function which requires that duplication. Instead this allows one to invoke a method to verify all the traits and, if this is checked first during verification, then all other traits could use attributes knowing they have been verified. It is a little bit funny to have these on the adaptor, but I see the adaptor as a place to collect information about the op before the op is constructed (e.g., avoiding stringly typed accessors, verifying what is possible to verify before the op is constructed) while being cheap to use even with constructed op (so layer of indirection between the op constructed/being constructed). And from that point of view it made sense to me. Differential Revision: https://reviews.llvm.org/D80842
This commit is contained in:
parent
f57dd41562
commit
b0921f68e1
|
@ -626,7 +626,8 @@ let verifier = [{
|
||||||
```
|
```
|
||||||
|
|
||||||
Code placed in `verifier` will be called after the auto-generated verification
|
Code placed in `verifier` will be called after the auto-generated verification
|
||||||
code.
|
code. The order of trait verification excluding those of `verifier` should not
|
||||||
|
be relied upon.
|
||||||
|
|
||||||
### Declarative Assembly Format
|
### Declarative Assembly Format
|
||||||
|
|
||||||
|
|
|
@ -254,7 +254,7 @@ func @reduce_op_and_body(%arg0 : f32) {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @reduce_invalid_op(%arg0 : f32) {
|
func @reduce_invalid_op(%arg0 : f32) {
|
||||||
// expected-error@+1 {{gpu.all_reduce' op attribute 'op' failed to satisfy constraint}}
|
// expected-error@+1 {{attribute 'op' failed to satisfy constraint}}
|
||||||
%res = "gpu.all_reduce"(%arg0) ({}) {op = "foo"} : (f32) -> (f32)
|
%res = "gpu.all_reduce"(%arg0) ({}) {op = "foo"} : (f32) -> (f32)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -321,14 +321,14 @@ func @reduce_incorrect_yield(%arg0 : f32) {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
|
func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
|
||||||
// expected-error@+1 {{'gpu.shuffle' op requires the same type for value operand and result}}
|
// expected-error@+1 {{requires the same type for value operand and result}}
|
||||||
%shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (i32, i1)
|
%shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (i32, i1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
|
func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
|
||||||
// expected-error@+1 {{'gpu.shuffle' op requires value operand type to be f32 or i32}}
|
// expected-error@+1 {{requires value operand type to be f32 or i32}}
|
||||||
%shfl, %pred = gpu.shuffle %arg0, %arg1, %arg2 xor : index
|
%shfl, %pred = gpu.shuffle %arg0, %arg1, %arg2 xor : index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -65,12 +65,12 @@ func @references() {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// expected-error @+1 {{op requires string attribute 'sym_name'}}
|
// expected-error @+1 {{requires string attribute 'sym_name'}}
|
||||||
"llvm.mlir.global"() ({}) {type = !llvm.i64, constant, value = 42 : i64} : () -> ()
|
"llvm.mlir.global"() ({}) {type = !llvm.i64, constant, value = 42 : i64} : () -> ()
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// expected-error @+1 {{op requires attribute 'type'}}
|
// expected-error @+1 {{requires attribute 'type'}}
|
||||||
"llvm.mlir.global"() ({}) {sym_name = "foo", constant, value = 42 : i64} : () -> ()
|
"llvm.mlir.global"() ({}) {sym_name = "foo", constant, value = 42 : i64} : () -> ()
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
|
@ -124,7 +124,7 @@ func @composite_extract_invalid_index_type_1() -> () {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @composite_extract_invalid_index_type_2(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () {
|
func @composite_extract_invalid_index_type_2(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () {
|
||||||
// expected-error @+1 {{op attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}}
|
// expected-error @+1 {{attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}}
|
||||||
%0 = spv.CompositeExtract %arg0[1] : !spv.array<4x!spv.array<4xf32>>
|
%0 = spv.CompositeExtract %arg0[1] : !spv.array<4x!spv.array<4xf32>>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1069,7 +1069,7 @@ func @reduce_elt_type_mismatch(%arg0: vector<16xf32>) -> i32 {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
|
func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
|
||||||
// expected-error@+1 {{'vector.reduction' op attribute 'kind' failed to satisfy constraint: string attribute}}
|
// expected-error@+1 {{attribute 'kind' failed to satisfy constraint: string attribute}}
|
||||||
%0 = vector.reduction 1234, %arg0 : vector<16xf32> into i32
|
%0 = vector.reduction 1234, %arg0 : vector<16xf32> into i32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ func @constant_wrong_type() {
|
||||||
func @affine_apply_no_map() {
|
func @affine_apply_no_map() {
|
||||||
^bb0:
|
^bb0:
|
||||||
%i = constant 0 : index
|
%i = constant 0 : index
|
||||||
%x = "affine.apply" (%i) { } : (index) -> (index) // expected-error {{'affine.apply' op requires attribute 'map'}}
|
%x = "affine.apply" (%i) { } : (index) -> (index) // expected-error {{requires attribute 'map'}}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1205,7 +1205,7 @@ func @assume_alignment(%0: memref<4x4xf16>) {
|
||||||
|
|
||||||
// 0 alignment value.
|
// 0 alignment value.
|
||||||
func @assume_alignment(%0: memref<4x4xf16>) {
|
func @assume_alignment(%0: memref<4x4xf16>) {
|
||||||
// expected-error@+1 {{'std.assume_alignment' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
|
// expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
|
||||||
std.assume_alignment %0, 0 : memref<4x4xf16>
|
std.assume_alignment %0, 0 : memref<4x4xf16>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,20 @@ def AOp : NS_Op<"a_op", []> {
|
||||||
|
|
||||||
// DEF-LABEL: AOp definitions
|
// DEF-LABEL: AOp definitions
|
||||||
|
|
||||||
|
// Test verify method
|
||||||
|
// ---
|
||||||
|
|
||||||
|
// DEF: LogicalResult AOpOperandAdaptor::verify
|
||||||
|
// DEF: auto tblgen_aAttr = odsAttrs.get("aAttr");
|
||||||
|
// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'");
|
||||||
|
// DEF: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind");
|
||||||
|
// DEF: auto tblgen_bAttr = odsAttrs.get("bAttr");
|
||||||
|
// DEF-NEXT: if (tblgen_bAttr) {
|
||||||
|
// DEF-NEXT: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind");
|
||||||
|
// DEF: auto tblgen_cAttr = odsAttrs.get("cAttr");
|
||||||
|
// DEF-NEXT: if (tblgen_cAttr) {
|
||||||
|
// DEF-NEXT: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind");
|
||||||
|
|
||||||
// Test getter methods
|
// Test getter methods
|
||||||
// ---
|
// ---
|
||||||
|
|
||||||
|
@ -80,20 +94,6 @@ def AOp : NS_Op<"a_op", []> {
|
||||||
// DEF: ArrayRef<NamedAttribute> attributes
|
// DEF: ArrayRef<NamedAttribute> attributes
|
||||||
// DEF: odsState.addAttributes(attributes);
|
// DEF: odsState.addAttributes(attributes);
|
||||||
|
|
||||||
// Test verify method
|
|
||||||
// ---
|
|
||||||
|
|
||||||
// DEF: AOp::verify()
|
|
||||||
// DEF: auto tblgen_aAttr = this->getAttr("aAttr");
|
|
||||||
// DEF-NEXT: if (!tblgen_aAttr) return emitOpError("requires attribute 'aAttr'");
|
|
||||||
// DEF: if (!((some-condition))) return emitOpError("attribute 'aAttr' failed to satisfy constraint: some attribute kind");
|
|
||||||
// DEF: auto tblgen_bAttr = this->getAttr("bAttr");
|
|
||||||
// DEF-NEXT: if (tblgen_bAttr) {
|
|
||||||
// DEF-NEXT: if (!((some-condition))) return emitOpError("attribute 'bAttr' failed to satisfy constraint: some attribute kind");
|
|
||||||
// DEF: auto tblgen_cAttr = this->getAttr("cAttr");
|
|
||||||
// DEF-NEXT: if (tblgen_cAttr) {
|
|
||||||
// DEF-NEXT: if (!((some-condition))) return emitOpError("attribute 'cAttr' failed to satisfy constraint: some attribute kind");
|
|
||||||
|
|
||||||
def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">;
|
def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">;
|
||||||
|
|
||||||
def BOp : NS_Op<"b_op", []> {
|
def BOp : NS_Op<"b_op", []> {
|
||||||
|
@ -114,6 +114,25 @@ def BOp : NS_Op<"b_op", []> {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Test common attribute kinds' constraints
|
||||||
|
// ---
|
||||||
|
|
||||||
|
// DEF-LABEL: BOpOperandAdaptor::verify
|
||||||
|
// DEF: if (!((true)))
|
||||||
|
// DEF: if (!((tblgen_bool_attr.isa<BoolAttr>())))
|
||||||
|
// DEF: if (!(((tblgen_i32_attr.isa<IntegerAttr>())) && ((tblgen_i32_attr.cast<IntegerAttr>().getType().isSignlessInteger(32)))))
|
||||||
|
// DEF: if (!(((tblgen_i64_attr.isa<IntegerAttr>())) && ((tblgen_i64_attr.cast<IntegerAttr>().getType().isSignlessInteger(64)))))
|
||||||
|
// DEF: if (!(((tblgen_f32_attr.isa<FloatAttr>())) && ((tblgen_f32_attr.cast<FloatAttr>().getType().isF32()))))
|
||||||
|
// DEF: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64()))))
|
||||||
|
// DEF: if (!((tblgen_str_attr.isa<StringAttr>())))
|
||||||
|
// DEF: if (!((tblgen_elements_attr.isa<ElementsAttr>())))
|
||||||
|
// DEF: if (!((tblgen_function_attr.isa<FlatSymbolRefAttr>())))
|
||||||
|
// DEF: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<SomeType>()))))
|
||||||
|
// DEF: if (!((tblgen_array_attr.isa<ArrayAttr>())))
|
||||||
|
// DEF: if (!(((tblgen_some_attr_array.isa<ArrayAttr>())) && (llvm::all_of(tblgen_some_attr_array.cast<ArrayAttr>(), [](Attribute attr) { return (some-condition); }))))
|
||||||
|
// DEF: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<Type>()))))
|
||||||
|
|
||||||
// Test common attribute kind getters' return types
|
// Test common attribute kind getters' return types
|
||||||
// ---
|
// ---
|
||||||
|
|
||||||
|
@ -131,24 +150,6 @@ def BOp : NS_Op<"b_op", []> {
|
||||||
// DEF: ArrayAttr BOp::some_attr_array()
|
// DEF: ArrayAttr BOp::some_attr_array()
|
||||||
// DEF: Type BOp::type_attr()
|
// DEF: Type BOp::type_attr()
|
||||||
|
|
||||||
// Test common attribute kinds' constraints
|
|
||||||
// ---
|
|
||||||
|
|
||||||
// DEF-LABEL: BOp::verify
|
|
||||||
// DEF: if (!((true)))
|
|
||||||
// DEF: if (!((tblgen_bool_attr.isa<BoolAttr>())))
|
|
||||||
// DEF: if (!(((tblgen_i32_attr.isa<IntegerAttr>())) && ((tblgen_i32_attr.cast<IntegerAttr>().getType().isSignlessInteger(32)))))
|
|
||||||
// DEF: if (!(((tblgen_i64_attr.isa<IntegerAttr>())) && ((tblgen_i64_attr.cast<IntegerAttr>().getType().isSignlessInteger(64)))))
|
|
||||||
// DEF: if (!(((tblgen_f32_attr.isa<FloatAttr>())) && ((tblgen_f32_attr.cast<FloatAttr>().getType().isF32()))))
|
|
||||||
// DEF: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64()))))
|
|
||||||
// DEF: if (!((tblgen_str_attr.isa<StringAttr>())))
|
|
||||||
// DEF: if (!((tblgen_elements_attr.isa<ElementsAttr>())))
|
|
||||||
// DEF: if (!((tblgen_function_attr.isa<FlatSymbolRefAttr>())))
|
|
||||||
// DEF: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<SomeType>()))))
|
|
||||||
// DEF: if (!((tblgen_array_attr.isa<ArrayAttr>())))
|
|
||||||
// DEF: if (!(((tblgen_some_attr_array.isa<ArrayAttr>())) && (llvm::all_of(tblgen_some_attr_array.cast<ArrayAttr>(), [](Attribute attr) { return (some-condition); }))))
|
|
||||||
// DEF: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<Type>()))))
|
|
||||||
|
|
||||||
// Test building constant values for array attribute kinds
|
// Test building constant values for array attribute kinds
|
||||||
// ---
|
// ---
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
|
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
@ -32,41 +32,41 @@ def OpF : NS_Op<"op_for_int_min_val", []> {
|
||||||
let arguments = (ins Confined<I32Attr, [IntMinValue<10>]>:$attr);
|
let arguments = (ins Confined<I32Attr, [IntMinValue<10>]>:$attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: OpF::verify()
|
// CHECK-LABEL: OpFOperandAdaptor::verify
|
||||||
// CHECK: (tblgen_attr.cast<IntegerAttr>().getInt() >= 10)
|
// CHECK: (tblgen_attr.cast<IntegerAttr>().getInt() >= 10)
|
||||||
// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10");
|
// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10"
|
||||||
|
|
||||||
def OpFX : NS_Op<"op_for_int_max_val", []> {
|
def OpFX : NS_Op<"op_for_int_max_val", []> {
|
||||||
let arguments = (ins Confined<I32Attr, [IntMaxValue<10>]>:$attr);
|
let arguments = (ins Confined<I32Attr, [IntMaxValue<10>]>:$attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: OpFX::verify()
|
// CHECK-LABEL: OpFXOperandAdaptor::verify
|
||||||
// CHECK: (tblgen_attr.cast<IntegerAttr>().getInt() <= 10)
|
// CHECK: (tblgen_attr.cast<IntegerAttr>().getInt() <= 10)
|
||||||
// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10");
|
// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10"
|
||||||
|
|
||||||
def OpG : NS_Op<"op_for_arr_min_count", []> {
|
def OpG : NS_Op<"op_for_arr_min_count", []> {
|
||||||
let arguments = (ins Confined<ArrayAttr, [ArrayMinCount<8>]>:$attr);
|
let arguments = (ins Confined<ArrayAttr, [ArrayMinCount<8>]>:$attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: OpG::verify()
|
// CHECK-LABEL: OpGOperandAdaptor::verify
|
||||||
// CHECK: (tblgen_attr.cast<ArrayAttr>().size() >= 8)
|
// CHECK: (tblgen_attr.cast<ArrayAttr>().size() >= 8)
|
||||||
// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements");
|
// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements"
|
||||||
|
|
||||||
def OpH : NS_Op<"op_for_arr_value_at_index", []> {
|
def OpH : NS_Op<"op_for_arr_value_at_index", []> {
|
||||||
let arguments = (ins Confined<ArrayAttr, [IntArrayNthElemEq<0, 8>]>:$attr);
|
let arguments = (ins Confined<ArrayAttr, [IntArrayNthElemEq<0, 8>]>:$attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: OpH::verify()
|
// CHECK-LABEL: OpHOperandAdaptor::verify
|
||||||
// CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>()[0].cast<IntegerAttr>().getInt() == 8)))))
|
// CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>()[0].cast<IntegerAttr>().getInt() == 8)))))
|
||||||
// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8");
|
// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8"
|
||||||
|
|
||||||
def OpI: NS_Op<"op_for_arr_min_value_at_index", []> {
|
def OpI: NS_Op<"op_for_arr_min_value_at_index", []> {
|
||||||
let arguments = (ins Confined<ArrayAttr, [IntArrayNthElemMinValue<0, 8>]>:$attr);
|
let arguments = (ins Confined<ArrayAttr, [IntArrayNthElemMinValue<0, 8>]>:$attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: OpI::verify()
|
// CHECK-LABEL: OpIOperandAdaptor::verify
|
||||||
// CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>()[0].cast<IntegerAttr>().getInt() >= 8)))))
|
// CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>()[0].cast<IntegerAttr>().getInt() >= 8)))))
|
||||||
// CHECK-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8");
|
// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8"
|
||||||
|
|
||||||
def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
|
def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
|
||||||
PredOpTrait<"operands indexed at 0, 2, 3 should all have "
|
PredOpTrait<"operands indexed at 0, 2, 3 should all have "
|
||||||
|
@ -80,11 +80,11 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: OpJ::verify()
|
// CHECK-LABEL: OpJOperandAdaptor::verify
|
||||||
// CHECK: llvm::is_splat(llvm::map_range(
|
// CHECK: llvm::is_splat(llvm::map_range(
|
||||||
// CHECK-SAME: llvm::ArrayRef<unsigned>({0, 2, 3}),
|
// CHECK-SAME: llvm::ArrayRef<unsigned>({0, 2, 3}),
|
||||||
// CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))
|
// CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))
|
||||||
// CHECK: return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type");
|
// CHECK: "failed to verify that operands indexed at 0, 2, 3 should all have the same type"
|
||||||
|
|
||||||
def OpK : NS_Op<"op_for_AnyTensorOf", []> {
|
def OpK : NS_Op<"op_for_AnyTensorOf", []> {
|
||||||
let arguments = (ins TensorOf<[F32, I32]>:$x);
|
let arguments = (ins TensorOf<[F32, I32]>:$x);
|
||||||
|
|
|
@ -321,6 +321,116 @@ private:
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
// Populate the format context `ctx` with substitutions of attributes, operands
|
||||||
|
// and results.
|
||||||
|
// - attrGet corresponds to the name of the function to call to get value of
|
||||||
|
// attribute (the generated function call returns an Attribute);
|
||||||
|
// - operandGet corresponds to the name of the function with which to retrieve
|
||||||
|
// an operand (the generaed function call returns an OperandRange);
|
||||||
|
// - reultGet corresponds to the name of the function to get an result (the
|
||||||
|
// generated function call returns a ValueRange);
|
||||||
|
static void populateSubstitutions(const Operator &op, const char *attrGet,
|
||||||
|
const char *operandGet, const char *resultGet,
|
||||||
|
FmtContext &ctx) {
|
||||||
|
// Populate substitutions for attributes and named operands.
|
||||||
|
for (const auto &namedAttr : op.getAttributes())
|
||||||
|
ctx.addSubst(namedAttr.name,
|
||||||
|
formatv("{0}(\"{1}\")", attrGet, namedAttr.name));
|
||||||
|
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||||
|
auto &value = op.getOperand(i);
|
||||||
|
if (value.name.empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (value.isVariadic())
|
||||||
|
ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i));
|
||||||
|
else
|
||||||
|
ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate substitutions for results.
|
||||||
|
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||||
|
auto &value = op.getResult(i);
|
||||||
|
if (value.name.empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (value.isVariadic())
|
||||||
|
ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i));
|
||||||
|
else
|
||||||
|
ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate attribute verification. If emitVerificationRequiringOp is set then
|
||||||
|
// only verification for attributes whose value depend on op being known are
|
||||||
|
// emitted, else only verification that doesn't depend on the op being known are
|
||||||
|
// generated.
|
||||||
|
// - emitErrorPrefix is the prefix for the error emitting call which consists
|
||||||
|
// of the entire function call up to start of error message fragment;
|
||||||
|
// - emitVerificationRequiringOp specifies whether verification should be
|
||||||
|
// emitted for verification that require the op to exist;
|
||||||
|
static void genAttributeVerifier(const Operator &op, const char *attrGet,
|
||||||
|
const Twine &emitErrorPrefix,
|
||||||
|
bool emitVerificationRequiringOp,
|
||||||
|
FmtContext &ctx, OpMethodBody &body) {
|
||||||
|
for (const auto &namedAttr : op.getAttributes()) {
|
||||||
|
const auto &attr = namedAttr.attr;
|
||||||
|
if (attr.isDerivedAttr())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto attrName = namedAttr.name;
|
||||||
|
bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
|
||||||
|
auto attrPred = attr.getPredicate();
|
||||||
|
auto condition = attrPred.isNull() ? "" : attrPred.getCondition();
|
||||||
|
// There is a condition to emit only if the use of $_op and whether to
|
||||||
|
// emit verifications for op matches.
|
||||||
|
bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^
|
||||||
|
emitVerificationRequiringOp);
|
||||||
|
|
||||||
|
// Prefix with `tblgen_` to avoid hiding the attribute accessor.
|
||||||
|
auto varName = tblgenNamePrefix + attrName;
|
||||||
|
|
||||||
|
// If the attribute is
|
||||||
|
// 1. Required (not allowed missing) and not in op verification, or
|
||||||
|
// 2. Has a condition that will get verified
|
||||||
|
// then the variable will be used.
|
||||||
|
//
|
||||||
|
// Therefore, for optional attributes whose verification requires that an
|
||||||
|
// op already exists for verification/emitVerificationRequiringOp is set
|
||||||
|
// has nothing that can be verified here.
|
||||||
|
if ((allowMissingAttr || emitVerificationRequiringOp) &&
|
||||||
|
!hasConditionToEmit)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
body << formatv(" {\n auto {0} = {1}(\"{2}\");\n", varName, attrGet,
|
||||||
|
attrName);
|
||||||
|
|
||||||
|
if (!emitVerificationRequiringOp && !allowMissingAttr) {
|
||||||
|
body << " if (!" << varName << ") return " << emitErrorPrefix
|
||||||
|
<< "\"requires attribute '" << attrName << "'\");\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasConditionToEmit) {
|
||||||
|
body << " }\n";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allowMissingAttr) {
|
||||||
|
// If the attribute has a default value, then only verify the predicate if
|
||||||
|
// set. This does effectively assume that the default value is valid.
|
||||||
|
// TODO: verify the debug value is valid (perhaps in debug mode only).
|
||||||
|
body << " if (" << varName << ") {\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
body << tgfmt(" if (!($0)) return $1\"attribute '$2' "
|
||||||
|
"failed to satisfy constraint: $3\");\n",
|
||||||
|
/*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)),
|
||||||
|
emitErrorPrefix, attrName, attr.getDescription());
|
||||||
|
if (allowMissingAttr)
|
||||||
|
body << " }\n";
|
||||||
|
body << " }\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
OpEmitter::OpEmitter(const Operator &op)
|
OpEmitter::OpEmitter(const Operator &op)
|
||||||
: def(op.getDef()), op(op),
|
: def(op.getDef()), op(op),
|
||||||
opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
|
opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
|
||||||
|
@ -1512,110 +1622,27 @@ void OpEmitter::genPrinter() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void OpEmitter::genVerifier() {
|
void OpEmitter::genVerifier() {
|
||||||
auto valueInit = def.getValueInit("verifier");
|
|
||||||
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
|
|
||||||
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
|
|
||||||
|
|
||||||
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
|
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
|
||||||
auto &body = method.body();
|
auto &body = method.body();
|
||||||
|
body << " if (failed(" << op.getAdaptorName()
|
||||||
|
<< "(*this).verify(this->getLoc()))) "
|
||||||
|
<< "return failure();\n";
|
||||||
|
|
||||||
const char *checkAttrSizedValueSegmentsCode = R"(
|
auto *valueInit = def.getValueInit("verifier");
|
||||||
{
|
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
|
||||||
auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
|
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
|
||||||
auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements();
|
populateSubstitutions(op, "this->getAttr", "this->getODSOperands",
|
||||||
if (numElements != {1}) {{
|
"this->getODSResults", verifyCtx);
|
||||||
return emitOpError("'{0}' attribute for specifying {2} segments "
|
|
||||||
"must have {1} elements");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
|
|
||||||
// Verify a few traits first so that we can use
|
|
||||||
// getODSOperands()/getODSResults() in the rest of the verifier.
|
|
||||||
for (auto &trait : op.getTraits()) {
|
|
||||||
if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
|
|
||||||
if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") {
|
|
||||||
body << formatv(checkAttrSizedValueSegmentsCode,
|
|
||||||
"operand_segment_sizes", op.getNumOperands(),
|
|
||||||
"operand");
|
|
||||||
} else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") {
|
|
||||||
body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
|
|
||||||
op.getNumResults(), "result");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populate substitutions for attributes and named operands and results.
|
|
||||||
for (const auto &namedAttr : op.getAttributes())
|
|
||||||
verifyCtx.addSubst(namedAttr.name,
|
|
||||||
formatv("this->getAttr(\"{0}\")", namedAttr.name));
|
|
||||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
|
||||||
auto &value = op.getOperand(i);
|
|
||||||
if (value.name.empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (value.isVariadic())
|
|
||||||
verifyCtx.addSubst(value.name, formatv("this->getODSOperands({0})", i));
|
|
||||||
else
|
|
||||||
verifyCtx.addSubst(value.name,
|
|
||||||
formatv("(*this->getODSOperands({0}).begin())", i));
|
|
||||||
}
|
|
||||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
|
||||||
auto &value = op.getResult(i);
|
|
||||||
if (value.name.empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (value.isVariadic())
|
|
||||||
verifyCtx.addSubst(value.name, formatv("this->getODSResults({0})", i));
|
|
||||||
else
|
|
||||||
verifyCtx.addSubst(value.name,
|
|
||||||
formatv("(*this->getODSResults({0}).begin())", i));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the attributes have the correct type.
|
|
||||||
for (const auto &namedAttr : op.getAttributes()) {
|
|
||||||
const auto &attr = namedAttr.attr;
|
|
||||||
if (attr.isDerivedAttr())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto attrName = namedAttr.name;
|
|
||||||
// Prefix with `tblgen_` to avoid hiding the attribute accessor.
|
|
||||||
auto varName = tblgenNamePrefix + attrName;
|
|
||||||
body << formatv(" auto {0} = this->getAttr(\"{1}\");\n", varName,
|
|
||||||
attrName);
|
|
||||||
|
|
||||||
bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
|
|
||||||
if (allowMissingAttr) {
|
|
||||||
// If the attribute has a default value, then only verify the predicate if
|
|
||||||
// set. This does effectively assume that the default value is valid.
|
|
||||||
// TODO: verify the debug value is valid (perhaps in debug mode only).
|
|
||||||
body << " if (" << varName << ") {\n";
|
|
||||||
} else {
|
|
||||||
body << " if (!" << varName
|
|
||||||
<< ") return emitOpError(\"requires attribute '" << attrName
|
|
||||||
<< "'\");\n {\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
auto attrPred = attr.getPredicate();
|
|
||||||
if (!attrPred.isNull()) {
|
|
||||||
body << tgfmt(
|
|
||||||
" if (!($0)) return emitOpError(\"attribute '$1' "
|
|
||||||
"failed to satisfy constraint: $2\");\n",
|
|
||||||
/*ctx=*/nullptr,
|
|
||||||
tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)),
|
|
||||||
attrName, attr.getDescription());
|
|
||||||
}
|
|
||||||
|
|
||||||
body << " }\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
genAttributeVerifier(op, "this->getAttr", "emitOpError(",
|
||||||
|
/*emitVerificationRequiringOp=*/true, verifyCtx, body);
|
||||||
genOperandResultVerifier(body, op.getOperands(), "operand");
|
genOperandResultVerifier(body, op.getOperands(), "operand");
|
||||||
genOperandResultVerifier(body, op.getResults(), "result");
|
genOperandResultVerifier(body, op.getResults(), "result");
|
||||||
|
|
||||||
for (auto &trait : op.getTraits()) {
|
for (auto &trait : op.getTraits()) {
|
||||||
if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
|
if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
|
||||||
body << tgfmt(" if (!($0)) {\n "
|
body << tgfmt(" if (!($0))\n "
|
||||||
"return emitOpError(\"failed to verify that $1\");\n }\n",
|
"return emitOpError(\"failed to verify that $1\");\n",
|
||||||
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
|
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
|
||||||
t->getDescription());
|
t->getDescription());
|
||||||
}
|
}
|
||||||
|
@ -1890,12 +1917,17 @@ public:
|
||||||
private:
|
private:
|
||||||
explicit OpOperandAdaptorEmitter(const Operator &op);
|
explicit OpOperandAdaptorEmitter(const Operator &op);
|
||||||
|
|
||||||
|
// Add verification function. This generates a verify method for the adaptor
|
||||||
|
// which verifies all the op-independent attribute constraints.
|
||||||
|
void addVerification();
|
||||||
|
|
||||||
|
const Operator &op;
|
||||||
Class adaptor;
|
Class adaptor;
|
||||||
};
|
};
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
|
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
|
||||||
: adaptor(op.getAdaptorName()) {
|
: op(op), adaptor(op.getAdaptorName()) {
|
||||||
adaptor.newField("ValueRange", "odsOperands");
|
adaptor.newField("ValueRange", "odsOperands");
|
||||||
adaptor.newField("DictionaryAttr", "odsAttrs");
|
adaptor.newField("DictionaryAttr", "odsAttrs");
|
||||||
const auto *attrSizedOperands =
|
const auto *attrSizedOperands =
|
||||||
|
@ -1957,6 +1989,50 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
|
||||||
if (!attr.isDerivedAttr())
|
if (!attr.isDerivedAttr())
|
||||||
emitAttr(name, attr);
|
emitAttr(name, attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add verification function.
|
||||||
|
addVerification();
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpOperandAdaptorEmitter::addVerification() {
|
||||||
|
auto &method = adaptor.newMethod("LogicalResult", "verify",
|
||||||
|
/*params=*/"Location loc");
|
||||||
|
auto &body = method.body();
|
||||||
|
|
||||||
|
const char *checkAttrSizedValueSegmentsCode = R"(
|
||||||
|
{
|
||||||
|
auto sizeAttr = odsAttrs.get("{0}").cast<DenseIntElementsAttr>();
|
||||||
|
auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements();
|
||||||
|
if (numElements != {1})
|
||||||
|
return emitError(loc, "'{0}' attribute for specifying {2} segments "
|
||||||
|
"must have {1} elements");
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
// Verify a few traits first so that we can use
|
||||||
|
// getODSOperands()/getODSResults() in the rest of the verifier.
|
||||||
|
for (auto &trait : op.getTraits()) {
|
||||||
|
if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
|
||||||
|
if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") {
|
||||||
|
body << formatv(checkAttrSizedValueSegmentsCode,
|
||||||
|
"operand_segment_sizes", op.getNumOperands(),
|
||||||
|
"operand");
|
||||||
|
} else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") {
|
||||||
|
body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
|
||||||
|
op.getNumResults(), "result");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FmtContext verifyCtx;
|
||||||
|
populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
|
||||||
|
"<no results should be genarated>", verifyCtx);
|
||||||
|
genAttributeVerifier(op, "odsAttrs.get",
|
||||||
|
Twine("emitError(loc, \"'") + op.getOperationName() +
|
||||||
|
"' op \"",
|
||||||
|
/*emitVerificationRequiringOp*/ false, verifyCtx, body);
|
||||||
|
|
||||||
|
body << " return success();";
|
||||||
}
|
}
|
||||||
|
|
||||||
void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
|
void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
|
||||||
|
|
Loading…
Reference in New Issue