forked from OSSInnovation/mindspore
!3859 embeddinglookup support cpu in auto parallel
Merge pull request !3859 from yao_yf/embeddinglookup_parallel_ops_host_device
This commit is contained in:
commit
8ff7c0b640
|
@ -455,6 +455,9 @@ Status GatherV2PInfo::InferForwardCommunication() {
|
||||||
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
if (group_.name().empty()) {
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
||||||
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
||||||
OperatorAttrs attrs = {attr_op, attr_group};
|
OperatorAttrs attrs = {attr_op, attr_group};
|
||||||
|
@ -472,7 +475,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
if (manual_split_) {
|
if (manual_split_ && target_ != CPU) {
|
||||||
if (InferOffset() != SUCCESS) {
|
if (InferOffset() != SUCCESS) {
|
||||||
MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
|
MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -519,7 +522,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
if (manual_split_) {
|
if (manual_split_ && target_ != CPU) {
|
||||||
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
if (ComputeReplaceGraph(cnode) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
|
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -540,13 +543,24 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GatherV2PInfo::ComputeReplaceOp() {
|
Status GatherV2PInfo::ComputeReplaceOp() {
|
||||||
if (InferBias() != SUCCESS) {
|
int32_t bias = 0;
|
||||||
MS_LOG(ERROR) << name_ << ": Infer offset failed.";
|
if (manual_split_) {
|
||||||
return FAILED;
|
if (InferOffset() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Infer offset failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
bias = index_offset_;
|
||||||
|
} else {
|
||||||
|
if (InferBias() != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << name_ << ": Infer offset failed.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
bias = bias_;
|
||||||
}
|
}
|
||||||
|
|
||||||
OperatorName op_name = EMBEDDING_LOOKUP;
|
OperatorName op_name = EMBEDDING_LOOKUP;
|
||||||
OperatorAttrs attrs;
|
OperatorAttrs attrs;
|
||||||
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
|
Attr param_offset = std::make_pair("offset", MakeValue(bias));
|
||||||
OperatorParams params = {std::make_pair(param_offset, 3)};
|
OperatorParams params = {std::make_pair(param_offset, 3)};
|
||||||
OperatorArgs args = std::make_pair(attrs, params);
|
OperatorArgs args = std::make_pair(attrs, params);
|
||||||
Operator op = std::make_pair(op_name, args);
|
Operator op = std::make_pair(op_name, args);
|
||||||
|
|
|
@ -56,7 +56,6 @@ class GatherV2PInfo : public OperatorInfo {
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
Status GetAttrs() override;
|
Status GetAttrs() override;
|
||||||
|
|
||||||
private:
|
|
||||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||||
Status CheckManualSplit();
|
Status CheckManualSplit();
|
||||||
Status ComputeReplaceOp();
|
Status ComputeReplaceOp();
|
||||||
|
|
Loading…
Reference in New Issue