[ODS] Support variadic operand/result verification

This CL enables verification code generation for variadic operands and results.
In verify(), we use fallback getter methods to access all the dynamic values
belonging to one static variadic operand/result to reuse the value range
calculation there.

PiperOrigin-RevId: 252288219
This commit is contained in:
Lei Zhang 2019-06-09 07:00:09 -07:00 committed by Mehdi Amini
parent 7f108e60cc
commit 3812d956ea
13 changed files with 152 additions and 86 deletions

View File

@ -255,9 +255,7 @@ class TypeAlias<Type t, string description = t.description> :
// class is used for supporting variadic operands/results. An op can declare no
// more than one variadic operand/result, and that operand/result must be the
// last one in the operand/result list.
class Variadic<Type type, string descr = "">
// TODO(b/132908002): support variadic type conditions
: TypeConstraint<CPred<"true">, descr> {
class Variadic<Type type> : TypeConstraint<type.predicate, type.description> {
Type baseType = type;
}
@ -907,6 +905,9 @@ def Terminator : NativeOpTrait<"IsTerminator">;
def FirstAttrDerivedResultType :
GenInternalOpTrait<"FirstAttrDerivedResultType">;
// TODO(antiagainst): Turn the following into normal traits and generate
// verification for them.
// All variadic operands of the op have the same number of values.
// A variadic operand contains an array of values whose array size is only
// known at runtime. This trait requires all variadic operands of an op

View File

@ -203,7 +203,9 @@ def LLVM_PtrToIntOp
// Call-related operations.
def LLVM_CallOp : LLVM_Op<"call">,
Arguments<(ins OptionalAttr<FunctionAttr>:$callee,
Variadic<LLVM_Type>)>,
// TODO(b/133216756): fix test failure and
// change to LLVM_Type
Variadic<AnyType>)>,
Results<(outs Variadic<LLVM_Type>)>,
LLVM_TwoBuilders<LLVM_OneResultOpBuilder,
LLVM_ZeroResultOpBuilder> {

View File

@ -69,11 +69,12 @@ public:
std::string getQualCppClassName() const;
using value_iterator = NamedTypeConstraint *;
using value_range = llvm::iterator_range<value_iterator>;
// Op result iterators.
value_iterator result_begin();
value_iterator result_end();
llvm::iterator_range<value_iterator> getResults();
value_range getResults();
// Returns the number of results this op produces.
int getNumResults() const;
@ -110,7 +111,7 @@ public:
// Op operand iterators.
value_iterator operand_begin();
value_iterator operand_end();
llvm::iterator_range<value_iterator> getOperands();
value_range getOperands();
int getNumOperands() const { return operands.size(); }
NamedTypeConstraint &getOperand(int index) { return operands[index]; }

View File

@ -1595,12 +1595,6 @@ static LogicalResult verify(ExtractElementOp op) {
if (op.getType() != aggregateType.getElementType())
return op.emitOpError("result type must match element type of aggregate");
// TODO(b/132908002) This should be covered by the op specification in
// tablegen, but for some reason it's not.
for (auto *idx : op.getIndices())
if (!idx->getType().isIndex())
return op.emitOpError("index to extract_element must have 'index' type");
// Verify the # indices match if we have a ranked type.
if (aggregateType.hasRank() &&
aggregateType.getRank() != op.getNumOperands() - 1)

View File

@ -95,7 +95,7 @@ auto tblgen::Operator::result_begin() -> value_iterator {
auto tblgen::Operator::result_end() -> value_iterator { return results.end(); }
auto tblgen::Operator::getResults() -> llvm::iterator_range<value_iterator> {
auto tblgen::Operator::getResults() -> value_range {
return {result_begin(), result_end()};
}
@ -205,7 +205,7 @@ auto tblgen::Operator::operand_begin() -> value_iterator {
auto tblgen::Operator::operand_end() -> value_iterator {
return operands.end();
}
auto tblgen::Operator::getOperands() -> llvm::iterator_range<value_iterator> {
auto tblgen::Operator::getOperands() -> value_range {
return {operand_begin(), operand_end()};
}

View File

@ -639,7 +639,7 @@ func @extract_element_no_indices(%v : vector<3xf32>) {
// -----
func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) {
// expected-error@+1 {{index to extract_element must have 'index' type}}
// expected-error@+1 {{operand #1 must be index}}
%0 = "std.extract_element"(%v, %i) : (vector<3xf32>, i32) -> f32
return
}

35
mlir/test/IR/operand.mlir Normal file
View File

@ -0,0 +1,35 @@
// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
//===----------------------------------------------------------------------===//
// Test mixed normal and variadic operands
//===----------------------------------------------------------------------===//
func @correct_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
// CHECK: mixed_normal_variadic_operand
"test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg0, %arg0) : (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
// -----
func @error_in_first_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
// expected-error @+1 {{operand #0 must be tensor of any type}}
"test.mixed_normal_variadic_operand"(%arg0, %arg1, %arg0, %arg0, %arg0) : (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
// -----
func @error_in_normal_operand(%arg0: tensor<f32>, %arg1: f32) {
// expected-error @+1 {{operand #1 must be tensor of any type}}
"test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg1, %arg0, %arg0) : (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>) -> ()
return
}
// -----
func @error_in_second_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
// expected-error @+1 {{operand #2 must be tensor of any type}}
"test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg1, %arg0) : (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>) -> ()
return
}

36
mlir/test/IR/result.mlir Normal file
View File

@ -0,0 +1,36 @@
// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
//===----------------------------------------------------------------------===//
// Test mixed normal and variadic results
//===----------------------------------------------------------------------===//
func @correct_variadic_result() -> tensor<f32> {
// CHECK: mixed_normal_variadic_result
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
return %0#4 : tensor<f32>
}
// -----
func @error_in_first_variadic_result() -> tensor<f32> {
// expected-error @+1 {{result #0 must be tensor of any type}}
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>)
return %0#4 : tensor<f32>
}
// -----
func @error_in_normal_result() -> tensor<f32> {
// expected-error @+1 {{result #1 must be tensor of any type}}
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>)
return %0#4 : tensor<f32>
}
// -----
func @error_in_second_variadic_result() -> tensor<f32> {
// expected-error @+1 {{result #2 must be tensor of any type}}
%0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>)
return %0#4 : tensor<f32>
}

View File

@ -60,6 +60,31 @@ def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> {
let results = (outs NestedTupleOf<[I32, F32]>);
}
//===----------------------------------------------------------------------===//
// Test Operands
//===----------------------------------------------------------------------===//
def MixedNormalVariadicOperandOp : TEST_Op<
"mixed_normal_variadic_operand", [SameVariadicOperandSize]> {
let arguments = (ins
Variadic<AnyTensor>:$input1,
AnyTensor:$input2,
Variadic<AnyTensor>:$input3
);
}
//===----------------------------------------------------------------------===//
// Test Results
//===----------------------------------------------------------------------===//
def MixedNormalVariadicResults : TEST_Op<
"mixed_normal_variadic_result", [SameVariadicResultSize]> {
let results = (outs
Variadic<AnyTensor>:$output1,
AnyTensor:$output2,
Variadic<AnyTensor>:$output3
);
}
//===----------------------------------------------------------------------===//
// Test Attributes

View File

@ -26,10 +26,6 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
// CHECK: assert(operands.size() == 1u && "mismatched number of parameters");
// CHECK: tblgen_state->addOperands(operands);
// CHECK: LogicalResult OpA::verify() {
// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32))))
// CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer");
def OpB : NS_Op<"one_variadic_operand_op", []> {
let arguments = (ins Variadic<I32>:$input);
}
@ -52,20 +48,6 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
// CHECK-LABEL: ArrayRef<Value *> OpDOperandAdaptor::input3
// CHECK-NEXT: return getODSOperands(2);
// TODO(b/134305899): Move to use TestDialect after fixing verification.
// CHECK-LABEL: Operation::operand_range OpD::getODSOperands(unsigned index)
// CHECK-NEXT: bool isVariadic[] = {true, false, true};
// CHECK-NEXT: int prevVariadicCount = 0;
// CHECK-NEXT: for (int i = 0; i < index; ++i)
// CHECK-NEXT: if (isVariadic[i]) ++prevVariadicCount;
// CHECK: int variadicSize = (getOperation()->getNumOperands() - 1) / 2;
// CHECK: int offset = index + (variadicSize - 1) * prevVariadicCount;
// CHECK-NEXT: int size = isVariadic[index] ? variadicSize : 1;
// CHECK: return {std::next(getOperation()->operand_begin(), offset), std::next(getOperation()->operand_begin(), offset + size)};
// CHECK-LABEL: Operation::operand_range OpD::input1
// CHECK-NEXT: return getODSOperands(0);

View File

@ -17,10 +17,6 @@ def OpA : NS_Op<"one_normal_result_op", []> {
// CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types");
// CHECK-NEXT: tblgen_state->addTypes(resultTypes);
// CHECK-LABEL: LogicalResult OpA::verify()
// CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32))))
// CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer");
def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
let arguments = (ins I32:$x);
let results = (outs I32:$y);
@ -90,20 +86,6 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3);
}
// TODO(b/134305899): Move to use TestDialect after fixing verification.
// CHECK-LABEL: Operation::result_range OpI::getODSResults(unsigned index)
// CHECK-NEXT: bool isVariadic[] = {true, false, true};
// CHECK-NEXT: int prevVariadicCount = 0;
// CHECK-NEXT: for (int i = 0; i < index; ++i)
// CHECK-NEXT: if (isVariadic[i]) ++prevVariadicCount;
// CHECK: int variadicSize = (getOperation()->getNumResults() - 1) / 2;
// CHECK: int offset = index + (variadicSize - 1) * prevVariadicCount;
// CHECK-NEXT: int size = isVariadic[index] ? variadicSize : 1;
// CHECK: return {std::next(getOperation()->result_begin(), offset), std::next(getOperation()->result_begin(), offset + size)};
// CHECK-LABEL: Operation::result_range OpI::output1
// CHECK-NEXT: return getODSResults(0);

View File

@ -16,7 +16,8 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
}
// CHECK-LABEL: OpA::verify
// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32) || this->getOperation()->getOperand(0)->getType().isF32())))
// CHECK: for (Value *v : getODSOperands(0)) {
// CHECK: if (!((v->getType().isInteger(32) || v->getType().isF32())))
def OpB : NS_Op<"op_for_And_PredOpTrait", [
PredOpTrait<"both first and second holds",
@ -103,4 +104,5 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
}
// CHECK-LABEL: OpK::verify
// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa<TensorType>())) && (((this->getOperation()->getOperand(0)->getType().cast<ShapedType>().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast<ShapedType>().getElementType().isInteger(32))))))
// CHECK: for (Value *v : getODSOperands(0)) {
// CHECK: if (!(((v->getType().isa<TensorType>())) && (((v->getType().cast<ShapedType>().getElementType().isF32())) || ((v->getType().cast<ShapedType>().getElementType().isInteger(32))))))

View File

@ -448,6 +448,12 @@ private:
// Generates verify method for the operation.
void genVerifier();
// Generates verify statements for operands and results in the operation.
// The generated code will be attached to `body`.
void genOperandResultVerifier(OpMethodBody &body,
Operator::value_range values,
StringRef valueKind);
// Generates verify statements for regions in the operation.
// The generated code will be attached to `body`.
void genRegionVerifier(OpMethodBody &body);
@ -1022,39 +1028,8 @@ void OpEmitter::genVerifier() {
body << " }\n";
}
// Emits verification code for an operand or result.
auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index,
bool isOperand) -> void {
// TODO: Handle variadic operand/result verification.
if (value.isVariadic())
return;
// TODO: Commonality between matchers could be extracted to have a more
// concise code.
if (value.hasPredicate()) {
auto description = value.constraint.getDescription();
body << " if (!("
<< tgfmt(
value.constraint.getConditionTemplate(),
&verifyCtx.withSelf("this->getOperation()->get" +
Twine(isOperand ? "Operand" : "Result") +
"(" + Twine(index) + ")->getType()"))
<< ")) {\n";
body << " return emitOpError(\"" << (isOperand ? "operand" : "result")
<< " #" << index
<< (description.empty() ? " type precondition failed"
: " must be " + Twine(description))
<< "\");\n }\n";
}
};
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
verifyValue(op.getOperand(i), i, /*isOperand=*/true);
}
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
verifyValue(op.getResult(i), i, /*isOperand=*/false);
}
genOperandResultVerifier(body, op.getOperands(), "operand");
genOperandResultVerifier(body, op.getResults(), "result");
for (auto &trait : op.getTraits()) {
if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
@ -1073,6 +1048,37 @@ void OpEmitter::genVerifier() {
body << " return mlir::success();\n";
}
void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
Operator::value_range values,
StringRef valueKind) {
FmtContext fctx;
unsigned i = 0;
for (auto &staticValue : values) {
if (!staticValue.hasPredicate())
continue;
// Emit a loop to check all the dynamic values in the pack.
body << formatv(" for (Value *v : getODS{0}{1}s({2})) {{\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1), i);
auto description = staticValue.constraint.getDescription();
body << " (void)v;\n";
body << " if (!("
<< tgfmt(staticValue.constraint.getConditionTemplate(),
&fctx.withSelf("v->getType()"))
<< "))\n";
body << " return emitOpError(\""
// TODO(b/129706806): Use the name of the operand/result here
<< valueKind << " #" << i
<< (description.empty() ? " type precondition failed"
: " must be " + Twine(description))
<< "\");\n";
body << " }\n";
++i;
}
}
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
unsigned numRegions = op.getNumRegions();