From 8503ffbe3af8a7b9c30f668353054a3721263374 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg@google.com>
Date: Tue, 1 Oct 2019 00:56:38 -0700
Subject: [PATCH] Add verification error message for ops that require at least
 one operand or result.

PiperOrigin-RevId: 272153634
---
 mlir/lib/IR/Operation.cpp            | 13 +++++++----
 mlir/test/IR/traits.mlir             | 35 ++++++++++++++++++++++++++++
 mlir/test/lib/TestDialect/TestOps.td | 10 ++++----
 3 files changed, 48 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 25302e5ff06e..27681d37f177 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -767,7 +767,7 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) {
 }
 
 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
-  if (op->getNumOperands() == 0)
+  if (failed(verifyAtLeastNOperands(op, 1)))
     return failure();
 
   auto type = op->getOperand(0)->getType();
@@ -779,7 +779,8 @@ LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
 }
 
 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
-  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+  if (failed(verifyAtLeastNOperands(op, 1)) ||
+      failed(verifyAtLeastNResults(op, 1)))
     return failure();
 
   auto type = op->getOperand(0)->getType();
@@ -797,7 +798,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
 }
 
 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) {
-  if (op->getNumOperands() == 0)
+  if (failed(verifyAtLeastNOperands(op, 1)))
     return failure();
 
   auto type = op->getOperand(0)->getType().dyn_cast<ShapedType>();
@@ -818,7 +819,8 @@ LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) {
 
 LogicalResult
 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
-  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+  if (failed(verifyAtLeastNOperands(op, 1)) ||
+      failed(verifyAtLeastNResults(op, 1)))
     return failure();
 
   auto type = op->getResult(0)->getType().dyn_cast<ShapedType>();
@@ -850,7 +852,8 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
 }
 
 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
-  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+  if (failed(verifyAtLeastNOperands(op, 1)) ||
+      failed(verifyAtLeastNResults(op, 1)))
     return failure();
 
   auto type = op->getResult(0)->getType();
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 40a4e963aa6f..dc8f6af57d72 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -45,6 +45,20 @@ func @failedSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: te
 
 // -----
 
+func @failedSameOperandAndResultElementType() {
+  // expected-error@+1 {{expected 1 or more operands}}
+  %0 = "test.same_operand_and_result_type"() : () -> tensor<1xf32>
+}
+
+// -----
+
+func @failedSameOperandAndResultElementType(%t1: tensor<1xf32>) {
+  // expected-error@+1 {{expected 1 or more results}}
+  "test.same_operand_and_result_type"(%t1) : (tensor<1xf32>) -> ()
+}
+
+// -----
+
 // CHECK: succeededSameOperandShape
 func @succeededSameOperandShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
   %0 = "test.same_operand_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> (tensor<10x10xf32>)
@@ -62,6 +76,13 @@ func @failedSameOperandShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>) {
 
 // -----
 
+func @failedSameOperandShape() {
+  // expected-error@+1 {{expected 1 or more operands}}
+  %0 = "test.same_operand_shape"() : () -> (tensor<1xf32>)
+}
+
+// -----
+
 // CHECK: succeededSameOperandAndResultShape
 func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
   %0 = "test.same_operand_and_result_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
@@ -79,6 +100,20 @@ func @failedSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1
 
 // -----
 
+func @failedSameOperandAndResultShape() {
+  // expected-error@+1 {{expected 1 or more operands}}
+  %0 = "test.same_operand_and_result_shape"() : () -> (tensor<1xf32>)
+}
+
+// -----
+
+func @failedSameOperandAndResultShape(%t1: tensor<1xf32>) {
+  // expected-error@+1 {{expected 1 or more results}}
+  "test.same_operand_and_result_shape"(%t1) : (tensor<1xf32>) -> ()
+}
+
+// -----
+
 func @hasParent() {
   "some.op"() ({
    // expected-error@+1 {{'test.child' op expects parent op 'test.parent'}}
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index e419b7ef3b12..944ce79a1822 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -219,19 +219,19 @@ def SameOperandElementTypeOp : TEST_Op<"same_operand_type",
 
 def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_type",
     [SameOperandsAndResultElementType]> {
-  let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
-  let results = (outs AnyVectorOrTensor:$res);
+  let arguments = (ins Variadic<AnyVectorOrTensor>:$args);
+  let results = (outs Variadic<AnyVectorOrTensor>:$res);
 }
 
 def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> {
-  let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
+  let arguments = (ins Variadic<AnyVectorOrTensor>:$args);
   let results = (outs AnyVectorOrTensor:$res);
 }
 
 def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
     [SameOperandsAndResultShape]> {
-  let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
-  let results = (outs AnyVectorOrTensor:$res);
+  let arguments = (ins Variadic<AnyVectorOrTensor>:$args);
+  let results = (outs Variadic<AnyVectorOrTensor>:$res);
 }
 
 def ArgAndResHaveFixedElementTypesOp :