forked from OSchip/llvm-project
Pull shape broadcast out as a stand-alone utility function
So that we can use this function to deduce broadcasted shapes elsewhere. Also added support for unknown dimensions, by following TensorFlow behavior. PiperOrigin-RevId: 237846065
This commit is contained in:
parent
0cc212f2b7
commit
7972dcef84
|
@ -36,6 +36,24 @@ bool verifyCompatibleOperandBroadcast(const Instruction *op);
|
|||
} // namespace impl
|
||||
|
||||
namespace util {
|
||||
/// Returns the result broadcasted shape from the two given shapes. Returns
|
||||
/// llvm::None if the given two shapes are not broadcast compatible.
|
||||
///
|
||||
/// The rules for determing the result shape are:
|
||||
///
|
||||
/// Zip together the dimensions in the two given shapes by prepending the shape
|
||||
/// with less dimensions with 1s. For each dimension pair, deduces the result
|
||||
/// dimension according to the following order:
|
||||
/// - If there are unknown dimensions, follows the TensorFlow behavior:
|
||||
/// - If either dimension is greater than 1, we assume that the program is
|
||||
/// correct, and the other dimension will be broadcast to match it.
|
||||
/// - If either dimension is 1, the other dimension is the result.
|
||||
/// - Otherwise, the result dimension is unknown dimension.
|
||||
/// - If one of the dimension is 1, the other dimension is the result.
|
||||
/// - If two dimensions are the same, that's the result.
|
||||
/// - Otherwise, incompatible shape.
|
||||
Optional<SmallVector<int64_t, 4>> getBroadcastedShape(ArrayRef<int64_t> shape1,
|
||||
ArrayRef<int64_t> shape2);
|
||||
/// Returns the result broadcast composition type from the two given types by
|
||||
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
|
||||
/// either of the input types has dynamic shape. Returns null type if the two
|
||||
|
|
|
@ -42,6 +42,61 @@ static bool isBroadcastableType(Type type) {
|
|||
return false;
|
||||
}
|
||||
|
||||
Optional<SmallVector<int64_t, 4>>
|
||||
OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
||||
ArrayRef<int64_t> shape2) {
|
||||
// To compute the result broadcasted shape, we compare operand shapes
|
||||
// element-wise: starting with the trailing dimensions, and working the
|
||||
// way backward. Two dimensions are compatible when
|
||||
// 1. they are equal, or
|
||||
// 2. one of them is 1
|
||||
// The result shape has the maximum among the two inputs at every
|
||||
// dimension index.
|
||||
|
||||
SmallVector<int64_t, 4> resultShape;
|
||||
if (shape1.size() > shape2.size()) {
|
||||
std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
|
||||
} else {
|
||||
std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
|
||||
}
|
||||
|
||||
auto i1 = shape1.rbegin(), e1 = shape1.rend();
|
||||
auto i2 = shape2.rbegin(), e2 = shape2.rend();
|
||||
auto iR = resultShape.rbegin();
|
||||
|
||||
// Check each dimension is consistent.
|
||||
for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
|
||||
if (*i1 == -1 || *i2 == -1) {
|
||||
// One or both dimensions is unknown. Follow TensorFlow behavior:
|
||||
// - If either dimension is greater than 1, we assume that the program is
|
||||
// correct, and the other dimension will be broadcast to match it.
|
||||
// - If either dimension is 1, the other dimension is the output.
|
||||
if (*i1 > 1) {
|
||||
*iR = *i1;
|
||||
} else if (*i2 > 1) {
|
||||
*iR = *i2;
|
||||
} else if (*i1 == 1) {
|
||||
*iR = *i2;
|
||||
} else if (*i2 == 1) {
|
||||
*iR = *i1;
|
||||
} else {
|
||||
*iR = -1;
|
||||
}
|
||||
} else {
|
||||
if (*i1 == *i2 || *i2 == 1) {
|
||||
*iR = *i1;
|
||||
} else if (*i1 == 1) {
|
||||
*iR = *i2;
|
||||
} else {
|
||||
// This dimension of the two operand types is incompatible.
|
||||
return llvm::None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resultShape;
|
||||
}
|
||||
|
||||
/// Returns the result broadcast composition type from the two given types by
|
||||
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
|
||||
/// either of the input types has dynamic shape. Returns null type if the two
|
||||
|
@ -104,45 +159,15 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
|
|||
};
|
||||
|
||||
// Get the shape of each type.
|
||||
auto shape1 = getShape(type1);
|
||||
auto shape2 = getShape(type2);
|
||||
|
||||
// To compute the result broadcasted shape, we compare operand shapes
|
||||
// element-wise: starting with the trailing dimensions, and working the
|
||||
// way backward. Two dimensions are compatible when
|
||||
// 1. they are equal, or
|
||||
// 2. one of them is 1
|
||||
// The result shape has the maximum among the two inputs at every
|
||||
// dimension index.
|
||||
|
||||
SmallVector<int64_t, 4> resultShape;
|
||||
if (shape1.size() > shape2.size()) {
|
||||
std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
|
||||
} else {
|
||||
std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
|
||||
}
|
||||
|
||||
auto i1 = shape1.rbegin(), e1 = shape1.rend();
|
||||
auto i2 = shape2.rbegin(), e2 = shape2.rend();
|
||||
auto iR = resultShape.rbegin();
|
||||
|
||||
// Check each dimension is consistent.
|
||||
for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
|
||||
if (*i1 == *i2 || *i2 == 1) {
|
||||
*iR = *i1;
|
||||
} else if (*i1 == 1) {
|
||||
*iR = *i2;
|
||||
} else {
|
||||
// This dimension of the two operand types is incompatible.
|
||||
return {};
|
||||
}
|
||||
}
|
||||
auto resultShape = getBroadcastedShape(getShape(type1), getShape(type2));
|
||||
if (!resultShape)
|
||||
return {};
|
||||
|
||||
// Compose the final broadcasted type
|
||||
if (resultCompositeKind == StandardTypes::Vector)
|
||||
return VectorType::get(resultShape, scalarType);
|
||||
return VectorType::get(*resultShape, scalarType);
|
||||
if (resultCompositeKind == StandardTypes::RankedTensor)
|
||||
return RankedTensorType::get(resultShape, scalarType);
|
||||
return RankedTensorType::get(*resultShape, scalarType);
|
||||
return scalarType;
|
||||
}
|
||||
|
||||
|
|
|
@ -136,10 +136,9 @@ func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) ->
|
|||
|
||||
// -----
|
||||
|
||||
// Check incompatible operand types with unknown dimension
|
||||
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32> {
|
||||
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<?xi32>):
|
||||
// expected-error @+1 {{operands don't have broadcast-compatible types}}
|
||||
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function: "RELU6"} : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32>
|
||||
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function: "RELU6"} : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32>
|
||||
return %0 : tensor<4x3x2xi32>
|
||||
}
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
//===- BroadcastShapeTest.cpp - broadcasting shape unit tests -------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace mlir::OpTrait::util;
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
TEST(BroadcastShapeTest, CompatibleScalarAndScalar) {
|
||||
auto result = getBroadcastedShape({}, {});
|
||||
ASSERT_TRUE(result.hasValue());
|
||||
EXPECT_TRUE(result->empty());
|
||||
}
|
||||
|
||||
TEST(BroadcastShapeTest, Compatible0DAnd1DTensor) {
|
||||
auto result = getBroadcastedShape({}, {4});
|
||||
ASSERT_TRUE(result.hasValue());
|
||||
EXPECT_THAT(result.getValue(), ElementsAre(4));
|
||||
}
|
||||
|
||||
TEST(BroadcastShapeTest, Compatible0DAnd3DTensor) {
|
||||
auto result = getBroadcastedShape({}, {3, 5, 4});
|
||||
ASSERT_TRUE(result.hasValue());
|
||||
EXPECT_THAT(result.getValue(), ElementsAre(3, 5, 4));
|
||||
}
|
||||
|
||||
TEST(BroadcastShapeTest, CompatibleTensorAndTensor) {
|
||||
auto result = getBroadcastedShape({1, 7, 8, 9}, {8, 9});
|
||||
ASSERT_TRUE(result.hasValue());
|
||||
EXPECT_THAT(result.getValue(), ElementsAre(1, 7, 8, 9));
|
||||
}
|
||||
|
||||
TEST(BroadcastShapeTest, InterleavingOnes) {
|
||||
auto result = getBroadcastedShape({8, 1, 2, 1, 4}, {5, 1, 7, 1});
|
||||
ASSERT_TRUE(result.hasValue());
|
||||
EXPECT_THAT(result.getValue(), ElementsAre(8, 5, 2, 7, 4));
|
||||
}
|
||||
|
||||
TEST(BroadcastShapeTest, InterleavingUnknowns) {
|
||||
auto result = getBroadcastedShape({1, 2, -1, -1, -1}, {-1, -1, -1, 4, 1});
|
||||
EXPECT_TRUE(result.hasValue());
|
||||
EXPECT_THAT(result.getValue(), ElementsAre(-1, 2, -1, 4, -1));
|
||||
}
|
||||
|
||||
TEST(BroadcastShapeTest, IncompatibleLowDim) {
|
||||
auto result = getBroadcastedShape({4, 3, 5, 5}, {3, 5, 4});
|
||||
EXPECT_FALSE(result.hasValue());
|
||||
}
|
||||
|
||||
TEST(BroadcastShapeTest, IncompatibleMiddleDim) {
|
||||
auto result = getBroadcastedShape({4, 3, 5, 5}, {3, 7, 5});
|
||||
EXPECT_FALSE(result.hasValue());
|
||||
}
|
||||
|
||||
TEST(BroadcastShapeTest, IncompatibleHighDim) {
|
||||
auto result = getBroadcastedShape({3, 5, 5}, {4, 5, 5});
|
||||
EXPECT_FALSE(result.hasValue());
|
||||
}
|
Loading…
Reference in New Issue