!3859 embeddinglookup support cpu in auto parallel

Merge pull request !3859 from yao_yf/embeddinglookup_parallel_ops_host_device
This commit is contained in:
mindspore-ci-bot 2020-08-03 15:51:44 +08:00 committed by Gitee
commit 8ff7c0b640
2 changed files with 20 additions and 7 deletions

View File

@ -455,6 +455,9 @@ Status GatherV2PInfo::InferForwardCommunication() {
MS_LOG(ERROR) << name_ << ": Infer Group failed."; MS_LOG(ERROR) << name_ << ": Infer Group failed.";
return FAILED; return FAILED;
} }
if (group_.name().empty()) {
return SUCCESS;
}
attr_group = std::make_pair(GROUP, MakeValue(group_.name())); attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
OperatorAttrs attrs = {attr_op, attr_group}; OperatorAttrs attrs = {attr_op, attr_group};
@ -472,7 +475,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
MS_LOG(ERROR) << "GenerateGraph Init failed"; MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED; return FAILED;
} }
if (manual_split_) { if (manual_split_ && target_ != CPU) {
if (InferOffset() != SUCCESS) { if (InferOffset() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer Bias failed."; MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
return FAILED; return FAILED;
@ -519,7 +522,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
} }
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
if (manual_split_) { if (manual_split_ && target_ != CPU) {
if (ComputeReplaceGraph(cnode) != SUCCESS) { if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
return nullptr; return nullptr;
@ -540,13 +543,24 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
} }
Status GatherV2PInfo::ComputeReplaceOp() { Status GatherV2PInfo::ComputeReplaceOp() {
if (InferBias() != SUCCESS) { int32_t bias = 0;
MS_LOG(ERROR) << name_ << ": Infer offset failed."; if (manual_split_) {
return FAILED; if (InferOffset() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer offset failed.";
return FAILED;
}
bias = index_offset_;
} else {
if (InferBias() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer offset failed.";
return FAILED;
}
bias = bias_;
} }
OperatorName op_name = EMBEDDING_LOOKUP; OperatorName op_name = EMBEDDING_LOOKUP;
OperatorAttrs attrs; OperatorAttrs attrs;
Attr param_offset = std::make_pair("offset", MakeValue(bias_)); Attr param_offset = std::make_pair("offset", MakeValue(bias));
OperatorParams params = {std::make_pair(param_offset, 3)}; OperatorParams params = {std::make_pair(param_offset, 3)};
OperatorArgs args = std::make_pair(attrs, params); OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(op_name, args); Operator op = std::make_pair(op_name, args);

View File

@ -56,7 +56,6 @@ class GatherV2PInfo : public OperatorInfo {
Status InferTensorMap() override; Status InferTensorMap() override;
Status GetAttrs() override; Status GetAttrs() override;
private:
Status ComputeReplaceGraph(const CNodePtr &cnode); Status ComputeReplaceGraph(const CNodePtr &cnode);
Status CheckManualSplit(); Status CheckManualSplit();
Status ComputeReplaceOp(); Status ComputeReplaceOp();