From 99e515bd3c584d2aec503d326c8c68ee292234f1 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Sat, 25 Feb 2023 19:07:12 +0800 Subject: [PATCH] adjust pp slice skip stra fix_opt_shard_param_init fix_opt_shard_with_no_grad lable switch condition with int64 Different formats of assign inputs cause memory cost to increase. make different models use different comm groups AtomicAddrClean uses list_int64 attr enhance mindrecord parallel write --- config/op_info.config | 1 + config/super_bar_config.json | 9 +- .../mindspore.mindrecord.ImageNetToMR.rst | 10 +- .../graph_util/pipeline_split_utils.cc | 4 + .../frontend/parallel/ops_info/gather_info.cc | 1 + .../parallel/ops_info/operator_info.cc | 14 +- .../parallel/ops_info/operator_info.h | 3 +- .../parallel_optimizer/opt_param_mgr.cc | 6 +- .../frontend/parallel/parameter_manager.cc | 5 +- .../ccsrc/frontend/parallel/step_parallel.cc | 16 +- .../mindrecord/common/shard_pybind.cc | 52 ++++- .../mindrecord/include/common/shard_utils.h | 3 + mindspore/ccsrc/pipeline/jit/action.cc | 3 +- .../device/ascend/hal/device/kernel_adjust.cc | 2 +- .../ascend/hal/device/kernel_build_ascend.cc | 12 +- .../hal/device/tasksink/task_generator.cc | 2 +- .../device/ascend/kernel/rts/label_switch.cc | 2 +- .../kernel/tbe/tbe_json/tbe_json_creator.cc | 5 +- .../kernel/tbe/tbe_json/tbe_json_utils.h | 1 + .../ascend/optimizer/ascend_comm_op_reuse.cc | 8 +- .../python/mindspore/common/parameter.py | 2 +- mindspore/python/mindspore/common/tensor.py | 12 +- .../python/mindspore/mindrecord/filewriter.py | 215 +++++++++++++++--- .../mindspore/mindrecord/shardwriter.py | 2 +- .../mindrecord/tools/imagenet_to_mr.py | 14 +- .../mindrecord/tools/tfrecord_to_mr.py | 4 +- .../mindspore/ops/_grad/grad_comm_ops.py | 22 +- .../mindspore/ops/_op_impl/tbe/__init__.py | 2 + .../ops/_op_impl/tbe/atomic_addr_clean.py | 2 +- .../mindspore/ops/operations/comm_ops.py | 1 + .../python/mindspore/parallel/_tensor.py | 8 +- .../mindrecord/test_imagenet_to_mindrecord.py | 7 +- .../python/mindrecord/test_mindrecord_base.py | 103 +++++++++ .../test_parallel_optimizer_without_grad.py | 133 +++++++++++ 34 files changed, 576 insertions(+), 110 deletions(-) create mode 100644 tests/ut/python/parallel/test_parallel_optimizer_without_grad.py diff --git a/config/op_info.config b/config/op_info.config index 9b8c9774037..8090d02711a 100644 --- a/config/op_info.config +++ b/config/op_info.config @@ -283,3 +283,4 @@ {"op_name": "QuantDTypeCast", "inputs": [{"index": 0, "name": "x", "param_type": "required"},{"index": 1, "name": "scales", "param_type": "required"},{"index": 2, "name": "zps", "param_type": "required"},{"index": 3, "name": "mean_corrs", "param_type": "required"},{"index": 4, "name": "var_corr", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "src_t", "type": "int"},{"name": "dst_t", "type": "int"},{"name": "axis", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]],[["int8", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "FRACTAL_NZ"]]], "imply_type": "AiCPU"} {"op_name": "FSEDecode", "inputs": [{"index": 0, "name": "x", "param_type": "required"},{"index": 1, "name": "states_table", "param_type": "required"},{"index": 2, "name": "bit_count_table", "param_type": "required"},{"index": 3, "name": "symbol_table", "param_type": "required"},{"index": 4, "name": "centroids", "param_type": "required"},{"index": 5, "name": "input_shape", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "dst_t", "type": "int"}, {"name": "curr_chunk", "type": "int"}, {"name": "curr_chunk_index", "type": "int"}, {"name": "curr_bit_count", "type": "int"}, {"name": "table_log", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]],[["int8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint16", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "FRACTAL_NZ"]]], "imply_type": "AiCPU"} {"op_name": "AssignAdd", "inputs": [{"index": 0, "name": "ref", "needCompile": false, "paramType": "required", "shape": "all"}, {"index": 1, "name": "value", "needCompile": false, "paramType": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "ref", "need_compile": false, "paramType": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"]], [["int8", "FRACTAL_Z"], ["int8", "FRACTAL_Z"], ["int8", "FRACTAL_Z"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"]], [["uint8", "FRACTAL_Z"], ["uint8", "FRACTAL_Z"], ["uint8", "FRACTAL_Z"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"]], [["int32", "FRACTAL_Z"], ["int32", "FRACTAL_Z"], ["int32", "FRACTAL_Z"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["int64", "NC1HWC0"], ["int64", "NC1HWC0"], ["int64", "NC1HWC0"]], [["int64", "C1HWNCoC0"], ["int64", "C1HWNCoC0"], ["int64", "C1HWNCoC0"]], [["int64", "FRACTAL_Z"], ["int64", "FRACTAL_Z"], ["int64", "FRACTAL_Z"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "FRACTAL_Z"], ["float16", "FRACTAL_Z"], ["float16", "FRACTAL_Z"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "FRACTAL_Z"], ["float32", "FRACTAL_Z"], ["float32", "FRACTAL_Z"]]], "imply_type": "TBE", "async_flag": false, "binfile": "assign_add.so", "compute_cost": 10, "kernel": "assign_add", "partial_flag": true, "reshape_type": "", "dynamicRankSupport": false, "dynamicShapeSupport": true, "dynamicCompileStatic": true, "needCheckSupport": false, "dynamicFormat": false, "op_pattern": "", "real_input_index": [], "input_to_attr_index": [], "unknown_shape_formats": []} +{'op_name': 'AtomicAddrClean', 'inputs': [], 'outputs': [], 'attr': [{'name': 'automic_add_mem_size', 'paramType': 'required', 'type': 'listInt64', 'value': 'all'}], 'fusion_type': 'ELEMWISE', 'dtype_format': [], 'imply_type': 'TBE', 'async_flag': False, 'binfile': 'atomic_addr_clean.so', 'compute_cost': 10, 'kernel': 'atomic_addr_clean', 'partial_flag': True, 'reshape_type': '', 'dynamicRankSupport': False, 'dynamicShapeSupport': False, 'dynamicCompileStatic': False, 'needCheckSupport': False, 'dynamicFormat': False, 'op_pattern': '', 'real_input_index': [], 'input_to_attr_index': [], 'unknown_shape_formats': []} diff --git a/config/super_bar_config.json b/config/super_bar_config.json index c311b05c8d7..e24d8e203ff 100644 --- a/config/super_bar_config.json +++ b/config/super_bar_config.json @@ -415,7 +415,8 @@ "TransData ": "support boll", "ScatterNdD ": "Accuracy issues", "Trace": "Hadn't adapted tbe implementation", - "AssignAdd": "Frac_nz in pangu not support" + "AssignAdd": "Frac_nz in pangu not support", + "AtomicAddrClean": "need to clean addr larger than 2G, int32 is not enough" }, "SkipNodes": [ "BroadcastTo", @@ -444,7 +445,9 @@ "ACos", "TransData", "ScatterNdD", - "AssignAdd" + "AssignAdd", + "Assign", + "AtomicAddrClean" ], "FallbackOps": { "DeformableOffsets": [ @@ -452,4 +455,4 @@ 2 ] } -} \ No newline at end of file +} diff --git a/docs/api/api_python/mindrecord/mindspore.mindrecord.ImageNetToMR.rst b/docs/api/api_python/mindrecord/mindspore.mindrecord.ImageNetToMR.rst index a161b9b9224..6afa7a8bc18 100644 --- a/docs/api/api_python/mindrecord/mindspore.mindrecord.ImageNetToMR.rst +++ b/docs/api/api_python/mindrecord/mindspore.mindrecord.ImageNetToMR.rst @@ -8,10 +8,12 @@ .. code-block:: - n02119789 0 - n02100735 1 - n02110185 2 - n02096294 3 + n01440764 0 + n01443537 1 + n01484850 2 + n01491361 3 + ... + n15075141 999 - **image_dir** (str) - ImageNet数据集的目录路径,目录中包含类似n02119789、n02100735、n02110185和n02096294的子目录。 - **destination** (str) - 转换生成的MindRecord文件路径,需提前创建目录并且目录下不能存在同名文件。 diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc index 500b8937d8c..c7f8a234bf8 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc @@ -104,6 +104,10 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) { if (skip_redis && !full_batch && input_strategy.size() > 0) { input_strategy[0] = dev_num < shape_list[1][0][0] ? dev_num : shape_list[1][0][0]; auto prim = GetCNodePrimitive(node); + if (prim->HasAttr("out_shard_size")) { + auto out_shard_size = GetValue(prim->GetAttr("out_shard_size")); + input_strategy[0] = out_shard_size; + } auto attrs = prim->attrs(); attrs[parallel::SKIP_REDISTRIBUTION] = MakeValue(true); prim->SetAttrs(attrs); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc index b1c39a90fce..79791a88d4c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc @@ -354,6 +354,7 @@ Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) { // parameter not split axis if (param_strategy.at(LongToSize(axis_)) == 1) { + SetAttribute(strategy); return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 3ac76459e1e..42c3bd86424 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -456,21 +456,11 @@ void AddCommOpMeanFlag(const CNodePtr &comm_node) { (void)prim->SetAttrs(attrs); } -void AddCommOpMirrorFlag(const CNodePtr &comm_node, bool do_mirror) { +void AddCNodePrimAttr(const CNodePtr &comm_node, const std::string &attr_name, const ValuePtr &attr_val) { MS_EXCEPTION_IF_NULL(comm_node); auto prim = GetValueNode(comm_node->input(0)); auto attrs = prim->attrs(); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - attrs[DO_MIRROR] = MakeValue(do_mirror); - (void)prim->SetAttrs(attrs); -} - -void AddCommOpAddAccuFlag(const CNodePtr &comm_node, bool add_accu) { - MS_EXCEPTION_IF_NULL(comm_node); - auto prim = GetValueNode(comm_node->input(0)); - auto attrs = prim->attrs(); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - attrs[ADD_ACCU] = MakeValue(add_accu); + attrs[attr_name] = attr_val; (void)prim->SetAttrs(attrs); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index 36aeaa82afc..d8d5e7dd9bf 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -351,9 +351,8 @@ Operator CreateAllGatherOp(const std::string &group); Operator CreateCastOp(TypePtr type); Operator CreateDivOp(float scale); Operator CreateMiniStepAllGatherOp(const std::string &group); +void AddCNodePrimAttr(const CNodePtr &comm_node, const std::string &attr_name, const ValuePtr &attr_val); int32_t AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node); -void AddCommOpMirrorFlag(const CNodePtr &comm_node, bool do_mirror); -void AddCommOpAddAccuFlag(const CNodePtr &comm_node, bool add_accu); Operator CreateMicroStepAllGatherOp(const std::string &group); void AddCommOpMeanFlag(const CNodePtr &comm_node); void AddCommOpParamFlag(const CNodePtr &comm_node); diff --git a/mindspore/ccsrc/frontend/parallel/parallel_optimizer/opt_param_mgr.cc b/mindspore/ccsrc/frontend/parallel/parallel_optimizer/opt_param_mgr.cc index 9450ba32f4e..da53a282635 100644 --- a/mindspore/ccsrc/frontend/parallel/parallel_optimizer/opt_param_mgr.cc +++ b/mindspore/ccsrc/frontend/parallel/parallel_optimizer/opt_param_mgr.cc @@ -117,9 +117,9 @@ class OptParamMgrImpl : public OptParamMgr { return false; } - if (!ParameterRequireGrad(parameter)) { - // only trainable parameters need parallel optimizer - MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; + auto param_ptr = parameter->cast(); + if ((!param_ptr) || (!param_ptr->has_default())) { + MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not a parameter."; return false; } diff --git a/mindspore/ccsrc/frontend/parallel/parameter_manager.cc b/mindspore/ccsrc/frontend/parallel/parameter_manager.cc index 38dab196c05..90f6dde1d2a 100644 --- a/mindspore/ccsrc/frontend/parallel/parameter_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/parameter_manager.cc @@ -396,10 +396,13 @@ void SliceParameterObj(const ParameterPtr ¶meter, const TensorLayoutPtr &ten // create python layout obj const auto &device_arrangement = tensor_layout->device_arrangement().array(); const auto &tensor_map = tensor_layout->tensor_map().array(); - const auto &slice_shape = tensor_layout->slice_shape().array(); + auto slice_shape = tensor_layout->slice_shape().array(); int64_t field_size = tensor_layout->get_field_size(); bool uniform_split = tensor_layout->uniform_split(); std::string opt_shard_group = tensor_layout->opt_shard_group(); + if (!opt_shard_group.empty()) { + slice_shape = tensor_layout->opt_shard_slice_shape(); + } py::tuple layout = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 1e3858fc2aa..a0edbc07c94 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1271,6 +1271,7 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group if (param_ptr->user_data()) { opt_shard_mirror_group = param_ptr->user_data()->opt_shard_mirror_group(); } + bool is_with_mirror = !opt_shard_mirror_group.empty(); if (!is_shared_param && cast_node) { allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root); MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name; @@ -1294,16 +1295,16 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group AddNodeFusionInfo(cnode, allgather, "reduce_scatter", fusion_id); // add gradients mean AddCommOpMeanFlag(allgather); + AddCNodePrimAttr(allgather, "with_mirror_operator", MakeValue(is_with_mirror)); if (op_name == MICRO_STEP_ALL_GATHER) { // When grad_accumulation_shard is enabled, the ReduceScatter is inserted at each micro step // so no need to do backward for the micro_step_allgather - AddCommOpMirrorFlag(allgather, !grad_accumulation_shard); + AddCNodePrimAttr(allgather, DO_MIRROR, MakeValue(!grad_accumulation_shard)); } else if (op_name == MINI_STEP_ALL_GATHER) { // We need to manually set the add_accu to be false if it's father node is MirrorMiniStep bool add_accu = root->has_flag(kAccumulation); - bool is_with_mirror = opt_shard_mirror_group.size() > 1; - AddCommOpAddAccuFlag(allgather, !add_accu && !is_with_mirror); - AddCommOpMirrorFlag(allgather, grad_accumulation_shard || !add_accu); + AddCNodePrimAttr(allgather, ADD_ACCU, MakeValue(!add_accu && !is_with_mirror)); + AddCNodePrimAttr(allgather, DO_MIRROR, MakeValue(!grad_accumulation_shard || !add_accu)); } } @@ -1311,17 +1312,20 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & const std::string &opt_shard_group) { int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); auto enable_opt_shard = ParallelContext::GetInstance()->enable_parallel_optimizer(); - if ((opt_shard_group.empty() && split_stage_num <= 1) || (!enable_opt_shard) || (!ParameterRequireGrad(parameter))) { + if ((opt_shard_group.empty() && split_stage_num <= 1) || (!enable_opt_shard)) { return; } + if (opt_shard_group.empty() && !ParameterRequireGrad(parameter)) { + return; + } // set all gather type MS_EXCEPTION_IF_NULL(parameter); int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); std::string op_name; if (grad_accumulation_step > 1) { op_name = MINI_STEP_ALL_GATHER; - } else if (split_stage_num > 1) { + } else if (split_stage_num > 1 && ParameterRequireGrad(parameter)) { op_name = MICRO_STEP_ALL_GATHER; } else { op_name = ALL_GATHER; diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc index 8bd2ddf28fd..dee78a25ec2 100644 --- a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc @@ -129,8 +129,9 @@ void BindShardWriter(py::module *m) { return SUCCESS; }) .def("write_raw_data", - [](ShardWriter &s, std::map> &raw_data, vector> &blob_data, + [](ShardWriter &s, std::map> &raw_data, vector &blob_data, bool sign, bool parallel_writer) { + // convert the raw_data from dict to json std::map> raw_data_json; (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), [](const std::pair> &p) { @@ -141,7 +142,54 @@ void BindShardWriter(py::module *m) { [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); return std::make_pair(p.first, std::move(json_raw_data)); }); - THROW_IF_ERROR(s.WriteRawData(raw_data_json, blob_data, sign, parallel_writer)); + + // parallel convert blob_data from vector to vector> + int32_t parallel_convert = kParallelConvert; + if (parallel_convert > blob_data.size()) { + parallel_convert = blob_data.size(); + } + parallel_convert = parallel_convert != 0 ? parallel_convert : 1; + std::vector thread_set(parallel_convert); + vector> vector_blob_data(blob_data.size()); + uint32_t step = uint32_t(blob_data.size() / parallel_convert); + if (blob_data.size() % parallel_convert != 0) { + step = step + 1; + } + for (int x = 0; x < parallel_convert; ++x) { + uint32_t start = x * step; + uint32_t end = ((x + 1) * step) < blob_data.size() ? ((x + 1) * step) : blob_data.size(); + thread_set[x] = std::thread([&vector_blob_data, &blob_data, start, end]() { + for (auto i = start; i < end; i++) { + char *buffer = nullptr; + ssize_t length = 0; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(blob_data[i].ptr(), &buffer, &length)) { + MS_LOG(ERROR) << "Unable to extract bytes contents!"; + return FAILED; + } + vector blob_data_item(length); + if (length < SECUREC_MEM_MAX_LEN) { + int ret_code = memcpy_s(&blob_data_item[0], length, buffer, length); + if (ret_code != EOK) { + MS_LOG(ERROR) << "memcpy_s failed for py::bytes to vector."; + return FAILED; + } + } else { + auto ret_code = std::memcpy(&blob_data_item[0], buffer, length); + if (ret_code != &blob_data_item[0]) { + MS_LOG(ERROR) << "memcpy failed for py::bytes to vector."; + return FAILED; + } + } + vector_blob_data[i] = blob_data_item; + } + }); + } + + // wait for the threads join + for (int x = 0; x < parallel_convert; ++x) { + thread_set[x].join(); + } + THROW_IF_ERROR(s.WriteRawData(raw_data_json, vector_blob_data, sign, parallel_writer)); return SUCCESS; }) .def("commit", [](ShardWriter &s) { diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h index 334f6360c40..4e5175043b6 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h @@ -160,6 +160,9 @@ const std::unordered_map kTypesMap = { /// \brief the max number of samples to enable lazy load const uint32_t LAZY_LOAD_THRESHOLD = 5000000; +/// \brief parallel convert from vector to vector> +const uint32_t kParallelConvert = 4; + /// \brief split a string using a character /// \param[in] field target string /// \param[in] separator a character for splitting diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 9fc4eace012..c1dab9a41f4 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -1500,7 +1500,8 @@ static std::vector CommonPipeline() { auto parallel_mode = parallel_context->parallel_mode(); const bool is_parallel_mode = parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel; - if (!is_cluster_initialized && !is_parallel_mode && pipeline::GetJitLevel() != "O0") { + static const auto combine_like_graphs = (common::GetEnv("COMBINE_LIKE_GRAPHS") == "1"); + if (!is_cluster_initialized && (!is_parallel_mode || combine_like_graphs) && pipeline::GetJitLevel() != "O0") { (void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); } diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_adjust.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_adjust.cc index 707861a7e21..a59c8f3f85b 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_adjust.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_adjust.cc @@ -567,7 +567,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr(RT_SWITCH_INT64); + int data_type = static_cast(RT_SWITCH_INT32); ValuePtr dt = MakeValue(data_type); common::AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); // set distinction label and graph id diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_build_ascend.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_build_ascend.cc index 021e7c98396..8146c8ef04a 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_build_ascend.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_build_ascend.cc @@ -239,11 +239,11 @@ bool IfAtomicOpNeedFusion(const size_t clean_total_num, const CNodePtr &first_no return false; } -std::vector GetClearSize(const CNodePtr &pre_node) { +std::vector GetClearSize(const CNodePtr &pre_node) { MS_EXCEPTION_IF_NULL(pre_node); auto kernel_mod = AnfAlgo::GetKernelMod(pre_node); MS_EXCEPTION_IF_NULL(kernel_mod); - std::vector clean_size_list; + std::vector clean_size_list; constexpr size_t kAlignBytes = 32 - 1; // clean output if (common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { @@ -251,7 +251,7 @@ std::vector GetClearSize(const CNodePtr &pre_node) { auto output_men_size = kernel_mod->GetOutputSizeList(); for (auto index : output_indexes) { auto clean_item = - SizeToInt((output_men_size.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize); + SizeToLong((output_men_size.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize); (void)clean_size_list.emplace_back(clean_item); } } @@ -261,7 +261,7 @@ std::vector GetClearSize(const CNodePtr &pre_node) { auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList(); for (const auto &index : workspace_indexes) { auto clean_item = - SizeToInt((workspace_men_sizes.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize); + SizeToLong((workspace_men_sizes.at(index) + kMemAlignSize + kAlignBytes) / kMemAlignSize * kMemAlignSize); (void)clean_size_list.emplace_back(clean_item); } } @@ -303,7 +303,7 @@ CNodePtr NewAtomicOp(const CNodePtr &pre_node, const std::vector &fu } void InsertFusionAtomicOp(const CNodePtr &first_clear_node, const std::vector &fusion_clear_inputs, - const std::vector &clean_size_list, CleanOpsMap *clean_ops) { + const std::vector &clean_size_list, CleanOpsMap *clean_ops) { MS_EXCEPTION_IF_NULL(first_clear_node); MS_EXCEPTION_IF_NULL(clean_ops); auto clear_zero = NewAtomicOp(first_clear_node, fusion_clear_inputs); @@ -355,7 +355,7 @@ void SpecialAkgOps(const std::string &op_name, const CNodePtr &node, CleanOpsMap void ProcessAtomicFusion(const std::vector &kernels, CleanOpsMap *clean_ops) { MS_EXCEPTION_IF_NULL(clean_ops); - std::vector clean_size_list; + std::vector clean_size_list; std::vector fusion_clear_inputs; CNodePtr first_node = nullptr; for (const auto &anf_node : kernels) { diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/task_generator.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/task_generator.cc index 1d35cd00b71..f458fb7c604 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/task_generator.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/task_generator.cc @@ -153,7 +153,7 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP MS_LOG(DEBUG) << "AtomicAddClean clean workspace size:" << clean_workspace_indexs.size(); } } - auto clear_mems = common::AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicAddMemSize); + auto clear_mems = common::AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicAddMemSize); if (kernel_inputs->size() != clear_mems.size()) { MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size, kernel inputs size:" << kernel_inputs->size() << ",clean mem size" << clear_mems.size(); diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/rts/label_switch.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/rts/label_switch.cc index 540928cfc6e..ef444ee8286 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/rts/label_switch.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/rts/label_switch.cc @@ -73,7 +73,7 @@ std::vector LabelSwitchKernel::GenTask(const std::vector> LabelSwitchDesc::GetKernelInfo(const CNodePtr &) { std::vector> label_switch_build_info{}; std::vector input_format{kOpFormat_DEFAULT}; - std::vector input_type{kNumberTypeInt32}; + std::vector input_type{kNumberTypeUInt64}; if (input_format.size() != input_type.size()) { MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " << input_type.size(); diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_creator.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_creator.cc index 02e0a90745e..f6ca203f2d7 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_creator.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_creator.cc @@ -48,6 +48,7 @@ static std::unordered_map type_attr_dtype_map = { {kVTypeFloat, ATTR_DTYPE::ATTR_FLOAT32}, {kVTypeListInt, ATTR_DTYPE::ATTR_LIST_INT32}, {kVTypeListFloat, ATTR_DTYPE::ATTR_LIST_FLOAT32}, + {kVTypeListInt64, ATTR_DTYPE::ATTR_LIST_INT64}, {kVTypeListUInt64, ATTR_DTYPE::ATTR_LIST_UINT64}, {kVTypeListListInt, ATTR_DTYPE::ATTR_LIST_LIST_INT64}}; @@ -181,6 +182,7 @@ bool ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, n case ATTR_DTYPE::ATTR_FLOAT32: return ParseAttrFloat(value, attr_obj); case ATTR_DTYPE::ATTR_LIST_INT32: + case ATTR_DTYPE::ATTR_LIST_INT64: return ParseAttrListInt(value, attr_obj); case ATTR_DTYPE::ATTR_LIST_FLOAT32: return ParseAttrListFloat(value, attr_obj); @@ -232,7 +234,8 @@ bool ParseAttrDefaultValue(const std::string &type, const std::string &value, nl case ATTR_DTYPE::ATTR_FLOAT32: (*attr_obj)[kJValue] = std::stof(value); break; - case ATTR_DTYPE::ATTR_LIST_INT32: { + case ATTR_DTYPE::ATTR_LIST_INT32: + case ATTR_DTYPE::ATTR_LIST_INT64: { std::stringstream string_value(value); std::string list_elem; std::vector attrs_value; diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_utils.h b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_utils.h index 438beeac044..f7e4fde0a2b 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_utils.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_json/tbe_json_utils.h @@ -60,6 +60,7 @@ constexpr auto kVTypeFloat32 = "float32"; constexpr auto kVTypeListInt = "listInt"; constexpr auto kVTypeInt32 = "Int32"; constexpr auto kVTypeInt64 = "Int64"; +constexpr auto kVTypeListInt64 = "listInt64"; constexpr auto kVTypeListUInt64 = "listUInt64"; constexpr auto kVTypeListFloat = "listFloat"; constexpr auto kVTypeListListInt = "listListInt"; diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_comm_op_reuse.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_comm_op_reuse.cc index 4558dd3ba77..3967f620062 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_comm_op_reuse.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_comm_op_reuse.cc @@ -40,7 +40,7 @@ std::string VecToString(const std::vector &vec) { return res; } -std::string GenCommOpKey(const CNodePtr &node) { +std::string GenCommOpKey(const CNodePtr &node, const KernelGraphPtr &root_graph) { std::string op_key; MS_EXCEPTION_IF_NULL(node); auto comm_prim = GetCNodePrimitive(node); @@ -68,6 +68,8 @@ std::string GenCommOpKey(const CNodePtr &node) { if (comm_prim->HasAttr(kAttrRecvRankIds)) { op_key += "_" + VecToString(GetValue>(comm_prim->GetAttr(kAttrRecvRankIds))); } + // model identifier, aka. root_graph_id + op_key += "_" + std::to_string(root_graph->root_graph_id()); MS_LOG(INFO) << node->DebugString() << " key " << op_key; return op_key; } @@ -198,7 +200,7 @@ void AscendCommOpReuse::AnalyseCommOpReuse() { if (!IsReusable(comm_op)) { continue; } - reuse_map[GenCommOpKey(comm_op)].push_back(comm_op); + reuse_map[GenCommOpKey(comm_op, root_graph_)].push_back(comm_op); } for (const auto &[key, comm_op_set] : reuse_map) { @@ -255,7 +257,7 @@ KernelGraphPtr AscendCommOpReuse::CreateCommSubGraph(const CNodePtr &comm_op) { MS_EXCEPTION_IF_NULL(new_comm_op); new_comm_op->set_abstract(comm_op->abstract()); - std::string group_name = GenCommOpKey(comm_op); + std::string group_name = GenCommOpKey(comm_op, root_graph_); auto rank_list = common::AnfAlgo::GetNodeAttr>(comm_op, kAttrRankList); if (!CommManager::GetInstance().CreateGroupSync(group_name, rank_list)) { MS_LOG(EXCEPTION) << "Create new group " << group_name << " failed, rank list = " << VecToString(rank_list); diff --git a/mindspore/python/mindspore/common/parameter.py b/mindspore/python/mindspore/common/parameter.py index 7295bdec16f..59845cb5232 100644 --- a/mindspore/python/mindspore/common/parameter.py +++ b/mindspore/python/mindspore/common/parameter.py @@ -710,7 +710,7 @@ class Parameter(Tensor_): raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout))) if len(layout) < 6: raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout))) - slice_index = int(_get_slice_index(layout[0], layout[1])) + slice_index = int(_get_slice_index(layout[0], layout[1], layout[5])) init_data_args += (slice_index, layout[2], layout[5]) return init_data_args diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 38bc927d83b..8b76764159c 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -20,7 +20,7 @@ import math import numbers import numpy as np -from mindspore.communication.management import get_rank, get_group_size +from mindspore.communication.management import get_group_size from mindspore.common._utils import is_shape_unknown, is_stub_tensor from mindspore.common.seed import get_seed from mindspore import context @@ -2280,9 +2280,9 @@ class Tensor(Tensor_): self._np_seed = np.random.get_state()[1][0] self.need_set_seed = (slice_index is not None) self._global_seed = global_seed - self._device_num = 1 + self._seed_offset = 1 if self.need_set_seed: - self._device_num = get_group_size() + self._seed_offset = get_group_size() * 2 def __enter__(self): if self.need_set_seed: @@ -2293,7 +2293,7 @@ class Tensor(Tensor_): else: np.random.seed(slice_index + Tensor.delta_seed) self.init.seed = slice_index + Tensor.delta_seed - Tensor.delta_seed += self._device_num + Tensor.delta_seed += self._seed_offset def __exit__(self, ptype, value, trace): if self.need_set_seed: @@ -2302,10 +2302,6 @@ class Tensor(Tensor_): with seed_context(self.init): self.init(data) - if opt_shard_group: - rank = get_rank(opt_shard_group) - size = get_group_size(opt_shard_group) - data = np.split(data, size)[rank] self.init = None # At embedding cache scenes. When size of tensor is out of range, we store data to persistent storage diff --git a/mindspore/python/mindspore/mindrecord/filewriter.py b/mindspore/python/mindspore/mindrecord/filewriter.py index 72ecac6a2fd..b2e1a56e5ce 100644 --- a/mindspore/python/mindspore/mindrecord/filewriter.py +++ b/mindspore/python/mindspore/mindrecord/filewriter.py @@ -17,8 +17,11 @@ This module is to write data into mindrecord. """ import os import platform +import queue import re import stat +import time +import multiprocessing as mp import numpy as np from mindspore import log as logger from .shardwriter import ShardWriter @@ -26,7 +29,7 @@ from .shardreader import ShardReader from .shardheader import ShardHeader from .shardindexgenerator import ShardIndexGenerator from .shardutils import MIN_SHARD_COUNT, MAX_SHARD_COUNT, VALID_ATTRIBUTES, VALID_ARRAY_ATTRIBUTES, \ - check_filename, VALUE_TYPE_MAP + check_filename, VALUE_TYPE_MAP, SUCCESS from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError __all__ = ['FileWriter'] @@ -103,6 +106,13 @@ class FileWriter: self._writer = ShardWriter() self._generator = None + # parallel write mode + self._parallel_writer = None + self._writers = None + self._queue = None + self._workers = None + self._index_workers = None + @classmethod def open_for_append(cls, file_name): r""" @@ -259,22 +269,67 @@ class FileWriter: MRMSetHeaderError: If failed to set header. MRMWriteDatasetError: If failed to write dataset. """ - if not self._writer.is_open: - self._writer.open(self._paths, self._overwrite) - if not self._writer.get_shard_header(): - self._writer.set_shard_header(self._header) - if not isinstance(raw_data, list): - raise ParamTypeError('raw_data', 'list') - if self._flush and not self._append: - raise RuntimeError("Not allowed to call `write_raw_data` on flushed MindRecord files." \ - "When creating new Mindrecord files, please remove `commit` before `write_raw_data`." \ - "In other cases, when appending to existing MindRecord files, " \ - "please call `open_for_append` first and then `write_raw_data`.") - for each_raw in raw_data: - if not isinstance(each_raw, dict): - raise ParamTypeError('raw_data item', 'dict') - self._verify_based_on_schema(raw_data) - return self._writer.write_raw_data(raw_data, True, parallel_writer) + if self._parallel_writer is None: + self._parallel_writer = parallel_writer + if self._parallel_writer != parallel_writer: + raise RuntimeError("The parameter `parallel_writer` must be consistent during use.") + if not self._parallel_writer: + if not self._writer.is_open: + self._writer.open(self._paths, self._overwrite) + if not self._writer.get_shard_header(): + self._writer.set_shard_header(self._header) + if not isinstance(raw_data, list): + raise ParamTypeError('raw_data', 'list') + if self._flush and not self._append: + raise RuntimeError("Not allowed to call `write_raw_data` on flushed MindRecord files." \ + "When creating new Mindrecord files, please remove `commit` before " \ + "`write_raw_data`. In other cases, when appending to existing MindRecord files, " \ + "please call `open_for_append` first and then `write_raw_data`.") + for each_raw in raw_data: + if not isinstance(each_raw, dict): + raise ParamTypeError('raw_data item', 'dict') + self._verify_based_on_schema(raw_data) + return self._writer.write_raw_data(raw_data, True, parallel_writer) + + ## parallel write mode + # init the _writers and launch the workers + if self._writers is None: + self._writers = [None] * len(self._paths) # writers used by worker + self._queue = mp.Queue(len(self._paths) * 2) # queue for worker + self._workers = [None] * len(self._paths) # worker process + for i, path in enumerate(self._paths): + self._writers[i] = ShardWriter() + self._writers[i].open([path], self._overwrite) + self._writers[i].set_shard_header(self._header) + + # launch the workers for parallel write + self._queue._joincancelled = True # pylint: disable=W0212 + p = mp.Process(target=self._write_worker, args=(i, self._queue)) + p.daemon = True + p.start() + logger.info("Start worker process(pid:{}) to parallel write.".format(p.pid)) + self._workers[i] = p + + # fill the self._queue + check_interval = 0.5 # 0.5s + start_time = time.time() + while True: + try: + self._queue.put(raw_data, block=False) + except queue.Full: + if time.time() - start_time > check_interval: + start_time = time.time() + logger.warning("Because there are too few MindRecord file shards, the efficiency of parallel " \ + "writing is too low. You can stop the current task and add the parameter " \ + "`shard_num` of `FileWriter` to upgrade the task.") + + # check the status of worker process + for i in range(len(self._paths)): + if not self._workers[i].is_alive(): + raise RuntimeError("Worker process(pid:{}) has stopped. Please check " \ + "the above log".format(self._workers[i].pid)) + continue + return SUCCESS def set_header_size(self, header_size): """ @@ -326,7 +381,7 @@ class FileWriter: """ return self._writer.set_page_size(page_size) - def commit(self): + def commit(self): # pylint: disable=W0212 """ Flush data in memory to disk and generate the corresponding database files. @@ -343,24 +398,35 @@ class FileWriter: MRMGenerateIndexError: If failed to write to database. MRMCommitError: If failed to flush data to disk. """ - self._flush = True - if not self._writer.is_open: - self._writer.open(self._paths, self._overwrite) - # permit commit without data - if not self._writer.get_shard_header(): - self._writer.set_shard_header(self._header) - ret = self._writer.commit() - if self._index_generator: - if self._append: - self._generator = ShardIndexGenerator(self._file_name, self._append) - elif len(self._paths) >= 1: - self._generator = ShardIndexGenerator(os.path.realpath(self._paths[0]), self._append) - self._generator.build() - self._generator.write_to_db() + if not self._parallel_writer: + self._flush = True + if not self._writer.is_open: + self._writer.open(self._paths, self._overwrite) + # permit commit without data + if not self._writer.get_shard_header(): + self._writer.set_shard_header(self._header) + self._writer.commit() + if self._index_generator: + if self._append: + self._generator = ShardIndexGenerator(self._file_name, self._append) + elif len(self._paths) >= 1: + self._generator = ShardIndexGenerator(os.path.realpath(self._paths[0]), self._append) + self._generator.build() + self._generator.write_to_db() + else: + # maybe a empty mindrecord, so need check _writers + if self._writers is None: + self._writers = [None] * len(self._paths) + for i, path in enumerate(self._paths): + self._writers[i] = ShardWriter() + self._writers[i].open(path, self._overwrite) + self._writers[i].set_shard_header(self._header) + self._parallel_commit() + + # change the file mode to 600 mindrecord_files = [] index_files = [] - # change the file mode to 600 for item in self._paths: if os.path.exists(item): os.chmod(item, stat.S_IRUSR | stat.S_IWUSR) @@ -373,7 +439,62 @@ class FileWriter: logger.info("The list of mindrecord files created are: {}, and the list of index files are: {}".format( mindrecord_files, index_files)) - return ret + return SUCCESS + + def _index_worker(self, i): + """The worker do the index generator""" + generator = ShardIndexGenerator(os.path.realpath(self._paths[i]), False) + generator.build() + generator.write_to_db() + + def _parallel_commit(self): + """Parallel commit""" + # send EOF to worker process + for _ in range(len(self._paths)): + while True: + try: + self._queue.put("EOF", block=False) + except queue.Full: + time.sleep(1) + continue + break + + # wait the worker processing + while True: + if not self._queue.empty(): + logger.info("Waiting for worker process write done.") + time.sleep(1) + continue + break + + del self._queue + + # wait for worker process stop + for index in range(len(self._paths)): + while True: + logger.info("Waiting for the worker process(pid:{}) to process all the data.".format( + self._workers[index].pid)) + if self._workers[index].is_alive(): + time.sleep(1) + continue + elif self._workers[index].exitcode != 0: + raise RuntimeError("Worker process(pid:{}) has stopped abnormal. Please check " \ + "the above log".format(self._workers[index].pid)) + break + + if self._index_generator: + # use parallel index workers to generator index + self._index_workers = [None] * len(self._paths) + for index in range(len(self._paths)): + p = mp.Process(target=self._index_worker, args=(index,)) + p.daemon = True + p.start() + logger.info("Start worker process(pid:{}) to generate index.".format(p.pid)) + self._index_workers[index] = p + + # wait the index workers stop + for index in range(len(self._paths)): + self._index_workers[index].join() def _validate_array(self, k, v): """ @@ -487,3 +608,29 @@ class FileWriter: error = "Field '{}' should be dict.".format(k) return False, error return True, error + + def _write_worker(self, i, in_queue): + """The worker do the data check and write to disk for parallel mode""" + while True: + # try to get new raw_data from master + try: + raw_data = in_queue.get() + except queue.Empty: + continue + + # get EOF from master, worker should commit and stop + if raw_data == "EOF": + ret = self._writers[i].commit() + if ret != SUCCESS: + raise RuntimeError("Commit the {}th shard of MindRecord file failed.".format(index)) + break + + # check the raw_data + if not isinstance(raw_data, list): + raise ParamTypeError('raw_data', 'list') + for each_raw in raw_data: + if not isinstance(each_raw, dict): + raise ParamTypeError('raw_data item', 'dict') + + self._verify_based_on_schema(raw_data) + self._writers[i].write_raw_data(raw_data, True, False) diff --git a/mindspore/python/mindspore/mindrecord/shardwriter.py b/mindspore/python/mindspore/mindrecord/shardwriter.py index 92f5cb63a4f..b97b8d88732 100644 --- a/mindspore/python/mindspore/mindrecord/shardwriter.py +++ b/mindspore/python/mindspore/mindrecord/shardwriter.py @@ -173,7 +173,7 @@ class ShardWriter: for item in data: row_blob = self._merge_blob({field: item[field] for field in self._header.blob_fields}) if row_blob: - blob_data.append(list(row_blob)) + blob_data.append(row_blob) # filter raw data according to schema row_raw = {field: self.convert_np_types(item[field]) for field in self._header.schema.keys() - self._header.blob_fields if field in item} diff --git a/mindspore/python/mindspore/mindrecord/tools/imagenet_to_mr.py b/mindspore/python/mindspore/mindrecord/tools/imagenet_to_mr.py index 9de1f9458e5..aa54383455f 100644 --- a/mindspore/python/mindspore/mindrecord/tools/imagenet_to_mr.py +++ b/mindspore/python/mindspore/mindrecord/tools/imagenet_to_mr.py @@ -38,10 +38,12 @@ class ImageNetToMR: .. code-block:: - n02119789 0 - n02100735 1 - n02110185 2 - n02096294 3 + n01440764 0 + n01443537 1 + n01484850 2 + n01491361 3 + ... + n15075141 999 image_dir (str): Image directory contains n02119789, n02100735, n02110185 and n02096294 directory. destination (str): MindRecord file path to transform into, ensure that the directory is created in advance and @@ -108,11 +110,11 @@ class ImageNetToMR: for _ in range(batch_size): data_list.append(imagenet_iter.__next__()) transform_count += 1 - self.writer.write_raw_data(data_list) + self.writer.write_raw_data(data_list, True) logger.info("transformed {} record...".format(transform_count)) except StopIteration: if data_list: - self.writer.write_raw_data(data_list) + self.writer.write_raw_data(data_list, True) logger.info("transformed {} record...".format(transform_count)) break diff --git a/mindspore/python/mindspore/mindrecord/tools/tfrecord_to_mr.py b/mindspore/python/mindspore/mindrecord/tools/tfrecord_to_mr.py index 6d1c2dc26b6..783c005ec5f 100644 --- a/mindspore/python/mindspore/mindrecord/tools/tfrecord_to_mr.py +++ b/mindspore/python/mindspore/mindrecord/tools/tfrecord_to_mr.py @@ -307,11 +307,11 @@ class TFRecordToMR: data_list.append(tf_iter.__next__()) transform_count += 1 - writer.write_raw_data(data_list) + writer.write_raw_data(data_list, True) logger.info("Transformed {} records...".format(transform_count)) except StopIteration: if data_list: - writer.write_raw_data(data_list) + writer.write_raw_data(data_list, True) logger.info("Transformed {} records...".format(transform_count)) break return writer.commit() diff --git a/mindspore/python/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/python/mindspore/ops/_grad/grad_comm_ops.py index 7700f7e918b..e1124aedc50 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/python/mindspore/ops/_grad/grad_comm_ops.py @@ -198,7 +198,6 @@ def get_bprop_mirror_micro_step_operator(self): assign.add_prim_attr("parameter_micro", 0) out_tensor = Tensor(1.0, mstype.float16) opt_shard = _get_enable_parallel_optimizer() - def bprop(x, z, out, dout): real_grad = z assign_out = dout @@ -207,16 +206,16 @@ def get_bprop_mirror_micro_step_operator(self): z = F.depend(z, dout) real_grad = all_reduce(z) real_grad = F.tensor_mul(real_grad, scale) - assign(z, real_grad) - assign_out = z + if opt_shard: + return (real_grad, cast(out_tensor, dtype(z))) + return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad)) else: if issubclass_(F.typeof(dout), mstype.tensor): z = F.depend(z, dout) real_grad = all_reduce(z) - assign(z, real_grad) - assign_out = z - if opt_shard: - return (real_grad, cast(out_tensor, dtype(z))) + if opt_shard: + return (real_grad, cast(out_tensor, dtype(z))) + return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad)) return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out) return bprop @@ -314,9 +313,18 @@ def get_bprop_micro_step_all_gather(self): cast = P.Cast() dtype = P.DType() out_tensor = Tensor(1.0, mstype.float16) + with_mirror_operator = self.get_attr_dict()["with_mirror_operator"] # z: accu_grad def bprop(x, z, out, dout): + if with_mirror_operator: + if not do_mirror: + return (dout, cast(out_tensor, dtype(z))) + real_grad = all_reduce(dout) + real_grad = split(real_grad)[rank] + if mean_flag: + real_grad = F.tensor_mul(real_grad, scale) + return (real_grad, cast(out_tensor, dtype(z))) z = F.depend(z, dout) if not do_mirror: return (z, cast(out_tensor, dtype(z))) diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py index ae53bc3ae0d..886c834f5f0 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py @@ -35,3 +35,5 @@ from .acos import _acos_tbe # Accuracy issues(task error in parallel) from .trans_data_ds import _trans_data_ds_tbe # support bool from .scatter_nd_d import _scatter_nd_d_tbe # in python no check supported from .assign_add_ds import _assign_add_ds_tbe # "Frac_nz in pangu not support" +from .assign import _assign_tbe # Different formats of assign inputs cause memory to increase +from .atomic_addr_clean import _atomic_addr_clean_tbe # need to clean addr larger than 2G, int32 is not enough diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/atomic_addr_clean.py b/mindspore/python/mindspore/ops/_op_impl/tbe/atomic_addr_clean.py index e707a1f26f7..909142f5e69 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/atomic_addr_clean.py @@ -23,7 +23,7 @@ atomic_addr_clean_op_info = TBERegOp("AtomicAddrClean") \ .compute_cost(10) \ .kernel_name("atomic_addr_clean") \ .partial_flag(True) \ - .attr("automic_add_mem_size", "required", "listInt", "all") \ + .attr("automic_add_mem_size", "required", "listInt64", "all") \ .get_op_info() diff --git a/mindspore/python/mindspore/ops/operations/comm_ops.py b/mindspore/python/mindspore/ops/operations/comm_ops.py index b522bddf1ee..f7d1be7d6d8 100644 --- a/mindspore/python/mindspore/ops/operations/comm_ops.py +++ b/mindspore/python/mindspore/ops/operations/comm_ops.py @@ -1186,6 +1186,7 @@ class _MirrorMicroStepOperator(PrimitiveWithInfer): self.dev_num = dev_num self.mean_flag = mean_flag self.add_prim_attr('order_enforce_skip', True) + self.add_prim_attr('side_effect_backprop_mem', True) def infer_shape(self, x_shape, z_shape): return x_shape diff --git a/mindspore/python/mindspore/parallel/_tensor.py b/mindspore/python/mindspore/parallel/_tensor.py index 12a96b9659f..072e7ffd36f 100644 --- a/mindspore/python/mindspore/parallel/_tensor.py +++ b/mindspore/python/mindspore/parallel/_tensor.py @@ -175,20 +175,26 @@ def _chunk_tensor_by_strategy(np_tensor, strategy): return _chunk_tensor(np_tensor, strategy, len(strategy)) -def _get_slice_index(dev_mat, tensor_map): +def _get_slice_index(dev_mat, tensor_map, opt_shard_group): """ Get the slice index for current slice. Args: dev_mat (list): The device matrix of devices. tensor_map (list): The split strategy of tensor. + opt_shard_group(string): The group of optimizer shard Returns: Integer, the slice index for slice on this device. """ rank = get_rank() + dev_num = get_group_size() tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank) + if opt_shard_group: + tensor_slice_index += dev_num + opt_rank = get_rank(opt_shard_group) + tensor_slice_index += opt_rank return tensor_slice_index diff --git a/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py b/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py index cc42239712f..f9733d470fc 100644 --- a/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py +++ b/tests/ut/python/mindrecord/test_imagenet_to_mindrecord.py @@ -67,7 +67,10 @@ def test_imagenet_to_mindrecord(fixture_file): for i in range(PARTITION_NUMBER): assert os.path.exists(file_name + str(i)) assert os.path.exists(file_name + str(i) + ".db") - read(file_name + "0") + read([file_name + "0", + file_name + "1", + file_name + "2", + file_name + "3"]) def test_imagenet_to_mindrecord_default_partition_number(fixture_file): """ @@ -76,7 +79,7 @@ def test_imagenet_to_mindrecord_default_partition_number(fixture_file): """ file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR, - file_name) + file_name, 1) imagenet_transformer.transform() assert os.path.exists(file_name) assert os.path.exists(file_name + ".db") diff --git a/tests/ut/python/mindrecord/test_mindrecord_base.py b/tests/ut/python/mindrecord/test_mindrecord_base.py index 93278d1eb6d..75222a80fb3 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_base.py +++ b/tests/ut/python/mindrecord/test_mindrecord_base.py @@ -1270,3 +1270,106 @@ def test_cv_file_overwrite_exception_02(): writer.write_raw_data(data) assert 'Invalid file, mindrecord files already exist. Please check file path:' in str(err.value) remove_multi_files(mindrecord_file_name, FILES_NUM) + + +def test_file_writer_parallel(file_name=None, remove_file=True): + """ + Feature: FileWriter + Description: parallel for writer + Expectation: generated mindrecord file + """ + if not file_name: + file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + + # single file + remove_one_file(file_name) + remove_one_file(file_name + ".db") + writer = FileWriter(file_name) + data = get_data("../data/mindrecord/testImageNetData/") + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + for _ in range(5): + writer.write_raw_data(data, True) + writer.commit() + if remove_file: + remove_one_file(file_name) + remove_one_file(file_name + ".db") + + # write_raw_data with empty + remove_one_file(file_name) + remove_one_file(file_name + ".db") + writer = FileWriter(file_name) + data = get_data("../data/mindrecord/testImageNetData/") + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + with pytest.raises(RuntimeError): + writer.write_raw_data([]) + + # multi files + # len(data) > FILES_NUM which is parallel size + remove_multi_files(file_name, FILES_NUM) + writer = FileWriter(file_name, FILES_NUM) + data = get_data("../data/mindrecord/testImageNetData/") + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + for _ in range(10): + writer.write_raw_data(data, True) + writer.commit() + if remove_file: + remove_multi_files(file_name, FILES_NUM) + + # len(data) < FILES_NUM which is parallel size + remove_multi_files(file_name, FILES_NUM) + writer = FileWriter(file_name, FILES_NUM) + data = get_data("../data/mindrecord/testImageNetData/") + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + for _ in range(2): + writer.write_raw_data(data[0:2], True) + writer.commit() + if remove_file: + remove_multi_files(file_name, FILES_NUM) + + # write_raw_data(.., True) and write_raw_data(.., False) + remove_multi_files(file_name, FILES_NUM) + writer = FileWriter(file_name, FILES_NUM) + data = get_data("../data/mindrecord/testImageNetData/") + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + with pytest.raises(RuntimeError): + writer.write_raw_data(data[0:2], True) + writer.write_raw_data(data[0:2]) + + # without write_raw_data + remove_multi_files(file_name, FILES_NUM) + writer = FileWriter(file_name, FILES_NUM) + data = get_data("../data/mindrecord/testImageNetData/") + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + writer.commit() + if remove_file: + remove_multi_files(file_name, FILES_NUM) + + # write_raw_data with empty + remove_multi_files(file_name, FILES_NUM) + writer = FileWriter(file_name, FILES_NUM) + data = get_data("../data/mindrecord/testImageNetData/") + cv_schema_json = {"file_name": {"type": "string"}, + "label": {"type": "int64"}, "data": {"type": "bytes"}} + writer.add_schema(cv_schema_json, "img_schema") + writer.add_index(["file_name", "label"]) + with pytest.raises(RuntimeError): + writer.write_raw_data([], True) + writer.commit() diff --git a/tests/ut/python/parallel/test_parallel_optimizer_without_grad.py b/tests/ut/python/parallel/test_parallel_optimizer_without_grad.py new file mode 100644 index 00000000000..0e6d792bb47 --- /dev/null +++ b/tests/ut/python/parallel/test_parallel_optimizer_without_grad.py @@ -0,0 +1,133 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore import context, Model +from mindspore.common.api import _cell_graph_executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.nn.wrap.cell_wrapper import PipelineCell +from tests.ut.python.ops.test_math_ops import VirtualLoss +from parallel.utils.utils import ParallelValidator +from .test_pipeline_split import DatasetLenet + + +def setup_function(): + context.set_auto_parallel_context(dataset_strategy="full_batch") + + +grad_all = C.GradOperation(get_all=True) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y): + return grad_all(self.network)(x, y) + + +def test_opt_parallel_without_grad(): + """ + Feature: Test optimizer parallel with parameter's requires_grad=False. + Description: Need insert AllGather. + Expectation: Successful graph compilation. + """ + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.fc1 = P.MatMul().shard(((4, 1), (1, 2))) + self.fc2 = P.MatMul().shard(((2, 2), (2, 1))) + self.p1 = Parameter(Tensor(np.ones([1024, 1024]).astype(np.float32)), name="weight1", requires_grad=False) + self.p2 = Parameter(Tensor(np.ones([1024, 64]).astype(np.float32)), name="weight2") + + def construct(self, x, y): + x = self.fc1(x, self.p1) + x = self.fc2(x, self.p2) + return x - y + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=8, global_rank=0, enable_parallel_optimizer=True) + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + + x = Tensor(np.ones([128, 1024]), dtype=ms.float32) + y = Tensor(np.ones([128, 64]), dtype=ms.float32) + net.set_train() + phase, _ = _cell_graph_executor.compile(net, x, y) + validator = ParallelValidator(net, phase) + expect_layout = ([4, 2], [-1, 0], [1024, 512], 0, True, '4-5226697808808137312') + assert validator.check_parameter_layout("network.network.p1", expect_layout) + + +def test_opt_parallel_without_grad_pipeline(): + """ + Feature: Test optimizer parallel + pipeline with parameter's requires_grad=False. + Description: Need insert AllGather. + Expectation: Successful graph compilation. + """ + class MatMulNet(nn.Cell): + def __init__(self): + super().__init__() + self.fc1 = P.MatMul().shard(((4, 1), (1, 2))) + self.fc2 = P.MatMul().shard(((2, 2), (2, 1))) + self.p1 = Parameter(Tensor(np.ones([1024, 1024]).astype(np.float32)), name="weight1", requires_grad=False) + self.p2 = Parameter(Tensor(np.ones([1024, 1024]).astype(np.float32)), name="weight2") + + def construct(self, x): + x = self.fc1(x, self.p1) + x = self.fc2(x, self.p2) + return x + + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.block = nn.CellList() + for i in range(2): + cell = MatMulNet() + cell.pipeline_stage = i + self.block.append(cell) + + def construct(self, x, y): + for i in range(2): + x = self.block[i](x) + return x + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=16, global_rank=0, enable_parallel_optimizer=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", pipeline_stages=2) + net = PipelineCell(Net(), 4) + x = Tensor(np.ones([128, 1024]), dtype=ms.float32) + y = Tensor(np.ones([128, 128]), dtype=ms.float32) + dataset = DatasetLenet(x, y, 3) + optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01) + model = Model(net, optimizer=optimizer) + model.train(2, dataset, dataset_sink_mode=False) + assert net.network.block[0].p1.shape == (256, 512)