Add utility to accept any tensor type.

--

PiperOrigin-RevId: 247264423
This commit is contained in:
MLIR Team 2019-05-08 12:18:19 -07:00 committed by Mehdi Amini
parent a1b24a0e08
commit b4684e229b
2 changed files with 12 additions and 0 deletions

View File

@ -375,6 +375,11 @@ def F16Tensor : TypedTensor<F16>;
def F32Tensor : TypedTensor<F32>;
def F64Tensor : TypedTensor<F64>;
// Any tensor type whose element type is from the given
// `allowedTypes` list
class AnyTensorOf<list<Type> allowedTypes, string elementDescription = ""> :
TypedTensor<AnyTypeOf<allowedTypes, elementDescription>>;
// This represents a generic tuple without any constraints on elemental type,
// ranks, or size. As Tuples can contain tensors, vectors, or scalar values
// there is not only a single elemental type.

View File

@ -110,3 +110,10 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
// CHECK-SAME: [this](unsigned i) { return this->getOperand(i)->getType().cast<VectorOrTensorType>().getElementType(); },
// CHECK-SAME: llvm::ArrayRef<unsigned>({0, 2, 3})))
// CHECK: return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type");
def OpK : NS_Op<"op_for_AnyTensorOf", []> {
let arguments = (ins AnyTensorOf<[F32, I32]>:$x);
}
// CHECK-LABEL: OpK::verify
// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa<TensorType>())) && (((this->getOperation()->getOperand(0)->getType().cast<TensorType>().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast<TensorType>().getElementType().isInteger(32))))))