2019-01-17 00:43:45 +08:00
|
|
|
//===- Traits.cpp - Common op traits shared by dialects -------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
#include "mlir/Dialect/Traits.h"
|
|
|
|
#include "mlir/IR/StandardTypes.h"
|
2019-02-11 06:14:08 +08:00
|
|
|
#include "llvm/Support/FormatVariadic.h"
|
2019-01-17 00:43:45 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
2019-03-21 00:01:58 +08:00
|
|
|
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
|
|
|
ArrayRef<int64_t> shape2,
|
|
|
|
SmallVectorImpl<int64_t> &resultShape) {
|
2019-03-12 02:36:04 +08:00
|
|
|
// To compute the result broadcasted shape, we compare operand shapes
|
|
|
|
// element-wise: starting with the trailing dimensions, and working the
|
|
|
|
// way backward. Two dimensions are compatible when
|
|
|
|
// 1. they are equal, or
|
|
|
|
// 2. one of them is 1
|
|
|
|
// The result shape has the maximum among the two inputs at every
|
|
|
|
// dimension index.
|
|
|
|
|
2019-03-21 00:01:58 +08:00
|
|
|
resultShape.clear();
|
2019-03-12 02:36:04 +08:00
|
|
|
if (shape1.size() > shape2.size()) {
|
|
|
|
std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
|
|
|
|
} else {
|
|
|
|
std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
|
|
|
|
}
|
|
|
|
|
|
|
|
auto i1 = shape1.rbegin(), e1 = shape1.rend();
|
|
|
|
auto i2 = shape2.rbegin(), e2 = shape2.rend();
|
|
|
|
auto iR = resultShape.rbegin();
|
|
|
|
|
|
|
|
// Check each dimension is consistent.
|
|
|
|
for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
|
|
|
|
if (*i1 == -1 || *i2 == -1) {
|
|
|
|
// One or both dimensions is unknown. Follow TensorFlow behavior:
|
|
|
|
// - If either dimension is greater than 1, we assume that the program is
|
|
|
|
// correct, and the other dimension will be broadcast to match it.
|
|
|
|
// - If either dimension is 1, the other dimension is the output.
|
|
|
|
if (*i1 > 1) {
|
|
|
|
*iR = *i1;
|
|
|
|
} else if (*i2 > 1) {
|
|
|
|
*iR = *i2;
|
|
|
|
} else if (*i1 == 1) {
|
|
|
|
*iR = *i2;
|
|
|
|
} else if (*i2 == 1) {
|
|
|
|
*iR = *i1;
|
|
|
|
} else {
|
|
|
|
*iR = -1;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if (*i1 == *i2 || *i2 == 1) {
|
|
|
|
*iR = *i1;
|
|
|
|
} else if (*i1 == 1) {
|
|
|
|
*iR = *i2;
|
|
|
|
} else {
|
|
|
|
// This dimension of the two operand types is incompatible.
|
2019-03-21 00:01:58 +08:00
|
|
|
resultShape.clear();
|
|
|
|
return false;
|
2019-03-12 02:36:04 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-03-21 00:01:58 +08:00
|
|
|
return true;
|
2019-03-12 02:36:04 +08:00
|
|
|
}
|
|
|
|
|
2019-05-10 04:35:43 +08:00
|
|
|
/// Returns the shape of the given type. Scalars will be considered as having a
|
|
|
|
/// shape with zero dimensions.
|
|
|
|
static ArrayRef<int64_t> getShape(Type type) {
|
|
|
|
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
|
|
|
return vtType.getShape();
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
2019-01-17 00:43:45 +08:00
|
|
|
/// Returns the result broadcast composition type from the two given types by
|
2019-02-08 04:56:12 +08:00
|
|
|
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
|
|
|
|
/// either of the input types has dynamic shape. Returns null type if the two
|
|
|
|
/// given types are not broadcast-compatible.
|
2019-01-17 00:43:45 +08:00
|
|
|
Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
|
|
|
// Returns the scalar type out of the given type.
|
|
|
|
auto getScalarType = [](Type type) -> Type {
|
|
|
|
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
|
|
|
|
return vtType.getElementType();
|
|
|
|
return type;
|
|
|
|
};
|
|
|
|
|
|
|
|
// Make sure underlying scalar type is the same.
|
|
|
|
auto scalarType = getScalarType(type1);
|
|
|
|
if (scalarType != getScalarType(type2))
|
|
|
|
return {};
|
|
|
|
|
2019-02-08 04:56:12 +08:00
|
|
|
// If one of the types is unranked tensor, then the other type shouldn't be
|
|
|
|
// vector and the result should have unranked tensor type.
|
|
|
|
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
|
|
|
|
if (type1.isa<VectorType>() || type2.isa<VectorType>())
|
|
|
|
return {};
|
|
|
|
return UnrankedTensorType::get(scalarType);
|
|
|
|
}
|
|
|
|
|
2019-01-17 00:43:45 +08:00
|
|
|
// Returns the type kind if the given type is a vector or ranked tensor type.
|
|
|
|
// Returns llvm::None otherwise.
|
|
|
|
auto getCompositeTypeKind =
|
|
|
|
[](Type type) -> llvm::Optional<StandardTypes::Kind> {
|
|
|
|
if (type.isa<VectorType>() || type.isa<RankedTensorType>())
|
|
|
|
return static_cast<StandardTypes::Kind>(type.getKind());
|
|
|
|
return llvm::None;
|
|
|
|
};
|
|
|
|
|
|
|
|
// Make sure the composite type, if has, is consistent.
|
|
|
|
auto compositeKind1 = getCompositeTypeKind(type1);
|
|
|
|
auto compositeKind2 = getCompositeTypeKind(type2);
|
|
|
|
llvm::Optional<StandardTypes::Kind> resultCompositeKind;
|
|
|
|
|
|
|
|
if (compositeKind1 && compositeKind2) {
|
|
|
|
// Disallow mixing vector and tensor.
|
|
|
|
if (compositeKind1 != compositeKind2)
|
|
|
|
return {};
|
|
|
|
resultCompositeKind = compositeKind1;
|
|
|
|
} else if (compositeKind1) {
|
|
|
|
resultCompositeKind = compositeKind1;
|
|
|
|
} else if (compositeKind2) {
|
|
|
|
resultCompositeKind = compositeKind2;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get the shape of each type.
|
2019-03-21 00:01:58 +08:00
|
|
|
SmallVector<int64_t, 4> resultShape;
|
|
|
|
if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
|
2019-03-12 02:36:04 +08:00
|
|
|
return {};
|
2019-01-17 00:43:45 +08:00
|
|
|
|
|
|
|
// Compose the final broadcasted type
|
|
|
|
if (resultCompositeKind == StandardTypes::Vector)
|
2019-03-21 00:01:58 +08:00
|
|
|
return VectorType::get(resultShape, scalarType);
|
2019-01-17 00:43:45 +08:00
|
|
|
if (resultCompositeKind == StandardTypes::RankedTensor)
|
2019-03-21 00:01:58 +08:00
|
|
|
return RankedTensorType::get(resultShape, scalarType);
|
2019-01-17 00:43:45 +08:00
|
|
|
return scalarType;
|
|
|
|
}
|
|
|
|
|
2019-05-10 04:35:43 +08:00
|
|
|
/// Returns true if the given types has both vector types and tensor types.
|
|
|
|
static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
|
|
|
|
return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
|
|
|
|
llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
|
2019-03-12 02:36:20 +08:00
|
|
|
}
|
|
|
|
|
2019-04-03 04:09:34 +08:00
|
|
|
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
|
2019-01-17 00:43:45 +08:00
|
|
|
assert(op->getNumOperands() == 2 &&
|
|
|
|
"only support broadcast check on two operands");
|
|
|
|
assert(op->getNumResults() == 1 &&
|
|
|
|
"only support broadcast check on one result");
|
|
|
|
|
|
|
|
auto type1 = op->getOperand(0)->getType();
|
|
|
|
auto type2 = op->getOperand(1)->getType();
|
|
|
|
auto retType = op->getResult(0)->getType();
|
|
|
|
|
2019-05-10 04:35:43 +08:00
|
|
|
// We forbid broadcasting vector and tensor.
|
|
|
|
if (hasBothVectorAndTensorType({type1, type2, retType}))
|
|
|
|
return op->emitError("cannot broadcast vector with tensor");
|
2019-01-17 00:43:45 +08:00
|
|
|
|
2019-05-10 04:35:43 +08:00
|
|
|
// Broadcasting unranked tensor with ranked/unranked tensor is allowed but
|
|
|
|
// the result should be unranked tensor.
|
|
|
|
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
|
|
|
|
if (!retType.isa<UnrankedTensorType>())
|
|
|
|
return op->emitError(
|
|
|
|
"broadcast unranked tensor should result in unranked tensor");
|
|
|
|
return success();
|
|
|
|
}
|
2019-03-12 02:36:20 +08:00
|
|
|
|
2019-05-10 04:35:43 +08:00
|
|
|
SmallVector<int64_t, 4> resultShape;
|
|
|
|
if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
|
|
|
|
return op->emitOpError("operands don't have broadcast-compatible shapes");
|
|
|
|
|
|
|
|
if (!retType.isa<UnrankedTensorType>() &&
|
|
|
|
llvm::makeArrayRef(resultShape) != getShape(retType))
|
|
|
|
return op->emitOpError() << "result type '" << retType
|
|
|
|
<< "' does not have the same shape as the one "
|
|
|
|
"computed from the operand types";
|
2019-01-17 00:43:45 +08:00
|
|
|
|
2019-04-03 04:09:34 +08:00
|
|
|
return success();
|
2019-01-17 00:43:45 +08:00
|
|
|
}
|