forked from OSchip/llvm-project
Only forbid mixing tensor and vector when considering broadcasting behavior
The previous approach is too restrictive; we end up forbidding all dialect-specific types as element types. Changed to not consider element types entirely. -- PiperOrigin-RevId: 247486537
This commit is contained in:
parent
0e481bae68
commit
b0be00c746
|
@ -21,27 +21,6 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
/// Returns true if the given `type` supports NumPy broadcast semantics.
|
|
||||||
/// Specifically, the given `type` must be integer type, floating point type,
|
|
||||||
/// vector type, or ranked tensor type from integer or floating point types.
|
|
||||||
static bool isBroadcastableType(Type type) {
|
|
||||||
switch (type.getKind()) {
|
|
||||||
case StandardTypes::BF16:
|
|
||||||
case StandardTypes::F16:
|
|
||||||
case StandardTypes::F32:
|
|
||||||
case StandardTypes::F64:
|
|
||||||
case StandardTypes::Integer:
|
|
||||||
case StandardTypes::Vector:
|
|
||||||
return true;
|
|
||||||
case StandardTypes::RankedTensor:
|
|
||||||
case StandardTypes::UnrankedTensor:
|
|
||||||
return type.cast<TensorType>().getElementType().isIntOrFloat();
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
||||||
ArrayRef<int64_t> shape2,
|
ArrayRef<int64_t> shape2,
|
||||||
SmallVectorImpl<int64_t> &resultShape) {
|
SmallVectorImpl<int64_t> &resultShape) {
|
||||||
|
@ -98,15 +77,19 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the shape of the given type. Scalars will be considered as having a
|
||||||
|
/// shape with zero dimensions.
|
||||||
|
static ArrayRef<int64_t> getShape(Type type) {
|
||||||
|
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
||||||
|
return vtType.getShape();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the result broadcast composition type from the two given types by
|
/// Returns the result broadcast composition type from the two given types by
|
||||||
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
|
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
|
||||||
/// either of the input types has dynamic shape. Returns null type if the two
|
/// either of the input types has dynamic shape. Returns null type if the two
|
||||||
/// given types are not broadcast-compatible.
|
/// given types are not broadcast-compatible.
|
||||||
Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
||||||
// Make sure both types are able to participate in broadcasting.
|
|
||||||
if (!isBroadcastableType(type1) || !isBroadcastableType(type2))
|
|
||||||
return {};
|
|
||||||
|
|
||||||
// Returns the scalar type out of the given type.
|
// Returns the scalar type out of the given type.
|
||||||
auto getScalarType = [](Type type) -> Type {
|
auto getScalarType = [](Type type) -> Type {
|
||||||
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
||||||
|
@ -152,13 +135,6 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
||||||
resultCompositeKind = compositeKind2;
|
resultCompositeKind = compositeKind2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the shape of the given type.
|
|
||||||
auto getShape = [](Type type) -> ArrayRef<int64_t> {
|
|
||||||
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
|
||||||
return vtType.getShape();
|
|
||||||
return {};
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get the shape of each type.
|
// Get the shape of each type.
|
||||||
SmallVector<int64_t, 4> resultShape;
|
SmallVector<int64_t, 4> resultShape;
|
||||||
if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
|
if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
|
||||||
|
@ -172,16 +148,10 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
||||||
return scalarType;
|
return scalarType;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if the two given types are both vectors or ranked tensors and
|
/// Returns true if the given types has both vector types and tensor types.
|
||||||
/// they have the same shape, regardless of element types.
|
static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
|
||||||
static bool isSameShapedVectorOrTensor(Type type1, Type type2) {
|
return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
|
||||||
if (auto vType1 = type1.dyn_cast<RankedTensorType>())
|
llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
|
||||||
if (auto vType2 = type2.dyn_cast<RankedTensorType>())
|
|
||||||
return vType1.getShape() == vType2.getShape();
|
|
||||||
if (auto vType1 = type1.dyn_cast<VectorType>())
|
|
||||||
if (auto vType2 = type2.dyn_cast<VectorType>())
|
|
||||||
return vType1.getShape() == vType2.getShape();
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
|
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
|
||||||
|
@ -194,19 +164,28 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
|
||||||
auto type2 = op->getOperand(1)->getType();
|
auto type2 = op->getOperand(1)->getType();
|
||||||
auto retType = op->getResult(0)->getType();
|
auto retType = op->getResult(0)->getType();
|
||||||
|
|
||||||
auto broadcastedType = util::getBroadcastedType(type1, type2);
|
// We forbid broadcasting vector and tensor.
|
||||||
|
if (hasBothVectorAndTensorType({type1, type2, retType}))
|
||||||
|
return op->emitError("cannot broadcast vector with tensor");
|
||||||
|
|
||||||
if (!broadcastedType)
|
// Broadcasting unranked tensor with ranked/unranked tensor is allowed but
|
||||||
return op->emitOpError("operands don't have broadcast-compatible types");
|
// the result should be unranked tensor.
|
||||||
|
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
|
||||||
|
if (!retType.isa<UnrankedTensorType>())
|
||||||
|
return op->emitError(
|
||||||
|
"broadcast unranked tensor should result in unranked tensor");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
bool hasCompatRetType = (retType == broadcastedType) ||
|
SmallVector<int64_t, 4> resultShape;
|
||||||
retType.isa<UnrankedTensorType>() ||
|
if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
|
||||||
isSameShapedVectorOrTensor(retType, broadcastedType);
|
return op->emitOpError("operands don't have broadcast-compatible shapes");
|
||||||
if (!hasCompatRetType)
|
|
||||||
return op->emitOpError()
|
if (!retType.isa<UnrankedTensorType>() &&
|
||||||
<< "result type '" << retType
|
llvm::makeArrayRef(resultShape) != getShape(retType))
|
||||||
<< "' does not have the same shape as the broadcasted type '"
|
return op->emitOpError() << "result type '" << retType
|
||||||
<< broadcastedType << "' computed from the operand types";
|
<< "' does not have the same shape as the one "
|
||||||
|
"computed from the operand types";
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue