fix pad gather strided_slice infeshape bug

This commit is contained in:
jianghui58 2020-09-11 19:26:19 +08:00
parent a26fdb83ee
commit 78df24ca0d
3 changed files with 6 additions and 7 deletions

View File

@ -90,7 +90,7 @@ int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
} }
std::vector<int> out_shape{in_shape}; std::vector<int> out_shape{in_shape};
out_shape.erase(out_shape.begin() + axis); out_shape.erase(out_shape.begin() + axis);
for (int i = 0; i < indices_rank; i++) { for (int i = indices_rank - 1; i >= 0; --i) {
out_shape.insert(out_shape.begin() + axis + i, indices_shape[i]); out_shape.insert(out_shape.begin() + axis + i, indices_shape[i]);
} }
output->set_shape(out_shape); output->set_shape(out_shape);

View File

@ -58,9 +58,6 @@ int Pad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
return RET_OK; return RET_OK;
} }
#endif #endif
namespace {
const size_t kInputRank = 4;
} // namespace
int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr); MS_ASSERT(this->primitive_ != nullptr);
if (this->primitive_ == nullptr) { if (this->primitive_ == nullptr) {
@ -84,9 +81,9 @@ int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs)
} }
auto input_shape = input->shape(); auto input_shape = input->shape();
std::vector<int> output_shape; std::vector<int> output_shape;
MS_ASSERT(input->shape().size() <= kInputRank); MS_ASSERT(input->shape().size() <= 4);
for (size_t i = 0; i < input_shape.size(); i++) { for (size_t i = 0; i < input_shape.size(); i++) {
auto paddings_index = i + kInputRank - input_shape.size(); auto paddings_index = i;
auto shape = input_shape[i] + paddings[2 * paddings_index] + paddings[2 * paddings_index + 1]; auto shape = input_shape[i] + paddings[2 * paddings_index] + paddings[2 * paddings_index + 1];
output_shape.push_back(shape); output_shape.push_back(shape);
} }

View File

@ -236,8 +236,10 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) { for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) {
if (i < ndim_ && new_axis_mask_.at(i)) { if (i < ndim_ && new_axis_mask_.at(i)) {
output_shape.at(i) = 1; output_shape.at(i) = 1;
} else { } else if (ends_.at(i) > 0) {
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
} else {
output_shape.at(i) = (input_shape.at(i) + ends_.at(i) - begins_.at(i)) % input_shape.at(i) / strides_.at(i);
} }
} }