!6723 fix bug of cpu op arithmetic & softmax

Merge pull request !6723 from 陶云浩/lite
This commit is contained in:
mindspore-ci-bot 2020-09-23 19:12:39 +08:00 committed by Gitee
commit 62f08d9dda
4 changed files with 21 additions and 12 deletions

View File

@ -28,20 +28,20 @@ typedef struct ArithmeticParameter {
bool broadcasting_;
size_t ndim_;
int activation_type_;
int in_shape0_[5];
int in_shape0_[10];
int in_elements_num0_;
int in_shape1_[5];
int in_shape1_[10];
int in_elements_num1_;
int out_shape_[5];
int out_shape_[10];
int out_elements_num_;
int in_strides0_[5];
int in_strides1_[5];
int out_strides_[5];
int in_strides0_[10];
int in_strides1_[10];
int out_strides_[10];
int multiples0_[5];
int multiples1_[5];
int multiples0_[10];
int multiples1_[10];
} ArithmeticParameter;
#ifdef __cplusplus

View File

@ -24,7 +24,7 @@ typedef struct SoftmaxParameter {
int32_t axis_;
int element_size_;
int n_dim_;
int input_shape_[4];
int input_shape_[5];
} SoftmaxParameter;
#endif // MINDSPORE_LITE_NNACL_SOFTMAX_PARAMETER_H_

View File

@ -46,9 +46,14 @@ int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite
if (!GetInferFlag()) {
return RET_OK;
}
in_shape0_.resize(5);
in_shape1_.resize(5);
out_shape_.resize(5);
if (input_shape0.size() > 10 || input_shape1.size() > 10) {
int wrong_dim = input_shape0.size() > input_shape1.size() ? input_shape0.size() : input_shape1.size();
MS_LOG(ERROR) << "Not support input dim: " << wrong_dim << ", The input dim must be less than 10";
return RET_ERROR;
}
in_shape0_.resize(10);
in_shape1_.resize(10);
out_shape_.resize(10);
ndim_ = input_shape0.size();
if (input_shape0.size() < input_shape1.size()) {

View File

@ -82,6 +82,10 @@ int SoftMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
if (!GetInferFlag()) {
return RET_OK;
}
if (input->shape().size() > 5) {
MS_LOG(ERROR) << "Softmax input dim must be less than 5, get " << input->shape().size();
return RET_ERROR;
}
output->set_shape(input->shape());
return RET_OK;
}