From 78df24ca0d6e9ae33045a100335064aaa184a679 Mon Sep 17 00:00:00 2001 From: jianghui58 Date: Fri, 11 Sep 2020 19:26:19 +0800 Subject: [PATCH] fix pad gather strided_slice infeshape bug --- mindspore/lite/src/ops/gather.cc | 2 +- mindspore/lite/src/ops/pad.cc | 7 ++----- mindspore/lite/src/ops/strided_slice.cc | 4 +++- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index ede7cc3ce9..69a29f8ff2 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -90,7 +90,7 @@ int Gather::InferShape(std::vector inputs_, std::vector outp } std::vector out_shape{in_shape}; out_shape.erase(out_shape.begin() + axis); - for (int i = 0; i < indices_rank; i++) { + for (int i = indices_rank - 1; i >= 0; --i) { out_shape.insert(out_shape.begin() + axis + i, indices_shape[i]); } output->set_shape(out_shape); diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index 0f0a5e6f53..6e80e16486 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -58,9 +58,6 @@ int Pad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl return RET_OK; } #endif -namespace { -const size_t kInputRank = 4; -} // namespace int Pad::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive_ != nullptr); if (this->primitive_ == nullptr) { @@ -84,9 +81,9 @@ int Pad::InferShape(std::vector inputs, std::vector outputs) } auto input_shape = input->shape(); std::vector output_shape; - MS_ASSERT(input->shape().size() <= kInputRank); + MS_ASSERT(input->shape().size() <= 4); for (size_t i = 0; i < input_shape.size(); i++) { - auto paddings_index = i + kInputRank - input_shape.size(); + auto paddings_index = i; auto shape = input_shape[i] + paddings[2 * paddings_index] + paddings[2 * paddings_index + 1]; output_shape.push_back(shape); } diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index bd56848c12..57c70b4c74 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -236,8 +236,10 @@ int StridedSlice::InferShape(std::vector inputs, std::vector(in_shape_.size()); i++) { if (i < ndim_ && new_axis_mask_.at(i)) { output_shape.at(i) = 1; - } else { + } else if (ends_.at(i) > 0) { output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); + } else { + output_shape.at(i) = (input_shape.at(i) + ends_.at(i) - begins_.at(i)) % input_shape.at(i) / strides_.at(i); } }