diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 25302e5ff06e..27681d37f177 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -767,7 +767,7 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) { } LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { - if (op->getNumOperands() == 0) + if (failed(verifyAtLeastNOperands(op, 1))) return failure(); auto type = op->getOperand(0)->getType(); @@ -779,7 +779,8 @@ LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { } LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { - if (op->getNumOperands() == 0 || op->getNumResults() == 0) + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) return failure(); auto type = op->getOperand(0)->getType(); @@ -797,7 +798,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { } LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { - if (op->getNumOperands() == 0) + if (failed(verifyAtLeastNOperands(op, 1))) return failure(); auto type = op->getOperand(0)->getType().dyn_cast(); @@ -818,7 +819,8 @@ LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { LogicalResult OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { - if (op->getNumOperands() == 0 || op->getNumResults() == 0) + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) return failure(); auto type = op->getResult(0)->getType().dyn_cast(); @@ -850,7 +852,8 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { } LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { - if (op->getNumOperands() == 0 || op->getNumResults() == 0) + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) return failure(); auto type = op->getResult(0)->getType(); diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 40a4e963aa6f..dc8f6af57d72 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -45,6 +45,20 @@ func @failedSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: te // ----- +func @failedSameOperandAndResultElementType() { + // expected-error@+1 {{expected 1 or more operands}} + %0 = "test.same_operand_and_result_type"() : () -> tensor<1xf32> +} + +// ----- + +func @failedSameOperandAndResultElementType(%t1: tensor<1xf32>) { + // expected-error@+1 {{expected 1 or more results}} + "test.same_operand_and_result_type"(%t1) : (tensor<1xf32>) -> () +} + +// ----- + // CHECK: succeededSameOperandShape func @succeededSameOperandShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) { %0 = "test.same_operand_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> (tensor<10x10xf32>) @@ -62,6 +76,13 @@ func @failedSameOperandShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>) { // ----- +func @failedSameOperandShape() { + // expected-error@+1 {{expected 1 or more operands}} + %0 = "test.same_operand_shape"() : () -> (tensor<1xf32>) +} + +// ----- + // CHECK: succeededSameOperandAndResultShape func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) { %0 = "test.same_operand_and_result_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> @@ -79,6 +100,20 @@ func @failedSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1 // ----- +func @failedSameOperandAndResultShape() { + // expected-error@+1 {{expected 1 or more operands}} + %0 = "test.same_operand_and_result_shape"() : () -> (tensor<1xf32>) +} + +// ----- + +func @failedSameOperandAndResultShape(%t1: tensor<1xf32>) { + // expected-error@+1 {{expected 1 or more results}} + "test.same_operand_and_result_shape"(%t1) : (tensor<1xf32>) -> () +} + +// ----- + func @hasParent() { "some.op"() ({ // expected-error@+1 {{'test.child' op expects parent op 'test.parent'}} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index e419b7ef3b12..944ce79a1822 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -219,19 +219,19 @@ def SameOperandElementTypeOp : TEST_Op<"same_operand_type", def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_type", [SameOperandsAndResultElementType]> { - let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y); - let results = (outs AnyVectorOrTensor:$res); + let arguments = (ins Variadic:$args); + let results = (outs Variadic:$res); } def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> { - let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y); + let arguments = (ins Variadic:$args); let results = (outs AnyVectorOrTensor:$res); } def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape", [SameOperandsAndResultShape]> { - let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y); - let results = (outs AnyVectorOrTensor:$res); + let arguments = (ins Variadic:$args); + let results = (outs Variadic:$res); } def ArgAndResHaveFixedElementTypesOp :