forked from mindspore-Ecosystem/mindspore
!49863 BatchToSpace算子 动态shape Infer报错
Merge pull request !49863 from Erpim/0306
This commit is contained in:
commit
243a6a973c
|
@ -81,14 +81,17 @@ class BatchToSpaceInfer : public abstract::OpInferBase {
|
|||
auto shape_element = x_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
auto input_shape = shape_element->shape();
|
||||
|
||||
if (mindspore::IsDynamicRank(shape_element->shape())) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
||||
}
|
||||
|
||||
const size_t input_rank = 4;
|
||||
if (input_shape.size() != input_rank) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', rank of 'input_x' should be 4, but got "
|
||||
<< shape_element->shape().size();
|
||||
}
|
||||
if (mindspore::IsDynamicRank(shape_element->shape())) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
||||
}
|
||||
|
||||
auto block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
|
||||
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
|
||||
const size_t height_dim_index = 2;
|
||||
|
@ -97,6 +100,10 @@ class BatchToSpaceInfer : public abstract::OpInferBase {
|
|||
output_shape[i] = input_shape[i];
|
||||
}
|
||||
for (size_t i = height_dim_index; i < input_rank; i++) {
|
||||
if (input_shape[i] == abstract::Shape::kShapeDimAny) {
|
||||
output_shape[i] = input_shape[i];
|
||||
continue;
|
||||
}
|
||||
auto x_block_prod = input_shape[i] * block_size;
|
||||
auto crop_sum = crops[i - height_dim_index][0] + crops[i - height_dim_index][1];
|
||||
if (x_block_prod <= crop_sum) {
|
||||
|
@ -108,12 +115,15 @@ class BatchToSpaceInfer : public abstract::OpInferBase {
|
|||
output_shape[i] = x_block_prod - crop_sum;
|
||||
}
|
||||
auto block_size_prod = block_size * block_size;
|
||||
if (output_shape[0] != abstract::Shape::kShapeDimAny) {
|
||||
if (output_shape[0] % block_size_prod != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the shape of output with index 0 must be divided exactly "
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the shape of output with index 0 must be divided exactly "
|
||||
<< "by square of 'block_size', but got the shape of output: " << output_shape
|
||||
<< " and square of 'block_size': " << block_size_prod << ".";
|
||||
}
|
||||
output_shape[0] = output_shape[0] / block_size_prod;
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue