forked from mindspore-Ecosystem/mindspore
commit
50bbdfef17
|
@ -21,7 +21,7 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define DEFAULT_GROUP_NAME_LEN 100
|
||||
#define DEFAULT_GROUP_NAME_LEN 101
|
||||
#define DEFAULT_GROUP_SIZE 2
|
||||
|
||||
static inline int get_rank(char *group) { return DEFAULT_GROUP_SIZE; }
|
||||
|
|
|
@ -30,6 +30,15 @@ OpParameter *PopulateAllGatherParameter(const void *prim) {
|
|||
MS_LOG(ERROR) << "cast all_gather_primitive to value failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto group = value->group();
|
||||
if (group == nullptr) {
|
||||
MS_LOG(ERROR) << "attr group must be not a nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
if (group->size() >= DEFAULT_GROUP_NAME_LEN) {
|
||||
MS_LOG(ERROR) << "group name size error: " << value->group()->size() << ", which is larger than 100.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *param = static_cast<AllGatherParameter *>(malloc(sizeof(AllGatherParameter)));
|
||||
if (param == nullptr) {
|
||||
|
@ -38,12 +47,7 @@ OpParameter *PopulateAllGatherParameter(const void *prim) {
|
|||
}
|
||||
memset(param, 0, sizeof(AllGatherParameter));
|
||||
|
||||
if (value->group()->size() > DEFAULT_GROUP_NAME_LEN) {
|
||||
MS_LOG(ERROR) << "group name size error: " << value->group()->size();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
memcpy(param->group_, value->group()->c_str(), value->group()->size());
|
||||
memcpy(param->group_, group->c_str(), group->size());
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
|
|
@ -31,6 +31,15 @@ OpParameter *PopulateReduceScatterParameter(const void *prim) {
|
|||
MS_LOG(ERROR) << "cast reduce_scatter_primitive to value failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto group = value->group();
|
||||
if (group == nullptr) {
|
||||
MS_LOG(ERROR) << "attr group must be not a nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
if (group->size() >= DEFAULT_GROUP_NAME_LEN) {
|
||||
MS_LOG(ERROR) << "group name size error: " << value->group()->size() << ", which is larger than 100.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *param = static_cast<ReduceScatterParameter *>(malloc(sizeof(ReduceScatterParameter)));
|
||||
if (param == nullptr) {
|
||||
|
@ -39,11 +48,7 @@ OpParameter *PopulateReduceScatterParameter(const void *prim) {
|
|||
}
|
||||
memset(param, 0, sizeof(ReduceScatterParameter));
|
||||
|
||||
if (value->group()->size() > DEFAULT_GROUP_NAME_LEN) {
|
||||
MS_LOG(ERROR) << "group name size error: " << value->group()->size();
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(param->group_, value->group()->c_str(), value->group()->size());
|
||||
memcpy(param->group_, group->c_str(), group->size());
|
||||
param->mode_ = value->mode();
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
|
|
Loading…
Reference in New Issue