diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index eb3c9900f8..6c626e17e7 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -455,6 +455,9 @@ Status GatherV2PInfo::InferForwardCommunication() { MS_LOG(ERROR) << name_ << ": Infer Group failed."; return FAILED; } + if (group_.name().empty()) { + return SUCCESS; + } attr_group = std::make_pair(GROUP, MakeValue(group_.name())); Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); OperatorAttrs attrs = {attr_op, attr_group}; @@ -472,7 +475,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { MS_LOG(ERROR) << "GenerateGraph Init failed"; return FAILED; } - if (manual_split_) { + if (manual_split_ && target_ != CPU) { if (InferOffset() != SUCCESS) { MS_LOG(ERROR) << name_ << ": Infer Bias failed."; return FAILED; @@ -519,7 +522,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { } ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { - if (manual_split_) { + if (manual_split_ && target_ != CPU) { if (ComputeReplaceGraph(cnode) != SUCCESS) { MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; return nullptr; @@ -540,13 +543,24 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { } Status GatherV2PInfo::ComputeReplaceOp() { - if (InferBias() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer offset failed."; - return FAILED; + int32_t bias = 0; + if (manual_split_) { + if (InferOffset() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer offset failed."; + return FAILED; + } + bias = index_offset_; + } else { + if (InferBias() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer offset failed."; + return FAILED; + } + bias = bias_; } + OperatorName op_name = EMBEDDING_LOOKUP; OperatorAttrs attrs; - Attr param_offset = std::make_pair("offset", MakeValue(bias_)); + Attr param_offset = std::make_pair("offset", MakeValue(bias)); OperatorParams params = {std::make_pair(param_offset, 3)}; OperatorArgs args = std::make_pair(attrs, params); Operator op = std::make_pair(op_name, args); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h index 5c7ae10eb2..ed5b4e527e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -56,7 +56,6 @@ class GatherV2PInfo : public OperatorInfo { Status InferTensorMap() override; Status GetAttrs() override; - private: Status ComputeReplaceGraph(const CNodePtr &cnode); Status CheckManualSplit(); Status ComputeReplaceOp();