diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/broadcast_to_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/broadcast_to_cpu_kernel.cc index 577af5e936b..7b930fa6cd4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/broadcast_to_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/broadcast_to_cpu_kernel.cc @@ -20,7 +20,6 @@ namespace mindspore { namespace kernel { namespace { -constexpr size_t kBroadcastToInputsNum = 1; constexpr size_t kBroadcastToOutputsNum = 1; } // namespace @@ -32,6 +31,21 @@ void BroadcastToCPUKernel::InitKernel(const CNodePtr &kernel_node) { output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); size_t input_shape_size = input_shape_.size(); size_t output_shape_size = output_shape_.size(); + + for (size_t i = 0; i < input_shape_size; ++i) { + shape_info_.input_shape_[i] = SizeToInt(input_shape_[i]); + } + for (size_t i = 0; i < output_shape_size; ++i) { + shape_info_.output_shape_[i] = SizeToInt(output_shape_[i]); + } + shape_info_.input_shape_size_ = SizeToInt(input_shape_size); + shape_info_.output_shape_size_ = SizeToInt(output_shape_size); +} + +template +void BroadcastToCPUKernel::CheckArgs() { + size_t input_shape_size = input_shape_.size(); + size_t output_shape_size = output_shape_.size(); if (output_shape_size < input_shape_size) { MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', input tensor 'input_x' and target shape 'shape' can't " @@ -56,22 +70,13 @@ void BroadcastToCPUKernel::InitKernel(const CNodePtr &kernel_node) { << Vector2Str(input_shape_) << ", and the dimension of target shape 'shape': " << Vector2Str(output_shape_); } } - - for (size_t i = 0; i < input_shape_size; ++i) { - shape_info_.input_shape_[i] = SizeToInt(input_shape_[i]); - } - for (size_t i = 0; i < output_shape_size; ++i) { - shape_info_.output_shape_[i] = SizeToInt(output_shape_[i]); - } - shape_info_.input_shape_size_ = SizeToInt(input_shape_size); - shape_info_.output_shape_size_ = SizeToInt(output_shape_size); } template bool BroadcastToCPUKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBroadcastToInputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBroadcastToOutputsNum, kernel_name_); + CheckArgs(); const auto *input_addr = reinterpret_cast(inputs[0]->addr); auto *output_addr = reinterpret_cast(outputs[0]->addr); int status = static_cast(NNACL_OK); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/broadcast_to_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/broadcast_to_cpu_kernel.h index eb0d24b5762..449f646733c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/broadcast_to_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/broadcast_to_cpu_kernel.h @@ -36,6 +36,8 @@ class BroadcastToCPUKernel : public CPUKernel { const std::vector &outputs) override; void InitKernel(const CNodePtr &kernel_node) override; + void CheckArgs(); + private: std::vector input_shape_; std::vector output_shape_; @@ -48,6 +50,18 @@ MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt32).Add BroadcastToCPUKernel, int); MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), BroadcastToCPUKernel, bool); +MS_REG_CPU_KERNEL_T( + DynamicBroadcastTo, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + BroadcastToCPUKernel, float); +MS_REG_CPU_KERNEL_T( + DynamicBroadcastTo, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastToCPUKernel, int); +MS_REG_CPU_KERNEL_T( + DynamicBroadcastTo, + KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + BroadcastToCPUKernel, bool); } // namespace kernel } // namespace mindspore