forked from OSchip/llvm-project
Only forbid mixing tensor and vector when considering broadcasting behavior
The previous approach is too restrictive; we end up forbidding all dialect-specific types as element types. Changed to not consider element types entirely. -- PiperOrigin-RevId: 247486537
This commit is contained in:
parent
0e481bae68
commit
b0be00c746
|
@ -21,27 +21,6 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
/// Returns true if the given `type` supports NumPy broadcast semantics.
|
||||
/// Specifically, the given `type` must be integer type, floating point type,
|
||||
/// vector type, or ranked tensor type from integer or floating point types.
|
||||
static bool isBroadcastableType(Type type) {
|
||||
switch (type.getKind()) {
|
||||
case StandardTypes::BF16:
|
||||
case StandardTypes::F16:
|
||||
case StandardTypes::F32:
|
||||
case StandardTypes::F64:
|
||||
case StandardTypes::Integer:
|
||||
case StandardTypes::Vector:
|
||||
return true;
|
||||
case StandardTypes::RankedTensor:
|
||||
case StandardTypes::UnrankedTensor:
|
||||
return type.cast<TensorType>().getElementType().isIntOrFloat();
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
||||
ArrayRef<int64_t> shape2,
|
||||
SmallVectorImpl<int64_t> &resultShape) {
|
||||
|
@ -98,15 +77,19 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Returns the shape of the given type. Scalars will be considered as having a
|
||||
/// shape with zero dimensions.
|
||||
static ArrayRef<int64_t> getShape(Type type) {
|
||||
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
||||
return vtType.getShape();
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Returns the result broadcast composition type from the two given types by
|
||||
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
|
||||
/// either of the input types has dynamic shape. Returns null type if the two
|
||||
/// given types are not broadcast-compatible.
|
||||
Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
||||
// Make sure both types are able to participate in broadcasting.
|
||||
if (!isBroadcastableType(type1) || !isBroadcastableType(type2))
|
||||
return {};
|
||||
|
||||
// Returns the scalar type out of the given type.
|
||||
auto getScalarType = [](Type type) -> Type {
|
||||
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
||||
|
@ -152,13 +135,6 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
|||
resultCompositeKind = compositeKind2;
|
||||
}
|
||||
|
||||
// Returns the shape of the given type.
|
||||
auto getShape = [](Type type) -> ArrayRef<int64_t> {
|
||||
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
||||
return vtType.getShape();
|
||||
return {};
|
||||
};
|
||||
|
||||
// Get the shape of each type.
|
||||
SmallVector<int64_t, 4> resultShape;
|
||||
if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
|
||||
|
@ -172,16 +148,10 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
|||
return scalarType;
|
||||
}
|
||||
|
||||
/// Returns true if the two given types are both vectors or ranked tensors and
|
||||
/// they have the same shape, regardless of element types.
|
||||
static bool isSameShapedVectorOrTensor(Type type1, Type type2) {
|
||||
if (auto vType1 = type1.dyn_cast<RankedTensorType>())
|
||||
if (auto vType2 = type2.dyn_cast<RankedTensorType>())
|
||||
return vType1.getShape() == vType2.getShape();
|
||||
if (auto vType1 = type1.dyn_cast<VectorType>())
|
||||
if (auto vType2 = type2.dyn_cast<VectorType>())
|
||||
return vType1.getShape() == vType2.getShape();
|
||||
return false;
|
||||
/// Returns true if the given types has both vector types and tensor types.
|
||||
static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
|
||||
return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
|
||||
llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
|
||||
|
@ -194,19 +164,28 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
|
|||
auto type2 = op->getOperand(1)->getType();
|
||||
auto retType = op->getResult(0)->getType();
|
||||
|
||||
auto broadcastedType = util::getBroadcastedType(type1, type2);
|
||||
// We forbid broadcasting vector and tensor.
|
||||
if (hasBothVectorAndTensorType({type1, type2, retType}))
|
||||
return op->emitError("cannot broadcast vector with tensor");
|
||||
|
||||
if (!broadcastedType)
|
||||
return op->emitOpError("operands don't have broadcast-compatible types");
|
||||
// Broadcasting unranked tensor with ranked/unranked tensor is allowed but
|
||||
// the result should be unranked tensor.
|
||||
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
|
||||
if (!retType.isa<UnrankedTensorType>())
|
||||
return op->emitError(
|
||||
"broadcast unranked tensor should result in unranked tensor");
|
||||
return success();
|
||||
}
|
||||
|
||||
bool hasCompatRetType = (retType == broadcastedType) ||
|
||||
retType.isa<UnrankedTensorType>() ||
|
||||
isSameShapedVectorOrTensor(retType, broadcastedType);
|
||||
if (!hasCompatRetType)
|
||||
return op->emitOpError()
|
||||
<< "result type '" << retType
|
||||
<< "' does not have the same shape as the broadcasted type '"
|
||||
<< broadcastedType << "' computed from the operand types";
|
||||
SmallVector<int64_t, 4> resultShape;
|
||||
if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
|
||||
return op->emitOpError("operands don't have broadcast-compatible shapes");
|
||||
|
||||
if (!retType.isa<UnrankedTensorType>() &&
|
||||
llvm::makeArrayRef(resultShape) != getShape(retType))
|
||||
return op->emitOpError() << "result type '" << retType
|
||||
<< "' does not have the same shape as the one "
|
||||
"computed from the operand types";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue