diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc index 601d5cf1ea1..bfd13275486 100755 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc +++ b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc @@ -16,12 +16,30 @@ #include "kernel/hccl/hccl_kernel_metadata.h" #include +#include #include "utils/utils.h" #include "kernel/hccl/hcom_util.h" #include "session/anf_runtime_algorithm.h" namespace mindspore { namespace kernel { +namespace { +std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { + const std::set kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; + auto op_name = AnfAlgo::GetCNodeName(kernel_node); + auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); + if (op_name != kReduceScatter && op_name != kAllGatherOpName) { + return format; + } + if (format == kOpFormat_FRAC_NZ && AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index).size() <= 2) { + return kOpFormat_DEFAULT; + } + if (kReduceNoSupportedSet.find(format) != kReduceNoSupportedSet.end()) { + return kOpFormat_DEFAULT; + } + return format; +} +} // namespace void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { const std::vector kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt16}; @@ -36,13 +54,13 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector inputs_format{}; std::vector inputs_type{}; for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); + inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index)); inputs_type.push_back(type); } std::vector outputs_format; std::vector outputs_type; for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); + outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index)); outputs_type.push_back(type); } auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); diff --git a/tests/st/pynative/test_pynative_resnet50.py b/tests/st/pynative/test_pynative_resnet50.py index 21d761dfcc0..de9ecebb9cf 100644 --- a/tests/st/pynative/test_pynative_resnet50.py +++ b/tests/st/pynative/test_pynative_resnet50.py @@ -428,5 +428,5 @@ def test_pynative_resnet50(): cost_time = end_time - start_time print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) if step > 1: - assert cost_time < 0.5 + assert cost_time < 0.3 \ No newline at end of file