!48126 fix broadcast check of batchmatmul

Merge pull request !48126 from zhoufeng/batchmatmul-broadcast-shape
This commit is contained in:
i-robot 2023-01-29 06:21:10 +00:00 committed by Gitee
commit 8a77bcf86e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 82 additions and 60 deletions

View File

@ -18,6 +18,7 @@
#include <map>
#include <string>
#include <memory>
#include <algorithm>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
@ -31,27 +32,23 @@ namespace mindspore {
namespace ops {
// batchmatmul
namespace {
constexpr size_t kMatSize = 2;
void BatchMatMulMakeShape(ShapeVector *output, const ShapeVector xshp, const ShapeVector yshp, bool transpose_a,
bool transpose_b, size_t offset) {
if (xshp.empty() || yshp.empty()) {
return;
}
if (xshp.size() != yshp.size()) {
ShapeVector broadcast_input = xshp.size() > yshp.size() ? xshp : yshp;
for (size_t i = 0; i < broadcast_input.size() - offset; i++) {
if (broadcast_input[i] < 0) {
output->push_back(abstract::Shape::kShapeDimAny);
} else {
output->push_back(broadcast_input[i]);
}
}
} else {
for (size_t i = 0; i < xshp.size() - offset; i++) {
if (xshp[i] < 0 || yshp[i] < 0) {
output->push_back(abstract::Shape::kShapeDimAny);
} else {
output->push_back(xshp[i] > yshp[i] ? xshp[i] : yshp[i]);
}
ShapeVector long_input = xshp.size() > yshp.size() ? xshp : yshp;
ShapeVector short_input = xshp.size() > yshp.size() ? yshp : xshp;
size_t size_diff = long_input.size() - short_input.size();
for (size_t i = 0; i < long_input.size() - offset; i++) {
if (long_input[i] < 0) {
output->push_back(abstract::Shape::kShapeDimAny);
} else if (i >= size_diff) {
output->push_back(long_input[i] > short_input[i - size_diff] ? long_input[i] : short_input[i - size_diff]);
} else {
output->push_back(long_input[i]);
}
}
size_t x_offset = xshp.size() - offset;
@ -61,6 +58,53 @@ void BatchMatMulMakeShape(ShapeVector *output, const ShapeVector xshp, const Sha
return;
}
void CheckBatchMatmulInputSize(const std::string &op_name, const std::string &input_name, const ShapeVector &shape) {
constexpr size_t dim_limit = 2;
if (shape.size() < dim_limit) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the input '" << input_name
<< "' must be a 2D or higher dimensional Tensor, but got " << shape.size() << "D shape "
<< shape;
}
}
void CheckBatchMatmulInputWhetherCanBeMul(const std::string &name, const ShapeVector &x_shape,
const ShapeVector &y_shape, bool transpose_a, bool transpose_b) {
ShapeVector x_mat_shape(x_shape.end() - kMatSize, x_shape.end());
ShapeVector y_mat_shape(y_shape.end() - kMatSize, y_shape.end());
int64_t x_col = x_mat_shape[static_cast<size_t>(!transpose_a)];
int64_t y_row = y_mat_shape[static_cast<size_t>(transpose_b)];
if (std::find(x_shape.begin(), x_shape.end(), -1) == x_shape.end() &&
std::find(y_shape.begin(), y_shape.end(), -1) == y_shape.end()) {
if (x_col != y_row) {
MS_EXCEPTION(ValueError) << "For " << name
<< ", the row of the input 'y' should be same as the col of the input 'x', with x shape "
<< x_shape << "(transpose_a=" << transpose_a << "), y shape " << y_shape
<< "(transpose_b=" << transpose_b << ")";
}
}
}
void CheckBatchMatmulInputWhetherCanBeBroadcast(const std::string &name, const ShapeVector &x_shape,
const ShapeVector &y_shape) {
ShapeVector x_batch(x_shape.begin(), x_shape.end() - kMatSize);
ShapeVector y_batch(y_shape.begin(), y_shape.end() - kMatSize);
if (x_batch == y_batch) {
return;
}
size_t min_size = std::min(x_batch.size(), y_batch.size());
for (size_t i = 0; i < min_size; ++i) {
auto x = *(x_batch.rbegin() + i);
auto y = *(y_batch.rbegin() + i);
if (x != 1 && y != 1 && x != y) {
MS_EXCEPTION(ValueError) << "For " << name
<< ", one of the input's batch dim must be equal to another input's peer batch dim, or "
"be equal to 1, or be empty, but are "
<< x << " and " << y << ", with x shape " << x_shape << ", y shape " << y_shape;
}
}
}
abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
@ -83,58 +127,24 @@ abstract::ShapePtr BatchMatmulInferShape(const PrimitivePtr &primitive,
if (IsDynamicRank(x_shp) || IsDynamicRank(y_shp)) {
return std::make_shared<abstract::Shape>(ShapeVector({abstract::Shape::kShapeRankAny}));
}
constexpr size_t x_dim_limit = 3;
constexpr size_t y_dim_limit = 2;
bool dynamic_shape = IsDynamic(x_shp) || IsDynamic(y_shp);
if (!dynamic_shape) {
if (x_shp.size() < x_dim_limit) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the input 'x' must be a 3D or higher dimensional Tensor, but got " << x_shp.size()
<< "D shape " << x_shp;
}
if (y_shp.size() < y_dim_limit) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the input 'y' must be a 2D or higher dimensional Tensor, but got " << y_shp.size()
<< "D shape " << y_shp;
}
}
constexpr size_t offset = 2;
std::vector<int> x_last(x_shp.end() - offset, x_shp.end());
std::vector<int> y_last(y_shp.end() - offset, y_shp.end());
ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a");
ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b");
bool transpose_a = GetValue<bool>(transpose_a_ptr);
bool transpose_b = GetValue<bool>(transpose_b_ptr);
int64_t x_col = x_last[static_cast<size_t>(!transpose_a)];
int64_t y_row = y_last[static_cast<size_t>(transpose_b)];
if (std::find(x_shp.begin(), x_shp.end(), -1) == x_shp.end() &&
std::find(y_shp.begin(), y_shp.end(), -1) == y_shp.end()) {
if (!dynamic_shape && x_col != y_row) {
MS_EXCEPTION(ValueError) << "For " << prim_name
<< ", the row of the input 'y' should be same as the col of the input 'x', with x shape "
<< x_shp << "(transpose_a=" << transpose_a << "), y shape " << y_shp
<< "(transpose_b=" << transpose_b << ")";
}
}
(void)primitive->AddAttr("transpose_x1", transpose_a_ptr);
(void)primitive->AddAttr("transpose_x2", transpose_b_ptr);
// Additional check for dynamic shape
// Last infer will be real shape values
bool dynamic_shape = IsDynamic(x_shp) || IsDynamic(y_shp);
if (!dynamic_shape) {
size_t x_offset = x_shp.size() - offset;
size_t y_offset = y_shp.size() - offset;
auto x_c = x_shp[x_offset + (transpose_a ? 0 : 1)];
auto y_r = y_shp[y_offset + (transpose_b ? 1 : 0)];
if (x_c != y_r) {
MS_LOG(EXCEPTION) << "For '" << prim_name << "', x_col must be equal to y_row. But got x_col: " << x_c
<< ", y_row: " << y_r << ".";
}
CheckBatchMatmulInputSize(prim_name, "x", x_shp);
CheckBatchMatmulInputSize(prim_name, "y", y_shp);
CheckBatchMatmulInputWhetherCanBeMul(prim_name, x_shp, y_shp, transpose_a, transpose_b);
CheckBatchMatmulInputWhetherCanBeBroadcast(prim_name, x_shp, y_shp);
}
ShapeVector ret_shape;
BatchMatMulMakeShape(&ret_shape, x_shp, y_shp, transpose_a, transpose_b, offset);
BatchMatMulMakeShape(&ret_shape, x_shp, y_shp, transpose_a, transpose_b, kMatSize);
return std::make_shared<abstract::Shape>(ret_shape);
}

View File

@ -180,6 +180,21 @@ test_case_check_ops = [
('BatchMatMul', {
'block': NetBatchMatMul(),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[3, 5, 4]))]}),
('BatchMatMul_broadcast_1', {
'block': NetBatchMatMul(),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[5, 4]))]}),
('BatchMatMul_broadcast_2', {
'block': NetBatchMatMul(),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[1, 5, 4]))]}),
('BatchMatMul_broadcast_3', {
'block': NetBatchMatMul(),
'desc_inputs': [Tensor(np.ones(shape=[2, 1, 1, 5])), Tensor(np.ones(shape=[1, 2, 5, 4]))]}),
('BatchMatMul_broadcast_4', {
'block': NetBatchMatMul(),
'desc_inputs': [Tensor(np.ones(shape=[2, 2, 1, 1, 5])), Tensor(np.ones(shape=[1, 2, 5, 4]))]}),
('BatchMatMul_broadcast_5', {
'block': NetBatchMatMul(),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[1, 3, 5, 4]))]}),
]
test_case_lists = [test_case_check_ops]
@ -250,9 +265,6 @@ raise_set = [
'block': (P.BatchMatMul(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[3, 3, 4]))]}),
('BatchMatMul_4_Error', {
'block': (P.BatchMatMul(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[1, 3, 5, 4]))]}),
('BatchMatMul_5_Error', {
'block': (P.BatchMatMul(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5])), Tensor(np.ones(shape=[2, 5, 4]))]}),
]