forked from OSchip/llvm-project
[ODS] Support variadic operand/result verification
This CL enables verification code generation for variadic operands and results. In verify(), we use fallback getter methods to access all the dynamic values belonging to one static variadic operand/result to reuse the value range calculation there. PiperOrigin-RevId: 252288219
This commit is contained in:
parent
7f108e60cc
commit
3812d956ea
|
@ -255,9 +255,7 @@ class TypeAlias<Type t, string description = t.description> :
|
|||
// class is used for supporting variadic operands/results. An op can declare no
|
||||
// more than one variadic operand/result, and that operand/result must be the
|
||||
// last one in the operand/result list.
|
||||
class Variadic<Type type, string descr = "">
|
||||
// TODO(b/132908002): support variadic type conditions
|
||||
: TypeConstraint<CPred<"true">, descr> {
|
||||
class Variadic<Type type> : TypeConstraint<type.predicate, type.description> {
|
||||
Type baseType = type;
|
||||
}
|
||||
|
||||
|
@ -907,6 +905,9 @@ def Terminator : NativeOpTrait<"IsTerminator">;
|
|||
def FirstAttrDerivedResultType :
|
||||
GenInternalOpTrait<"FirstAttrDerivedResultType">;
|
||||
|
||||
// TODO(antiagainst): Turn the following into normal traits and generate
|
||||
// verification for them.
|
||||
|
||||
// All variadic operands of the op have the same number of values.
|
||||
// A variadic operand contains an array of values whose array size is only
|
||||
// known at runtime. This trait requires all variadic operands of an op
|
||||
|
|
|
@ -203,7 +203,9 @@ def LLVM_PtrToIntOp
|
|||
// Call-related operations.
|
||||
def LLVM_CallOp : LLVM_Op<"call">,
|
||||
Arguments<(ins OptionalAttr<FunctionAttr>:$callee,
|
||||
Variadic<LLVM_Type>)>,
|
||||
// TODO(b/133216756): fix test failure and
|
||||
// change to LLVM_Type
|
||||
Variadic<AnyType>)>,
|
||||
Results<(outs Variadic<LLVM_Type>)>,
|
||||
LLVM_TwoBuilders<LLVM_OneResultOpBuilder,
|
||||
LLVM_ZeroResultOpBuilder> {
|
||||
|
|
|
@ -69,11 +69,12 @@ public:
|
|||
std::string getQualCppClassName() const;
|
||||
|
||||
using value_iterator = NamedTypeConstraint *;
|
||||
using value_range = llvm::iterator_range<value_iterator>;
|
||||
|
||||
// Op result iterators.
|
||||
value_iterator result_begin();
|
||||
value_iterator result_end();
|
||||
llvm::iterator_range<value_iterator> getResults();
|
||||
value_range getResults();
|
||||
|
||||
// Returns the number of results this op produces.
|
||||
int getNumResults() const;
|
||||
|
@ -110,7 +111,7 @@ public:
|
|||
// Op operand iterators.
|
||||
value_iterator operand_begin();
|
||||
value_iterator operand_end();
|
||||
llvm::iterator_range<value_iterator> getOperands();
|
||||
value_range getOperands();
|
||||
|
||||
int getNumOperands() const { return operands.size(); }
|
||||
NamedTypeConstraint &getOperand(int index) { return operands[index]; }
|
||||
|
|
|
@ -1595,12 +1595,6 @@ static LogicalResult verify(ExtractElementOp op) {
|
|||
if (op.getType() != aggregateType.getElementType())
|
||||
return op.emitOpError("result type must match element type of aggregate");
|
||||
|
||||
// TODO(b/132908002) This should be covered by the op specification in
|
||||
// tablegen, but for some reason it's not.
|
||||
for (auto *idx : op.getIndices())
|
||||
if (!idx->getType().isIndex())
|
||||
return op.emitOpError("index to extract_element must have 'index' type");
|
||||
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
if (aggregateType.hasRank() &&
|
||||
aggregateType.getRank() != op.getNumOperands() - 1)
|
||||
|
|
|
@ -95,7 +95,7 @@ auto tblgen::Operator::result_begin() -> value_iterator {
|
|||
|
||||
auto tblgen::Operator::result_end() -> value_iterator { return results.end(); }
|
||||
|
||||
auto tblgen::Operator::getResults() -> llvm::iterator_range<value_iterator> {
|
||||
auto tblgen::Operator::getResults() -> value_range {
|
||||
return {result_begin(), result_end()};
|
||||
}
|
||||
|
||||
|
@ -205,7 +205,7 @@ auto tblgen::Operator::operand_begin() -> value_iterator {
|
|||
auto tblgen::Operator::operand_end() -> value_iterator {
|
||||
return operands.end();
|
||||
}
|
||||
auto tblgen::Operator::getOperands() -> llvm::iterator_range<value_iterator> {
|
||||
auto tblgen::Operator::getOperands() -> value_range {
|
||||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
||||
|
|
|
@ -639,7 +639,7 @@ func @extract_element_no_indices(%v : vector<3xf32>) {
|
|||
// -----
|
||||
|
||||
func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) {
|
||||
// expected-error@+1 {{index to extract_element must have 'index' type}}
|
||||
// expected-error@+1 {{operand #1 must be index}}
|
||||
%0 = "std.extract_element"(%v, %i) : (vector<3xf32>, i32) -> f32
|
||||
return
|
||||
}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test mixed normal and variadic operands
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @correct_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
|
||||
// CHECK: mixed_normal_variadic_operand
|
||||
"test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg0, %arg0) : (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @error_in_first_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
|
||||
// expected-error @+1 {{operand #0 must be tensor of any type}}
|
||||
"test.mixed_normal_variadic_operand"(%arg0, %arg1, %arg0, %arg0, %arg0) : (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @error_in_normal_operand(%arg0: tensor<f32>, %arg1: f32) {
|
||||
// expected-error @+1 {{operand #1 must be tensor of any type}}
|
||||
"test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg1, %arg0, %arg0) : (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @error_in_second_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
|
||||
// expected-error @+1 {{operand #2 must be tensor of any type}}
|
||||
"test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg1, %arg0) : (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>) -> ()
|
||||
return
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test mixed normal and variadic results
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @correct_variadic_result() -> tensor<f32> {
|
||||
// CHECK: mixed_normal_variadic_result
|
||||
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
|
||||
return %0#4 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @error_in_first_variadic_result() -> tensor<f32> {
|
||||
// expected-error @+1 {{result #0 must be tensor of any type}}
|
||||
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>)
|
||||
return %0#4 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @error_in_normal_result() -> tensor<f32> {
|
||||
// expected-error @+1 {{result #1 must be tensor of any type}}
|
||||
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>)
|
||||
return %0#4 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @error_in_second_variadic_result() -> tensor<f32> {
|
||||
// expected-error @+1 {{result #2 must be tensor of any type}}
|
||||
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>)
|
||||
return %0#4 : tensor<f32>
|
||||
}
|
||||
|
|
@ -60,6 +60,31 @@ def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> {
|
|||
let results = (outs NestedTupleOf<[I32, F32]>);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Operands
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MixedNormalVariadicOperandOp : TEST_Op<
|
||||
"mixed_normal_variadic_operand", [SameVariadicOperandSize]> {
|
||||
let arguments = (ins
|
||||
Variadic<AnyTensor>:$input1,
|
||||
AnyTensor:$input2,
|
||||
Variadic<AnyTensor>:$input3
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Results
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MixedNormalVariadicResults : TEST_Op<
|
||||
"mixed_normal_variadic_result", [SameVariadicResultSize]> {
|
||||
let results = (outs
|
||||
Variadic<AnyTensor>:$output1,
|
||||
AnyTensor:$output2,
|
||||
Variadic<AnyTensor>:$output3
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Attributes
|
||||
|
|
|
@ -26,10 +26,6 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
|
|||
// CHECK: assert(operands.size() == 1u && "mismatched number of parameters");
|
||||
// CHECK: tblgen_state->addOperands(operands);
|
||||
|
||||
// CHECK: LogicalResult OpA::verify() {
|
||||
// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32))))
|
||||
// CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer");
|
||||
|
||||
def OpB : NS_Op<"one_variadic_operand_op", []> {
|
||||
let arguments = (ins Variadic<I32>:$input);
|
||||
}
|
||||
|
@ -52,20 +48,6 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
|
|||
// CHECK-LABEL: ArrayRef<Value *> OpDOperandAdaptor::input3
|
||||
// CHECK-NEXT: return getODSOperands(2);
|
||||
|
||||
// TODO(b/134305899): Move to use TestDialect after fixing verification.
|
||||
|
||||
// CHECK-LABEL: Operation::operand_range OpD::getODSOperands(unsigned index)
|
||||
// CHECK-NEXT: bool isVariadic[] = {true, false, true};
|
||||
// CHECK-NEXT: int prevVariadicCount = 0;
|
||||
// CHECK-NEXT: for (int i = 0; i < index; ++i)
|
||||
// CHECK-NEXT: if (isVariadic[i]) ++prevVariadicCount;
|
||||
|
||||
// CHECK: int variadicSize = (getOperation()->getNumOperands() - 1) / 2;
|
||||
// CHECK: int offset = index + (variadicSize - 1) * prevVariadicCount;
|
||||
// CHECK-NEXT: int size = isVariadic[index] ? variadicSize : 1;
|
||||
|
||||
// CHECK: return {std::next(getOperation()->operand_begin(), offset), std::next(getOperation()->operand_begin(), offset + size)};
|
||||
|
||||
// CHECK-LABEL: Operation::operand_range OpD::input1
|
||||
// CHECK-NEXT: return getODSOperands(0);
|
||||
|
||||
|
|
|
@ -17,10 +17,6 @@ def OpA : NS_Op<"one_normal_result_op", []> {
|
|||
// CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types");
|
||||
// CHECK-NEXT: tblgen_state->addTypes(resultTypes);
|
||||
|
||||
// CHECK-LABEL: LogicalResult OpA::verify()
|
||||
// CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32))))
|
||||
// CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer");
|
||||
|
||||
def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
|
||||
let arguments = (ins I32:$x);
|
||||
let results = (outs I32:$y);
|
||||
|
@ -90,20 +86,6 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
|
|||
let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3);
|
||||
}
|
||||
|
||||
// TODO(b/134305899): Move to use TestDialect after fixing verification.
|
||||
|
||||
// CHECK-LABEL: Operation::result_range OpI::getODSResults(unsigned index)
|
||||
// CHECK-NEXT: bool isVariadic[] = {true, false, true};
|
||||
// CHECK-NEXT: int prevVariadicCount = 0;
|
||||
// CHECK-NEXT: for (int i = 0; i < index; ++i)
|
||||
// CHECK-NEXT: if (isVariadic[i]) ++prevVariadicCount;
|
||||
|
||||
// CHECK: int variadicSize = (getOperation()->getNumResults() - 1) / 2;
|
||||
// CHECK: int offset = index + (variadicSize - 1) * prevVariadicCount;
|
||||
// CHECK-NEXT: int size = isVariadic[index] ? variadicSize : 1;
|
||||
|
||||
// CHECK: return {std::next(getOperation()->result_begin(), offset), std::next(getOperation()->result_begin(), offset + size)};
|
||||
|
||||
// CHECK-LABEL: Operation::result_range OpI::output1
|
||||
// CHECK-NEXT: return getODSResults(0);
|
||||
|
||||
|
|
|
@ -16,7 +16,8 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: OpA::verify
|
||||
// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32) || this->getOperation()->getOperand(0)->getType().isF32())))
|
||||
// CHECK: for (Value *v : getODSOperands(0)) {
|
||||
// CHECK: if (!((v->getType().isInteger(32) || v->getType().isF32())))
|
||||
|
||||
def OpB : NS_Op<"op_for_And_PredOpTrait", [
|
||||
PredOpTrait<"both first and second holds",
|
||||
|
@ -103,4 +104,5 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: OpK::verify
|
||||
// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa<TensorType>())) && (((this->getOperation()->getOperand(0)->getType().cast<ShapedType>().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast<ShapedType>().getElementType().isInteger(32))))))
|
||||
// CHECK: for (Value *v : getODSOperands(0)) {
|
||||
// CHECK: if (!(((v->getType().isa<TensorType>())) && (((v->getType().cast<ShapedType>().getElementType().isF32())) || ((v->getType().cast<ShapedType>().getElementType().isInteger(32))))))
|
||||
|
|
|
@ -448,6 +448,12 @@ private:
|
|||
// Generates verify method for the operation.
|
||||
void genVerifier();
|
||||
|
||||
// Generates verify statements for operands and results in the operation.
|
||||
// The generated code will be attached to `body`.
|
||||
void genOperandResultVerifier(OpMethodBody &body,
|
||||
Operator::value_range values,
|
||||
StringRef valueKind);
|
||||
|
||||
// Generates verify statements for regions in the operation.
|
||||
// The generated code will be attached to `body`.
|
||||
void genRegionVerifier(OpMethodBody &body);
|
||||
|
@ -1022,39 +1028,8 @@ void OpEmitter::genVerifier() {
|
|||
body << " }\n";
|
||||
}
|
||||
|
||||
// Emits verification code for an operand or result.
|
||||
auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index,
|
||||
bool isOperand) -> void {
|
||||
// TODO: Handle variadic operand/result verification.
|
||||
if (value.isVariadic())
|
||||
return;
|
||||
|
||||
// TODO: Commonality between matchers could be extracted to have a more
|
||||
// concise code.
|
||||
if (value.hasPredicate()) {
|
||||
auto description = value.constraint.getDescription();
|
||||
body << " if (!("
|
||||
<< tgfmt(
|
||||
value.constraint.getConditionTemplate(),
|
||||
&verifyCtx.withSelf("this->getOperation()->get" +
|
||||
Twine(isOperand ? "Operand" : "Result") +
|
||||
"(" + Twine(index) + ")->getType()"))
|
||||
<< ")) {\n";
|
||||
body << " return emitOpError(\"" << (isOperand ? "operand" : "result")
|
||||
<< " #" << index
|
||||
<< (description.empty() ? " type precondition failed"
|
||||
: " must be " + Twine(description))
|
||||
<< "\");\n }\n";
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
verifyValue(op.getOperand(i), i, /*isOperand=*/true);
|
||||
}
|
||||
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
verifyValue(op.getResult(i), i, /*isOperand=*/false);
|
||||
}
|
||||
genOperandResultVerifier(body, op.getOperands(), "operand");
|
||||
genOperandResultVerifier(body, op.getResults(), "result");
|
||||
|
||||
for (auto &trait : op.getTraits()) {
|
||||
if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
|
||||
|
@ -1073,6 +1048,37 @@ void OpEmitter::genVerifier() {
|
|||
body << " return mlir::success();\n";
|
||||
}
|
||||
|
||||
void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
|
||||
Operator::value_range values,
|
||||
StringRef valueKind) {
|
||||
FmtContext fctx;
|
||||
unsigned i = 0;
|
||||
for (auto &staticValue : values) {
|
||||
if (!staticValue.hasPredicate())
|
||||
continue;
|
||||
|
||||
// Emit a loop to check all the dynamic values in the pack.
|
||||
body << formatv(" for (Value *v : getODS{0}{1}s({2})) {{\n",
|
||||
// Capitalize the first letter to match the function name
|
||||
valueKind.substr(0, 1).upper(), valueKind.substr(1), i);
|
||||
|
||||
auto description = staticValue.constraint.getDescription();
|
||||
body << " (void)v;\n";
|
||||
body << " if (!("
|
||||
<< tgfmt(staticValue.constraint.getConditionTemplate(),
|
||||
&fctx.withSelf("v->getType()"))
|
||||
<< "))\n";
|
||||
body << " return emitOpError(\""
|
||||
// TODO(b/129706806): Use the name of the operand/result here
|
||||
<< valueKind << " #" << i
|
||||
<< (description.empty() ? " type precondition failed"
|
||||
: " must be " + Twine(description))
|
||||
<< "\");\n";
|
||||
body << " }\n";
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
|
||||
unsigned numRegions = op.getNumRegions();
|
||||
|
||||
|
|
Loading…
Reference in New Issue