!19385 [MS][LITE] allow reduce op has only one input when axes param is not given

Merge pull request !19385 from XianglongZeng/myms_new_3
This commit is contained in:
i-robot 2021-07-11 06:40:21 +00:00 committed by Gitee
commit 33a797daf9
6 changed files with 5 additions and 11 deletions

View File

@ -52,7 +52,7 @@ int ReduceOnSelectedAxes(const TensorC *input, size_t num_axes, const int *actua
int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1);
int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2, 1);
if (check_ret != NNACL_OK) {
return check_ret;
}
@ -70,11 +70,11 @@ int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
bool keep_dims = param->keep_dims_;
int out_shape[MAX_SHAPE_SIZE] = {0};
const size_t out_shape_size = 0;
// get axes from input tensor
const TensorC *axes_input = inputs[1];
if (axes_input->shape_size_ == 1 && axes_input->shape_[0] == 0) {
if (inputs_size == 1) {
return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims);
}
// get axes from input tensor
const TensorC *axes_input = inputs[1];
int *axes = (int *)axes_input->data_;
if (axes == NULL) {
return NNACL_NULL_PTR;

View File

@ -59,8 +59,6 @@ int ReduceFp16CPUKernel::Init() {
return ReSize();
}
int ReduceFp16CPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); }
int ReduceFp16CPUKernel::CallReduceUnit(int task_id) {
auto ret =
reducer_(outer_size_, inner_size_, axis_size_, fp16_src_data_, fp16_dst_data_, task_id, op_parameter_->thread_num_);

View File

@ -36,7 +36,6 @@ class ReduceFp16CPUKernel : public ReduceBaseCPUKernel {
~ReduceFp16CPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
int CallReduceUnit(int task_id);

View File

@ -53,8 +53,6 @@ int ReduceCPUKernel::Init() {
return ReSize();
}
int ReduceCPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); }
int ReduceCPUKernel::CallReduceUnit(int task_id) {
if (data_type_ == kDataTypeFloat) {
if (!reducer_) {

View File

@ -54,7 +54,6 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel {
}
int Init() override;
int ReSize() override;
int Run() override;
int CallReduceUnit(int task_id);

View File

@ -34,11 +34,11 @@ ops::PrimitiveC *OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, con
for (int i = 0; i < size; ++i) {
axes.push_back(onnx_node_attr.ints(i));
}
prim->AddAttr("axes", MakeValue(axes));
} else if (attribute_name == "keepdims") {
prim->set_keep_dims(static_cast<bool>(onnx_node_attr.i()));
}
}
prim->AddAttr("axes", MakeValue(axes));
const auto &type = onnx_node.op_type();
if (type == "ReduceMean") {