forked from mindspore-Ecosystem/mindspore
!9220 [MS][LITE]add reduce_all mode
From: @YeFeng_24 Reviewed-by: @zhanghaibo5,@hangangqiang,@hangangqiang Signed-off-by:
This commit is contained in:
commit
af1022a587
|
@ -144,6 +144,29 @@ int IntReduceMin(int outer_size, int inner_size, int axis_size, const int *src_d
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid,
|
||||
int thread_num) {
|
||||
if (src_data == NULL || dst_data == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
int i, j, k;
|
||||
for (j = tid; j < outer_size; j += thread_num) {
|
||||
const bool *outer_src = src_data + j * axis_size * inner_size;
|
||||
bool *outer_dst = dst_data + j * inner_size;
|
||||
for (k = 0; k < inner_size; k++) {
|
||||
const bool *inner_src = outer_src + k;
|
||||
bool *inner_dst = outer_dst + k;
|
||||
bool tmp = true;
|
||||
for (i = 0; i < axis_size; i++) {
|
||||
tmp = tmp && inner_src[i * inner_size];
|
||||
}
|
||||
*inner_dst = tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ReduceProd(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
|
||||
int thread_num) {
|
||||
if (src_data == NULL || dst_data == NULL) {
|
||||
|
|
|
@ -38,6 +38,8 @@ int IntReduceProd(int outer_size, int inner_size, int axis_size, const int *src_
|
|||
int thread_num);
|
||||
int ReduceSumSquare(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
|
||||
int thread_num);
|
||||
int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid,
|
||||
int thread_num);
|
||||
|
||||
#ifdef ENABLE_NNACL_INFER_SHAPE
|
||||
int ReduceInferShape(int **in_shape, size_t *dim_size, int *out_shape, int *in_format, int *out_format,
|
||||
|
|
|
@ -63,6 +63,7 @@ typedef enum LiteDataType {
|
|||
kDataTypeFloat,
|
||||
kDataTypeInt,
|
||||
kDataTypeInt8,
|
||||
KDataTypeBool,
|
||||
} LiteDataType;
|
||||
|
||||
typedef enum DataOrder {
|
||||
|
|
|
@ -765,7 +765,8 @@ enum ReduceMode : byte {
|
|||
ReduceProd = 3,
|
||||
ReduceSum = 4,
|
||||
ReduceSumSquare = 5,
|
||||
ReduceASum = 6
|
||||
ReduceASum = 6,
|
||||
ReduceAll = 7
|
||||
}
|
||||
|
||||
table Reduce {
|
||||
|
|
|
@ -67,6 +67,11 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
attr->mode = schema::ReduceMode_ReduceProd;
|
||||
} else if (prim.name() == "ReduceSumSquare") {
|
||||
attr->mode = schema::ReduceMode_ReduceSumSquare;
|
||||
} else if (prim.name() == "ReduceAll") {
|
||||
attr->mode = schema::ReduceMode_ReduceAll;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not supported reduce mode: " << prim.name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
attr->keepDims = GetValue<bool>(prim.GetAttr("keep_dims"));
|
||||
|
|
|
@ -31,6 +31,7 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_Mean;
|
||||
using mindspore::schema::PrimitiveType_Reduce;
|
||||
using mindspore::schema::ReduceMode;
|
||||
using mindspore::schema::ReduceMode_ReduceAll;
|
||||
using mindspore::schema::ReduceMode_ReduceASum;
|
||||
using mindspore::schema::ReduceMode_ReduceMax;
|
||||
using mindspore::schema::ReduceMode_ReduceMean;
|
||||
|
@ -78,6 +79,10 @@ int ReduceCPUKernel::Init() {
|
|||
reducer_ = ReduceSum;
|
||||
break;
|
||||
}
|
||||
case static_cast<int>(ReduceMode_ReduceAll): {
|
||||
bool_reducer_ = ReduceAll;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_;
|
||||
return RET_ERROR;
|
||||
|
@ -96,6 +101,9 @@ int ReduceCPUKernel::CallReduceUnit(int task_id) {
|
|||
if (data_type_ == kDataTypeFloat) {
|
||||
ret = reducer_(outer_size_, inner_size_, axis_size_, static_cast<const float *>(src_data_),
|
||||
static_cast<float *>(dst_data_), task_id, context_->thread_num_);
|
||||
} else if (data_type_ == KDataTypeBool) {
|
||||
ret = bool_reducer_(outer_size_, inner_size_, axis_size_, static_cast<const bool *>(src_data_),
|
||||
static_cast<bool *>(dst_data_), task_id, context_->thread_num_);
|
||||
} else {
|
||||
ret = int_reducer_(outer_size_, inner_size_, axis_size_, static_cast<const int *>(src_data_),
|
||||
static_cast<int *>(dst_data_), task_id, context_->thread_num_);
|
||||
|
@ -117,6 +125,8 @@ int ReduceImpl(void *cdata, int task_id) {
|
|||
int ReduceCPUKernel::Run() {
|
||||
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
|
||||
data_type_ = kDataTypeFloat;
|
||||
} else if (in_tensors().at(0)->data_type() == kNumberTypeBool) {
|
||||
data_type_ = KDataTypeBool;
|
||||
} else {
|
||||
data_type_ = kDataTypeInt;
|
||||
}
|
||||
|
@ -202,6 +212,8 @@ int ReduceCPUKernel::MallocTmpBuffer() {
|
|||
void *buffer = nullptr;
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
buffer = context_->allocator->Malloc(size * sizeof(float));
|
||||
} else if (data_type_ == KDataTypeBool) {
|
||||
buffer = context_->allocator->Malloc(size * sizeof(bool));
|
||||
} else {
|
||||
buffer = context_->allocator->Malloc(size * sizeof(int));
|
||||
}
|
||||
|
|
|
@ -31,6 +31,8 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel {
|
|||
float *dst_data, const int tid, const int thread_num);
|
||||
typedef int (*IntReducer)(const int outer_size, const int inner_size, const int axis_size, const int *src_data,
|
||||
int *dst_data, const int tid, const int thread_num);
|
||||
typedef int (*BoolReducer)(const int outer_size, const int inner_size, const int axis_size, const bool *src_data,
|
||||
bool *dst_data, const int tid, const int thread_num);
|
||||
|
||||
public:
|
||||
ReduceCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs,
|
||||
|
@ -54,6 +56,7 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel {
|
|||
private:
|
||||
ReduceParameter *reduce_param_;
|
||||
Reducer reducer_ = nullptr;
|
||||
BoolReducer bool_reducer_ = nullptr;
|
||||
IntReducer int_reducer_ = nullptr;
|
||||
std::vector<void *> data_buffers_;
|
||||
LiteDataType data_type_;
|
||||
|
|
|
@ -52,6 +52,8 @@ STATUS TFReduceParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
attr->mode = schema::ReduceMode_ReduceMean;
|
||||
} else if (tf_op.op() == "Prod") {
|
||||
attr->mode = schema::ReduceMode_ReduceProd;
|
||||
} else if (tf_op.op() == "All") {
|
||||
attr->mode = schema::ReduceMode_ReduceAll;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported reduce mode: " << tf_op.op();
|
||||
return RET_ERROR;
|
||||
|
@ -106,5 +108,6 @@ TFNodeRegistrar g_tfMaxParser("Max", new TFReduceParser());
|
|||
TFNodeRegistrar g_tfMinParser("Min", new TFReduceParser());
|
||||
TFNodeRegistrar g_tfMeanParser("Mean", new TFReduceParser());
|
||||
TFNodeRegistrar g_tfProdParser("Prod", new TFReduceParser());
|
||||
TFNodeRegistrar g_tfAllParser("All", new TFReduceParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue