Pull shape broadcast out as a stand-alone utility function

So that we can use this function to deduce broadcasted shapes elsewhere.

Also added support for unknown dimensions, by following TensorFlow behavior.

PiperOrigin-RevId: 237846065
This commit is contained in:
Lei Zhang 2019-03-11 11:36:04 -07:00 committed by jpienaar
parent 0cc212f2b7
commit 7972dcef84
4 changed files with 153 additions and 37 deletions

View File

@ -36,6 +36,24 @@ bool verifyCompatibleOperandBroadcast(const Instruction *op);
} // namespace impl
namespace util {
/// Returns the result broadcasted shape from the two given shapes. Returns
/// llvm::None if the given two shapes are not broadcast compatible.
///
/// The rules for determing the result shape are:
///
/// Zip together the dimensions in the two given shapes by prepending the shape
/// with less dimensions with 1s. For each dimension pair, deduces the result
/// dimension according to the following order:
/// - If there are unknown dimensions, follows the TensorFlow behavior:
/// - If either dimension is greater than 1, we assume that the program is
/// correct, and the other dimension will be broadcast to match it.
/// - If either dimension is 1, the other dimension is the result.
/// - Otherwise, the result dimension is unknown dimension.
/// - If one of the dimension is 1, the other dimension is the result.
/// - If two dimensions are the same, that's the result.
/// - Otherwise, incompatible shape.
Optional<SmallVector<int64_t, 4>> getBroadcastedShape(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2);
/// Returns the result broadcast composition type from the two given types by
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two

View File

@ -42,6 +42,61 @@ static bool isBroadcastableType(Type type) {
return false;
}
Optional<SmallVector<int64_t, 4>>
OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2) {
// To compute the result broadcasted shape, we compare operand shapes
// element-wise: starting with the trailing dimensions, and working the
// way backward. Two dimensions are compatible when
// 1. they are equal, or
// 2. one of them is 1
// The result shape has the maximum among the two inputs at every
// dimension index.
SmallVector<int64_t, 4> resultShape;
if (shape1.size() > shape2.size()) {
std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
} else {
std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
}
auto i1 = shape1.rbegin(), e1 = shape1.rend();
auto i2 = shape2.rbegin(), e2 = shape2.rend();
auto iR = resultShape.rbegin();
// Check each dimension is consistent.
for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
if (*i1 == -1 || *i2 == -1) {
// One or both dimensions is unknown. Follow TensorFlow behavior:
// - If either dimension is greater than 1, we assume that the program is
// correct, and the other dimension will be broadcast to match it.
// - If either dimension is 1, the other dimension is the output.
if (*i1 > 1) {
*iR = *i1;
} else if (*i2 > 1) {
*iR = *i2;
} else if (*i1 == 1) {
*iR = *i2;
} else if (*i2 == 1) {
*iR = *i1;
} else {
*iR = -1;
}
} else {
if (*i1 == *i2 || *i2 == 1) {
*iR = *i1;
} else if (*i1 == 1) {
*iR = *i2;
} else {
// This dimension of the two operand types is incompatible.
return llvm::None;
}
}
}
return resultShape;
}
/// Returns the result broadcast composition type from the two given types by
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
@ -104,45 +159,15 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
};
// Get the shape of each type.
auto shape1 = getShape(type1);
auto shape2 = getShape(type2);
// To compute the result broadcasted shape, we compare operand shapes
// element-wise: starting with the trailing dimensions, and working the
// way backward. Two dimensions are compatible when
// 1. they are equal, or
// 2. one of them is 1
// The result shape has the maximum among the two inputs at every
// dimension index.
SmallVector<int64_t, 4> resultShape;
if (shape1.size() > shape2.size()) {
std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
} else {
std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
}
auto i1 = shape1.rbegin(), e1 = shape1.rend();
auto i2 = shape2.rbegin(), e2 = shape2.rend();
auto iR = resultShape.rbegin();
// Check each dimension is consistent.
for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
if (*i1 == *i2 || *i2 == 1) {
*iR = *i1;
} else if (*i1 == 1) {
*iR = *i2;
} else {
// This dimension of the two operand types is incompatible.
return {};
}
}
auto resultShape = getBroadcastedShape(getShape(type1), getShape(type2));
if (!resultShape)
return {};
// Compose the final broadcasted type
if (resultCompositeKind == StandardTypes::Vector)
return VectorType::get(resultShape, scalarType);
return VectorType::get(*resultShape, scalarType);
if (resultCompositeKind == StandardTypes::RankedTensor)
return RankedTensorType::get(resultShape, scalarType);
return RankedTensorType::get(*resultShape, scalarType);
return scalarType;
}

View File

@ -136,10 +136,9 @@ func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) ->
// -----
// Check incompatible operand types with unknown dimension
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32> {
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<?xi32>):
// expected-error @+1 {{operands don't have broadcast-compatible types}}
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function: "RELU6"} : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32>
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function: "RELU6"} : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32>
return %0 : tensor<4x3x2xi32>
}

View File

@ -0,0 +1,74 @@
//===- BroadcastShapeTest.cpp - broadcasting shape unit tests -------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Dialect/Traits.h"
#include "gmock/gmock.h"
using namespace mlir::OpTrait::util;
using ::testing::ElementsAre;
TEST(BroadcastShapeTest, CompatibleScalarAndScalar) {
auto result = getBroadcastedShape({}, {});
ASSERT_TRUE(result.hasValue());
EXPECT_TRUE(result->empty());
}
TEST(BroadcastShapeTest, Compatible0DAnd1DTensor) {
auto result = getBroadcastedShape({}, {4});
ASSERT_TRUE(result.hasValue());
EXPECT_THAT(result.getValue(), ElementsAre(4));
}
TEST(BroadcastShapeTest, Compatible0DAnd3DTensor) {
auto result = getBroadcastedShape({}, {3, 5, 4});
ASSERT_TRUE(result.hasValue());
EXPECT_THAT(result.getValue(), ElementsAre(3, 5, 4));
}
TEST(BroadcastShapeTest, CompatibleTensorAndTensor) {
auto result = getBroadcastedShape({1, 7, 8, 9}, {8, 9});
ASSERT_TRUE(result.hasValue());
EXPECT_THAT(result.getValue(), ElementsAre(1, 7, 8, 9));
}
TEST(BroadcastShapeTest, InterleavingOnes) {
auto result = getBroadcastedShape({8, 1, 2, 1, 4}, {5, 1, 7, 1});
ASSERT_TRUE(result.hasValue());
EXPECT_THAT(result.getValue(), ElementsAre(8, 5, 2, 7, 4));
}
TEST(BroadcastShapeTest, InterleavingUnknowns) {
auto result = getBroadcastedShape({1, 2, -1, -1, -1}, {-1, -1, -1, 4, 1});
EXPECT_TRUE(result.hasValue());
EXPECT_THAT(result.getValue(), ElementsAre(-1, 2, -1, 4, -1));
}
TEST(BroadcastShapeTest, IncompatibleLowDim) {
auto result = getBroadcastedShape({4, 3, 5, 5}, {3, 5, 4});
EXPECT_FALSE(result.hasValue());
}
TEST(BroadcastShapeTest, IncompatibleMiddleDim) {
auto result = getBroadcastedShape({4, 3, 5, 5}, {3, 7, 5});
EXPECT_FALSE(result.hasValue());
}
TEST(BroadcastShapeTest, IncompatibleHighDim) {
auto result = getBroadcastedShape({3, 5, 5}, {4, 5, 5});
EXPECT_FALSE(result.hasValue());
}