!2960 Check attr existing before getting it in embeddinglookup cpu kernel

Merge pull request !2960 from YuJianfeng/master
This commit is contained in:
mindspore-ci-bot 2020-07-09 20:09:44 +08:00 committed by Gitee
commit 44269cd288
2 changed files with 8 additions and 2 deletions

View File

@ -36,7 +36,9 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
axis_ = 4 - input_shape_.size();
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag");
if (AnfAlgo::HasNodeAttr(kAttrReduceScatterFlag, kernel_node)) {
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrReduceScatterFlag);
}
#ifdef ENABLE_MPI
if (reduce_scatter_flag_) {
size_t gatherv2_out_lens = 1;
@ -65,7 +67,9 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true";
}
#endif
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "offset");
if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) {
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, kAttrOffset);
}
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
}

View File

@ -223,6 +223,8 @@ constexpr auto kAttrNumSplit = "num_split";
constexpr auto kAttrOutputNum = "output_num";
constexpr auto kAttrSizeSplits = "size_splits";
constexpr auto kAttrOutputDefault = "output_default";
constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
constexpr auto kAttrOffset = "offset";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";