diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index d05b9fafa10..0a23e2da7bc 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -82,6 +82,13 @@ bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel:: } return true; }; + if (AnfAlgo::GetCNodeName(kernel_node) == "Adam") { + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_num - 1) != + kernel_build_info.GetInputFormat(input_num - 1)) { + return false; + } + } if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);