!49752 fix_opt_shard_and_pipeline_allreduce_hang

Merge pull request !49752 from yao_yf/fix_opt_shard_and_pipeline_allreduce_hang
This commit is contained in:
huangxinjing 2023-03-06 02:41:54 +00:00 committed by Gitee
commit a04b337d78
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
34 changed files with 576 additions and 110 deletions

View File

@ -283,3 +283,4 @@
{"op_name": "QuantDTypeCast", "inputs": [{"index": 0, "name": "x", "param_type": "required"},{"index": 1, "name": "scales", "param_type": "required"},{"index": 2, "name": "zps", "param_type": "required"},{"index": 3, "name": "mean_corrs", "param_type": "required"},{"index": 4, "name": "var_corr", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "src_t", "type": "int"},{"name": "dst_t", "type": "int"},{"name": "axis", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]],[["int8", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "FRACTAL_NZ"]]], "imply_type": "AiCPU"} {"op_name": "QuantDTypeCast", "inputs": [{"index": 0, "name": "x", "param_type": "required"},{"index": 1, "name": "scales", "param_type": "required"},{"index": 2, "name": "zps", "param_type": "required"},{"index": 3, "name": "mean_corrs", "param_type": "required"},{"index": 4, "name": "var_corr", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "src_t", "type": "int"},{"name": "dst_t", "type": "int"},{"name": "axis", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]],[["int8", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "FRACTAL_NZ"]]], "imply_type": "AiCPU"}
{"op_name": "FSEDecode", "inputs": [{"index": 0, "name": "x", "param_type": "required"},{"index": 1, "name": "states_table", "param_type": "required"},{"index": 2, "name": "bit_count_table", "param_type": "required"},{"index": 3, "name": "symbol_table", "param_type": "required"},{"index": 4, "name": "centroids", "param_type": "required"},{"index": 5, "name": "input_shape", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "dst_t", "type": "int"}, {"name": "curr_chunk", "type": "int"}, {"name": "curr_chunk_index", "type": "int"}, {"name": "curr_bit_count", "type": "int"}, {"name": "table_log", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]],[["int8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "FRACTAL_NZ"]]], "imply_type": "AiCPU"} {"op_name": "FSEDecode", "inputs": [{"index": 0, "name": "x", "param_type": "required"},{"index": 1, "name": "states_table", "param_type": "required"},{"index": 2, "name": "bit_count_table", "param_type": "required"},{"index": 3, "name": "symbol_table", "param_type": "required"},{"index": 4, "name": "centroids", "param_type": "required"},{"index": 5, "name": "input_shape", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "dst_t", "type": "int"}, {"name": "curr_chunk", "type": "int"}, {"name": "curr_chunk_index", "type": "int"}, {"name": "curr_bit_count", "type": "int"}, {"name": "table_log", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]],[["int8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "FRACTAL_NZ"]]], "imply_type": "AiCPU"}
{"op_name": "AssignAdd", "inputs": [{"index": 0, "name": "ref", "needCompile": false, "paramType": "required", "shape": "all"}, {"index": 1, "name": "value", "needCompile": false, "paramType": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "ref", "need_compile": false, "paramType": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"]], [["int8", "FRACTAL_Z"], ["int8", "FRACTAL_Z"], ["int8", "FRACTAL_Z"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"]], [["uint8", "FRACTAL_Z"], ["uint8", "FRACTAL_Z"], ["uint8", "FRACTAL_Z"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"]], [["int32", "FRACTAL_Z"], ["int32", "FRACTAL_Z"], ["int32", "FRACTAL_Z"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["int64", "NC1HWC0"], ["int64", "NC1HWC0"], ["int64", "NC1HWC0"]], [["int64", "C1HWNCoC0"], ["int64", "C1HWNCoC0"], ["int64", "C1HWNCoC0"]], [["int64", "FRACTAL_Z"], ["int64", "FRACTAL_Z"], ["int64", "FRACTAL_Z"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "FRACTAL_Z"], ["float16", "FRACTAL_Z"], ["float16", "FRACTAL_Z"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "FRACTAL_Z"], ["float32", "FRACTAL_Z"], ["float32", "FRACTAL_Z"]]], "imply_type": "TBE", "async_flag": false, "binfile": "assign_add.so", "compute_cost": 10, "kernel": "assign_add", "partial_flag": true, "reshape_type": "", "dynamicRankSupport": false, "dynamicShapeSupport": true, "dynamicCompileStatic": true, "needCheckSupport": false, "dynamicFormat": false, "op_pattern": "", "real_input_index": [], "input_to_attr_index": [], "unknown_shape_formats": []} {"op_name": "AssignAdd", "inputs": [{"index": 0, "name": "ref", "needCompile": false, "paramType": "required", "shape": "all"}, {"index": 1, "name": "value", "needCompile": false, "paramType": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "ref", "need_compile": false, "paramType": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"]], [["int8", "FRACTAL_Z"], ["int8", "FRACTAL_Z"], ["int8", "FRACTAL_Z"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"]], [["uint8", "FRACTAL_Z"], ["uint8", "FRACTAL_Z"], ["uint8", "FRACTAL_Z"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"]], [["int32", "FRACTAL_Z"], ["int32", "FRACTAL_Z"], ["int32", "FRACTAL_Z"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["int64", "NC1HWC0"], ["int64", "NC1HWC0"], ["int64", "NC1HWC0"]], [["int64", "C1HWNCoC0"], ["int64", "C1HWNCoC0"], ["int64", "C1HWNCoC0"]], [["int64", "FRACTAL_Z"], ["int64", "FRACTAL_Z"], ["int64", "FRACTAL_Z"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "FRACTAL_Z"], ["float16", "FRACTAL_Z"], ["float16", "FRACTAL_Z"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "FRACTAL_Z"], ["float32", "FRACTAL_Z"], ["float32", "FRACTAL_Z"]]], "imply_type": "TBE", "async_flag": false, "binfile": "assign_add.so", "compute_cost": 10, "kernel": "assign_add", "partial_flag": true, "reshape_type": "", "dynamicRankSupport": false, "dynamicShapeSupport": true, "dynamicCompileStatic": true, "needCheckSupport": false, "dynamicFormat": false, "op_pattern": "", "real_input_index": [], "input_to_attr_index": [], "unknown_shape_formats": []}
{'op_name': 'AtomicAddrClean', 'inputs': [], 'outputs': [], 'attr': [{'name': 'automic_add_mem_size', 'paramType': 'required', 'type': 'listInt64', 'value': 'all'}], 'fusion_type': 'ELEMWISE', 'dtype_format': [], 'imply_type': 'TBE', 'async_flag': False, 'binfile': 'atomic_addr_clean.so', 'compute_cost': 10, 'kernel': 'atomic_addr_clean', 'partial_flag': True, 'reshape_type': '', 'dynamicRankSupport': False, 'dynamicShapeSupport': False, 'dynamicCompileStatic': False, 'needCheckSupport': False, 'dynamicFormat': False, 'op_pattern': '', 'real_input_index': [], 'input_to_attr_index': [], 'unknown_shape_formats': []}

View File

@ -415,7 +415,8 @@
"TransData ": "support boll", "TransData ": "support boll",
"ScatterNdD ": "Accuracy issues", "ScatterNdD ": "Accuracy issues",
"Trace": "Hadn't adapted tbe implementation", "Trace": "Hadn't adapted tbe implementation",
"AssignAdd": "Frac_nz in pangu not support" "AssignAdd": "Frac_nz in pangu not support",
"AtomicAddrClean": "need to clean addr larger than 2G, int32 is not enough"
}, },
"SkipNodes": [ "SkipNodes": [
"BroadcastTo", "BroadcastTo",
@ -444,7 +445,9 @@
"ACos", "ACos",
"TransData", "TransData",
"ScatterNdD", "ScatterNdD",
"AssignAdd" "AssignAdd",
"Assign",
"AtomicAddrClean"
], ],
"FallbackOps": { "FallbackOps": {
"DeformableOffsets": [ "DeformableOffsets": [
@ -452,4 +455,4 @@
2 2
] ]
} }
} }

View File

@ -8,10 +8,12 @@
.. code-block:: .. code-block::
n02119789 0 n01440764 0
n02100735 1 n01443537 1
n02110185 2 n01484850 2
n02096294 3 n01491361 3
...
n15075141 999
- **image_dir** (str) - ImageNet数据集的目录路径目录中包含类似n02119789、n02100735、n02110185和n02096294的子目录。 - **image_dir** (str) - ImageNet数据集的目录路径目录中包含类似n02119789、n02100735、n02110185和n02096294的子目录。
- **destination** (str) - 转换生成的MindRecord文件路径需提前创建目录并且目录下不能存在同名文件。 - **destination** (str) - 转换生成的MindRecord文件路径需提前创建目录并且目录下不能存在同名文件。

View File

@ -104,6 +104,10 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) {
if (skip_redis && !full_batch && input_strategy.size() > 0) { if (skip_redis && !full_batch && input_strategy.size() > 0) {
input_strategy[0] = dev_num < shape_list[1][0][0] ? dev_num : shape_list[1][0][0]; input_strategy[0] = dev_num < shape_list[1][0][0] ? dev_num : shape_list[1][0][0];
auto prim = GetCNodePrimitive(node); auto prim = GetCNodePrimitive(node);
if (prim->HasAttr("out_shard_size")) {
auto out_shard_size = GetValue<int64_t>(prim->GetAttr("out_shard_size"));
input_strategy[0] = out_shard_size;
}
auto attrs = prim->attrs(); auto attrs = prim->attrs();
attrs[parallel::SKIP_REDISTRIBUTION] = MakeValue<bool>(true); attrs[parallel::SKIP_REDISTRIBUTION] = MakeValue<bool>(true);
prim->SetAttrs(attrs); prim->SetAttrs(attrs);

View File

@ -354,6 +354,7 @@ Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
// parameter not split axis // parameter not split axis
if (param_strategy.at(LongToSize(axis_)) == 1) { if (param_strategy.at(LongToSize(axis_)) == 1) {
SetAttribute(strategy);
return SUCCESS; return SUCCESS;
} }

View File

@ -456,21 +456,11 @@ void AddCommOpMeanFlag(const CNodePtr &comm_node) {
(void)prim->SetAttrs(attrs); (void)prim->SetAttrs(attrs);
} }
void AddCommOpMirrorFlag(const CNodePtr &comm_node, bool do_mirror) { void AddCNodePrimAttr(const CNodePtr &comm_node, const std::string &attr_name, const ValuePtr &attr_val) {
MS_EXCEPTION_IF_NULL(comm_node); MS_EXCEPTION_IF_NULL(comm_node);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0)); auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
auto attrs = prim->attrs(); auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); attrs[attr_name] = attr_val;
attrs[DO_MIRROR] = MakeValue<bool>(do_mirror);
(void)prim->SetAttrs(attrs);
}
void AddCommOpAddAccuFlag(const CNodePtr &comm_node, bool add_accu) {
MS_EXCEPTION_IF_NULL(comm_node);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
attrs[ADD_ACCU] = MakeValue<bool>(add_accu);
(void)prim->SetAttrs(attrs); (void)prim->SetAttrs(attrs);
} }

View File

@ -351,9 +351,8 @@ Operator CreateAllGatherOp(const std::string &group);
Operator CreateCastOp(TypePtr type); Operator CreateCastOp(TypePtr type);
Operator CreateDivOp(float scale); Operator CreateDivOp(float scale);
Operator CreateMiniStepAllGatherOp(const std::string &group); Operator CreateMiniStepAllGatherOp(const std::string &group);
void AddCNodePrimAttr(const CNodePtr &comm_node, const std::string &attr_name, const ValuePtr &attr_val);
int32_t AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node); int32_t AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node);
void AddCommOpMirrorFlag(const CNodePtr &comm_node, bool do_mirror);
void AddCommOpAddAccuFlag(const CNodePtr &comm_node, bool add_accu);
Operator CreateMicroStepAllGatherOp(const std::string &group); Operator CreateMicroStepAllGatherOp(const std::string &group);
void AddCommOpMeanFlag(const CNodePtr &comm_node); void AddCommOpMeanFlag(const CNodePtr &comm_node);
void AddCommOpParamFlag(const CNodePtr &comm_node); void AddCommOpParamFlag(const CNodePtr &comm_node);

View File

@ -117,9 +117,9 @@ class OptParamMgrImpl : public OptParamMgr {
return false; return false;
} }
if (!ParameterRequireGrad(parameter)) { auto param_ptr = parameter->cast<ParameterPtr>();
// only trainable parameters need parallel optimizer if ((!param_ptr) || (!param_ptr->has_default())) {
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not a parameter.";
return false; return false;
} }

View File

@ -396,10 +396,13 @@ void SliceParameterObj(const ParameterPtr &parameter, const TensorLayoutPtr &ten
// create python layout obj // create python layout obj
const auto &device_arrangement = tensor_layout->device_arrangement().array(); const auto &device_arrangement = tensor_layout->device_arrangement().array();
const auto &tensor_map = tensor_layout->tensor_map().array(); const auto &tensor_map = tensor_layout->tensor_map().array();
const auto &slice_shape = tensor_layout->slice_shape().array(); auto slice_shape = tensor_layout->slice_shape().array();
int64_t field_size = tensor_layout->get_field_size(); int64_t field_size = tensor_layout->get_field_size();
bool uniform_split = tensor_layout->uniform_split(); bool uniform_split = tensor_layout->uniform_split();
std::string opt_shard_group = tensor_layout->opt_shard_group(); std::string opt_shard_group = tensor_layout->opt_shard_group();
if (!opt_shard_group.empty()) {
slice_shape = tensor_layout->opt_shard_slice_shape();
}
py::tuple layout = py::tuple layout =
py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group); py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);

View File

@ -1271,6 +1271,7 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
if (param_ptr->user_data<TensorLayout>()) { if (param_ptr->user_data<TensorLayout>()) {
opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group(); opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
} }
bool is_with_mirror = !opt_shard_mirror_group.empty();
if (!is_shared_param && cast_node) { if (!is_shared_param && cast_node) {
allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root); allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root);
MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name; MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
@ -1294,16 +1295,16 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
AddNodeFusionInfo(cnode, allgather, "reduce_scatter", fusion_id); AddNodeFusionInfo(cnode, allgather, "reduce_scatter", fusion_id);
// add gradients mean // add gradients mean
AddCommOpMeanFlag(allgather); AddCommOpMeanFlag(allgather);
AddCNodePrimAttr(allgather, "with_mirror_operator", MakeValue<bool>(is_with_mirror));
if (op_name == MICRO_STEP_ALL_GATHER) { if (op_name == MICRO_STEP_ALL_GATHER) {
// When grad_accumulation_shard is enabled, the ReduceScatter is inserted at each micro step // When grad_accumulation_shard is enabled, the ReduceScatter is inserted at each micro step
// so no need to do backward for the micro_step_allgather // so no need to do backward for the micro_step_allgather
AddCommOpMirrorFlag(allgather, !grad_accumulation_shard); AddCNodePrimAttr(allgather, DO_MIRROR, MakeValue<bool>(!grad_accumulation_shard));
} else if (op_name == MINI_STEP_ALL_GATHER) { } else if (op_name == MINI_STEP_ALL_GATHER) {
// We need to manually set the add_accu to be false if it's father node is MirrorMiniStep // We need to manually set the add_accu to be false if it's father node is MirrorMiniStep
bool add_accu = root->has_flag(kAccumulation); bool add_accu = root->has_flag(kAccumulation);
bool is_with_mirror = opt_shard_mirror_group.size() > 1; AddCNodePrimAttr(allgather, ADD_ACCU, MakeValue<bool>(!add_accu && !is_with_mirror));
AddCommOpAddAccuFlag(allgather, !add_accu && !is_with_mirror); AddCNodePrimAttr(allgather, DO_MIRROR, MakeValue<bool>(!grad_accumulation_shard || !add_accu));
AddCommOpMirrorFlag(allgather, grad_accumulation_shard || !add_accu);
} }
} }
@ -1311,17 +1312,20 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
const std::string &opt_shard_group) { const std::string &opt_shard_group) {
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
auto enable_opt_shard = ParallelContext::GetInstance()->enable_parallel_optimizer(); auto enable_opt_shard = ParallelContext::GetInstance()->enable_parallel_optimizer();
if ((opt_shard_group.empty() && split_stage_num <= 1) || (!enable_opt_shard) || (!ParameterRequireGrad(parameter))) { if ((opt_shard_group.empty() && split_stage_num <= 1) || (!enable_opt_shard)) {
return; return;
} }
if (opt_shard_group.empty() && !ParameterRequireGrad(parameter)) {
return;
}
// set all gather type // set all gather type
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
std::string op_name; std::string op_name;
if (grad_accumulation_step > 1) { if (grad_accumulation_step > 1) {
op_name = MINI_STEP_ALL_GATHER; op_name = MINI_STEP_ALL_GATHER;
} else if (split_stage_num > 1) { } else if (split_stage_num > 1 && ParameterRequireGrad(parameter)) {
op_name = MICRO_STEP_ALL_GATHER; op_name = MICRO_STEP_ALL_GATHER;
} else { } else {
op_name = ALL_GATHER; op_name = ALL_GATHER;

View File

@ -129,8 +129,9 @@ void BindShardWriter(py::module *m) {
return SUCCESS; return SUCCESS;
}) })
.def("write_raw_data", .def("write_raw_data",
[](ShardWriter &s, std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data, [](ShardWriter &s, std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<py::bytes> &blob_data,
bool sign, bool parallel_writer) { bool sign, bool parallel_writer) {
// convert the raw_data from dict to json
std::map<uint64_t, std::vector<json>> raw_data_json; std::map<uint64_t, std::vector<json>> raw_data_json;
(void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
[](const std::pair<uint64_t, std::vector<py::handle>> &p) { [](const std::pair<uint64_t, std::vector<py::handle>> &p) {
@ -141,7 +142,54 @@ void BindShardWriter(py::module *m) {
[](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
return std::make_pair(p.first, std::move(json_raw_data)); return std::make_pair(p.first, std::move(json_raw_data));
}); });
THROW_IF_ERROR(s.WriteRawData(raw_data_json, blob_data, sign, parallel_writer));
// parallel convert blob_data from vector<py::bytes> to vector<vector<uint8_t>>
int32_t parallel_convert = kParallelConvert;
if (parallel_convert > blob_data.size()) {
parallel_convert = blob_data.size();
}
parallel_convert = parallel_convert != 0 ? parallel_convert : 1;
std::vector<std::thread> thread_set(parallel_convert);
vector<vector<uint8_t>> vector_blob_data(blob_data.size());
uint32_t step = uint32_t(blob_data.size() / parallel_convert);
if (blob_data.size() % parallel_convert != 0) {
step = step + 1;
}
for (int x = 0; x < parallel_convert; ++x) {
uint32_t start = x * step;
uint32_t end = ((x + 1) * step) < blob_data.size() ? ((x + 1) * step) : blob_data.size();
thread_set[x] = std::thread([&vector_blob_data, &blob_data, start, end]() {
for (auto i = start; i < end; i++) {
char *buffer = nullptr;
ssize_t length = 0;
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(blob_data[i].ptr(), &buffer, &length)) {
MS_LOG(ERROR) << "Unable to extract bytes contents!";
return FAILED;
}
vector<uint8_t> blob_data_item(length);
if (length < SECUREC_MEM_MAX_LEN) {
int ret_code = memcpy_s(&blob_data_item[0], length, buffer, length);
if (ret_code != EOK) {
MS_LOG(ERROR) << "memcpy_s failed for py::bytes to vector<uint8_t>.";
return FAILED;
}
} else {
auto ret_code = std::memcpy(&blob_data_item[0], buffer, length);
if (ret_code != &blob_data_item[0]) {
MS_LOG(ERROR) << "memcpy failed for py::bytes to vector<uint8_t>.";
return FAILED;
}
}
vector_blob_data[i] = blob_data_item;
}
});
}
// wait for the threads join
for (int x = 0; x < parallel_convert; ++x) {
thread_set[x].join();
}
THROW_IF_ERROR(s.WriteRawData(raw_data_json, vector_blob_data, sign, parallel_writer));
return SUCCESS; return SUCCESS;
}) })
.def("commit", [](ShardWriter &s) { .def("commit", [](ShardWriter &s) {

View File

@ -160,6 +160,9 @@ const std::unordered_map<std::string, std::string> kTypesMap = {
/// \brief the max number of samples to enable lazy load /// \brief the max number of samples to enable lazy load
const uint32_t LAZY_LOAD_THRESHOLD = 5000000; const uint32_t LAZY_LOAD_THRESHOLD = 5000000;
/// \brief parallel convert from vector<py::bytes> to vector<vector<uint8_t>>
const uint32_t kParallelConvert = 4;
/// \brief split a string using a character /// \brief split a string using a character
/// \param[in] field target string /// \param[in] field target string
/// \param[in] separator a character for splitting /// \param[in] separator a character for splitting

View File

@ -1500,7 +1500,8 @@ static std::vector<ActionItem> CommonPipeline() {
auto parallel_mode = parallel_context->parallel_mode(); auto parallel_mode = parallel_context->parallel_mode();
const bool is_parallel_mode = const bool is_parallel_mode =
parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel; parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
if (!is_cluster_initialized && !is_parallel_mode && pipeline::GetJitLevel() != "O0") { static const auto combine_like_graphs = (common::GetEnv("COMBINE_LIKE_GRAPHS") == "1");
if (!is_cluster_initialized && (!is_parallel_mode || combine_like_graphs) && pipeline::GetJitLevel() != "O0") {
(void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); (void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
} }

View File

@ -567,7 +567,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne
ValuePtr cond = MakeValue(condition); ValuePtr cond = MakeValue(condition);
common::AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); common::AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app);
// set attr:data_type // set attr:data_type
int data_type = static_cast<int>(RT_SWITCH_INT64); int data_type = static_cast<int>(RT_SWITCH_INT32);
ValuePtr dt = MakeValue(data_type); ValuePtr dt = MakeValue(data_type);
common::AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); common::AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app);
// set distinction label and graph id // set distinction label and graph id

View File

@ -239,11 +239,11 @@ bool IfAtomicOpNeedFusion(const size_t clean_total_num, const CNodePtr &first_no
return false; return false;
} }
std::vector<int32_t> GetClearSize(const CNodePtr &pre_node) { std::vector<int64_t> GetClearSize(const CNodePtr &pre_node) {
MS_EXCEPTION_IF_NULL(pre_node); MS_EXCEPTION_IF_NULL(pre_node);
auto kernel_mod = AnfAlgo::GetKernelMod(pre_node); auto kernel_mod = AnfAlgo::GetKernelMod(pre_node);
MS_EXCEPTION_IF_NULL(kernel_mod); MS_EXCEPTION_IF_NULL(kernel_mod);
std::vector<int32_t> clean_size_list; std::vector<int64_t> clean_size_list;
constexpr size_t kAlignBytes = 32 - 1; constexpr size_t kAlignBytes = 32 - 1;
// clean output // clean output
if (common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { if (common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
@ -251,7 +251,7 @@ std::vector<int32_t> GetClearSize(const CNodePtr &pre_node) {
auto output_men_size = kernel_mod->GetOutputSizeList(); auto output_men_size = kernel_mod->GetOutputSizeList();
for (auto index : output_indexes) { for (auto index : output_indexes) {
auto clean_item = auto clean_item =
SizeToInt((output_men_size.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize); SizeToLong((output_men_size.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize);
(void)clean_size_list.emplace_back(clean_item); (void)clean_size_list.emplace_back(clean_item);
} }
} }
@ -261,7 +261,7 @@ std::vector<int32_t> GetClearSize(const CNodePtr &pre_node) {
auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList(); auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList();
for (const auto &index : workspace_indexes) { for (const auto &index : workspace_indexes) {
auto clean_item = auto clean_item =
SizeToInt((workspace_men_sizes.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize); SizeToLong((workspace_men_sizes.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize);
(void)clean_size_list.emplace_back(clean_item); (void)clean_size_list.emplace_back(clean_item);
} }
} }
@ -303,7 +303,7 @@ CNodePtr NewAtomicOp(const CNodePtr &pre_node, const std::vector<AnfNodePtr> &fu
} }
void InsertFusionAtomicOp(const CNodePtr &first_clear_node, const std::vector<AnfNodePtr> &fusion_clear_inputs, void InsertFusionAtomicOp(const CNodePtr &first_clear_node, const std::vector<AnfNodePtr> &fusion_clear_inputs,
const std::vector<int32_t> &clean_size_list, CleanOpsMap *clean_ops) { const std::vector<int64_t> &clean_size_list, CleanOpsMap *clean_ops) {
MS_EXCEPTION_IF_NULL(first_clear_node); MS_EXCEPTION_IF_NULL(first_clear_node);
MS_EXCEPTION_IF_NULL(clean_ops); MS_EXCEPTION_IF_NULL(clean_ops);
auto clear_zero = NewAtomicOp(first_clear_node, fusion_clear_inputs); auto clear_zero = NewAtomicOp(first_clear_node, fusion_clear_inputs);
@ -355,7 +355,7 @@ void SpecialAkgOps(const std::string &op_name, const CNodePtr &node, CleanOpsMap
void ProcessAtomicFusion(const std::vector<CNodePtr> &kernels, CleanOpsMap *clean_ops) { void ProcessAtomicFusion(const std::vector<CNodePtr> &kernels, CleanOpsMap *clean_ops) {
MS_EXCEPTION_IF_NULL(clean_ops); MS_EXCEPTION_IF_NULL(clean_ops);
std::vector<int32_t> clean_size_list; std::vector<int64_t> clean_size_list;
std::vector<AnfNodePtr> fusion_clear_inputs; std::vector<AnfNodePtr> fusion_clear_inputs;
CNodePtr first_node = nullptr; CNodePtr first_node = nullptr;
for (const auto &anf_node : kernels) { for (const auto &anf_node : kernels) {

View File

@ -153,7 +153,7 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP
MS_LOG(DEBUG) << "AtomicAddClean clean workspace size:" << clean_workspace_indexs.size(); MS_LOG(DEBUG) << "AtomicAddClean clean workspace size:" << clean_workspace_indexs.size();
} }
} }
auto clear_mems = common::AnfAlgo::GetNodeAttr<std::vector<int32_t>>(anf_node_ptr, kAttrAtomicAddMemSize); auto clear_mems = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(anf_node_ptr, kAttrAtomicAddMemSize);
if (kernel_inputs->size() != clear_mems.size()) { if (kernel_inputs->size() != clear_mems.size()) {
MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size, kernel inputs size:" MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size, kernel inputs size:"
<< kernel_inputs->size() << ",clean mem size" << clear_mems.size(); << kernel_inputs->size() << ",clean mem size" << clear_mems.size();

View File

@ -73,7 +73,7 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernelInfo(const CNodePtr &) { std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernelInfo(const CNodePtr &) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> label_switch_build_info{}; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> label_switch_build_info{};
std::vector<string> input_format{kOpFormat_DEFAULT}; std::vector<string> input_format{kOpFormat_DEFAULT};
std::vector<TypeId> input_type{kNumberTypeInt32}; std::vector<TypeId> input_type{kNumberTypeUInt64};
if (input_format.size() != input_type.size()) { if (input_format.size() != input_type.size()) {
MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size "
<< input_type.size(); << input_type.size();

View File

@ -48,6 +48,7 @@ static std::unordered_map<std::string, ATTR_DTYPE> type_attr_dtype_map = {
{kVTypeFloat, ATTR_DTYPE::ATTR_FLOAT32}, {kVTypeFloat, ATTR_DTYPE::ATTR_FLOAT32},
{kVTypeListInt, ATTR_DTYPE::ATTR_LIST_INT32}, {kVTypeListInt, ATTR_DTYPE::ATTR_LIST_INT32},
{kVTypeListFloat, ATTR_DTYPE::ATTR_LIST_FLOAT32}, {kVTypeListFloat, ATTR_DTYPE::ATTR_LIST_FLOAT32},
{kVTypeListInt64, ATTR_DTYPE::ATTR_LIST_INT64},
{kVTypeListUInt64, ATTR_DTYPE::ATTR_LIST_UINT64}, {kVTypeListUInt64, ATTR_DTYPE::ATTR_LIST_UINT64},
{kVTypeListListInt, ATTR_DTYPE::ATTR_LIST_LIST_INT64}}; {kVTypeListListInt, ATTR_DTYPE::ATTR_LIST_LIST_INT64}};
@ -181,6 +182,7 @@ bool ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, n
case ATTR_DTYPE::ATTR_FLOAT32: case ATTR_DTYPE::ATTR_FLOAT32:
return ParseAttrFloat(value, attr_obj); return ParseAttrFloat(value, attr_obj);
case ATTR_DTYPE::ATTR_LIST_INT32: case ATTR_DTYPE::ATTR_LIST_INT32:
case ATTR_DTYPE::ATTR_LIST_INT64:
return ParseAttrListInt(value, attr_obj); return ParseAttrListInt(value, attr_obj);
case ATTR_DTYPE::ATTR_LIST_FLOAT32: case ATTR_DTYPE::ATTR_LIST_FLOAT32:
return ParseAttrListFloat(value, attr_obj); return ParseAttrListFloat(value, attr_obj);
@ -232,7 +234,8 @@ bool ParseAttrDefaultValue(const std::string &type, const std::string &value, nl
case ATTR_DTYPE::ATTR_FLOAT32: case ATTR_DTYPE::ATTR_FLOAT32:
(*attr_obj)[kJValue] = std::stof(value); (*attr_obj)[kJValue] = std::stof(value);
break; break;
case ATTR_DTYPE::ATTR_LIST_INT32: { case ATTR_DTYPE::ATTR_LIST_INT32:
case ATTR_DTYPE::ATTR_LIST_INT64: {
std::stringstream string_value(value); std::stringstream string_value(value);
std::string list_elem; std::string list_elem;
std::vector<int64_t> attrs_value; std::vector<int64_t> attrs_value;

View File

@ -60,6 +60,7 @@ constexpr auto kVTypeFloat32 = "float32";
constexpr auto kVTypeListInt = "listInt"; constexpr auto kVTypeListInt = "listInt";
constexpr auto kVTypeInt32 = "Int32"; constexpr auto kVTypeInt32 = "Int32";
constexpr auto kVTypeInt64 = "Int64"; constexpr auto kVTypeInt64 = "Int64";
constexpr auto kVTypeListInt64 = "listInt64";
constexpr auto kVTypeListUInt64 = "listUInt64"; constexpr auto kVTypeListUInt64 = "listUInt64";
constexpr auto kVTypeListFloat = "listFloat"; constexpr auto kVTypeListFloat = "listFloat";
constexpr auto kVTypeListListInt = "listListInt"; constexpr auto kVTypeListListInt = "listListInt";

View File

@ -40,7 +40,7 @@ std::string VecToString(const std::vector<T> &vec) {
return res; return res;
} }
std::string GenCommOpKey(const CNodePtr &node) { std::string GenCommOpKey(const CNodePtr &node, const KernelGraphPtr &root_graph) {
std::string op_key; std::string op_key;
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto comm_prim = GetCNodePrimitive(node); auto comm_prim = GetCNodePrimitive(node);
@ -68,6 +68,8 @@ std::string GenCommOpKey(const CNodePtr &node) {
if (comm_prim->HasAttr(kAttrRecvRankIds)) { if (comm_prim->HasAttr(kAttrRecvRankIds)) {
op_key += "_" + VecToString(GetValue<std::vector<int64_t>>(comm_prim->GetAttr(kAttrRecvRankIds))); op_key += "_" + VecToString(GetValue<std::vector<int64_t>>(comm_prim->GetAttr(kAttrRecvRankIds)));
} }
// model identifier, aka. root_graph_id
op_key += "_" + std::to_string(root_graph->root_graph_id());
MS_LOG(INFO) << node->DebugString() << " key " << op_key; MS_LOG(INFO) << node->DebugString() << " key " << op_key;
return op_key; return op_key;
} }
@ -198,7 +200,7 @@ void AscendCommOpReuse::AnalyseCommOpReuse() {
if (!IsReusable(comm_op)) { if (!IsReusable(comm_op)) {
continue; continue;
} }
reuse_map[GenCommOpKey(comm_op)].push_back(comm_op); reuse_map[GenCommOpKey(comm_op, root_graph_)].push_back(comm_op);
} }
for (const auto &[key, comm_op_set] : reuse_map) { for (const auto &[key, comm_op_set] : reuse_map) {
@ -255,7 +257,7 @@ KernelGraphPtr AscendCommOpReuse::CreateCommSubGraph(const CNodePtr &comm_op) {
MS_EXCEPTION_IF_NULL(new_comm_op); MS_EXCEPTION_IF_NULL(new_comm_op);
new_comm_op->set_abstract(comm_op->abstract()); new_comm_op->set_abstract(comm_op->abstract());
std::string group_name = GenCommOpKey(comm_op); std::string group_name = GenCommOpKey(comm_op, root_graph_);
auto rank_list = common::AnfAlgo::GetNodeAttr<std::vector<unsigned int>>(comm_op, kAttrRankList); auto rank_list = common::AnfAlgo::GetNodeAttr<std::vector<unsigned int>>(comm_op, kAttrRankList);
if (!CommManager::GetInstance().CreateGroupSync(group_name, rank_list)) { if (!CommManager::GetInstance().CreateGroupSync(group_name, rank_list)) {
MS_LOG(EXCEPTION) << "Create new group " << group_name << " failed, rank list = " << VecToString(rank_list); MS_LOG(EXCEPTION) << "Create new group " << group_name << " failed, rank list = " << VecToString(rank_list);

View File

@ -710,7 +710,7 @@ class Parameter(Tensor_):
raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout))) raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout)))
if len(layout) < 6: if len(layout) < 6:
raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout))) raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout)))
slice_index = int(_get_slice_index(layout[0], layout[1])) slice_index = int(_get_slice_index(layout[0], layout[1], layout[5]))
init_data_args += (slice_index, layout[2], layout[5]) init_data_args += (slice_index, layout[2], layout[5])
return init_data_args return init_data_args

View File

@ -20,7 +20,7 @@ import math
import numbers import numbers
import numpy as np import numpy as np
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import get_group_size
from mindspore.common._utils import is_shape_unknown, is_stub_tensor from mindspore.common._utils import is_shape_unknown, is_stub_tensor
from mindspore.common.seed import get_seed from mindspore.common.seed import get_seed
from mindspore import context from mindspore import context
@ -2280,9 +2280,9 @@ class Tensor(Tensor_):
self._np_seed = np.random.get_state()[1][0] self._np_seed = np.random.get_state()[1][0]
self.need_set_seed = (slice_index is not None) self.need_set_seed = (slice_index is not None)
self._global_seed = global_seed self._global_seed = global_seed
self._device_num = 1 self._seed_offset = 1
if self.need_set_seed: if self.need_set_seed:
self._device_num = get_group_size() self._seed_offset = get_group_size() * 2
def __enter__(self): def __enter__(self):
if self.need_set_seed: if self.need_set_seed:
@ -2293,7 +2293,7 @@ class Tensor(Tensor_):
else: else:
np.random.seed(slice_index + Tensor.delta_seed) np.random.seed(slice_index + Tensor.delta_seed)
self.init.seed = slice_index + Tensor.delta_seed self.init.seed = slice_index + Tensor.delta_seed
Tensor.delta_seed += self._device_num Tensor.delta_seed += self._seed_offset
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
if self.need_set_seed: if self.need_set_seed:
@ -2302,10 +2302,6 @@ class Tensor(Tensor_):
with seed_context(self.init): with seed_context(self.init):
self.init(data) self.init(data)
if opt_shard_group:
rank = get_rank(opt_shard_group)
size = get_group_size(opt_shard_group)
data = np.split(data, size)[rank]
self.init = None self.init = None
# At embedding cache scenes. When size of tensor is out of range, we store data to persistent storage # At embedding cache scenes. When size of tensor is out of range, we store data to persistent storage

View File

@ -17,8 +17,11 @@ This module is to write data into mindrecord.
""" """
import os import os
import platform import platform
import queue
import re import re
import stat import stat
import time
import multiprocessing as mp
import numpy as np import numpy as np
from mindspore import log as logger from mindspore import log as logger
from .shardwriter import ShardWriter from .shardwriter import ShardWriter
@ -26,7 +29,7 @@ from .shardreader import ShardReader
from .shardheader import ShardHeader from .shardheader import ShardHeader
from .shardindexgenerator import ShardIndexGenerator from .shardindexgenerator import ShardIndexGenerator
from .shardutils import MIN_SHARD_COUNT, MAX_SHARD_COUNT, VALID_ATTRIBUTES, VALID_ARRAY_ATTRIBUTES, \ from .shardutils import MIN_SHARD_COUNT, MAX_SHARD_COUNT, VALID_ATTRIBUTES, VALID_ARRAY_ATTRIBUTES, \
check_filename, VALUE_TYPE_MAP check_filename, VALUE_TYPE_MAP, SUCCESS
from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError
__all__ = ['FileWriter'] __all__ = ['FileWriter']
@ -103,6 +106,13 @@ class FileWriter:
self._writer = ShardWriter() self._writer = ShardWriter()
self._generator = None self._generator = None
# parallel write mode
self._parallel_writer = None
self._writers = None
self._queue = None
self._workers = None
self._index_workers = None
@classmethod @classmethod
def open_for_append(cls, file_name): def open_for_append(cls, file_name):
r""" r"""
@ -259,22 +269,67 @@ class FileWriter:
MRMSetHeaderError: If failed to set header. MRMSetHeaderError: If failed to set header.
MRMWriteDatasetError: If failed to write dataset. MRMWriteDatasetError: If failed to write dataset.
""" """
if not self._writer.is_open: if self._parallel_writer is None:
self._writer.open(self._paths, self._overwrite) self._parallel_writer = parallel_writer
if not self._writer.get_shard_header(): if self._parallel_writer != parallel_writer:
self._writer.set_shard_header(self._header) raise RuntimeError("The parameter `parallel_writer` must be consistent during use.")
if not isinstance(raw_data, list): if not self._parallel_writer:
raise ParamTypeError('raw_data', 'list') if not self._writer.is_open:
if self._flush and not self._append: self._writer.open(self._paths, self._overwrite)
raise RuntimeError("Not allowed to call `write_raw_data` on flushed MindRecord files." \ if not self._writer.get_shard_header():
"When creating new Mindrecord files, please remove `commit` before `write_raw_data`." \ self._writer.set_shard_header(self._header)
"In other cases, when appending to existing MindRecord files, " \ if not isinstance(raw_data, list):
"please call `open_for_append` first and then `write_raw_data`.") raise ParamTypeError('raw_data', 'list')
for each_raw in raw_data: if self._flush and not self._append:
if not isinstance(each_raw, dict): raise RuntimeError("Not allowed to call `write_raw_data` on flushed MindRecord files." \
raise ParamTypeError('raw_data item', 'dict') "When creating new Mindrecord files, please remove `commit` before " \
self._verify_based_on_schema(raw_data) "`write_raw_data`. In other cases, when appending to existing MindRecord files, " \
return self._writer.write_raw_data(raw_data, True, parallel_writer) "please call `open_for_append` first and then `write_raw_data`.")
for each_raw in raw_data:
if not isinstance(each_raw, dict):
raise ParamTypeError('raw_data item', 'dict')
self._verify_based_on_schema(raw_data)
return self._writer.write_raw_data(raw_data, True, parallel_writer)
## parallel write mode
# init the _writers and launch the workers
if self._writers is None:
self._writers = [None] * len(self._paths) # writers used by worker
self._queue = mp.Queue(len(self._paths) * 2) # queue for worker
self._workers = [None] * len(self._paths) # worker process
for i, path in enumerate(self._paths):
self._writers[i] = ShardWriter()
self._writers[i].open([path], self._overwrite)
self._writers[i].set_shard_header(self._header)
# launch the workers for parallel write
self._queue._joincancelled = True # pylint: disable=W0212
p = mp.Process(target=self._write_worker, args=(i, self._queue))
p.daemon = True
p.start()
logger.info("Start worker process(pid:{}) to parallel write.".format(p.pid))
self._workers[i] = p
# fill the self._queue
check_interval = 0.5 # 0.5s
start_time = time.time()
while True:
try:
self._queue.put(raw_data, block=False)
except queue.Full:
if time.time() - start_time > check_interval:
start_time = time.time()
logger.warning("Because there are too few MindRecord file shards, the efficiency of parallel " \
"writing is too low. You can stop the current task and add the parameter " \
"`shard_num` of `FileWriter` to upgrade the task.")
# check the status of worker process
for i in range(len(self._paths)):
if not self._workers[i].is_alive():
raise RuntimeError("Worker process(pid:{}) has stopped. Please check " \
"the above log".format(self._workers[i].pid))
continue
return SUCCESS
def set_header_size(self, header_size): def set_header_size(self, header_size):
""" """
@ -326,7 +381,7 @@ class FileWriter:
""" """
return self._writer.set_page_size(page_size) return self._writer.set_page_size(page_size)
def commit(self): def commit(self): # pylint: disable=W0212
""" """
Flush data in memory to disk and generate the corresponding database files. Flush data in memory to disk and generate the corresponding database files.
@ -343,24 +398,35 @@ class FileWriter:
MRMGenerateIndexError: If failed to write to database. MRMGenerateIndexError: If failed to write to database.
MRMCommitError: If failed to flush data to disk. MRMCommitError: If failed to flush data to disk.
""" """
self._flush = True if not self._parallel_writer:
if not self._writer.is_open: self._flush = True
self._writer.open(self._paths, self._overwrite) if not self._writer.is_open:
# permit commit without data self._writer.open(self._paths, self._overwrite)
if not self._writer.get_shard_header(): # permit commit without data
self._writer.set_shard_header(self._header) if not self._writer.get_shard_header():
ret = self._writer.commit() self._writer.set_shard_header(self._header)
if self._index_generator: self._writer.commit()
if self._append: if self._index_generator:
self._generator = ShardIndexGenerator(self._file_name, self._append) if self._append:
elif len(self._paths) >= 1: self._generator = ShardIndexGenerator(self._file_name, self._append)
self._generator = ShardIndexGenerator(os.path.realpath(self._paths[0]), self._append) elif len(self._paths) >= 1:
self._generator.build() self._generator = ShardIndexGenerator(os.path.realpath(self._paths[0]), self._append)
self._generator.write_to_db() self._generator.build()
self._generator.write_to_db()
else:
# maybe a empty mindrecord, so need check _writers
if self._writers is None:
self._writers = [None] * len(self._paths)
for i, path in enumerate(self._paths):
self._writers[i] = ShardWriter()
self._writers[i].open(path, self._overwrite)
self._writers[i].set_shard_header(self._header)
self._parallel_commit()
# change the file mode to 600
mindrecord_files = [] mindrecord_files = []
index_files = [] index_files = []
# change the file mode to 600
for item in self._paths: for item in self._paths:
if os.path.exists(item): if os.path.exists(item):
os.chmod(item, stat.S_IRUSR | stat.S_IWUSR) os.chmod(item, stat.S_IRUSR | stat.S_IWUSR)
@ -373,7 +439,62 @@ class FileWriter:
logger.info("The list of mindrecord files created are: {}, and the list of index files are: {}".format( logger.info("The list of mindrecord files created are: {}, and the list of index files are: {}".format(
mindrecord_files, index_files)) mindrecord_files, index_files))
return ret return SUCCESS
def _index_worker(self, i):
"""The worker do the index generator"""
generator = ShardIndexGenerator(os.path.realpath(self._paths[i]), False)
generator.build()
generator.write_to_db()
def _parallel_commit(self):
"""Parallel commit"""
# send EOF to worker process
for _ in range(len(self._paths)):
while True:
try:
self._queue.put("EOF", block=False)
except queue.Full:
time.sleep(1)
continue
break
# wait the worker processing
while True:
if not self._queue.empty():
logger.info("Waiting for worker process write done.")
time.sleep(1)
continue
break
del self._queue
# wait for worker process stop
for index in range(len(self._paths)):
while True:
logger.info("Waiting for the worker process(pid:{}) to process all the data.".format(
self._workers[index].pid))
if self._workers[index].is_alive():
time.sleep(1)
continue
elif self._workers[index].exitcode != 0:
raise RuntimeError("Worker process(pid:{}) has stopped abnormal. Please check " \
"the above log".format(self._workers[index].pid))
break
if self._index_generator:
# use parallel index workers to generator index
self._index_workers = [None] * len(self._paths)
for index in range(len(self._paths)):
p = mp.Process(target=self._index_worker, args=(index,))
p.daemon = True
p.start()
logger.info("Start worker process(pid:{}) to generate index.".format(p.pid))
self._index_workers[index] = p
# wait the index workers stop
for index in range(len(self._paths)):
self._index_workers[index].join()
def _validate_array(self, k, v): def _validate_array(self, k, v):
""" """
@ -487,3 +608,29 @@ class FileWriter:
error = "Field '{}' should be dict.".format(k) error = "Field '{}' should be dict.".format(k)
return False, error return False, error
return True, error return True, error
def _write_worker(self, i, in_queue):
"""The worker do the data check and write to disk for parallel mode"""
while True:
# try to get new raw_data from master
try:
raw_data = in_queue.get()
except queue.Empty:
continue
# get EOF from master, worker should commit and stop
if raw_data == "EOF":
ret = self._writers[i].commit()
if ret != SUCCESS:
raise RuntimeError("Commit the {}th shard of MindRecord file failed.".format(index))
break
# check the raw_data
if not isinstance(raw_data, list):
raise ParamTypeError('raw_data', 'list')
for each_raw in raw_data:
if not isinstance(each_raw, dict):
raise ParamTypeError('raw_data item', 'dict')
self._verify_based_on_schema(raw_data)
self._writers[i].write_raw_data(raw_data, True, False)

View File

@ -173,7 +173,7 @@ class ShardWriter:
for item in data: for item in data:
row_blob = self._merge_blob({field: item[field] for field in self._header.blob_fields}) row_blob = self._merge_blob({field: item[field] for field in self._header.blob_fields})
if row_blob: if row_blob:
blob_data.append(list(row_blob)) blob_data.append(row_blob)
# filter raw data according to schema # filter raw data according to schema
row_raw = {field: self.convert_np_types(item[field]) row_raw = {field: self.convert_np_types(item[field])
for field in self._header.schema.keys() - self._header.blob_fields if field in item} for field in self._header.schema.keys() - self._header.blob_fields if field in item}

View File

@ -38,10 +38,12 @@ class ImageNetToMR:
.. code-block:: .. code-block::
n02119789 0 n01440764 0
n02100735 1 n01443537 1
n02110185 2 n01484850 2
n02096294 3 n01491361 3
...
n15075141 999
image_dir (str): Image directory contains n02119789, n02100735, n02110185 and n02096294 directory. image_dir (str): Image directory contains n02119789, n02100735, n02110185 and n02096294 directory.
destination (str): MindRecord file path to transform into, ensure that the directory is created in advance and destination (str): MindRecord file path to transform into, ensure that the directory is created in advance and
@ -108,11 +110,11 @@ class ImageNetToMR:
for _ in range(batch_size): for _ in range(batch_size):
data_list.append(imagenet_iter.__next__()) data_list.append(imagenet_iter.__next__())
transform_count += 1 transform_count += 1
self.writer.write_raw_data(data_list) self.writer.write_raw_data(data_list, True)
logger.info("transformed {} record...".format(transform_count)) logger.info("transformed {} record...".format(transform_count))
except StopIteration: except StopIteration:
if data_list: if data_list:
self.writer.write_raw_data(data_list) self.writer.write_raw_data(data_list, True)
logger.info("transformed {} record...".format(transform_count)) logger.info("transformed {} record...".format(transform_count))
break break

View File

@ -307,11 +307,11 @@ class TFRecordToMR:
data_list.append(tf_iter.__next__()) data_list.append(tf_iter.__next__())
transform_count += 1 transform_count += 1
writer.write_raw_data(data_list) writer.write_raw_data(data_list, True)
logger.info("Transformed {} records...".format(transform_count)) logger.info("Transformed {} records...".format(transform_count))
except StopIteration: except StopIteration:
if data_list: if data_list:
writer.write_raw_data(data_list) writer.write_raw_data(data_list, True)
logger.info("Transformed {} records...".format(transform_count)) logger.info("Transformed {} records...".format(transform_count))
break break
return writer.commit() return writer.commit()

View File

@ -198,7 +198,6 @@ def get_bprop_mirror_micro_step_operator(self):
assign.add_prim_attr("parameter_micro", 0) assign.add_prim_attr("parameter_micro", 0)
out_tensor = Tensor(1.0, mstype.float16) out_tensor = Tensor(1.0, mstype.float16)
opt_shard = _get_enable_parallel_optimizer() opt_shard = _get_enable_parallel_optimizer()
def bprop(x, z, out, dout): def bprop(x, z, out, dout):
real_grad = z real_grad = z
assign_out = dout assign_out = dout
@ -207,16 +206,16 @@ def get_bprop_mirror_micro_step_operator(self):
z = F.depend(z, dout) z = F.depend(z, dout)
real_grad = all_reduce(z) real_grad = all_reduce(z)
real_grad = F.tensor_mul(real_grad, scale) real_grad = F.tensor_mul(real_grad, scale)
assign(z, real_grad) if opt_shard:
assign_out = z return (real_grad, cast(out_tensor, dtype(z)))
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
else: else:
if issubclass_(F.typeof(dout), mstype.tensor): if issubclass_(F.typeof(dout), mstype.tensor):
z = F.depend(z, dout) z = F.depend(z, dout)
real_grad = all_reduce(z) real_grad = all_reduce(z)
assign(z, real_grad) if opt_shard:
assign_out = z return (real_grad, cast(out_tensor, dtype(z)))
if opt_shard: return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
return (real_grad, cast(out_tensor, dtype(z)))
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out) return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out)
return bprop return bprop
@ -314,9 +313,18 @@ def get_bprop_micro_step_all_gather(self):
cast = P.Cast() cast = P.Cast()
dtype = P.DType() dtype = P.DType()
out_tensor = Tensor(1.0, mstype.float16) out_tensor = Tensor(1.0, mstype.float16)
with_mirror_operator = self.get_attr_dict()["with_mirror_operator"]
# z: accu_grad # z: accu_grad
def bprop(x, z, out, dout): def bprop(x, z, out, dout):
if with_mirror_operator:
if not do_mirror:
return (dout, cast(out_tensor, dtype(z)))
real_grad = all_reduce(dout)
real_grad = split(real_grad)[rank]
if mean_flag:
real_grad = F.tensor_mul(real_grad, scale)
return (real_grad, cast(out_tensor, dtype(z)))
z = F.depend(z, dout) z = F.depend(z, dout)
if not do_mirror: if not do_mirror:
return (z, cast(out_tensor, dtype(z))) return (z, cast(out_tensor, dtype(z)))

View File

@ -35,3 +35,5 @@ from .acos import _acos_tbe # Accuracy issues(task error in parallel)
from .trans_data_ds import _trans_data_ds_tbe # support bool from .trans_data_ds import _trans_data_ds_tbe # support bool
from .scatter_nd_d import _scatter_nd_d_tbe # in python no check supported from .scatter_nd_d import _scatter_nd_d_tbe # in python no check supported
from .assign_add_ds import _assign_add_ds_tbe # "Frac_nz in pangu not support" from .assign_add_ds import _assign_add_ds_tbe # "Frac_nz in pangu not support"
from .assign import _assign_tbe # Different formats of assign inputs cause memory to increase
from .atomic_addr_clean import _atomic_addr_clean_tbe # need to clean addr larger than 2G, int32 is not enough

View File

@ -23,7 +23,7 @@ atomic_addr_clean_op_info = TBERegOp("AtomicAddrClean") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("atomic_addr_clean") \ .kernel_name("atomic_addr_clean") \
.partial_flag(True) \ .partial_flag(True) \
.attr("automic_add_mem_size", "required", "listInt", "all") \ .attr("automic_add_mem_size", "required", "listInt64", "all") \
.get_op_info() .get_op_info()

View File

@ -1186,6 +1186,7 @@ class _MirrorMicroStepOperator(PrimitiveWithInfer):
self.dev_num = dev_num self.dev_num = dev_num
self.mean_flag = mean_flag self.mean_flag = mean_flag
self.add_prim_attr('order_enforce_skip', True) self.add_prim_attr('order_enforce_skip', True)
self.add_prim_attr('side_effect_backprop_mem', True)
def infer_shape(self, x_shape, z_shape): def infer_shape(self, x_shape, z_shape):
return x_shape return x_shape

View File

@ -175,20 +175,26 @@ def _chunk_tensor_by_strategy(np_tensor, strategy):
return _chunk_tensor(np_tensor, strategy, len(strategy)) return _chunk_tensor(np_tensor, strategy, len(strategy))
def _get_slice_index(dev_mat, tensor_map): def _get_slice_index(dev_mat, tensor_map, opt_shard_group):
""" """
Get the slice index for current slice. Get the slice index for current slice.
Args: Args:
dev_mat (list): The device matrix of devices. dev_mat (list): The device matrix of devices.
tensor_map (list): The split strategy of tensor. tensor_map (list): The split strategy of tensor.
opt_shard_group(string): The group of optimizer shard
Returns: Returns:
Integer, the slice index for slice on this device. Integer, the slice index for slice on this device.
""" """
rank = get_rank() rank = get_rank()
dev_num = get_group_size()
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
if opt_shard_group:
tensor_slice_index += dev_num
opt_rank = get_rank(opt_shard_group)
tensor_slice_index += opt_rank
return tensor_slice_index return tensor_slice_index

View File

@ -67,7 +67,10 @@ def test_imagenet_to_mindrecord(fixture_file):
for i in range(PARTITION_NUMBER): for i in range(PARTITION_NUMBER):
assert os.path.exists(file_name + str(i)) assert os.path.exists(file_name + str(i))
assert os.path.exists(file_name + str(i) + ".db") assert os.path.exists(file_name + str(i) + ".db")
read(file_name + "0") read([file_name + "0",
file_name + "1",
file_name + "2",
file_name + "3"])
def test_imagenet_to_mindrecord_default_partition_number(fixture_file): def test_imagenet_to_mindrecord_default_partition_number(fixture_file):
""" """
@ -76,7 +79,7 @@ def test_imagenet_to_mindrecord_default_partition_number(fixture_file):
""" """
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR, imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR,
file_name) file_name, 1)
imagenet_transformer.transform() imagenet_transformer.transform()
assert os.path.exists(file_name) assert os.path.exists(file_name)
assert os.path.exists(file_name + ".db") assert os.path.exists(file_name + ".db")

View File

@ -1270,3 +1270,106 @@ def test_cv_file_overwrite_exception_02():
writer.write_raw_data(data) writer.write_raw_data(data)
assert 'Invalid file, mindrecord files already exist. Please check file path:' in str(err.value) assert 'Invalid file, mindrecord files already exist. Please check file path:' in str(err.value)
remove_multi_files(mindrecord_file_name, FILES_NUM) remove_multi_files(mindrecord_file_name, FILES_NUM)
def test_file_writer_parallel(file_name=None, remove_file=True):
"""
Feature: FileWriter
Description: parallel for writer
Expectation: generated mindrecord file
"""
if not file_name:
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
# single file
remove_one_file(file_name)
remove_one_file(file_name + ".db")
writer = FileWriter(file_name)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
for _ in range(5):
writer.write_raw_data(data, True)
writer.commit()
if remove_file:
remove_one_file(file_name)
remove_one_file(file_name + ".db")
# write_raw_data with empty
remove_one_file(file_name)
remove_one_file(file_name + ".db")
writer = FileWriter(file_name)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
with pytest.raises(RuntimeError):
writer.write_raw_data([])
# multi files
# len(data) > FILES_NUM which is parallel size
remove_multi_files(file_name, FILES_NUM)
writer = FileWriter(file_name, FILES_NUM)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
for _ in range(10):
writer.write_raw_data(data, True)
writer.commit()
if remove_file:
remove_multi_files(file_name, FILES_NUM)
# len(data) < FILES_NUM which is parallel size
remove_multi_files(file_name, FILES_NUM)
writer = FileWriter(file_name, FILES_NUM)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
for _ in range(2):
writer.write_raw_data(data[0:2], True)
writer.commit()
if remove_file:
remove_multi_files(file_name, FILES_NUM)
# write_raw_data(.., True) and write_raw_data(.., False)
remove_multi_files(file_name, FILES_NUM)
writer = FileWriter(file_name, FILES_NUM)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
with pytest.raises(RuntimeError):
writer.write_raw_data(data[0:2], True)
writer.write_raw_data(data[0:2])
# without write_raw_data
remove_multi_files(file_name, FILES_NUM)
writer = FileWriter(file_name, FILES_NUM)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
writer.commit()
if remove_file:
remove_multi_files(file_name, FILES_NUM)
# write_raw_data with empty
remove_multi_files(file_name, FILES_NUM)
writer = FileWriter(file_name, FILES_NUM)
data = get_data("../data/mindrecord/testImageNetData/")
cv_schema_json = {"file_name": {"type": "string"},
"label": {"type": "int64"}, "data": {"type": "bytes"}}
writer.add_schema(cv_schema_json, "img_schema")
writer.add_index(["file_name", "label"])
with pytest.raises(RuntimeError):
writer.write_raw_data([], True)
writer.commit()

View File

@ -0,0 +1,133 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import context, Model
from mindspore.common.api import _cell_graph_executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn.wrap.cell_wrapper import PipelineCell
from tests.ut.python.ops.test_math_ops import VirtualLoss
from parallel.utils.utils import ParallelValidator
from .test_pipeline_split import DatasetLenet
def setup_function():
context.set_auto_parallel_context(dataset_strategy="full_batch")
grad_all = C.GradOperation(get_all=True)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
def test_opt_parallel_without_grad():
"""
Feature: Test optimizer parallel with parameter's requires_grad=False.
Description: Need insert AllGather.
Expectation: Successful graph compilation.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.fc1 = P.MatMul().shard(((4, 1), (1, 2)))
self.fc2 = P.MatMul().shard(((2, 2), (2, 1)))
self.p1 = Parameter(Tensor(np.ones([1024, 1024]).astype(np.float32)), name="weight1", requires_grad=False)
self.p2 = Parameter(Tensor(np.ones([1024, 64]).astype(np.float32)), name="weight2")
def construct(self, x, y):
x = self.fc1(x, self.p1)
x = self.fc2(x, self.p2)
return x - y
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=8, global_rank=0, enable_parallel_optimizer=True)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([128, 1024]), dtype=ms.float32)
y = Tensor(np.ones([128, 64]), dtype=ms.float32)
net.set_train()
phase, _ = _cell_graph_executor.compile(net, x, y)
validator = ParallelValidator(net, phase)
expect_layout = ([4, 2], [-1, 0], [1024, 512], 0, True, '4-5226697808808137312')
assert validator.check_parameter_layout("network.network.p1", expect_layout)
def test_opt_parallel_without_grad_pipeline():
"""
Feature: Test optimizer parallel + pipeline with parameter's requires_grad=False.
Description: Need insert AllGather.
Expectation: Successful graph compilation.
"""
class MatMulNet(nn.Cell):
def __init__(self):
super().__init__()
self.fc1 = P.MatMul().shard(((4, 1), (1, 2)))
self.fc2 = P.MatMul().shard(((2, 2), (2, 1)))
self.p1 = Parameter(Tensor(np.ones([1024, 1024]).astype(np.float32)), name="weight1", requires_grad=False)
self.p2 = Parameter(Tensor(np.ones([1024, 1024]).astype(np.float32)), name="weight2")
def construct(self, x):
x = self.fc1(x, self.p1)
x = self.fc2(x, self.p2)
return x
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.block = nn.CellList()
for i in range(2):
cell = MatMulNet()
cell.pipeline_stage = i
self.block.append(cell)
def construct(self, x, y):
for i in range(2):
x = self.block[i](x)
return x
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=16, global_rank=0, enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", pipeline_stages=2)
net = PipelineCell(Net(), 4)
x = Tensor(np.ones([128, 1024]), dtype=ms.float32)
y = Tensor(np.ones([128, 128]), dtype=ms.float32)
dataset = DatasetLenet(x, y, 3)
optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
assert net.network.block[0].p1.shape == (256, 512)