diff --git a/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc index 95a75e26882..76e91e059aa 100644 --- a/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc +++ b/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc @@ -109,6 +109,21 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector< } return true; } + +void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { + MS_EXCEPTION_IF_NULL(kernel_attr); + TypeId input_dtype = kernel_attr->GetInputAttr(0).first; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 1; i < input_num; ++i) { + kernel_attr->AddInputAttr(input_dtype); + } + + TypeId output_dtype = kernel_attr->GetOutputAttr(0).first; + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t i = 1; i < output_num; ++i) { + kernel_attr->AddOutputAttr(output_dtype); + } +} } // namespace void SetKernelInfo(const CNodePtr &kernel_node) { @@ -125,12 +140,16 @@ void SetKernelInfo(const CNodePtr &kernel_node) { kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); for (size_t index = 0; index < kernel_attrs.size(); ++index) { - if (IsInputFormatDtypeMatched(kernel_attrs[index], input_formats, input_types, input_not_cnode_indexes)) { + auto kernel_attr = kernel_attrs[index]; + if (kernel_attr.GetAllSame()) { + ExpandKernelAttr(kernel_node, &kernel_attr); + } + if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; - GetOutputFormatsAndDtypes(kernel_node, kernel_attrs[index], &output_formats, &output_types); - UpdatePrevNotCNodeFormatDtype(kernel_attrs[index], input_not_cnode_indexes, kernel_node); + GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); + UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); for (auto &input_index : input_not_cnode_indexes) { - input_types[input_index] = kernel_attrs[index].GetInputAttr(input_index).first; + input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; } break; } diff --git a/mindspore/ccsrc/device/cpu/kernel_select_cpu.h b/mindspore/ccsrc/device/cpu/kernel_select_cpu.h index 5242c7b34e6..d2138ec66da 100644 --- a/mindspore/ccsrc/device/cpu/kernel_select_cpu.h +++ b/mindspore/ccsrc/device/cpu/kernel_select_cpu.h @@ -46,8 +46,14 @@ class KernelAttr { return *this; } + KernelAttr &SetAllSameAttr(bool all_same) { + all_same_ = all_same; + return *this; + } + const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } + bool GetAllSame() const { return all_same_; } size_t GetInputSize() const { return input_type_.size(); } size_t GetOutputSize() const { return output_type_.size(); } @@ -55,6 +61,7 @@ class KernelAttr { private: std::vector input_type_; std::vector output_type_; + bool all_same_; }; } // namespace cpu } // namespace device diff --git a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.h index 70b65233c73..1a1a9157d9c 100644 --- a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.h @@ -39,16 +39,8 @@ class AddNCPUKernel : public CPUKernel { std::vector output_shape_; }; -MS_REG_CPU_KERNEL( - AddN, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AddNCPUKernel); MS_REG_CPU_KERNEL(AddN, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), AddNCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h index 1d74979b558..46f9078178d 100644 --- a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h @@ -41,10 +41,9 @@ class ConcatCPUKernel : public CPUKernel { std::vector output_shape_; }; -MS_REG_CPU_KERNEL( - Concat, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ConcatCPUKernel); +MS_REG_CPU_KERNEL(Concat, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConcatCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc index 55b37ba13fa..bcda7af9fd9 100644 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc +++ b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc @@ -59,28 +59,27 @@ std::pair CPUKernelFactory::CPUKernelAttrCheck(const std::string & auto creators = iter->second; for (size_t index = 0; index < creators.size(); ++index) { auto attr_creator = creators[index]; - if (CPUKernelSingleAttrCheck(attr_creator, kernel_info)) { + if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) { return std::make_pair(true, index); } } return std::make_pair(false, 0); } -bool CPUKernelFactory::CPUKernelSingleAttrCheck(const std::pair &attr_creator, - const KernelBuildInfo &kernel_info) { +bool CPUKernelFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) { for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) { - if (kernel_info.GetInputDeviceType(i) != attr_creator.first.GetInputAttr(i).first) { - MS_LOG(DEBUG) << "cpu kernel attr check failed. input index: " << i << "."; - MS_LOG(DEBUG) << "kernel info type:" << kernel_info.GetInputDeviceType(i) << ", " - << "register type:" << attr_creator.first.GetInputAttr(i).first; + auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first; + if (kernel_info.GetInputDeviceType(i) != dtype) { + MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i) + << ", register type:" << dtype; return false; } } for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) { - if (kernel_info.GetOutputDeviceType(i) != attr_creator.first.GetOutputAttr(i).first) { - MS_LOG(DEBUG) << "cpu kernel attr check failed. output index: " << i << "."; - MS_LOG(DEBUG) << "kernel info type:" << kernel_info.GetOutputDeviceType(i) << ", " - << "register type:" << attr_creator.first.GetOutputAttr(i).first; + auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first; + if (kernel_info.GetOutputDeviceType(i) != dtype) { + MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i) + << ", register type:" << dtype; return false; } } diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h index b3901d257ec..52eda12ba7c 100644 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h +++ b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h @@ -35,7 +35,6 @@ class CPUKernelFactory { public: static CPUKernelFactory &GetInstance(); void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator); - std::shared_ptr Create(const std::string &kernel_name); std::shared_ptr Create(const std::string &kernel_name, const CNodePtr &apply_kernel); std::vector GetSupportedKernelAttrList(const std::string &kernel_name); @@ -44,8 +43,7 @@ class CPUKernelFactory { ~CPUKernelFactory() = default; DISABLE_COPY_AND_ASSIGN(CPUKernelFactory) std::pair CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info); - bool CPUKernelSingleAttrCheck(const std::pair &attr_creator, - const KernelBuildInfo &kernel_info); + bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info); std::map>> name_to_attr_creator_; }; diff --git a/tests/st/ops/cpu/test_concat_op.py b/tests/st/ops/cpu/test_concat_op.py index 42ac8e08435..9d5067a35d6 100644 --- a/tests/st/ops/cpu/test_concat_op.py +++ b/tests/st/ops/cpu/test_concat_op.py @@ -71,13 +71,13 @@ def test_in2_axis1(): assert np.all(diff < error) assert np.all(-diff < error) -class Concat_Axis2(nn.Cell): +class Concat_in3_Axis2(nn.Cell): def __init__(self): - super(Concat_Axis2, self).__init__() + super(Concat_in3_Axis2, self).__init__() self.cat = P.Concat(axis=-1) - def construct(self, x1, x2): - return self.cat((x1, x2)) + def construct(self, x1, x2, x3): + return self.cat((x1, x2, x3)) @pytest.mark.level0 @pytest.mark.platform_x86_cpu @@ -86,10 +86,10 @@ def test_in3_axis2(): x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1), mstype.float32) x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32) x3 = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32) - cat = Concat_Axis2() - output_ms = cat(x1, x2) + cat = Concat_in3_Axis2() + output_ms = cat(x1, x2, x3) print("output:\n", output_ms) - output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=-1) + output_np = np.concatenate((x1.asnumpy(), x2.asnumpy(), x3.asnumpy()), axis=-1) error = np.ones(shape=output_np.shape) * 10e-6 diff = output_ms.asnumpy() - output_np