forked from mindspore-Ecosystem/mindspore
!19385 [MS][LITE] allow reduce op has only one input when axes param is not given
Merge pull request !19385 from XianglongZeng/myms_new_3
This commit is contained in:
commit
33a797daf9
|
@ -52,7 +52,7 @@ int ReduceOnSelectedAxes(const TensorC *input, size_t num_axes, const int *actua
|
||||||
|
|
||||||
int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||||
OpParameter *parameter) {
|
OpParameter *parameter) {
|
||||||
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1);
|
int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2, 1);
|
||||||
if (check_ret != NNACL_OK) {
|
if (check_ret != NNACL_OK) {
|
||||||
return check_ret;
|
return check_ret;
|
||||||
}
|
}
|
||||||
|
@ -70,11 +70,11 @@ int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
||||||
bool keep_dims = param->keep_dims_;
|
bool keep_dims = param->keep_dims_;
|
||||||
int out_shape[MAX_SHAPE_SIZE] = {0};
|
int out_shape[MAX_SHAPE_SIZE] = {0};
|
||||||
const size_t out_shape_size = 0;
|
const size_t out_shape_size = 0;
|
||||||
// get axes from input tensor
|
if (inputs_size == 1) {
|
||||||
const TensorC *axes_input = inputs[1];
|
|
||||||
if (axes_input->shape_size_ == 1 && axes_input->shape_[0] == 0) {
|
|
||||||
return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims);
|
return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims);
|
||||||
}
|
}
|
||||||
|
// get axes from input tensor
|
||||||
|
const TensorC *axes_input = inputs[1];
|
||||||
int *axes = (int *)axes_input->data_;
|
int *axes = (int *)axes_input->data_;
|
||||||
if (axes == NULL) {
|
if (axes == NULL) {
|
||||||
return NNACL_NULL_PTR;
|
return NNACL_NULL_PTR;
|
||||||
|
|
|
@ -59,8 +59,6 @@ int ReduceFp16CPUKernel::Init() {
|
||||||
return ReSize();
|
return ReSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
int ReduceFp16CPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); }
|
|
||||||
|
|
||||||
int ReduceFp16CPUKernel::CallReduceUnit(int task_id) {
|
int ReduceFp16CPUKernel::CallReduceUnit(int task_id) {
|
||||||
auto ret =
|
auto ret =
|
||||||
reducer_(outer_size_, inner_size_, axis_size_, fp16_src_data_, fp16_dst_data_, task_id, op_parameter_->thread_num_);
|
reducer_(outer_size_, inner_size_, axis_size_, fp16_src_data_, fp16_dst_data_, task_id, op_parameter_->thread_num_);
|
||||||
|
|
|
@ -36,7 +36,6 @@ class ReduceFp16CPUKernel : public ReduceBaseCPUKernel {
|
||||||
~ReduceFp16CPUKernel() = default;
|
~ReduceFp16CPUKernel() = default;
|
||||||
|
|
||||||
int Init() override;
|
int Init() override;
|
||||||
int ReSize() override;
|
|
||||||
int Run() override;
|
int Run() override;
|
||||||
int CallReduceUnit(int task_id);
|
int CallReduceUnit(int task_id);
|
||||||
|
|
||||||
|
|
|
@ -53,8 +53,6 @@ int ReduceCPUKernel::Init() {
|
||||||
return ReSize();
|
return ReSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
int ReduceCPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); }
|
|
||||||
|
|
||||||
int ReduceCPUKernel::CallReduceUnit(int task_id) {
|
int ReduceCPUKernel::CallReduceUnit(int task_id) {
|
||||||
if (data_type_ == kDataTypeFloat) {
|
if (data_type_ == kDataTypeFloat) {
|
||||||
if (!reducer_) {
|
if (!reducer_) {
|
||||||
|
|
|
@ -54,7 +54,6 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
int Init() override;
|
int Init() override;
|
||||||
int ReSize() override;
|
|
||||||
int Run() override;
|
int Run() override;
|
||||||
int CallReduceUnit(int task_id);
|
int CallReduceUnit(int task_id);
|
||||||
|
|
||||||
|
|
|
@ -34,11 +34,11 @@ ops::PrimitiveC *OnnxReduceParser::Parse(const onnx::GraphProto &onnx_graph, con
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
axes.push_back(onnx_node_attr.ints(i));
|
axes.push_back(onnx_node_attr.ints(i));
|
||||||
}
|
}
|
||||||
|
prim->AddAttr("axes", MakeValue(axes));
|
||||||
} else if (attribute_name == "keepdims") {
|
} else if (attribute_name == "keepdims") {
|
||||||
prim->set_keep_dims(static_cast<bool>(onnx_node_attr.i()));
|
prim->set_keep_dims(static_cast<bool>(onnx_node_attr.i()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prim->AddAttr("axes", MakeValue(axes));
|
|
||||||
|
|
||||||
const auto &type = onnx_node.op_type();
|
const auto &type = onnx_node.op_type();
|
||||||
if (type == "ReduceMean") {
|
if (type == "ReduceMean") {
|
||||||
|
|
Loading…
Reference in New Issue