!26319 [lite]fix codedex

Merge pull request !26319 from 徐安越/master
This commit is contained in:
i-robot 2021-11-16 02:54:36 +00:00 committed by Gitee
commit 50bbdfef17
3 changed files with 21 additions and 12 deletions

View File

@ -21,7 +21,7 @@
extern "C" {
#endif
#define DEFAULT_GROUP_NAME_LEN 100
#define DEFAULT_GROUP_NAME_LEN 101
#define DEFAULT_GROUP_SIZE 2
static inline int get_rank(char *group) { return DEFAULT_GROUP_SIZE; }

View File

@ -30,6 +30,15 @@ OpParameter *PopulateAllGatherParameter(const void *prim) {
MS_LOG(ERROR) << "cast all_gather_primitive to value failed";
return nullptr;
}
auto group = value->group();
if (group == nullptr) {
MS_LOG(ERROR) << "attr group must be not a nullptr.";
return nullptr;
}
if (group->size() >= DEFAULT_GROUP_NAME_LEN) {
MS_LOG(ERROR) << "group name size error: " << value->group()->size() << ", which is larger than 100.";
return nullptr;
}
auto *param = static_cast<AllGatherParameter *>(malloc(sizeof(AllGatherParameter)));
if (param == nullptr) {
@ -38,12 +47,7 @@ OpParameter *PopulateAllGatherParameter(const void *prim) {
}
memset(param, 0, sizeof(AllGatherParameter));
if (value->group()->size() > DEFAULT_GROUP_NAME_LEN) {
MS_LOG(ERROR) << "group name size error: " << value->group()->size();
return nullptr;
}
memcpy(param->group_, value->group()->c_str(), value->group()->size());
memcpy(param->group_, group->c_str(), group->size());
param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
}

View File

@ -31,6 +31,15 @@ OpParameter *PopulateReduceScatterParameter(const void *prim) {
MS_LOG(ERROR) << "cast reduce_scatter_primitive to value failed";
return nullptr;
}
auto group = value->group();
if (group == nullptr) {
MS_LOG(ERROR) << "attr group must be not a nullptr.";
return nullptr;
}
if (group->size() >= DEFAULT_GROUP_NAME_LEN) {
MS_LOG(ERROR) << "group name size error: " << value->group()->size() << ", which is larger than 100.";
return nullptr;
}
auto *param = static_cast<ReduceScatterParameter *>(malloc(sizeof(ReduceScatterParameter)));
if (param == nullptr) {
@ -39,11 +48,7 @@ OpParameter *PopulateReduceScatterParameter(const void *prim) {
}
memset(param, 0, sizeof(ReduceScatterParameter));
if (value->group()->size() > DEFAULT_GROUP_NAME_LEN) {
MS_LOG(ERROR) << "group name size error: " << value->group()->size();
return nullptr;
}
memcpy(param->group_, value->group()->c_str(), value->group()->size());
memcpy(param->group_, group->c_str(), group->size());
param->mode_ = value->mode();
param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);