Only forbid mixing tensor and vector when considering broadcasting behavior

The previous approach is too restrictive; we end up forbidding all dialect-specific
    types as element types. Changed to not consider element types entirely.

--

PiperOrigin-RevId: 247486537
This commit is contained in:
Lei Zhang 2019-05-09 13:35:43 -07:00 committed by Mehdi Amini
parent 0e481bae68
commit b0be00c746
1 changed files with 32 additions and 53 deletions

View File

@ -21,27 +21,6 @@
using namespace mlir; using namespace mlir;
/// Returns true if the given `type` supports NumPy broadcast semantics.
/// Specifically, the given `type` must be integer type, floating point type,
/// vector type, or ranked tensor type from integer or floating point types.
static bool isBroadcastableType(Type type) {
switch (type.getKind()) {
case StandardTypes::BF16:
case StandardTypes::F16:
case StandardTypes::F32:
case StandardTypes::F64:
case StandardTypes::Integer:
case StandardTypes::Vector:
return true;
case StandardTypes::RankedTensor:
case StandardTypes::UnrankedTensor:
return type.cast<TensorType>().getElementType().isIntOrFloat();
default:
break;
}
return false;
}
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2, ArrayRef<int64_t> shape2,
SmallVectorImpl<int64_t> &resultShape) { SmallVectorImpl<int64_t> &resultShape) {
@ -98,15 +77,19 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
return true; return true;
} }
/// Returns the shape of the given type. Scalars will be considered as having a
/// shape with zero dimensions.
static ArrayRef<int64_t> getShape(Type type) {
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
return vtType.getShape();
return {};
}
/// Returns the result broadcast composition type from the two given types by /// Returns the result broadcast composition type from the two given types by
/// following NumPy broadcast semantics. Returned type may have dynamic shape if /// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two /// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible. /// given types are not broadcast-compatible.
Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
// Make sure both types are able to participate in broadcasting.
if (!isBroadcastableType(type1) || !isBroadcastableType(type2))
return {};
// Returns the scalar type out of the given type. // Returns the scalar type out of the given type.
auto getScalarType = [](Type type) -> Type { auto getScalarType = [](Type type) -> Type {
if (auto vtType = type.dyn_cast<VectorOrTensorType>()) if (auto vtType = type.dyn_cast<VectorOrTensorType>())
@ -152,13 +135,6 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
resultCompositeKind = compositeKind2; resultCompositeKind = compositeKind2;
} }
// Returns the shape of the given type.
auto getShape = [](Type type) -> ArrayRef<int64_t> {
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
return vtType.getShape();
return {};
};
// Get the shape of each type. // Get the shape of each type.
SmallVector<int64_t, 4> resultShape; SmallVector<int64_t, 4> resultShape;
if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
@ -172,16 +148,10 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
return scalarType; return scalarType;
} }
/// Returns true if the two given types are both vectors or ranked tensors and /// Returns true if the given types has both vector types and tensor types.
/// they have the same shape, regardless of element types. static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
static bool isSameShapedVectorOrTensor(Type type1, Type type2) { return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
if (auto vType1 = type1.dyn_cast<RankedTensorType>()) llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
if (auto vType2 = type2.dyn_cast<RankedTensorType>())
return vType1.getShape() == vType2.getShape();
if (auto vType1 = type1.dyn_cast<VectorType>())
if (auto vType2 = type2.dyn_cast<VectorType>())
return vType1.getShape() == vType2.getShape();
return false;
} }
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
@ -194,19 +164,28 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
auto type2 = op->getOperand(1)->getType(); auto type2 = op->getOperand(1)->getType();
auto retType = op->getResult(0)->getType(); auto retType = op->getResult(0)->getType();
auto broadcastedType = util::getBroadcastedType(type1, type2); // We forbid broadcasting vector and tensor.
if (hasBothVectorAndTensorType({type1, type2, retType}))
return op->emitError("cannot broadcast vector with tensor");
if (!broadcastedType) // Broadcasting unranked tensor with ranked/unranked tensor is allowed but
return op->emitOpError("operands don't have broadcast-compatible types"); // the result should be unranked tensor.
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
if (!retType.isa<UnrankedTensorType>())
return op->emitError(
"broadcast unranked tensor should result in unranked tensor");
return success();
}
bool hasCompatRetType = (retType == broadcastedType) || SmallVector<int64_t, 4> resultShape;
retType.isa<UnrankedTensorType>() || if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
isSameShapedVectorOrTensor(retType, broadcastedType); return op->emitOpError("operands don't have broadcast-compatible shapes");
if (!hasCompatRetType)
return op->emitOpError() if (!retType.isa<UnrankedTensorType>() &&
<< "result type '" << retType llvm::makeArrayRef(resultShape) != getShape(retType))
<< "' does not have the same shape as the broadcasted type '" return op->emitOpError() << "result type '" << retType
<< broadcastedType << "' computed from the operand types"; << "' does not have the same shape as the one "
"computed from the operand types";
return success(); return success();
} }