!49863 BatchToSpace算子 动态shape Infer报错

Merge pull request !49863 from Erpim/0306
This commit is contained in:
i-robot 2023-03-08 01:52:24 +00:00 committed by Gitee
commit 243a6a973c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 18 additions and 8 deletions

View File

@ -81,14 +81,17 @@ class BatchToSpaceInfer : public abstract::OpInferBase {
auto shape_element = x_shape->cast<abstract::ShapePtr>(); auto shape_element = x_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element); MS_EXCEPTION_IF_NULL(shape_element);
auto input_shape = shape_element->shape(); auto input_shape = shape_element->shape();
if (mindspore::IsDynamicRank(shape_element->shape())) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
}
const size_t input_rank = 4; const size_t input_rank = 4;
if (input_shape.size() != input_rank) { if (input_shape.size() != input_rank) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', rank of 'input_x' should be 4, but got " MS_EXCEPTION(ValueError) << "For '" << prim_name << "', rank of 'input_x' should be 4, but got "
<< shape_element->shape().size(); << shape_element->shape().size();
} }
if (mindspore::IsDynamicRank(shape_element->shape())) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
}
auto block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize)); auto block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops)); auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
const size_t height_dim_index = 2; const size_t height_dim_index = 2;
@ -97,6 +100,10 @@ class BatchToSpaceInfer : public abstract::OpInferBase {
output_shape[i] = input_shape[i]; output_shape[i] = input_shape[i];
} }
for (size_t i = height_dim_index; i < input_rank; i++) { for (size_t i = height_dim_index; i < input_rank; i++) {
if (input_shape[i] == abstract::Shape::kShapeDimAny) {
output_shape[i] = input_shape[i];
continue;
}
auto x_block_prod = input_shape[i] * block_size; auto x_block_prod = input_shape[i] * block_size;
auto crop_sum = crops[i - height_dim_index][0] + crops[i - height_dim_index][1]; auto crop_sum = crops[i - height_dim_index][0] + crops[i - height_dim_index][1];
if (x_block_prod <= crop_sum) { if (x_block_prod <= crop_sum) {
@ -108,12 +115,15 @@ class BatchToSpaceInfer : public abstract::OpInferBase {
output_shape[i] = x_block_prod - crop_sum; output_shape[i] = x_block_prod - crop_sum;
} }
auto block_size_prod = block_size * block_size; auto block_size_prod = block_size * block_size;
if (output_shape[0] % block_size_prod != 0) { if (output_shape[0] != abstract::Shape::kShapeDimAny) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the shape of output with index 0 must be divided exactly " if (output_shape[0] % block_size_prod != 0) {
<< "by square of 'block_size', but got the shape of output: " << output_shape MS_EXCEPTION(ValueError) << "For '" << prim_name
<< " and square of 'block_size': " << block_size_prod << "."; << "', the shape of output with index 0 must be divided exactly "
<< "by square of 'block_size', but got the shape of output: " << output_shape
<< " and square of 'block_size': " << block_size_prod << ".";
}
output_shape[0] = output_shape[0] / block_size_prod;
} }
output_shape[0] = output_shape[0] / block_size_prod;
return std::make_shared<abstract::Shape>(output_shape); return std::make_shared<abstract::Shape>(output_shape);
} }