diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index e2d01fb7795..3d9470e7d8e 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -44,6 +44,18 @@ Status GatherV2PInfo::GetAttrs() { } axis_ = axis; + // get target + auto target_iter = attrs_.find(TARGET); + if (target_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(target_iter->second); + if (target_iter->second->isa()) { + target_ = target_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of target is not a string."; + return FAILED; + } + } + return SUCCESS; } @@ -61,8 +73,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { auto param_shape = inputs_shape_.at(0); auto param_strategy = strategy->GetInputDim().at(0); auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1); - if (slice_shape % 8 != 0) { - MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned."; + if (slice_shape % 8 != 0 && slice_shape != 1) { + MS_LOG(DEBUG) << name_ << ": Last dim of param slice shape need 32Byte aligned."; return FAILED; } @@ -74,20 +86,20 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { // don't support scalar index if (inputs_shape_.at(1).size() == 0) { - MS_LOG(ERROR) << name_ << ": Don't support scalar index."; + MS_LOG(DEBUG) << name_ << ": Don't support scalar index."; return FAILED; } // axis=0, index_shape(0)%param_strategy(0) must be 0 Shape index_shape = inputs_shape_.at(1); if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) { - MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; + MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; return FAILED; } // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { - MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; + MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; return FAILED; } @@ -95,7 +107,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { auto index_strategy = strategy->GetInputDim().at(1); auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) { - MS_LOG(ERROR) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; + MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; return FAILED; } @@ -104,7 +116,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { - MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; + MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; return FAILED; } @@ -290,18 +302,85 @@ Status GatherV2PInfo::InferBias() { } Status GatherV2PInfo::InferGroup() { - std::vector group_list; auto param_strategy = strategy_->GetInputDim().at(0); size_t dim = IntToSize(axis_); if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { dim = (axis_ + 1) % 2; } - if (CreateGroupByDim(dim, &group_list) != SUCCESS) { + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + int32_t rank = g_device_manager->global_rank(); + RankList dev_list = g_device_manager->GetDeviceListByStageId(0); + DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); + RankList group_devices; + if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Create group failed."; return FAILED; } + if (group_devices.size() == 1) { + MS_LOG(INFO) << "the group is empty"; + return SUCCESS; + } - group_ = group_list.at(0); + group_ = g_device_manager->CreateGroup(group_devices); + return SUCCESS; +} + +std::vector GetRankFromGroup(const Group &group) { + std::vector rank_list; + auto device_list = group.GetDevicesList(); + for (auto &device : device_list) { + rank_list.insert(rank_list.end(), device.rank() % 8); + } + return rank_list; +} + +Status GatherV2PInfo::InferForwardCommunication() { + forward_op_.clear(); + if (target_ != CPU) { + return SUCCESS; + } + auto param_strategy = strategy_->GetInputDim().at(0); + // don't split axis, no need forward communication + if (param_strategy.at(IntToSize(axis_)) == 1) { + return SUCCESS; + } + // split axis + OperatorName operator_name; + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Group failed."; + return FAILED; + } + auto group_size = group_.GetDevNum(); + Attr attr_group; + // group size <= 8 + std::vector rank_list; + if (group_size <= 8) { + reduce_scatter_flag_ = false; + operator_name = HOST_REDUCE_SCATTER; + rank_list = GetRankFromGroup(group_); + attr_group = std::make_pair(GROUP, MakeValue(rank_list)); + } else { + // group size > 8 + reduce_scatter_flag_ = true; + split_num_ = SizeToInt(group_size / 8); + CheckGlobalDeviceManager(); + operator_name = REDUCE_SCATTER; + int32_t rank = g_device_manager->global_rank(); + size_t repeat = group_size / 8; + for (size_t i = 0; i < repeat; ++i) { + rank_list.push_back(rank + SizeToInt(i * 8)); + } + Group g = g_device_manager->CreateGroup(rank_list); + attr_group = std::make_pair(GROUP, MakeValue(g.name())); + } + Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); + OperatorAttrs attrs = {attr_op, attr_group}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + Operator op = std::make_pair(operator_name, args); + + forward_op_.push_back(op); return SUCCESS; } @@ -346,6 +425,10 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { auto param_strategy = strategy_->GetInputDim().at(0); + // target_ == CPU, no need to raplace graph + if (target_ == CPU) { + return nullptr; + } if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; return nullptr; @@ -353,11 +436,34 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { return replace_graph_; } +Status GatherV2PInfo::ComputeReplaceOp() { + if (InferBias() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer offset failed."; + return FAILED; + } + OperatorName op_name = EMBEDDING_LOOKUP; + OperatorAttrs attrs; + Attr param_offset = std::make_pair("offset", MakeValue(bias_)); + Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); + Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_)); + OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5), + std::make_pair(param_split_num, 6)}; + OperatorArgs args = std::make_pair(attrs, params); + Operator op = std::make_pair(op_name, args); + replace_op_.push_back(op); + + return SUCCESS; +} + Status GatherV2PInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; } + // only target_ == CPU, we need to replace op + if (target_ == CPU && ComputeReplaceOp() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed."; + } MS_LOG(INFO) << name_ << ": Init success."; return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h index a87b9838c9c..22aff16b493 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h @@ -49,7 +49,7 @@ class GatherV2PInfo : public OperatorInfo { protected: Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; - Status InferForwardCommunication() override { return SUCCESS; } + Status InferForwardCommunication() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; @@ -57,14 +57,18 @@ class GatherV2PInfo : public OperatorInfo { private: Status ComputeReplaceGraph(const CNodePtr &cnode); + Status ComputeReplaceOp(); Status InferBias(); Status InferGroup(); int32_t axis_; + std::string target_; int32_t bias_; int32_t slice_size_; Shape out_dev_matrix_shape_; Group group_; + bool reduce_scatter_flag_ = false; + int32_t split_num_ = 1; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 4da54a358d5..d0c874bb6f8 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -76,6 +76,8 @@ constexpr char DEPEND[] = "depend"; constexpr char BATCH_PARALLEL[] = "BatchParallel"; constexpr char ACTIVATION_TYPE[] = "activation_type"; +constexpr char TARGET[] = "target"; +constexpr char CPU[] = "CPU"; constexpr char TRANSPOSE_A[] = "transpose_a"; constexpr char TRANSPOSE_B[] = "transpose_b"; constexpr char SHAPE[] = "shape"; @@ -141,6 +143,8 @@ constexpr char MIRROR_OPERATOR[] = "_MirrorOperator"; constexpr char STRIDED_SLICE[] = "StridedSlice"; constexpr char ALL_GATHER[] = "AllGather"; constexpr char REDUCE_SCATTER[] = "ReduceScatter"; +constexpr char HOST_REDUCE_SCATTER[] = "HostReduceScatter"; +constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; constexpr char CONCAT[] = "Concat"; constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits"; constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits"; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index fd09b5e0b57..d11d78d9bd4 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -534,6 +534,10 @@ std::vector ReplaceOpInput(const Operator &replace_op, const std::st MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; } std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; + auto prim = GetValueNode(node->input(0)); + if (prim->name() == GATHERV2) { + replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)}; + } if (!params.empty()) { Param param_first = *(params.begin()); int32_t first_position = param_first.second; diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py index 2720cb33e17..c295bf93abe 100644 --- a/tests/ut/python/parallel/test_gather_v2.py +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -182,3 +182,39 @@ def test_gatherv2_auto1(): x = Tensor(np.ones([64, 32]), dtype=ms.float32) y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) _executor.compile(net, x, y) + + +def test_gatherv2_cpu0(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = NetWithLoss(Net(0, strategy1, strategy2)) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_cpu1(): + context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((16, 1), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = NetWithLoss(Net(0, strategy1, strategy2)) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_gatherv2_cpu2(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((1, 8), (1, 1)) + strategy2 = ((4, 2, 1), (4, 2, 1)) + net = NetWithLoss(Net(0, strategy1, strategy2)) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y)