From 6b6d2bd196f84a2ba88933b13b5890d4686b1f29 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Mon, 15 Nov 2021 20:26:11 +0800 Subject: [PATCH] fix codedex --- .../cpu/nnacl/communication_func.h | 2 +- mindspore/lite/src/ops/populate/all_gather.cc | 16 ++++++++++------ .../lite/src/ops/populate/reduce_scatter.cc | 15 ++++++++++----- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/communication_func.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/communication_func.h index 37ad240586b..5b723b73eac 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/communication_func.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/communication_func.h @@ -21,7 +21,7 @@ extern "C" { #endif -#define DEFAULT_GROUP_NAME_LEN 100 +#define DEFAULT_GROUP_NAME_LEN 101 #define DEFAULT_GROUP_SIZE 2 static inline int get_rank(char *group) { return DEFAULT_GROUP_SIZE; } diff --git a/mindspore/lite/src/ops/populate/all_gather.cc b/mindspore/lite/src/ops/populate/all_gather.cc index 2c5cdd33018..ca3d5e11c9e 100644 --- a/mindspore/lite/src/ops/populate/all_gather.cc +++ b/mindspore/lite/src/ops/populate/all_gather.cc @@ -30,6 +30,15 @@ OpParameter *PopulateAllGatherParameter(const void *prim) { MS_LOG(ERROR) << "cast all_gather_primitive to value failed"; return nullptr; } + auto group = value->group(); + if (group == nullptr) { + MS_LOG(ERROR) << "attr group must be not a nullptr."; + return nullptr; + } + if (group->size() >= DEFAULT_GROUP_NAME_LEN) { + MS_LOG(ERROR) << "group name size error: " << value->group()->size() << ", which is larger than 100."; + return nullptr; + } auto *param = static_cast(malloc(sizeof(AllGatherParameter))); if (param == nullptr) { @@ -38,12 +47,7 @@ OpParameter *PopulateAllGatherParameter(const void *prim) { } memset(param, 0, sizeof(AllGatherParameter)); - if (value->group()->size() > DEFAULT_GROUP_NAME_LEN) { - MS_LOG(ERROR) << "group name size error: " << value->group()->size(); - return nullptr; - } - - memcpy(param->group_, value->group()->c_str(), value->group()->size()); + memcpy(param->group_, group->c_str(), group->size()); param->op_parameter_.type_ = primitive->value_type(); return reinterpret_cast(param); } diff --git a/mindspore/lite/src/ops/populate/reduce_scatter.cc b/mindspore/lite/src/ops/populate/reduce_scatter.cc index 70bc489ff9d..e785b33c6e8 100644 --- a/mindspore/lite/src/ops/populate/reduce_scatter.cc +++ b/mindspore/lite/src/ops/populate/reduce_scatter.cc @@ -31,6 +31,15 @@ OpParameter *PopulateReduceScatterParameter(const void *prim) { MS_LOG(ERROR) << "cast reduce_scatter_primitive to value failed"; return nullptr; } + auto group = value->group(); + if (group == nullptr) { + MS_LOG(ERROR) << "attr group must be not a nullptr."; + return nullptr; + } + if (group->size() >= DEFAULT_GROUP_NAME_LEN) { + MS_LOG(ERROR) << "group name size error: " << value->group()->size() << ", which is larger than 100."; + return nullptr; + } auto *param = static_cast(malloc(sizeof(ReduceScatterParameter))); if (param == nullptr) { @@ -39,11 +48,7 @@ OpParameter *PopulateReduceScatterParameter(const void *prim) { } memset(param, 0, sizeof(ReduceScatterParameter)); - if (value->group()->size() > DEFAULT_GROUP_NAME_LEN) { - MS_LOG(ERROR) << "group name size error: " << value->group()->size(); - return nullptr; - } - memcpy(param->group_, value->group()->c_str(), value->group()->size()); + memcpy(param->group_, group->c_str(), group->size()); param->mode_ = value->mode(); param->op_parameter_.type_ = primitive->value_type(); return reinterpret_cast(param);