forked from mindspore-Ecosystem/mindspore
!49863 BatchToSpace算子 动态shape Infer报错
Merge pull request !49863 from Erpim/0306
This commit is contained in:
commit
243a6a973c
|
@ -81,14 +81,17 @@ class BatchToSpaceInfer : public abstract::OpInferBase {
|
||||||
auto shape_element = x_shape->cast<abstract::ShapePtr>();
|
auto shape_element = x_shape->cast<abstract::ShapePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(shape_element);
|
MS_EXCEPTION_IF_NULL(shape_element);
|
||||||
auto input_shape = shape_element->shape();
|
auto input_shape = shape_element->shape();
|
||||||
|
|
||||||
|
if (mindspore::IsDynamicRank(shape_element->shape())) {
|
||||||
|
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
||||||
|
}
|
||||||
|
|
||||||
const size_t input_rank = 4;
|
const size_t input_rank = 4;
|
||||||
if (input_shape.size() != input_rank) {
|
if (input_shape.size() != input_rank) {
|
||||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', rank of 'input_x' should be 4, but got "
|
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', rank of 'input_x' should be 4, but got "
|
||||||
<< shape_element->shape().size();
|
<< shape_element->shape().size();
|
||||||
}
|
}
|
||||||
if (mindspore::IsDynamicRank(shape_element->shape())) {
|
|
||||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
|
||||||
}
|
|
||||||
auto block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
|
auto block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
|
||||||
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
|
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
|
||||||
const size_t height_dim_index = 2;
|
const size_t height_dim_index = 2;
|
||||||
|
@ -97,6 +100,10 @@ class BatchToSpaceInfer : public abstract::OpInferBase {
|
||||||
output_shape[i] = input_shape[i];
|
output_shape[i] = input_shape[i];
|
||||||
}
|
}
|
||||||
for (size_t i = height_dim_index; i < input_rank; i++) {
|
for (size_t i = height_dim_index; i < input_rank; i++) {
|
||||||
|
if (input_shape[i] == abstract::Shape::kShapeDimAny) {
|
||||||
|
output_shape[i] = input_shape[i];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto x_block_prod = input_shape[i] * block_size;
|
auto x_block_prod = input_shape[i] * block_size;
|
||||||
auto crop_sum = crops[i - height_dim_index][0] + crops[i - height_dim_index][1];
|
auto crop_sum = crops[i - height_dim_index][0] + crops[i - height_dim_index][1];
|
||||||
if (x_block_prod <= crop_sum) {
|
if (x_block_prod <= crop_sum) {
|
||||||
|
@ -108,12 +115,15 @@ class BatchToSpaceInfer : public abstract::OpInferBase {
|
||||||
output_shape[i] = x_block_prod - crop_sum;
|
output_shape[i] = x_block_prod - crop_sum;
|
||||||
}
|
}
|
||||||
auto block_size_prod = block_size * block_size;
|
auto block_size_prod = block_size * block_size;
|
||||||
if (output_shape[0] % block_size_prod != 0) {
|
if (output_shape[0] != abstract::Shape::kShapeDimAny) {
|
||||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the shape of output with index 0 must be divided exactly "
|
if (output_shape[0] % block_size_prod != 0) {
|
||||||
<< "by square of 'block_size', but got the shape of output: " << output_shape
|
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||||
<< " and square of 'block_size': " << block_size_prod << ".";
|
<< "', the shape of output with index 0 must be divided exactly "
|
||||||
|
<< "by square of 'block_size', but got the shape of output: " << output_shape
|
||||||
|
<< " and square of 'block_size': " << block_size_prod << ".";
|
||||||
|
}
|
||||||
|
output_shape[0] = output_shape[0] / block_size_prod;
|
||||||
}
|
}
|
||||||
output_shape[0] = output_shape[0] / block_size_prod;
|
|
||||||
return std::make_shared<abstract::Shape>(output_shape);
|
return std::make_shared<abstract::Shape>(output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue