forked from OSSInnovation/mindspore
fix pad gather strided_slice infeshape bug
This commit is contained in:
parent
a26fdb83ee
commit
78df24ca0d
|
@ -90,7 +90,7 @@ int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
|
|||
}
|
||||
std::vector<int> out_shape{in_shape};
|
||||
out_shape.erase(out_shape.begin() + axis);
|
||||
for (int i = 0; i < indices_rank; i++) {
|
||||
for (int i = indices_rank - 1; i >= 0; --i) {
|
||||
out_shape.insert(out_shape.begin() + axis + i, indices_shape[i]);
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
|
|
|
@ -58,9 +58,6 @@ int Pad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
|
|||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
namespace {
|
||||
const size_t kInputRank = 4;
|
||||
} // namespace
|
||||
int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if (this->primitive_ == nullptr) {
|
||||
|
@ -84,9 +81,9 @@ int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs)
|
|||
}
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int> output_shape;
|
||||
MS_ASSERT(input->shape().size() <= kInputRank);
|
||||
MS_ASSERT(input->shape().size() <= 4);
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
auto paddings_index = i + kInputRank - input_shape.size();
|
||||
auto paddings_index = i;
|
||||
auto shape = input_shape[i] + paddings[2 * paddings_index] + paddings[2 * paddings_index + 1];
|
||||
output_shape.push_back(shape);
|
||||
}
|
||||
|
|
|
@ -236,8 +236,10 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|||
for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) {
|
||||
if (i < ndim_ && new_axis_mask_.at(i)) {
|
||||
output_shape.at(i) = 1;
|
||||
} else {
|
||||
} else if (ends_.at(i) > 0) {
|
||||
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
|
||||
} else {
|
||||
output_shape.at(i) = (input_shape.at(i) + ends_.at(i) - begins_.at(i)) % input_shape.at(i) / strides_.at(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue