forked from mindspore-Ecosystem/mindspore
!2960 Check attr existing before getting it in embeddinglookup cpu kernel
Merge pull request !2960 from YuJianfeng/master
This commit is contained in:
commit
44269cd288
|
@ -36,7 +36,9 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
axis_ = 4 - input_shape_.size();
|
||||
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag");
|
||||
if (AnfAlgo::HasNodeAttr(kAttrReduceScatterFlag, kernel_node)) {
|
||||
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrReduceScatterFlag);
|
||||
}
|
||||
#ifdef ENABLE_MPI
|
||||
if (reduce_scatter_flag_) {
|
||||
size_t gatherv2_out_lens = 1;
|
||||
|
@ -65,7 +67,9 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true";
|
||||
}
|
||||
#endif
|
||||
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "offset");
|
||||
if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) {
|
||||
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, kAttrOffset);
|
||||
}
|
||||
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
|
||||
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
|
||||
}
|
||||
|
|
|
@ -223,6 +223,8 @@ constexpr auto kAttrNumSplit = "num_split";
|
|||
constexpr auto kAttrOutputNum = "output_num";
|
||||
constexpr auto kAttrSizeSplits = "size_splits";
|
||||
constexpr auto kAttrOutputDefault = "output_default";
|
||||
constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
|
||||
constexpr auto kAttrOffset = "offset";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
Loading…
Reference in New Issue