!6339 fix strided slice infershape error when stride is negative

Merge pull request !6339 from zhaozhenlong/lite/issue/stride_neg_infershape
This commit is contained in:
mindspore-ci-bot 2020-09-17 09:25:14 +08:00 committed by Gitee
commit 4d6bbd1218
2 changed files with 16 additions and 4 deletions

View File

@ -226,6 +226,17 @@ void StridedSlice::ApplyEndMask() {
}
}
void StridedSlice::TransIndexToPositive() {
for (int i = 0; i < static_cast<int>(begins_.size()); ++i) {
if (begins_.at(i) < 0) {
begins_.at(i) += in_shape_.at(i);
}
if (ends_.at(i) < 0) {
ends_.at(i) += in_shape_.at(i);
}
}
}
int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr);
if (outputs.size() != kStridedSliceOutputNum) {
@ -266,7 +277,7 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
return RET_INFER_ERR;
}
ndim_ = begin_tensor->ElementsNum();
for (int i=0; i< ndim_; ++i) {
for (int i = 0; i < ndim_; ++i) {
in_shape_.emplace_back(input_shape.at(i));
begins_.emplace_back(begin_data[i]);
ends_.emplace_back(end_data[i]);
@ -297,13 +308,13 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
output_shape.clear();
output_shape.resize(in_shape_.size());
TransIndexToPositive();
for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) {
if (i < ndim_ && new_axis_mask_.at(i)) {
output_shape.at(i) = 1;
} else if (ends_.at(i) > 0) {
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
} else {
output_shape.at(i) = (input_shape.at(i) + ends_.at(i) - begins_.at(i)) % input_shape.at(i) / strides_.at(i);
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
}
}

View File

@ -80,6 +80,7 @@ class StridedSlice : public PrimitiveC {
std::vector<bool> ellipsis_mask_;
std::vector<bool> new_axis_mask_;
std::vector<bool> shrink_axis_mask_;
void TransIndexToPositive();
};
} // namespace lite
} // namespace mindspore