forked from mindspore-Ecosystem/mindspore
!6339 fix strided slice infershape error when stride is negative
Merge pull request !6339 from zhaozhenlong/lite/issue/stride_neg_infershape
This commit is contained in:
commit
4d6bbd1218
|
@ -226,6 +226,17 @@ void StridedSlice::ApplyEndMask() {
|
|||
}
|
||||
}
|
||||
|
||||
void StridedSlice::TransIndexToPositive() {
|
||||
for (int i = 0; i < static_cast<int>(begins_.size()); ++i) {
|
||||
if (begins_.at(i) < 0) {
|
||||
begins_.at(i) += in_shape_.at(i);
|
||||
}
|
||||
if (ends_.at(i) < 0) {
|
||||
ends_.at(i) += in_shape_.at(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if (outputs.size() != kStridedSliceOutputNum) {
|
||||
|
@ -266,7 +277,7 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|||
return RET_INFER_ERR;
|
||||
}
|
||||
ndim_ = begin_tensor->ElementsNum();
|
||||
for (int i=0; i< ndim_; ++i) {
|
||||
for (int i = 0; i < ndim_; ++i) {
|
||||
in_shape_.emplace_back(input_shape.at(i));
|
||||
begins_.emplace_back(begin_data[i]);
|
||||
ends_.emplace_back(end_data[i]);
|
||||
|
@ -297,13 +308,13 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|||
|
||||
output_shape.clear();
|
||||
output_shape.resize(in_shape_.size());
|
||||
|
||||
TransIndexToPositive();
|
||||
for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) {
|
||||
if (i < ndim_ && new_axis_mask_.at(i)) {
|
||||
output_shape.at(i) = 1;
|
||||
} else if (ends_.at(i) > 0) {
|
||||
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
|
||||
} else {
|
||||
output_shape.at(i) = (input_shape.at(i) + ends_.at(i) - begins_.at(i)) % input_shape.at(i) / strides_.at(i);
|
||||
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -80,6 +80,7 @@ class StridedSlice : public PrimitiveC {
|
|||
std::vector<bool> ellipsis_mask_;
|
||||
std::vector<bool> new_axis_mask_;
|
||||
std::vector<bool> shrink_axis_mask_;
|
||||
void TransIndexToPositive();
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue