forked from mindspore-Ecosystem/mindspore
!6723 fix bug of cpu op arithmetic & softmax
Merge pull request !6723 from 陶云浩/lite
This commit is contained in:
commit
62f08d9dda
|
@ -28,20 +28,20 @@ typedef struct ArithmeticParameter {
|
|||
bool broadcasting_;
|
||||
size_t ndim_;
|
||||
int activation_type_;
|
||||
int in_shape0_[5];
|
||||
int in_shape0_[10];
|
||||
int in_elements_num0_;
|
||||
int in_shape1_[5];
|
||||
int in_shape1_[10];
|
||||
int in_elements_num1_;
|
||||
|
||||
int out_shape_[5];
|
||||
int out_shape_[10];
|
||||
int out_elements_num_;
|
||||
|
||||
int in_strides0_[5];
|
||||
int in_strides1_[5];
|
||||
int out_strides_[5];
|
||||
int in_strides0_[10];
|
||||
int in_strides1_[10];
|
||||
int out_strides_[10];
|
||||
|
||||
int multiples0_[5];
|
||||
int multiples1_[5];
|
||||
int multiples0_[10];
|
||||
int multiples1_[10];
|
||||
} ArithmeticParameter;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -24,7 +24,7 @@ typedef struct SoftmaxParameter {
|
|||
int32_t axis_;
|
||||
int element_size_;
|
||||
int n_dim_;
|
||||
int input_shape_[4];
|
||||
int input_shape_[5];
|
||||
} SoftmaxParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_SOFTMAX_PARAMETER_H_
|
||||
|
|
|
@ -46,9 +46,14 @@ int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
in_shape0_.resize(5);
|
||||
in_shape1_.resize(5);
|
||||
out_shape_.resize(5);
|
||||
if (input_shape0.size() > 10 || input_shape1.size() > 10) {
|
||||
int wrong_dim = input_shape0.size() > input_shape1.size() ? input_shape0.size() : input_shape1.size();
|
||||
MS_LOG(ERROR) << "Not support input dim: " << wrong_dim << ", The input dim must be less than 10";
|
||||
return RET_ERROR;
|
||||
}
|
||||
in_shape0_.resize(10);
|
||||
in_shape1_.resize(10);
|
||||
out_shape_.resize(10);
|
||||
|
||||
ndim_ = input_shape0.size();
|
||||
if (input_shape0.size() < input_shape1.size()) {
|
||||
|
|
|
@ -82,6 +82,10 @@ int SoftMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (input->shape().size() > 5) {
|
||||
MS_LOG(ERROR) << "Softmax input dim must be less than 5, get " << input->shape().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
output->set_shape(input->shape());
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue