From 4de4eb1a4357f0d3c7941aa30ecda8387b4f71bb Mon Sep 17 00:00:00 2001 From: hedongdong Date: Wed, 22 Jun 2022 21:50:53 +0800 Subject: [PATCH] Refactor Ops MaxUnpool2D, MaxUnpool3D, MultiMarginLoss, MultilabelMarginLoss, TripletMarginLoss, BartlettWindow --- .../cpu/kernel/bartlett_window_cpu_kernel.cc | 155 +++++----- .../cpu/kernel/bartlett_window_cpu_kernel.h | 47 ++- .../cpu/kernel/max_unpool2d_cpu_kernel.cc | 119 ++++++-- .../cpu/kernel/max_unpool2d_cpu_kernel.h | 122 ++------ .../kernel/max_unpool2d_grad_cpu_kernel.cc | 208 +++++++++++-- .../cpu/kernel/max_unpool2d_grad_cpu_kernel.h | 194 ++---------- .../cpu/kernel/max_unpool3d_cpu_kernel.cc | 124 ++++++-- .../cpu/kernel/max_unpool3d_cpu_kernel.h | 122 ++------ .../kernel/max_unpool3d_grad_cpu_kernel.cc | 213 +++++++++++-- .../cpu/kernel/max_unpool3d_grad_cpu_kernel.h | 202 ++---------- .../kernel/multi_margin_loss_cpu_kernel.cc | 68 +++-- .../cpu/kernel/multi_margin_loss_cpu_kernel.h | 78 ++--- .../multi_margin_loss_grad_cpu_kernel.cc | 86 +++--- .../multi_margin_loss_grad_cpu_kernel.h | 104 +++---- .../device/cpu/kernel/mvlgamma_cpu_kernel.cc | 2 +- .../device/cpu/kernel/mvlgamma_cpu_kernel.h | 5 +- .../cpu/kernel/mvlgamma_grad_cpu_kernel.cc | 3 +- .../cpu/kernel/mvlgamma_grad_cpu_kernel.h | 5 +- .../kernel/triplet_margin_loss_cpu_kernel.cc | 288 +++++++++--------- .../kernel/triplet_margin_loss_cpu_kernel.h | 213 ++++++------- mindspore/core/ops/core_ops.h | 2 +- mindspore/core/ops/grad/max_unpool2d_grad.cc | 18 +- mindspore/core/ops/grad/max_unpool2d_grad.h | 19 +- mindspore/core/ops/grad/max_unpool3d_grad.cc | 19 +- mindspore/core/ops/grad/max_unpool3d_grad.h | 19 +- .../core/ops/grad/multi_margin_loss_grad.cc | 21 +- .../core/ops/grad/multi_margin_loss_grad.h | 19 +- .../ops/grad/multilabel_margin_loss_grad.cc | 12 +- .../ops/grad/multilabel_margin_loss_grad.h | 20 +- mindspore/core/ops/grad/mvlgamma_grad.cc | 11 +- mindspore/core/ops/grad/mvlgamma_grad.h | 2 +- mindspore/core/ops/max_unpool2d.cc | 47 +-- mindspore/core/ops/max_unpool2d.h | 19 +- mindspore/core/ops/max_unpool3d.cc | 47 +-- mindspore/core/ops/max_unpool3d.h | 19 +- mindspore/core/ops/multi_margin_loss.cc | 29 +- mindspore/core/ops/multi_margin_loss.h | 19 +- mindspore/core/ops/multilabel_margin_loss.cc | 10 +- mindspore/core/ops/multilabel_margin_loss.h | 25 +- mindspore/core/ops/mvlgamma.cc | 6 +- mindspore/core/ops/mvlgamma.h | 4 +- mindspore/core/ops/triplet_margin_loss.cc | 21 +- mindspore/core/ops/triplet_margin_loss.h | 19 +- mindspore/python/mindspore/nn/loss/loss.py | 77 +---- .../ops/_grad_experimental/grad_array_ops.py | 1 - .../ops/_grad_experimental/grad_nn_ops.py | 16 +- .../ops/_op_impl/aicpu/bartlett_window.py | 1 + .../ops/_op_impl/aicpu/max_unpool2d.py | 1 + .../ops/_op_impl/aicpu/max_unpool2d_grad.py | 1 + .../ops/_op_impl/aicpu/max_unpool3d.py | 1 + .../ops/_op_impl/aicpu/max_unpool3d_grad.py | 1 + .../ops/_op_impl/aicpu/multi_margin_loss.py | 1 + .../_op_impl/aicpu/multi_margin_loss_grad.py | 1 + .../aicpu/multilabel_margin_loss_grad.py | 3 +- .../mindspore/ops/_op_impl/aicpu/mvlgamma.py | 1 + .../ops/_op_impl/aicpu/mvlgamma_grad.py | 1 + .../mindspore/ops/_op_impl/tbe/__init__.py | 1 - .../_op_impl/tbe/multilabel_margin_loss.py | 2 +- .../python/mindspore/ops/operations/nn_ops.py | 12 +- tests/ut/python/ops/test_ops.py | 57 ++-- 60 files changed, 1426 insertions(+), 1537 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/bartlett_window_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/bartlett_window_cpu_kernel.cc index 79eb80a9e80..69f852bc7d1 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/bartlett_window_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/bartlett_window_cpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,11 @@ * limitations under the License. */ -#include "backend/kernel_compiler/cpu/bartlett_window_cpu_kernel.h" -#include "runtime/device/cpu/cpu_device_address.h" +#include +#include +#include +#include "plugin/device/cpu/kernel/bartlett_window_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" namespace mindspore { namespace kernel { @@ -24,102 +27,94 @@ constexpr size_t kBartlettWindowInputsNum = 1; constexpr size_t kBartlettWindowOutputsNum = 1; } // namespace -void BartlettWindowCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { +void BartlettWindowCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); - input_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); - output_dtype = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0); - periodic_ = AnfAlgo::GetNodeAttr(kernel_node, PERIODIC); - if ((input_dtype != kNumberTypeInt32) && (input_dtype != kNumberTypeInt64)) { - MS_LOG(EXCEPTION) << "Input tensor types must be int32 or int64"; - } - input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + periodic_ = common::AnfAlgo::GetNodeAttr(kernel_node, "periodic"); + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); if (input_shape.size() > 0) { - MS_EXCEPTION(ValueError) << "The dim of window_length must be 0."; + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dim of window_length should be 0, but got " + << input_shape.size(); } + node_wpt_ = kernel_node; + cnode_ptr_ = kernel_node; + auto kernel_attr = GetKernelAttrFromNode(kernel_node); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; } -// template -// void BartlettWindowCpuKernelMod::LaunchKernel(const std::vector &inputs, -// const std::vector &outputs) { -// auto input = reinterpret_cast(inputs[0]->addr); -// auto output = reinterpret_cast(outputs[0]->addr); -// auto input_data = *input; -// const size_t window_length = static_cast(*input); -// const S output_one = static_cast(1.); -// if (input_data < 0) { -// MS_EXCEPTION(ValueError) << "Input window_length must ≥ 0!"; -// } -// if (input_data == 1) { -// *output = output_one; -// } else { -// if (periodic_) { -// input_data += 1; -// } -// const size_t first_half_size = static_cast((input_data - 1) / 2); -// const double x = static_cast(input_data); -// for (size_t i = 0; i <= first_half_size; i++) { -// auto value = static_cast((2. * i) / (x - 1.)); -// *(output + i) = value; -// } -// for (size_t i = first_half_size + 1; i < window_length; i++) { -// auto value = static_cast(2. - (2. * i) / (x - 1.)); -// *(output + i) = value; -// } -// } -// } +template +bool BartlettWindowCpuKernelMod::BartlettWindowKernelFunc(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + auto node_ = cnode_ptr_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', node_wpt_ is expired."; + } -template -bool BartlettWindowCpuKernelMod::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBartlettWindowInputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBartlettWindowOutputsNum, kernel_name_); - auto input = reinterpret_cast(inputs[0]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - auto input_data = *input; - const size_t window_length = static_cast(*input); - const S output_one = static_cast(1.); - if (input_data < 0) { - MS_EXCEPTION(ValueError) << "Input window_length must ≥ 0!"; + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + + if (*input < 0) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', input window_length should be >= 0, but got " << *input; } - if (input_data == 1) { - *output = output_one; + + auto window_length = static_cast(*input); + double pre_window_length = static_cast(window_length); + const size_t OUTPUTISONE = 1.0; + + ShapeVector out_shape = {window_length}; + std::vector dtypes = {AnfAlgo::GetOutputDeviceDataType(node_, 0)}; + + if (*input == 1) { + *output = static_cast(OUTPUTISONE); } else { if (periodic_) { - input_data += 1; + window_length += 1; } - const size_t first_half_size = static_cast((input_data - 1) / 2); - const double x = static_cast(input_data); + const size_t first_half_size = static_cast((window_length - 1) / 2); + const double x = static_cast(window_length); for (size_t i = 0; i <= first_half_size; i++) { - auto value = static_cast((2. * i) / (x - 1.)); + auto value = static_cast((2. * i) / (x - 1.)); *(output + i) = value; } - for (size_t i = first_half_size + 1; i < window_length; i++) { - auto value = static_cast(2. - (2. * i) / (x - 1.)); + for (size_t i = first_half_size + 1; i < pre_window_length; i++) { + auto value = static_cast(2. - (2. * i) / (x - 1.)); *(output + i) = value; } } - // if (output_dtype == kNumberTypeFloat16) { - // if (input_dtype == kNumberTypeInt32) { - // LaunchKernel(inputs, outputs); - // } else if (input_dtype == kNumberTypeInt64) { - // LaunchKernel(inputs, outputs); - // } - // } else if (output_dtype == kNumberTypeFloat32) { - // if (input_dtype == kNumberTypeInt32) { - // LaunchKernel(inputs, outputs); - // } else if (input_dtype == kNumberTypeInt64) { - // LaunchKernel(inputs, outputs); - // } - // } else if (output_dtype == kNumberTypeFloat64) { - // if (input_dtype == kNumberTypeInt32) { - // LaunchKernel(inputs, outputs); - // } else if (input_dtype == kNumberTypeInt64) { - // LaunchKernel(inputs, outputs); - // } - // } + + common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape}, node_.get()); return true; } + +std::vector> + BartlettWindowCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &BartlettWindowCpuKernelMod::BartlettWindowKernelFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &BartlettWindowCpuKernelMod::BartlettWindowKernelFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + &BartlettWindowCpuKernelMod::BartlettWindowKernelFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &BartlettWindowCpuKernelMod::BartlettWindowKernelFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + &BartlettWindowCpuKernelMod::BartlettWindowKernelFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &BartlettWindowCpuKernelMod::BartlettWindowKernelFunc}}; + +std::vector BartlettWindowCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BartlettWindow, BartlettWindowCpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/bartlett_window_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/bartlett_window_cpu_kernel.h index 46bf1baea0a..e74185ab3e9 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/bartlett_window_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/bartlett_window_cpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,46 +13,45 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BARTLETT_WINDOW_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BARTLETT_WINDOW_CPU_KERNEL_H_ #include -#include "backend/kernel_compiler/cpu/cpu_kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { -template -class BartlettWindowCpuKernelMod : public NativeCpuKernelMod { +class BartlettWindowCpuKernelMod : public DeprecatedNativeCpuKernelMod { public: BartlettWindowCpuKernelMod() = default; ~BartlettWindowCpuKernelMod() override = default; void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; + const std::vector &outputs) override { + return kernel_func_(this, inputs, workspace, outputs); + } + + protected: + std::vector GetOpSupport() override; private: - // template - // void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + template + bool BartlettWindowKernelFunc(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs); bool periodic_{true}; - TypeId output_dtype{kNumberTypeFloat32}; TypeId input_dtype{kTypeUnknown}; - std::vector input_shape; + using BartlettWindowFunc = + std::function &, + const std::vector &, const std::vector &)>; + static std::vector> func_list_; + BartlettWindowFunc kernel_func_; + ShapeVector input_shape; + CNodePtr node_wpt_; }; - -MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - BartlettWindowCpuKernelMod, int32_t, float); -MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - BartlettWindowCpuKernelMod, int32_t, float16); -MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - BartlettWindowCpuKernelMod, int32_t, double); -MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - BartlettWindowCpuKernelMod, int64_t, float); -MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - BartlettWindowCpuKernelMod, int64_t, float16); -MS_REG_CPU_KERNEL_T_S(BartlettWindow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - BartlettWindowCpuKernelMod, int64_t, double); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BARTLETT_WINDOW_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.cc index 472e955e37d..756f33b7e99 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.cc @@ -15,8 +15,10 @@ */ #include #include -#include "backend/kernel_compiler/cpu/max_unpool2d_cpu_kernel.h" -#include "runtime/device/cpu/cpu_device_address.h" +#include +#include +#include "plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" namespace mindspore { namespace kernel { @@ -29,28 +31,40 @@ constexpr size_t kInputIndex2 = 2; constexpr size_t kInputIndex3 = 3; } // namespace -template -void MaxUnpool2DCPUKernel::InitKernel(const CNodePtr &kernel_node) { +void MaxUnpool2DCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); node_wpt_ = kernel_node; input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0); indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1); output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0); - data_format_ = AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + data_format_ = common::AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + + if (AnfAlgo::IsShapesDynamic({input_shape_, indices_shape_, output_shape_})) { + return; + } + + auto kernel_attr = GetKernelAttrFromNode(kernel_node); + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list); + if (!is_match) { + MS_LOG(EXCEPTION) << "MaxUnpool2D does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; } -template -void MaxUnpool2DCPUKernel::OutPutInitKernel(DATA_T *raw_output, size_t length) { +template +void MaxUnpool2DCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) { for (size_t s = 0; s < length; s++) { raw_output[s] = (DATA_T)0; } } template -bool MaxUnpool2DCPUKernel::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { +bool MaxUnpool2DCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { auto node = node_wpt_.lock(); if (!node) { MS_LOG(EXCEPTION) << "node_wpt_ is expired."; @@ -66,14 +80,14 @@ bool MaxUnpool2DCPUKernel::Launch(const std::vector(inputs[kInputIndex1]->addr); auto *raw_output = reinterpret_cast(outputs[kInputIndex0]->addr); if (data_format_ == "NHWC") { - size_t num_batch = input_shape_[kInputIndex0]; - size_t input_height = input_shape_[kInputIndex1]; - size_t input_width = input_shape_[kInputIndex2]; - size_t num_channels = input_shape_[kInputIndex3]; - size_t oheight = output_shape_[kInputIndex1]; - size_t owidth = output_shape_[kInputIndex2]; + size_t num_batch = LongToSize(input_shape_[kInputIndex0]); + size_t input_height = LongToSize(input_shape_[kInputIndex1]); + size_t input_width = LongToSize(input_shape_[kInputIndex2]); + size_t num_channels = LongToSize(input_shape_[kInputIndex3]); + size_t oheight = LongToSize(output_shape_[kInputIndex1]); + size_t owidth = LongToSize(output_shape_[kInputIndex2]); size_t length = num_batch * oheight * owidth * num_channels; - OutPutInitKernel(raw_output, length); + OutPutInitKernel(raw_output, length); for (size_t n = 0; n < num_batch; n++) { size_t noutput_offset = n * num_channels * owidth * oheight; size_t n_input_offset = n * num_channels * input_width * input_height; @@ -100,14 +114,14 @@ bool MaxUnpool2DCPUKernel::Launch(const std::vector(raw_output, length); for (size_t n = 0; n < num_batch; n++) { size_t noutput_offset = n * num_channels * owidth * oheight; size_t n_input_offset = n * num_channels * input_width * input_height; @@ -139,5 +153,60 @@ bool MaxUnpool2DCPUKernel::Launch(const std::vector> MaxUnpool2DCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + &MaxUnpool2DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &MaxUnpool2DCpuKernelMod::LaunchKernel}}; + +std::vector MaxUnpool2DCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool2D, MaxUnpool2DCpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.h index 3b0ab763d9b..c4f5ea98481 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_cpu_kernel.h @@ -19,117 +19,43 @@ #include #include #include -#include "backend/kernel_compiler/cpu/cpu_kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { -template -class MaxUnpool2DCPUKernel : public CPUKernel { +class MaxUnpool2DCpuKernelMod : public DeprecatedNativeCpuKernelMod { public: - MaxUnpool2DCPUKernel() = default; - ~MaxUnpool2DCPUKernel() override = default; + MaxUnpool2DCpuKernelMod() = default; + ~MaxUnpool2DCpuKernelMod() override = default; void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + }; + + protected: + std::vector GetOpSupport() override; private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + using MaxUnpool2DFunc = std::function &, + const std::vector &)>; + static std::vector> func_list_; + MaxUnpool2DFunc kernel_func_; + + template void OutPutInitKernel(DATA_T *rawOutput, size_t length); CNodeWeakPtr node_wpt_; - std::vector input_shape_; - std::vector indices_shape_; - std::vector output_shape_; + ShapeVector input_shape_; + ShapeVector indices_shape_; + ShapeVector output_shape_; std::string data_format_; }; - -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), - MaxUnpool2DCPUKernel, uint8_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), - MaxUnpool2DCPUKernel, uint8_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), - MaxUnpool2DCPUKernel, uint16_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), - MaxUnpool2DCPUKernel, uint16_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), - MaxUnpool2DCPUKernel, uint32_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), - MaxUnpool2DCPUKernel, uint32_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), - MaxUnpool2DCPUKernel, uint64_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), - MaxUnpool2DCPUKernel, uint64_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), - MaxUnpool2DCPUKernel, int8_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), - MaxUnpool2DCPUKernel, int8_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), - MaxUnpool2DCPUKernel, int16_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), - MaxUnpool2DCPUKernel, int16_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - MaxUnpool2DCPUKernel, int32_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - MaxUnpool2DCPUKernel, int32_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), - MaxUnpool2DCPUKernel, int64_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - MaxUnpool2DCPUKernel, int64_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - MaxUnpool2DCPUKernel, float16, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - MaxUnpool2DCPUKernel, float16, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - MaxUnpool2DCPUKernel, float, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - MaxUnpool2DCPUKernel, float, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - MaxUnpool2DCPUKernel, double, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool2D, - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - MaxUnpool2DCPUKernel, double, int64_t); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.cc index a763ed44c73..705f0dba778 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.cc @@ -15,8 +15,10 @@ */ #include #include -#include "backend/kernel_compiler/cpu/max_unpool2d_grad_cpu_kernel.h" -#include "runtime/device/cpu/cpu_device_address.h" +#include +#include +#include "plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" namespace mindspore { namespace kernel { @@ -28,29 +30,42 @@ constexpr size_t kInputIndex1 = 1; constexpr size_t kInputIndex2 = 2; constexpr size_t kInputIndex3 = 3; } // namespace -template -void MaxUnpool2DGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + +void MaxUnpool2DGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); node_wpt_ = kernel_node; input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0); grads_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1); indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex2); output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0); - data_format_ = AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + data_format_ = common::AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + + if (AnfAlgo::IsShapesDynamic({input_shape_, grads_shape_, indices_shape_, output_shape_})) { + return; + } + + auto kernel_attr = GetKernelAttrFromNode(kernel_node); + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list); + if (!is_match) { + MS_LOG(EXCEPTION) << "MaxUnpool2DGrad does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; } -template -void MaxUnpool2DGradCPUKernel::OutPutInitKernel(DATA_T *raw_output, size_t length) { +template +void MaxUnpool2DGradCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) { for (size_t s = 0; s < length; s++) { raw_output[s] = (DATA_T)0; } } template -bool MaxUnpool2DGradCPUKernel::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { +bool MaxUnpool2DGradCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { auto node = node_wpt_.lock(); if (!node) { MS_LOG(EXCEPTION) << "node_wpt_ is expired."; @@ -66,14 +81,14 @@ bool MaxUnpool2DGradCPUKernel::Launch(const std::vector(inputs[kInputIndex2]->addr); auto *raw_output = reinterpret_cast(outputs[kInputIndex0]->addr); if (data_format_ == "NHWC") { - size_t num_batch = grads_shape_[kInputIndex0]; - size_t oheight = grads_shape_[kInputIndex1]; - size_t owidth = grads_shape_[kInputIndex2]; - size_t num_channels = grads_shape_[kInputIndex3]; - size_t iheight = output_shape_[kInputIndex1]; - size_t iwidth = output_shape_[kInputIndex2]; + size_t num_batch = LongToSize(grads_shape_[kInputIndex0]); + size_t oheight = LongToSize(grads_shape_[kInputIndex1]); + size_t owidth = LongToSize(grads_shape_[kInputIndex2]); + size_t num_channels = LongToSize(grads_shape_[kInputIndex3]); + size_t iheight = LongToSize(output_shape_[kInputIndex1]); + size_t iwidth = LongToSize(output_shape_[kInputIndex2]); size_t length = num_batch * iheight * iwidth * num_channels; - OutPutInitKernel(raw_output, length); + OutPutInitKernel(raw_output, length); for (size_t n = 0; n < num_batch; n++) { size_t noutput_offset = n * num_channels * iwidth * iheight; size_t n_grads_offset = n * num_channels * owidth * oheight; @@ -100,14 +115,14 @@ bool MaxUnpool2DGradCPUKernel::Launch(const std::vector(raw_output, length); for (size_t n = 0; n < num_batch; n++) { size_t noutput_offset = n * num_channels * iwidth * iheight; size_t n_grads_offset = n * num_channels * owidth * oheight; @@ -139,5 +154,148 @@ bool MaxUnpool2DGradCPUKernel::Launch(const std::vector> + MaxUnpool2DGradCpuKernelMod::func_list_ = {{KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt8), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt8), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt16), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt16), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt32), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt32), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt64), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt64), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt8), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt8), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt16), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &MaxUnpool2DGradCpuKernelMod::LaunchKernel}}; + +std::vector MaxUnpool2DGradCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool2DGrad, MaxUnpool2DGradCpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.h index 656b5ffc913..f5c5cfaf830 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool2d_grad_cpu_kernel.h @@ -19,187 +19,45 @@ #include #include #include -#include "backend/kernel_compiler/cpu/cpu_kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { -template -class MaxUnpool2DGradCPUKernel : public CPUKernel { +class MaxUnpool2DGradCpuKernelMod : public DeprecatedNativeCpuKernelMod { public: - MaxUnpool2DGradCPUKernel() = default; - ~MaxUnpool2DGradCPUKernel() override = default; + MaxUnpool2DGradCpuKernelMod() = default; + ~MaxUnpool2DGradCpuKernelMod() override = default; void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + }; + + protected: + std::vector GetOpSupport() override; private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + using MaxUnpool2DGradFunc = std::function &, + const std::vector &)>; + static std::vector> func_list_; + MaxUnpool2DGradFunc kernel_func_; + + template void OutPutInitKernel(DATA_T *rawOutput, size_t length); CNodeWeakPtr node_wpt_; - std::vector input_shape_; - std::vector grads_shape_; - std::vector indices_shape_; - std::vector output_shape_; + ShapeVector input_shape_; + ShapeVector grads_shape_; + ShapeVector indices_shape_; + ShapeVector output_shape_; std::string data_format_; }; - -MS_REG_CPU_KERNEL_T_S(MaxUnpool2D, - KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt8), - MaxUnpool2DGradCPUKernel, uint8_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt8), - MaxUnpool2DGradCPUKernel, uint8_t, int64_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt16), - MaxUnpool2DGradCPUKernel, uint16_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt16), - MaxUnpool2DGradCPUKernel, uint16_t, int64_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt32), - MaxUnpool2DGradCPUKernel, uint32_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt32), - MaxUnpool2DGradCPUKernel, uint32_t, int64_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt64), - MaxUnpool2DGradCPUKernel, uint64_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt64), - MaxUnpool2DGradCPUKernel, uint64_t, int64_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt8), - MaxUnpool2DGradCPUKernel, int8_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt8), - MaxUnpool2DGradCPUKernel, int8_t, int64_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt16), - MaxUnpool2DGradCPUKernel, int16_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt16), - MaxUnpool2DGradCPUKernel, int16_t, int64_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - MaxUnpool2DGradCPUKernel, int32_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - MaxUnpool2DGradCPUKernel, int32_t, int64_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt64), - MaxUnpool2DGradCPUKernel, int64_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - MaxUnpool2DGradCPUKernel, int64_t, int64_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat16), - MaxUnpool2DGradCPUKernel, float16, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - MaxUnpool2DGradCPUKernel, float16, int64_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - MaxUnpool2DGradCPUKernel, float, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - MaxUnpool2DGradCPUKernel, float, int64_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat64), - MaxUnpool2DGradCPUKernel, double, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool2DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - MaxUnpool2DGradCPUKernel, double, int64_t); } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MaxUnpool2DGradGRAD_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MAXUNPOOL2DGRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.cc index dadaeb8904e..a48a5903d07 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.cc @@ -15,8 +15,10 @@ */ #include #include -#include "backend/kernel_compiler/cpu/max_unpool3d_cpu_kernel.h" -#include "runtime/device/cpu/cpu_device_address.h" +#include +#include +#include "plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" namespace mindspore { namespace kernel { @@ -29,28 +31,41 @@ constexpr size_t kInputIndex2 = 2; constexpr size_t kInputIndex3 = 3; constexpr size_t kInputIndex4 = 4; } // namespace -template -void MaxUnpool3DCPUKernel::InitKernel(const CNodePtr &kernel_node) { + +void MaxUnpool3DCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); node_wpt_ = kernel_node; input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0); indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1); output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0); - data_format_ = AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + data_format_ = common::AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + + if (AnfAlgo::IsShapesDynamic({input_shape_, indices_shape_, output_shape_})) { + return; + } + + auto kernel_attr = GetKernelAttrFromNode(kernel_node); + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list); + if (!is_match) { + MS_LOG(EXCEPTION) << "MaxUnpool3D does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; } -template -void MaxUnpool3DCPUKernel::OutPutInitKernel(DATA_T *raw_output, size_t length) { +template +void MaxUnpool3DCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) { for (size_t s = 0; s < length; s++) { raw_output[s] = (DATA_T)0; } } template -bool MaxUnpool3DCPUKernel::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { +bool MaxUnpool3DCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { auto node = node_wpt_.lock(); if (!node) { MS_LOG(EXCEPTION) << "node_wpt_ is expired."; @@ -66,15 +81,15 @@ bool MaxUnpool3DCPUKernel::Launch(const std::vector(outputs[kInputIndex0]->addr); size_t num_batch = input_shape_[kInputIndex0]; if (data_format_ == "NDHWC") { - size_t input_depth = input_shape_[kInputIndex1]; - size_t input_height = input_shape_[kInputIndex2]; - size_t input_width = input_shape_[kInputIndex3]; - size_t num_channels = input_shape_[kInputIndex4]; - size_t odepth = output_shape_[kInputIndex1]; - size_t oheight = output_shape_[kInputIndex2]; - size_t owidth = output_shape_[kInputIndex3]; + size_t input_depth = LongToSize(input_shape_[kInputIndex1]); + size_t input_height = LongToSize(input_shape_[kInputIndex2]); + size_t input_width = LongToSize(input_shape_[kInputIndex3]); + size_t num_channels = LongToSize(input_shape_[kInputIndex4]); + size_t odepth = LongToSize(output_shape_[kInputIndex1]); + size_t oheight = LongToSize(output_shape_[kInputIndex2]); + size_t owidth = LongToSize(output_shape_[kInputIndex3]); size_t length = num_batch * odepth * oheight * owidth * num_channels; - OutPutInitKernel(raw_output, length); + OutPutInitKernel(raw_output, length); for (size_t n = 0; n < num_batch; n++) { size_t noutput_offset = n * num_channels * odepth * owidth * oheight; size_t n_input_offset = n * num_channels * input_depth * input_width * input_height; @@ -105,15 +120,15 @@ bool MaxUnpool3DCPUKernel::Launch(const std::vector(raw_output, length); for (size_t n = 0; n < num_batch; n++) { size_t noutput_offset = n * num_channels * odepth * owidth * oheight; size_t n_input_offset = n * num_channels * input_depth * input_width * input_height; @@ -148,5 +163,60 @@ bool MaxUnpool3DCPUKernel::Launch(const std::vector> MaxUnpool3DCpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + &MaxUnpool3DCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &MaxUnpool3DCpuKernelMod::LaunchKernel}}; + +std::vector MaxUnpool3DCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool3D, MaxUnpool3DCpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.h index a59862b6a0e..587a0e980b1 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_cpu_kernel.h @@ -19,117 +19,43 @@ #include #include #include -#include "backend/kernel_compiler/cpu/cpu_kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { -template -class MaxUnpool3DCPUKernel : public CPUKernel { +class MaxUnpool3DCpuKernelMod : public DeprecatedNativeCpuKernelMod { public: - MaxUnpool3DCPUKernel() = default; - ~MaxUnpool3DCPUKernel() override = default; + MaxUnpool3DCpuKernelMod() = default; + ~MaxUnpool3DCpuKernelMod() override = default; void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + }; + + protected: + std::vector GetOpSupport() override; private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + using MaxUnpool3DFunc = std::function &, + const std::vector &)>; + static std::vector> func_list_; + MaxUnpool3DFunc kernel_func_; + + template void OutPutInitKernel(DATA_T *rawOutput, size_t length); CNodeWeakPtr node_wpt_; - std::vector input_shape_; - std::vector indices_shape_; - std::vector output_shape_; + ShapeVector input_shape_; + ShapeVector indices_shape_; + ShapeVector output_shape_; std::string data_format_; }; - -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), - MaxUnpool3DCPUKernel, uint8_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), - MaxUnpool3DCPUKernel, uint8_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), - MaxUnpool3DCPUKernel, uint16_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), - MaxUnpool3DCPUKernel, uint16_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), - MaxUnpool3DCPUKernel, uint32_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), - MaxUnpool3DCPUKernel, uint32_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), - MaxUnpool3DCPUKernel, uint64_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), - MaxUnpool3DCPUKernel, uint64_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), - MaxUnpool3DCPUKernel, int8_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), - MaxUnpool3DCPUKernel, int8_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), - MaxUnpool3DCPUKernel, int16_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), - MaxUnpool3DCPUKernel, int16_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - MaxUnpool3DCPUKernel, int32_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - MaxUnpool3DCPUKernel, int32_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), - MaxUnpool3DCPUKernel, int64_t, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - MaxUnpool3DCPUKernel, int64_t, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - MaxUnpool3DCPUKernel, float16, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - MaxUnpool3DCPUKernel, float16, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - MaxUnpool3DCPUKernel, float, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - MaxUnpool3DCPUKernel, float, int64_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - MaxUnpool3DCPUKernel, double, int32_t); -MS_REG_CPU_KERNEL_T_S( - MaxUnpool3D, - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - MaxUnpool3DCPUKernel, double, int64_t); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.cc index 3bfc3959f3c..53fdcd830d4 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.cc @@ -15,8 +15,10 @@ */ #include #include -#include "backend/kernel_compiler/cpu/max_unpool3d_grad_cpu_kernel.h" -#include "runtime/device/cpu/cpu_device_address.h" +#include +#include +#include "plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" namespace mindspore { namespace kernel { @@ -30,29 +32,41 @@ constexpr size_t kInputIndex3 = 3; constexpr size_t kInputIndex4 = 4; } // namespace -template -void MaxUnpool3DGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { +void MaxUnpool3DGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); node_wpt_ = kernel_node; input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex0); grads_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex1); indices_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, kInputIndex2); output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, kInputIndex0); - data_format_ = AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + data_format_ = common::AnfAlgo::GetNodeAttr(kernel_node, FORMAT); + + if (AnfAlgo::IsShapesDynamic({input_shape_, grads_shape_, indices_shape_, output_shape_})) { + return; + } + + auto kernel_attr = GetKernelAttrFromNode(kernel_node); + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list); + if (!is_match) { + MS_LOG(EXCEPTION) << "MaxUnpool3DGrad does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; } -template -void MaxUnpool3DGradCPUKernel::OutPutInitKernel(DATA_T *raw_output, size_t length) { +template +void MaxUnpool3DGradCpuKernelMod::OutPutInitKernel(DATA_T *raw_output, size_t length) { for (size_t s = 0; s < length; s++) { raw_output[s] = (DATA_T)0; } } template -bool MaxUnpool3DGradCPUKernel::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { +bool MaxUnpool3DGradCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { auto node = node_wpt_.lock(); if (!node) { MS_LOG(EXCEPTION) << "node_wpt_ is expired."; @@ -66,17 +80,17 @@ bool MaxUnpool3DGradCPUKernel::Launch(const std::vector(inputs[kInputIndex1]->addr); auto *raw_indices = reinterpret_cast(inputs[kInputIndex2]->addr); auto *raw_output = reinterpret_cast(outputs[kInputIndex0]->addr); - size_t num_batch = grads_shape_[kInputIndex0]; + auto num_batch = LongToSize(grads_shape_[kInputIndex0]); if (data_format_ == "NDHWC") { - size_t odepth = grads_shape_[kInputIndex1]; - size_t oheight = grads_shape_[kInputIndex2]; - size_t owidth = grads_shape_[kInputIndex3]; - size_t num_channels = grads_shape_[kInputIndex4]; - size_t idepth = output_shape_[kInputIndex1]; - size_t iheight = output_shape_[kInputIndex2]; - size_t iwidth = output_shape_[kInputIndex3]; + size_t odepth = LongToSize(grads_shape_[kInputIndex1]); + size_t oheight = LongToSize(grads_shape_[kInputIndex2]); + size_t owidth = LongToSize(grads_shape_[kInputIndex3]); + size_t num_channels = LongToSize(grads_shape_[kInputIndex4]); + size_t idepth = LongToSize(output_shape_[kInputIndex1]); + size_t iheight = LongToSize(output_shape_[kInputIndex2]); + size_t iwidth = LongToSize(output_shape_[kInputIndex3]); size_t length = num_batch * iheight * iwidth * idepth * num_channels; - OutPutInitKernel(raw_output, length); + OutPutInitKernel(raw_output, length); for (size_t n = 0; n < num_batch; n++) { size_t noutput_offset = n * num_channels * iwidth * iheight * idepth; size_t n_grads_offset = n * num_channels * owidth * oheight * odepth; @@ -106,15 +120,15 @@ bool MaxUnpool3DGradCPUKernel::Launch(const std::vector(raw_output, length); for (size_t n = 0; n < num_batch; n++) { size_t noutput_offset = n * num_channels * iwidth * iheight * idepth; size_t n_grads_offset = n * num_channels * owidth * oheight * odepth; @@ -149,5 +163,148 @@ bool MaxUnpool3DGradCPUKernel::Launch(const std::vector> + MaxUnpool3DGradCpuKernelMod::func_list_ = {{KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt8), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt8), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt16), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt16), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt32), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt32), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt64), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt64), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt8), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt8), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt16), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &MaxUnpool3DGradCpuKernelMod::LaunchKernel}}; + +std::vector MaxUnpool3DGradCpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MaxUnpool3DGrad, MaxUnpool3DGradCpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.h index c1e94d58d97..efeb24b9cbc 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/max_unpool3d_grad_cpu_kernel.h @@ -19,196 +19,44 @@ #include #include #include -#include "backend/kernel_compiler/cpu/cpu_kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { -template -class MaxUnpool3DGradCPUKernel : public CPUKernel { +class MaxUnpool3DGradCpuKernelMod : public DeprecatedNativeCpuKernelMod { public: - MaxUnpool3DGradCPUKernel() = default; - ~MaxUnpool3DGradCPUKernel() override = default; + MaxUnpool3DGradCpuKernelMod() = default; + ~MaxUnpool3DGradCpuKernelMod() override = default; void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override { + return kernel_func_(this, inputs, outputs); + }; + + protected: + std::vector GetOpSupport() override; private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + using MaxUnpool3DGradFunc = std::function &, + const std::vector &)>; + static std::vector> func_list_; + MaxUnpool3DGradFunc kernel_func_; + + template void OutPutInitKernel(DATA_T *rawOutput, size_t length); CNodeWeakPtr node_wpt_; - std::vector input_shape_; - std::vector grads_shape_; - std::vector indices_shape_; - std::vector output_shape_; + ShapeVector input_shape_; + ShapeVector grads_shape_; + ShapeVector indices_shape_; + ShapeVector output_shape_; std::string data_format_; }; - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt8), - MaxUnpool3DGradCPUKernel, uint8_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt8), - MaxUnpool3DGradCPUKernel, uint8_t, int64_t); - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt16), - MaxUnpool3DGradCPUKernel, uint16_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt16), - MaxUnpool3DGradCPUKernel, uint16_t, int64_t); - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt32), - MaxUnpool3DGradCPUKernel, uint32_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt32), - MaxUnpool3DGradCPUKernel, uint32_t, int64_t); - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt64), - MaxUnpool3DGradCPUKernel, uint64_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeUInt64), - MaxUnpool3DGradCPUKernel, uint64_t, int64_t); - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt8), - MaxUnpool3DGradCPUKernel, int8_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt8), - MaxUnpool3DGradCPUKernel, int8_t, int64_t); - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt16), - MaxUnpool3DGradCPUKernel, int16_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt16), - MaxUnpool3DGradCPUKernel, int16_t, int64_t); - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - MaxUnpool3DGradCPUKernel, int32_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt32), - MaxUnpool3DGradCPUKernel, int32_t, int64_t); - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt64), - MaxUnpool3DGradCPUKernel, int64_t, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeInt64), - MaxUnpool3DGradCPUKernel, int64_t, int64_t); - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat16), - MaxUnpool3DGradCPUKernel, float16, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - MaxUnpool3DGradCPUKernel, float16, int64_t); - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - MaxUnpool3DGradCPUKernel, float, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - MaxUnpool3DGradCPUKernel, float, int64_t); - -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat64), - MaxUnpool3DGradCPUKernel, double, int32_t); -MS_REG_CPU_KERNEL_T_S(MaxUnpool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - MaxUnpool3DGradCPUKernel, double, int64_t); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/multi_margin_loss_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/multi_margin_loss_cpu_kernel.cc index 2465532d443..15483eff2dc 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/multi_margin_loss_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/multi_margin_loss_cpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * specific language governing permissions and limitations under the License. */ -#include "backend/kernel_compiler/cpu/multi_margin_loss_cpu_kernel.h" -#include "runtime/device/cpu/cpu_device_address.h" +#include "plugin/device/cpu/kernel/multi_margin_loss_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" namespace mindspore { namespace kernel { @@ -23,24 +23,30 @@ namespace { constexpr size_t kMultiMarginLossInputNumWithWeight = 3; constexpr size_t kMultiMarginLossInputNumWithoutWeight = 2; constexpr size_t kMultiMarginLossOutputsNum = 1; +const size_t kZero = 0; +const size_t kOne = 1; +const size_t kTwo = 2; constexpr char kKernelName[] = "MultiMarginLoss"; } // namespace -void MultiMarginLossCPUKernel::InitKernel(const CNodePtr &kernel_node) { +void MultiMarginLossCPUKernelMod::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); - std::vector x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - batch_size = x_shape[0]; - dims = x_shape[1]; - reduction = AnfAlgo::GetNodeAttr(kernel_node, REDUCTION); - p = AnfAlgo::GetNodeAttr(kernel_node, "p"); - margin = AnfAlgo::GetNodeAttr(kernel_node, "margin"); - dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); - input_num = AnfAlgo::GetInputTensorNum(kernel_node); + ShapeVector x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kZero); + if (IsDynamic({x_shape})) { + return; + } + batch_size = LongToSize(x_shape[kZero]); + dims = LongToSize(x_shape[kOne]); + reduction = common::AnfAlgo::GetNodeAttr(kernel_node, REDUCTION); + p = common::AnfAlgo::GetNodeAttr(kernel_node, "p"); + margin = common::AnfAlgo::GetNodeAttr(kernel_node, "margin"); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kZero); + input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); } -bool MultiMarginLossCPUKernel::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { +bool MultiMarginLossCPUKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { if (dtype_ == kNumberTypeFloat16) { LaunchKernelFP16(inputs, outputs); } else if (dtype_ == kNumberTypeFloat32) { @@ -54,10 +60,10 @@ bool MultiMarginLossCPUKernel::Launch(const std::vector &inp } template -void MultiMarginLossCPUKernel::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - auto x_addr = reinterpret_cast(inputs[0]->addr); - auto target_addr = reinterpret_cast(inputs[1]->addr); +void MultiMarginLossCPUKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto x_addr = reinterpret_cast(inputs[kZero]->addr); + auto target_addr = reinterpret_cast(inputs[kOne]->addr); for (size_t i = 0; i < batch_size; i++) { if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) { MS_EXCEPTION(ValueError) << "Target out of range."; @@ -66,9 +72,9 @@ void MultiMarginLossCPUKernel::LaunchKernel(const std::vector(inputs[2]->addr); + weight_addr = reinterpret_cast(inputs[kTwo]->addr); } - auto y_addr = reinterpret_cast(outputs[0]->addr); + auto y_addr = reinterpret_cast(outputs[kZero]->addr); std::vector tmp_loss(batch_size); auto task = [&](size_t start, size_t end) { start *= dims; @@ -117,10 +123,10 @@ void MultiMarginLossCPUKernel::LaunchKernel(const std::vector -void MultiMarginLossCPUKernel::LaunchKernelFP16(const std::vector &inputs, - const std::vector &outputs) { - auto x_addr = reinterpret_cast(inputs[0]->addr); - auto target_addr = reinterpret_cast(inputs[1]->addr); +void MultiMarginLossCPUKernelMod::LaunchKernelFP16(const std::vector &inputs, + const std::vector &outputs) { + auto x_addr = reinterpret_cast(inputs[kZero]->addr); + auto target_addr = reinterpret_cast(inputs[kOne]->addr); for (size_t i = 0; i < batch_size; i++) { if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) { MS_EXCEPTION(ValueError) << "Target out of range."; @@ -129,9 +135,9 @@ void MultiMarginLossCPUKernel::LaunchKernelFP16(const std::vector(inputs[2]->addr); + weight_addr = reinterpret_cast(inputs[kTwo]->addr); } - auto y_addr = reinterpret_cast(outputs[0]->addr); + auto y_addr = reinterpret_cast(outputs[kZero]->addr); std::vector tmp_loss(batch_size); auto task = [&](size_t start, size_t end) { start *= dims; @@ -180,13 +186,15 @@ void MultiMarginLossCPUKernel::LaunchKernelFP16(const std::vector #include #include -#include "backend/kernel_compiler/cpu/cpu_kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { -class MultiMarginLossCPUKernel : public CPUKernel { +class MultiMarginLossCPUKernelMod : public DeprecatedNativeCpuKernelMod { public: - MultiMarginLossCPUKernel() = default; + MultiMarginLossCPUKernelMod() = default; - ~MultiMarginLossCPUKernel() override = default; + ~MultiMarginLossCPUKernelMod() override = default; void InitKernel(const CNodePtr &kernel_node) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; + protected: + std::vector GetOpSupport() override { + static std::vector support_list = { + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64)}; + return support_list; + } + + private: template void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - template void LaunchKernelFP16(const std::vector &inputs, const std::vector &outputs); - private: void CheckParam(const CNodePtr &kernel_node); size_t batch_size = 2; size_t dims = 1; @@ -52,45 +75,6 @@ class MultiMarginLossCPUKernel : public CPUKernel { size_t input_num = 1; TypeId dtype_{kTypeUnknown}; }; - -MS_REG_CPU_KERNEL(MultiMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - MultiMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(MultiMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - MultiMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(MultiMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64), - MultiMarginLossCPUKernel); - -MS_REG_CPU_KERNEL( - MultiMarginLoss, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - MultiMarginLossCPUKernel); - -MS_REG_CPU_KERNEL( - MultiMarginLoss, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - MultiMarginLossCPUKernel); - -MS_REG_CPU_KERNEL( - MultiMarginLoss, - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - MultiMarginLossCPUKernel); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MULTI_MARGIN_LOSS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/multi_margin_loss_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/multi_margin_loss_grad_cpu_kernel.cc index cb729e39647..16b5363d8a8 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/multi_margin_loss_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/multi_margin_loss_grad_cpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "backend/kernel_compiler/cpu/multi_margin_loss_grad_cpu_kernel.h" -#include "runtime/device/cpu/cpu_device_address.h" +#include "plugin/device/cpu/kernel/multi_margin_loss_grad_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" namespace mindspore { namespace kernel { @@ -23,25 +23,32 @@ namespace { constexpr size_t kMultiMarginLossGradInputNumWithWeight = 4; constexpr size_t kMultiMarginLossGradInputNumWithoutWeight = 3; constexpr size_t kMultiMarginLossGradOutputsNum = 1; +const size_t kZero = 0; +const size_t kOne = 1; +const size_t kTwo = 2; +const size_t kThree = 3; constexpr char kKernelName[] = "MultiMarginLossGrad"; } // namespace -void MultiMarginLossGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { +void MultiMarginLossGradCPUKernelMod::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); - std::vector x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - batch_size = x_shape[0]; - dims = x_shape[1]; - reduction = AnfAlgo::GetNodeAttr(kernel_node, REDUCTION); - p = AnfAlgo::GetNodeAttr(kernel_node, "p"); - margin = AnfAlgo::GetNodeAttr(kernel_node, "margin"); - y_grad_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size(); - dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); - input_num = AnfAlgo::GetInputTensorNum(kernel_node); + ShapeVector x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kOne); + if (IsDynamic({x_shape})) { + return; + } + batch_size = LongToSize(x_shape[kZero]); + dims = LongToSize(x_shape[kOne]); + reduction = common::AnfAlgo::GetNodeAttr(kernel_node, REDUCTION); + p = common::AnfAlgo::GetNodeAttr(kernel_node, "p"); + margin = common::AnfAlgo::GetNodeAttr(kernel_node, "margin"); + y_grad_dims = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kZero).size(); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kZero); + input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); } -bool MultiMarginLossGradCPUKernel::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { +bool MultiMarginLossGradCPUKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { if (dtype_ == kNumberTypeFloat16) { LaunchKernelFP16(inputs, outputs); } else if (dtype_ == kNumberTypeFloat32) { @@ -55,11 +62,11 @@ bool MultiMarginLossGradCPUKernel::Launch(const std::vector } template -void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - auto y_grad_addr = reinterpret_cast(inputs[0]->addr); - auto x_addr = reinterpret_cast(inputs[1]->addr); - auto target_addr = reinterpret_cast(inputs[2]->addr); +void MultiMarginLossGradCPUKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto y_grad_addr = reinterpret_cast(inputs[kZero]->addr); + auto x_addr = reinterpret_cast(inputs[kOne]->addr); + auto target_addr = reinterpret_cast(inputs[kTwo]->addr); for (size_t i = 0; i < batch_size; i++) { if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) { MS_EXCEPTION(ValueError) << "Target out of range."; @@ -68,12 +75,9 @@ void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector(inputs[3]->addr); + weight_addr = reinterpret_cast(inputs[kThree]->addr); } - auto x_grad_addr = reinterpret_cast(outputs[0]->addr); - T weights; - weights = reduction == MEAN ? (static_cast(1) / (static_cast(dims) * static_cast(batch_size))) - : (static_cast(1) / static_cast(dims)); + auto x_grad_addr = reinterpret_cast(outputs[kZero]->addr); auto task = [&](size_t start, size_t end) { start *= dims; end *= dims; @@ -91,6 +95,8 @@ void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector static_cast(0)) { + auto weights = reduction == MEAN ? (static_cast(1) / (static_cast(dims) * static_cast(batch_size))) + : (static_cast(1) / static_cast(dims)); calc_data[d] = (p == 1) ? weights : static_cast(2) * weights * calc_data[d]; if (weight_defined_) { calc_data[d] *= static_cast(weight_addr[target_idx]); @@ -122,11 +128,11 @@ void MultiMarginLossGradCPUKernel::LaunchKernel(const std::vector -void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector &inputs, - const std::vector &outputs) { - auto y_grad_addr = reinterpret_cast(inputs[0]->addr); - auto x_addr = reinterpret_cast(inputs[1]->addr); - auto target_addr = reinterpret_cast(inputs[2]->addr); +void MultiMarginLossGradCPUKernelMod::LaunchKernelFP16(const std::vector &inputs, + const std::vector &outputs) { + auto y_grad_addr = reinterpret_cast(inputs[kZero]->addr); + auto x_addr = reinterpret_cast(inputs[kOne]->addr); + auto target_addr = reinterpret_cast(inputs[kTwo]->addr); for (size_t i = 0; i < batch_size; i++) { if (target_addr[i] < 0 || target_addr[i] >= SizeToLong(dims)) { MS_EXCEPTION(ValueError) << "Target out of range."; @@ -135,12 +141,9 @@ void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector(inputs[3]->addr); + weight_addr = reinterpret_cast(inputs[kThree]->addr); } - auto x_grad_addr = reinterpret_cast(outputs[0]->addr); - float weights; - weights = reduction == MEAN ? (static_cast(1) / (static_cast(dims) * static_cast(batch_size))) - : (static_cast(1) / static_cast(dims)); + auto x_grad_addr = reinterpret_cast(outputs[kZero]->addr); auto task = [&](size_t start, size_t end) { start *= dims; end *= dims; @@ -158,6 +161,9 @@ void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector static_cast(0)) { + auto weights = reduction == MEAN + ? (static_cast(1) / (static_cast(dims) * static_cast(batch_size))) + : (static_cast(1) / static_cast(dims)); calc_data[d] = (p == 1) ? weights : static_cast(2) * weights * calc_data[d]; if (weight_defined_) { calc_data[d] *= static_cast(weight_addr[target_idx]); @@ -189,13 +195,15 @@ void MultiMarginLossGradCPUKernel::LaunchKernelFP16(const std::vector #include #include -#include "backend/kernel_compiler/cpu/cpu_kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { -class MultiMarginLossGradCPUKernel : public CPUKernel { +class MultiMarginLossGradCPUKernelMod : public DeprecatedNativeCpuKernelMod { public: - MultiMarginLossGradCPUKernel() = default; + MultiMarginLossGradCPUKernelMod() = default; - ~MultiMarginLossGradCPUKernel() override = default; + ~MultiMarginLossGradCPUKernelMod() override = default; void InitKernel(const CNodePtr &kernel_node) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; + protected: + std::vector GetOpSupport() override { + static std::vector support_list = {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64)}; + return support_list; + } + + private: template void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - template void LaunchKernelFP16(const std::vector &inputs, const std::vector &outputs); - private: void CheckParam(const CNodePtr &kernel_node); size_t batch_size = 2; size_t dims = 1; @@ -53,57 +90,6 @@ class MultiMarginLossGradCPUKernel : public CPUKernel { size_t y_grad_dims = 1; TypeId dtype_{kTypeUnknown}; }; - -MS_REG_CPU_KERNEL(MultiMarginLossGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - MultiMarginLossGradCPUKernel); - -MS_REG_CPU_KERNEL(MultiMarginLossGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - MultiMarginLossGradCPUKernel); - -MS_REG_CPU_KERNEL(MultiMarginLossGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64), - MultiMarginLossGradCPUKernel); - -MS_REG_CPU_KERNEL(MultiMarginLossGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat16), - MultiMarginLossGradCPUKernel); - -MS_REG_CPU_KERNEL(MultiMarginLossGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat32), - MultiMarginLossGradCPUKernel); - -MS_REG_CPU_KERNEL(MultiMarginLossGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeInt64) - .AddOutputAttr(kNumberTypeFloat64), - MultiMarginLossGradCPUKernel); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MULTI_MARGIN_LOSS_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_cpu_kernel.cc index fca2007870a..8d04637b56a 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_cpu_kernel.cc @@ -91,4 +91,4 @@ std::vector MvlgammaCpuKernelMod::GetOpSupport() { } MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Mvlgamma, MvlgammaCpuKernelMod); } // namespace kernel -} // namespace mindspore \ No newline at end of file +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_cpu_kernel.h index 3d75c97034d..63b25dd8ad1 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_cpu_kernel.h @@ -24,7 +24,6 @@ namespace mindspore { namespace kernel { - class MvlgammaCpuKernelMod : public DeprecatedNativeCpuKernelMod { public: MvlgammaCpuKernelMod() = default; @@ -51,8 +50,6 @@ class MvlgammaCpuKernelMod : public DeprecatedNativeCpuKernelMod { protected: std::vector GetOpSupport() override; }; - } // namespace kernel } // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_CPU_KERNEL_H_ \ No newline at end of file +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.cc index 707a0d1c15f..5482018d0be 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.cc @@ -152,6 +152,5 @@ std::vector MvlgammaGradCpuKernelMod::GetOpSupport() { } MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MvlgammaGrad, MvlgammaGradCpuKernelMod); - } // namespace kernel -} // namespace mindspore \ No newline at end of file +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.h index 7e4889a0bdf..856dc323ac7 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mvlgamma_grad_cpu_kernel.h @@ -24,7 +24,6 @@ namespace mindspore { namespace kernel { - class MvlgammaGradCpuKernelMod : public DeprecatedNativeCpuKernelMod { public: MvlgammaGradCpuKernelMod() = default; @@ -54,8 +53,6 @@ class MvlgammaGradCpuKernelMod : public DeprecatedNativeCpuKernelMod { protected: std::vector GetOpSupport() override; }; - } // namespace kernel } // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_GRAD_CPU_KERNEL_H_ \ No newline at end of file +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MVLGAMMA_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.cc index 929c27c6040..15b15a9324d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,32 +14,35 @@ * limitations under the License. */ -#include "backend/kernel_compiler/cpu/triplet_margin_loss_cpu_kernel.h" -#include "runtime/device/cpu/cpu_device_address.h" +#include "plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" namespace mindspore { namespace kernel { -void TripletMarginLossCPUKernel::InitKernel(const CNodePtr &kernel_node) { +void TripletMarginLossCPUKernelMod::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); - constexpr int kzero = 0; - constexpr int kone = 1; - constexpr int ktwo = 2; - constexpr int kthree = 3; + constexpr int kZero = 0; + constexpr int kOne = 1; + constexpr int kTwo = 2; + constexpr int kThree = 3; constexpr int kParallel = 28; constexpr int kParallelunit = 1024; - p = AnfAlgo::GetNodeAttr(kernel_node, "p"); - swap = AnfAlgo::GetNodeAttr(kernel_node, "swap"); - eps = AnfAlgo::GetNodeAttr(kernel_node, "eps"); - reduction = AnfAlgo::GetNodeAttr(kernel_node, "reduction"); - dtype_0 = AnfAlgo::GetInputDeviceDataType(kernel_node, kzero); - dtype_1 = AnfAlgo::GetInputDeviceDataType(kernel_node, kone); - dtype_2 = AnfAlgo::GetInputDeviceDataType(kernel_node, ktwo); - dtype_3 = AnfAlgo::GetInputDeviceDataType(kernel_node, kthree); - x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kzero); - positive_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kone); - negative_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, ktwo); + p = common::AnfAlgo::GetNodeAttr(kernel_node, "p"); + swap = common::AnfAlgo::GetNodeAttr(kernel_node, "swap"); + eps = common::AnfAlgo::GetNodeAttr(kernel_node, "eps"); + reduction = common::AnfAlgo::GetNodeAttr(kernel_node, "reduction"); + dtype_0 = AnfAlgo::GetInputDeviceDataType(kernel_node, kZero); + dtype_1 = AnfAlgo::GetInputDeviceDataType(kernel_node, kOne); + dtype_2 = AnfAlgo::GetInputDeviceDataType(kernel_node, kTwo); + dtype_3 = AnfAlgo::GetInputDeviceDataType(kernel_node, kThree); + x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kZero); + positive_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kOne); + negative_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kTwo); + if (AnfAlgo::IsShapesDynamic({x_shape, positive_shape, negative_shape})) { + return; + } kParallelDataNum = kParallel * kParallelunit; - std::vector broadcast_shape_x_and_positive = CPUKernelUtils::GetBroadcastShape(x_shape, positive_shape); + auto broadcast_shape_x_and_positive = CPUKernelUtils::GetBroadcastShape(x_shape, positive_shape); broadcast_shape = CPUKernelUtils::GetBroadcastShape(broadcast_shape_x_and_positive, negative_shape); size_t dim_x = x_shape.size(); size_t dim_positive = positive_shape.size(); @@ -51,30 +54,28 @@ void TripletMarginLossCPUKernel::InitKernel(const CNodePtr &kernel_node) { std::reverse(x_reshape_vector.begin(), x_reshape_vector.end()); std::reverse(positive_reshape_vector.begin(), positive_reshape_vector.end()); std::reverse(negative_reshape_vector.begin(), negative_reshape_vector.end()); - if (dim_x < max_size) x_reshape_vector.resize(max_size, kone); - if (dim_positive < max_size) positive_reshape_vector.resize(max_size, kone); - if (dim_negative < max_size) negative_reshape_vector.resize(max_size, kone); + if (dim_x < max_size) x_reshape_vector.resize(max_size, kOne); + if (dim_positive < max_size) positive_reshape_vector.resize(max_size, kOne); + if (dim_negative < max_size) negative_reshape_vector.resize(max_size, kOne); std::reverse(x_reshape_vector.begin(), x_reshape_vector.end()); std::reverse(positive_reshape_vector.begin(), positive_reshape_vector.end()); std::reverse(negative_reshape_vector.begin(), negative_reshape_vector.end()); - numelements = 1; - for (size_t i = 0; i < broadcast_shape.size(); i++) { - numelements *= broadcast_shape[i]; - } - data_num = (numelements) / (broadcast_shape[1]); - data_num_each_batch = (numelements) / (broadcast_shape[0]); - index = data_num / (broadcast_shape[0]); - batch_size = broadcast_shape[0]; - once_compute_size = broadcast_shape[1]; + numelements = LongToSize(SizeOf(broadcast_shape)); + + data_num = (numelements) / LongToSize(broadcast_shape[1]); + data_num_each_batch = (numelements) / LongToSize(broadcast_shape[0]); + index = data_num / LongToSize(broadcast_shape[0]); + batch_size = LongToSize(broadcast_shape[0]); + once_compute_size = LongToSize(broadcast_shape[1]); broadcast = false; if (x_shape != positive_shape || x_shape != negative_shape || positive_shape != negative_shape) { broadcast = true; } } -bool TripletMarginLossCPUKernel::Launch(const std::vector &inputs, - const std::vector &, - const std::vector &outputs) { +bool TripletMarginLossCPUKernelMod::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { switch (dtype_0) { case kNumberTypeFloat16: TripletMarginLossCompute_realtype(inputs, outputs); @@ -123,73 +124,23 @@ bool TripletMarginLossCPUKernel::Launch(const std::vector &i } template -void TripletMarginLossCPUKernel::TripletMarginLossCompute_realtype(const std::vector &inputs, - const std::vector &outputs) { - auto out_data = reinterpret_cast(outputs[0]->addr); - Eigen::Array out(data_num, 1); - float *output_reduction_none_data = reinterpret_cast(out.data()); - auto task_nobroadcast = [&](size_t start, size_t end) { - TripletMarginLossCPUKernel::realtype_nobroadcast_task(start, end, output_reduction_none_data, inputs, outputs); - }; - auto task_broadcast = [&](size_t start, size_t end) { - TripletMarginLossCPUKernel::realtype_broadcast_task(start, end, output_reduction_none_data, inputs, outputs); - }; - if (broadcast == true) { - if (numelements * sizeof(T) > kParallelDataNum) { - CPUKernelUtils::ParallelFor(task_broadcast, batch_size); - } else { - TripletMarginLossCPUKernel::realtype_broadcast_compute(output_reduction_none_data, inputs, outputs); - } - if (reduction == NONE) { - for (size_t i = 0; i < data_num; i++) { - *(out_data + i) = *(output_reduction_none_data + i); - } - } - if (reduction == MEAN) { - *(out_data) = (out.mean()); - } - if (reduction == SUM) { - *(out_data) = (out.sum()); - } - return; - } - if (numelements * sizeof(T) > kParallelDataNum) { - CPUKernelUtils::ParallelFor(task_nobroadcast, batch_size); - } else { - TripletMarginLossCPUKernel::realtype_nobroadcast_compute(output_reduction_none_data, inputs, outputs); - } - if (reduction == NONE) { - for (size_t i = 0; i < data_num; i++) { - *(out_data + i) = *(output_reduction_none_data + i); - } - } - if (reduction == MEAN) { - *(out_data) = (out.mean()); - } - if (reduction == SUM) { - *(out_data) = (out.sum()); - } - return; -} - -template -void TripletMarginLossCPUKernel::TripletMarginLossCompute_complextype(const std::vector &inputs, +void TripletMarginLossCPUKernelMod::TripletMarginLossCompute_realtype(const std::vector &inputs, const std::vector &outputs) { auto out_data = reinterpret_cast(outputs[0]->addr); Eigen::Array out(data_num, 1); float *output_reduction_none_data = reinterpret_cast(out.data()); auto task_nobroadcast = [&](size_t start, size_t end) { - TripletMarginLossCPUKernel::complextype_nobroadcast_task(start, end, output_reduction_none_data, inputs, + TripletMarginLossCPUKernelMod::realtype_nobroadcast_task(start, end, output_reduction_none_data, inputs, outputs); }; auto task_broadcast = [&](size_t start, size_t end) { - TripletMarginLossCPUKernel::complextype_broadcast_task(start, end, output_reduction_none_data, inputs, outputs); + TripletMarginLossCPUKernelMod::realtype_broadcast_task(start, end, output_reduction_none_data, inputs, outputs); }; if (broadcast == true) { if (numelements * sizeof(T) > kParallelDataNum) { CPUKernelUtils::ParallelFor(task_broadcast, batch_size); } else { - TripletMarginLossCPUKernel::complextype_broadcast_compute(output_reduction_none_data, inputs, outputs); + TripletMarginLossCPUKernelMod::realtype_broadcast_compute(output_reduction_none_data, inputs, outputs); } if (reduction == NONE) { for (size_t i = 0; i < data_num; i++) { @@ -207,7 +158,7 @@ void TripletMarginLossCPUKernel::TripletMarginLossCompute_complextype(const std: if (numelements * sizeof(T) > kParallelDataNum) { CPUKernelUtils::ParallelFor(task_nobroadcast, batch_size); } else { - TripletMarginLossCPUKernel::complextype_nobroadcast_compute(output_reduction_none_data, inputs, outputs); + TripletMarginLossCPUKernelMod::realtype_nobroadcast_compute(output_reduction_none_data, inputs, outputs); } if (reduction == NONE) { for (size_t i = 0; i < data_num; i++) { @@ -224,9 +175,62 @@ void TripletMarginLossCPUKernel::TripletMarginLossCompute_complextype(const std: } template -void TripletMarginLossCPUKernel::realtype_nobroadcast_task(size_t start, size_t end, float *output_reduction_none_data, - const std::vector &inputs, - const std::vector &outputs) { +void TripletMarginLossCPUKernelMod::TripletMarginLossCompute_complextype( + const std::vector &inputs, const std::vector &outputs) { + auto out_data = reinterpret_cast(outputs[0]->addr); + Eigen::Array out(data_num, 1); + float *output_reduction_none_data = reinterpret_cast(out.data()); + auto task_nobroadcast = [&](size_t start, size_t end) { + TripletMarginLossCPUKernelMod::complextype_nobroadcast_task(start, end, output_reduction_none_data, inputs, + outputs); + }; + auto task_broadcast = [&](size_t start, size_t end) { + TripletMarginLossCPUKernelMod::complextype_broadcast_task(start, end, output_reduction_none_data, inputs, + outputs); + }; + if (broadcast == true) { + if (numelements * sizeof(T) > kParallelDataNum) { + CPUKernelUtils::ParallelFor(task_broadcast, batch_size); + } else { + TripletMarginLossCPUKernelMod::complextype_broadcast_compute(output_reduction_none_data, inputs, outputs); + } + if (reduction == NONE) { + for (size_t i = 0; i < data_num; i++) { + *(out_data + i) = *(output_reduction_none_data + i); + } + } + if (reduction == MEAN) { + *(out_data) = (out.mean()); + } + if (reduction == SUM) { + *(out_data) = (out.sum()); + } + return; + } + if (numelements * sizeof(T) > kParallelDataNum) { + CPUKernelUtils::ParallelFor(task_nobroadcast, batch_size); + } else { + TripletMarginLossCPUKernelMod::complextype_nobroadcast_compute(output_reduction_none_data, inputs, outputs); + } + if (reduction == NONE) { + for (size_t i = 0; i < data_num; i++) { + *(out_data + i) = *(output_reduction_none_data + i); + } + } + if (reduction == MEAN) { + *(out_data) = (out.mean()); + } + if (reduction == SUM) { + *(out_data) = (out.sum()); + } + return; +} + +template +void TripletMarginLossCPUKernelMod::realtype_nobroadcast_task(size_t start, size_t end, + float *output_reduction_none_data, + const std::vector &inputs, + const std::vector &outputs) { auto x_addr = reinterpret_cast(inputs[0]->addr); auto positive_addr = reinterpret_cast(inputs[1]->addr); auto negative_addr = reinterpret_cast(inputs[2]->addr); @@ -298,9 +302,9 @@ void TripletMarginLossCPUKernel::realtype_nobroadcast_task(size_t start, size_t } template -void TripletMarginLossCPUKernel::realtype_broadcast_task(size_t start, size_t end, float *output_reduction_none_data, - const std::vector &inputs, - const std::vector &outputs) { +void TripletMarginLossCPUKernelMod::realtype_broadcast_task(size_t start, size_t end, float *output_reduction_none_data, + const std::vector &inputs, + const std::vector &outputs) { auto x_addr = reinterpret_cast(inputs[0]->addr); auto positive_addr = reinterpret_cast(inputs[1]->addr); auto negative_addr = reinterpret_cast(inputs[2]->addr); @@ -349,8 +353,8 @@ void TripletMarginLossCPUKernel::realtype_broadcast_task(size_t start, size_t en } calc_1_sum += calculate_positive[k]; calc_2_sum += calculate_negative[k]; - TripletMarginLossCPUKernel::realtype_swap(start, positive_broadcast, negative_broadcast, calculate_swap, j, - k, calc_swap_sum, inputs, outputs); + TripletMarginLossCPUKernelMod::realtype_swap(start, positive_broadcast, negative_broadcast, calculate_swap, + j, k, calc_swap_sum, inputs, outputs); } positive_distance = std::pow(static_cast(calc_1_sum), (1 / static_cast(p))); if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) { @@ -375,9 +379,9 @@ void TripletMarginLossCPUKernel::realtype_broadcast_task(size_t start, size_t en } template -void TripletMarginLossCPUKernel::realtype_broadcast_compute(float *output_reduction_none_data, - const std::vector &inputs, - const std::vector &outputs) { +void TripletMarginLossCPUKernelMod::realtype_broadcast_compute(float *output_reduction_none_data, + const std::vector &inputs, + const std::vector &outputs) { auto x_addr = reinterpret_cast(inputs[0]->addr); auto positive_addr = reinterpret_cast(inputs[1]->addr); auto negative_addr = reinterpret_cast(inputs[2]->addr); @@ -424,8 +428,8 @@ void TripletMarginLossCPUKernel::realtype_broadcast_compute(float *output_reduct } calc_1_sum += calculate_positive[k]; calc_2_sum += calculate_negative[k]; - TripletMarginLossCPUKernel::realtype_swap(i * data_num_each_batch, positive_broadcast, negative_broadcast, - calculate_swap, j, k, calc_swap_sum, inputs, outputs); + TripletMarginLossCPUKernelMod::realtype_swap(i * data_num_each_batch, positive_broadcast, negative_broadcast, + calculate_swap, j, k, calc_swap_sum, inputs, outputs); } positive_distance = std::pow(static_cast(calc_1_sum), (1 / static_cast(p))); if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) { @@ -449,9 +453,9 @@ void TripletMarginLossCPUKernel::realtype_broadcast_compute(float *output_reduct } template -void TripletMarginLossCPUKernel::realtype_nobroadcast_compute(float *output_reduction_none_data, - const std::vector &inputs, - const std::vector &outputs) { +void TripletMarginLossCPUKernelMod::realtype_nobroadcast_compute(float *output_reduction_none_data, + const std::vector &inputs, + const std::vector &outputs) { auto x_addr = reinterpret_cast(inputs[0]->addr); auto positive_addr = reinterpret_cast(inputs[1]->addr); auto negative_addr = reinterpret_cast(inputs[2]->addr); @@ -518,10 +522,10 @@ void TripletMarginLossCPUKernel::realtype_nobroadcast_compute(float *output_redu } template -void TripletMarginLossCPUKernel::complextype_nobroadcast_task(size_t start, size_t end, - float *output_reduction_none_data, - const std::vector &inputs, - const std::vector &outputs) { +void TripletMarginLossCPUKernelMod::complextype_nobroadcast_task(size_t start, size_t end, + float *output_reduction_none_data, + const std::vector &inputs, + const std::vector &outputs) { auto x_addr = reinterpret_cast(inputs[0]->addr); auto positive_addr = reinterpret_cast(inputs[1]->addr); auto negative_addr = reinterpret_cast(inputs[2]->addr); @@ -575,9 +579,10 @@ void TripletMarginLossCPUKernel::complextype_nobroadcast_task(size_t start, size } template -void TripletMarginLossCPUKernel::complextype_broadcast_task(size_t start, size_t end, float *output_reduction_none_data, - const std::vector &inputs, - const std::vector &outputs) { +void TripletMarginLossCPUKernelMod::complextype_broadcast_task(size_t start, size_t end, + float *output_reduction_none_data, + const std::vector &inputs, + const std::vector &outputs) { auto x_addr = reinterpret_cast(inputs[0]->addr); auto positive_addr = reinterpret_cast(inputs[1]->addr); auto negative_addr = reinterpret_cast(inputs[2]->addr); @@ -630,8 +635,8 @@ void TripletMarginLossCPUKernel::complextype_broadcast_task(size_t start, size_t } calc_1_sum += calculate_positive_float; calc_2_sum += calculate_negative_float; - TripletMarginLossCPUKernel::complextype_swap(start, positive_broadcast, negative_broadcast, calculate_swap, - j, k, calc_swap_sum, inputs, outputs); + TripletMarginLossCPUKernelMod::complextype_swap(start, positive_broadcast, negative_broadcast, + calculate_swap, j, k, calc_swap_sum, inputs, outputs); } positive_distance = std::pow(static_cast(calc_1_sum), (1 / static_cast(p))); if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) { @@ -656,9 +661,9 @@ void TripletMarginLossCPUKernel::complextype_broadcast_task(size_t start, size_t } template -void TripletMarginLossCPUKernel::complextype_broadcast_compute(float *output_reduction_none_data, - const std::vector &inputs, - const std::vector &outputs) { +void TripletMarginLossCPUKernelMod::complextype_broadcast_compute(float *output_reduction_none_data, + const std::vector &inputs, + const std::vector &outputs) { auto x_addr = reinterpret_cast(inputs[0]->addr); auto positive_addr = reinterpret_cast(inputs[1]->addr); auto negative_addr = reinterpret_cast(inputs[2]->addr); @@ -707,8 +712,9 @@ void TripletMarginLossCPUKernel::complextype_broadcast_compute(float *output_red } calc_1_sum += calculate_positive_float; calc_2_sum += calculate_negative_float; - TripletMarginLossCPUKernel::complextype_swap(i * data_num_each_batch, positive_broadcast, negative_broadcast, - calculate_swap, j, k, calc_swap_sum, inputs, outputs); + TripletMarginLossCPUKernelMod::complextype_swap(i * data_num_each_batch, positive_broadcast, + negative_broadcast, calculate_swap, j, k, calc_swap_sum, + inputs, outputs); } positive_distance = std::pow(static_cast(calc_1_sum), (1 / static_cast(p))); if (x_reshape_vector[1] == 1 && positive_reshape_vector[1] == 1 && broadcast_shape[1] != 1) { @@ -732,9 +738,9 @@ void TripletMarginLossCPUKernel::complextype_broadcast_compute(float *output_red } template -void TripletMarginLossCPUKernel::complextype_nobroadcast_compute(float *output_reduction_none_data, - const std::vector &inputs, - const std::vector &outputs) { +void TripletMarginLossCPUKernelMod::complextype_nobroadcast_compute(float *output_reduction_none_data, + const std::vector &inputs, + const std::vector &outputs) { auto x_addr = reinterpret_cast(inputs[0]->addr); auto positive_addr = reinterpret_cast(inputs[1]->addr); auto negative_addr = reinterpret_cast(inputs[2]->addr); @@ -786,11 +792,11 @@ void TripletMarginLossCPUKernel::complextype_nobroadcast_compute(float *output_r } template -void TripletMarginLossCPUKernel::realtype_swap(size_t start, std::vector &positive_broadcast, - std::vector &negative_broadcast, std::vector &calculate_swap, - size_t j, size_t k, float &calc_swap_sum, - const std::vector &inputs, - const std::vector &outputs) { +void TripletMarginLossCPUKernelMod::realtype_swap(size_t start, std::vector &positive_broadcast, + std::vector &negative_broadcast, + std::vector &calculate_swap, size_t j, size_t k, + float &calc_swap_sum, const std::vector &inputs, + const std::vector &outputs) { if (swap == true) { calculate_swap[k] = abs(static_cast(positive_broadcast[start + j + k * index]) - static_cast(negative_broadcast[start + j + k * index]) + eps); @@ -803,11 +809,11 @@ void TripletMarginLossCPUKernel::realtype_swap(size_t start, std::vector &pos } template -void TripletMarginLossCPUKernel::complextype_swap(size_t start, std::vector &positive_broadcast, - std::vector &negative_broadcast, std::vector &calculate_swap, - size_t j, size_t k, float &calc_swap_sum, - const std::vector &inputs, - const std::vector &outputs) { +void TripletMarginLossCPUKernelMod::complextype_swap(size_t start, std::vector &positive_broadcast, + std::vector &negative_broadcast, std::vector &calculate_swap, + size_t j, size_t k, float &calc_swap_sum, + const std::vector &inputs, + const std::vector &outputs) { if (swap == true) { calculate_swap[k] = positive_broadcast[start + j + k * index] - negative_broadcast[start + j + k * index] + static_cast(eps); @@ -821,17 +827,19 @@ void TripletMarginLossCPUKernel::complextype_swap(size_t start, std::vector & } } -void TripletMarginLossCPUKernel::CheckParam(const CNodePtr &kernel_node) { - auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); - constexpr int kone = 1; +void TripletMarginLossCPUKernelMod::CheckParam(const CNodePtr &kernel_node) { + auto input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); + constexpr int kOne = 1; constexpr int kfour = 4; if (input_num != kfour) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but TripletMarginLossCPUKernel needs 4 inputs."; + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but TripletMarginLossCPUKernelMod needs 4 inputs."; } - auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != kone) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but TripletMarginLossCPUKernel needs 1 output."; + auto output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != kOne) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but TripletMarginLossCPUKernelMod needs 1 output."; } } + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TripletMarginLoss, TripletMarginLossCPUKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.h index d052da227ea..c76f09c785d 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/triplet_margin_loss_cpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,21 +25,99 @@ #include #include #include -#include "backend/kernel_compiler/cpu/cpu_kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { -class TripletMarginLossCPUKernel : public CPUKernel { +class TripletMarginLossCPUKernelMod : public DeprecatedNativeCpuKernelMod { public: - TripletMarginLossCPUKernel() = default; - ~TripletMarginLossCPUKernel() override = default; + TripletMarginLossCPUKernelMod() = default; + ~TripletMarginLossCPUKernelMod() override = default; void InitKernel(const CNodePtr &kernel_node) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; + protected: + std::vector GetOpSupport() override { + static std::vector support_list = {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32)}; + return support_list; + } + + private: template void LaunchKernel(const std::vector &inputs, const std::vector &outputs); @@ -96,7 +174,6 @@ class TripletMarginLossCPUKernel : public CPUKernel { std::vector &calculate_swap, size_t j, size_t k, float &calc_swap_sum, const std::vector &inputs, const std::vector &outputs); - private: void CheckParam(const CNodePtr &kernel_node); int64_t p = 2; bool swap = false; @@ -109,13 +186,13 @@ class TripletMarginLossCPUKernel : public CPUKernel { TypeId dtype_1{kTypeUnknown}; TypeId dtype_2{kTypeUnknown}; TypeId dtype_3{kTypeUnknown}; - std::vector x_shape; - std::vector positive_shape; - std::vector negative_shape; - std::vector broadcast_shape; - std::vector x_reshape_vector; - std::vector positive_reshape_vector; - std::vector negative_reshape_vector; + ShapeVector x_shape; + ShapeVector positive_shape; + ShapeVector negative_shape; + ShapeVector broadcast_shape; + ShapeVector x_reshape_vector; + ShapeVector positive_reshape_vector; + ShapeVector negative_reshape_vector; size_t numelements = 1; size_t data_num = 1; size_t data_num_each_batch = 1; @@ -124,114 +201,6 @@ class TripletMarginLossCPUKernel : public CPUKernel { size_t once_compute_size = 1; bool broadcast = false; }; - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeInt8) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeInt16) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeInt64) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeUInt8) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeUInt16) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeUInt32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeUInt64) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat64) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeComplex64) - .AddInputAttr(kNumberTypeComplex64) - .AddInputAttr(kNumberTypeComplex64) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); - -MS_REG_CPU_KERNEL(TripletMarginLoss, - KernelAttr() - .AddInputAttr(kNumberTypeComplex128) - .AddInputAttr(kNumberTypeComplex128) - .AddInputAttr(kNumberTypeComplex128) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - TripletMarginLossCPUKernel); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRIPLET_MARGIN_LOSS_CPU_KERNEL_H_ diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 46dd29d40f9..771a1123a50 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -592,7 +592,7 @@ GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradWithArgmax, std::make_shared(" GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradGradWithArgmax, std::make_shared("MaxPoolGradGradWithArgmax")); GVAR_DEF(PrimitivePtr, kPrimMaxPool3DWithArgmax, std::make_shared("MaxPool3DWithArgmax")); GVAR_DEF(PrimitivePtr, kPrimMaxPool3DGradWithArgmax, std::make_shared("MaxPool3DGradWithArgmax")); -GVAR_DEF(PrimitivePtr, kPrimMaxUnpool2D, std::make_shared(kMaxUnpool2DGrad)); +GVAR_DEF(PrimitivePtr, kPrimMaxUnpool2D, std::make_shared(kMaxUnpool2D)); GVAR_DEF(PrimitivePtr, kPrimMaxUnpool2DGrad, std::make_shared(kMaxUnpool2DGrad)); GVAR_DEF(PrimitivePtr, kPrimMaxUnpool3D, std::make_shared(kMaxUnpool3D)); GVAR_DEF(PrimitivePtr, kPrimMaxUnpool3DGrad, std::make_shared(kMaxUnpool3DGrad)); diff --git a/mindspore/core/ops/grad/max_unpool2d_grad.cc b/mindspore/core/ops/grad/max_unpool2d_grad.cc index 4ff58591fa0..d7d735884e2 100644 --- a/mindspore/core/ops/grad/max_unpool2d_grad.cc +++ b/mindspore/core/ops/grad/max_unpool2d_grad.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,13 @@ #include "ops/grad/max_unpool2d_grad.h" #include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + namespace mindspore { namespace ops { -constexpr int64_t k4DInputDims = 4; namespace { abstract::ShapePtr MaxUnpool2DGradInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { @@ -34,11 +38,10 @@ abstract::ShapePtr MaxUnpool2DGradInferShape(const PrimitivePtr &primitive, auto grads_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape]; auto argmax_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape]; - (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k4DInputDims, op_name); - (void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, k4DInputDims, op_name); - (void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k4DInputDims, - op_name); - CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name); + (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim4, op_name); + (void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, kDim4, op_name); + (void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim4, op_name); + CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError); auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; return std::make_shared(x1_shape); } @@ -59,6 +62,7 @@ TypePtr MaxUnpool2DGradInferType(const PrimitivePtr &primitive, const std::vecto } } // namespace +MIND_API_OPERATOR_IMPL(MaxUnpool2DGrad, BaseOperator); AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); diff --git a/mindspore/core/ops/grad/max_unpool2d_grad.h b/mindspore/core/ops/grad/max_unpool2d_grad.h index fae69cee737..0474f4fde1d 100644 --- a/mindspore/core/ops/grad/max_unpool2d_grad.h +++ b/mindspore/core/ops/grad/max_unpool2d_grad.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,23 +19,20 @@ #include #include -#include "ops/primitive_c.h" -#include "ops/op_utils.h" -#include "abstract/abstract_value.h" -#include "utils/check_convert_utils.h" +#include "ops/base_operator.h" +#include "mindapi/base/types.h" namespace mindspore { namespace ops { constexpr auto kNameMaxUnpool2DGrad = "MaxUnpool2DGrad"; -class MS_CORE_API MaxUnpool2DGrad : public PrimitiveC { +class MIND_API MaxUnpool2DGrad : public BaseOperator { public: - MaxUnpool2DGrad() : PrimitiveC(kNameMaxUnpool2DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); } - ~MaxUnpool2DGrad() = default; - MS_DECLARE_PARENT(MaxUnpool2DGrad, PrimitiveC); + MIND_API_BASE_MEMBER(MaxUnpool2DGrad); + MaxUnpool2DGrad() : BaseOperator(kNameMaxUnpool2DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); } }; -AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); +abstract::AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); using PrimMaxUnpool2DGradPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/max_unpool3d_grad.cc b/mindspore/core/ops/grad/max_unpool3d_grad.cc index 3f9447adc47..10ffc40530d 100644 --- a/mindspore/core/ops/grad/max_unpool3d_grad.cc +++ b/mindspore/core/ops/grad/max_unpool3d_grad.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,13 @@ #include "ops/grad/max_unpool3d_grad.h" #include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + namespace mindspore { namespace ops { -constexpr int64_t k5DInputDims = 5; - namespace { abstract::ShapePtr MaxUnpool3DGradInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { @@ -35,11 +38,10 @@ abstract::ShapePtr MaxUnpool3DGradInferShape(const PrimitivePtr &primitive, auto grads_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape]; auto argmax_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape]; - (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k5DInputDims, op_name); - (void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, k5DInputDims, op_name); - (void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k5DInputDims, - op_name); - CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name); + (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim5, op_name); + (void)CheckAndConvertUtils::CheckInteger("grads_rank", SizeToLong(grads_shape.size()), kEqual, kDim5, op_name); + (void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim5, op_name); + CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError); auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; return std::make_shared(x1_shape); } @@ -60,6 +62,7 @@ TypePtr MaxUnpool3DGradInferType(const PrimitivePtr &primitive, const std::vecto } } // namespace +MIND_API_OPERATOR_IMPL(MaxUnpool3DGrad, BaseOperator); AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); diff --git a/mindspore/core/ops/grad/max_unpool3d_grad.h b/mindspore/core/ops/grad/max_unpool3d_grad.h index dc20a00fd7f..463710baa30 100644 --- a/mindspore/core/ops/grad/max_unpool3d_grad.h +++ b/mindspore/core/ops/grad/max_unpool3d_grad.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,23 +19,20 @@ #include #include -#include "ops/primitive_c.h" -#include "ops/op_utils.h" -#include "abstract/abstract_value.h" -#include "utils/check_convert_utils.h" +#include "ops/base_operator.h" +#include "mindapi/base/types.h" namespace mindspore { namespace ops { constexpr auto kNameMaxUnpool3DGrad = "MaxUnpool3DGrad"; -class MS_CORE_API MaxUnpool3DGrad : public PrimitiveC { +class MIND_API MaxUnpool3DGrad : public BaseOperator { public: - MaxUnpool3DGrad() : PrimitiveC(kNameMaxUnpool3DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); } - ~MaxUnpool3DGrad() = default; - MS_DECLARE_PARENT(MaxUnpool3DGrad, PrimitiveC); + MIND_API_BASE_MEMBER(MaxUnpool3DGrad); + MaxUnpool3DGrad() : BaseOperator(kNameMaxUnpool3DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); } }; -AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); +abstract::AbstractBasePtr MaxUnpool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); using PrimMaxUnpool3DGradPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/multi_margin_loss_grad.cc b/mindspore/core/ops/grad/multi_margin_loss_grad.cc index aba2f27d645..89d68db56a3 100644 --- a/mindspore/core/ops/grad/multi_margin_loss_grad.cc +++ b/mindspore/core/ops/grad/multi_margin_loss_grad.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,16 +15,14 @@ */ #include "ops/grad/multi_margin_loss_grad.h" -#include "abstract/primitive_infer_map.h" #include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" namespace mindspore { namespace ops { namespace { -const size_t kone = 1; -const size_t ktwo = 2; -const size_t kfour = 4; - TypePtr MultiMarginLossGradInferType(const PrimitivePtr &prim, const std::vector &input_args) { (void)CheckAndConvertUtils::CheckTensorTypeValid("target", input_args[kInputIndex2]->BuildType(), {kInt64}, prim->name()); @@ -32,7 +30,7 @@ TypePtr MultiMarginLossGradInferType(const PrimitivePtr &prim, const std::vector std::map types; (void)types.emplace("y_grad", input_args[kInputIndex0]->BuildType()); (void)types.emplace("x", input_args[kInputIndex1]->BuildType()); - if (input_args.size() == kfour && input_args[kInputIndex3]->BuildType()->isa()) { + if (input_args.size() == kDim4 && input_args[kInputIndex3]->BuildType()->isa()) { auto tensor_type = input_args[kInputIndex3]->BuildType()->cast(); MS_EXCEPTION_IF_NULL(tensor_type); auto element = tensor_type->element(); @@ -50,7 +48,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive, auto prim_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; auto target_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; - if (x_shape.size() != ktwo || target_shape.size() != kone) { + if (x_shape.size() != kDim2 || target_shape.size() != kDim1) { MS_EXCEPTION(ValueError) << "For MultiMarginLossGrad, the rank of input x should be 2, and " "the rank of target should be 1," << " while rank of x is " << x_shape.size() << ", rank of target is " @@ -61,7 +59,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive, << " while x_shape[0] is " << x_shape[kInputIndex0] << ", target_shape[0] is " << target_shape[kInputIndex0]; } - if (input_args.size() == kfour && input_args[kInputIndex3]->BuildType()->isa()) { + if (input_args.size() == kDim4 && input_args[kInputIndex3]->BuildType()->isa()) { auto tensor_type = input_args[kInputIndex3]->BuildType()->cast(); MS_EXCEPTION_IF_NULL(tensor_type); auto element = tensor_type->element(); @@ -69,7 +67,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive, if (element->type_id() != kMetaTypeNone) { auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; - if (weight_shape.size() != kone) { + if (weight_shape.size() != kDim1) { MS_EXCEPTION(ValueError) << "For " << prim_name << " the rank of weight should be 1," << " but get " << weight_shape.size(); } @@ -84,6 +82,7 @@ abstract::ShapePtr MultiMarginLossGradInferShape(const PrimitivePtr &primitive, } } // namespace +MIND_API_OPERATOR_IMPL(MultiMarginLossGrad, BaseOperator); AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { constexpr size_t kInputNumWithWeight = 4; @@ -96,7 +95,7 @@ AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, co MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]); MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]); - if (input_args.size() == kfour) { + if (input_args.size() == kInputNumWithWeight) { MS_EXCEPTION_IF_NULL(input_args[kInputIndex3]); } auto types = MultiMarginLossGradInferType(primitive, input_args); diff --git a/mindspore/core/ops/grad/multi_margin_loss_grad.h b/mindspore/core/ops/grad/multi_margin_loss_grad.h index 01738a32788..8eb02505974 100644 --- a/mindspore/core/ops/grad/multi_margin_loss_grad.h +++ b/mindspore/core/ops/grad/multi_margin_loss_grad.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,24 +23,23 @@ #include #include #include -#include "abstract/abstract_value.h" -#include "ops/primitive_c.h" -#include "utils/check_convert_utils.h" + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" namespace mindspore { namespace ops { constexpr auto kNameMultiMarginLossGrad = "MultiMarginLossGrad"; -class MS_CORE_API MultiMarginLossGrad : public PrimitiveC { +class MIND_API MultiMarginLossGrad : public BaseOperator { public: - MultiMarginLossGrad() : PrimitiveC(kNameMultiMarginLossGrad) { + MIND_API_BASE_MEMBER(MultiMarginLossGrad); + MultiMarginLossGrad() : BaseOperator(kNameMultiMarginLossGrad) { InitIOName({"y_grad", "x", "target", "weight"}, {"x_grad"}); } - ~MultiMarginLossGrad() = default; - MS_DECLARE_PARENT(MultiMarginLossGrad, PrimitiveC); }; -AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); +abstract::AbstractBasePtr MultiMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); using PrimMultiMarginLossGradPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/multilabel_margin_loss_grad.cc b/mindspore/core/ops/grad/multilabel_margin_loss_grad.cc index 7397c73f5f3..bd5c5168d61 100644 --- a/mindspore/core/ops/grad/multilabel_margin_loss_grad.cc +++ b/mindspore/core/ops/grad/multilabel_margin_loss_grad.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,9 @@ #include "ops/grad/multilabel_margin_loss_grad.h" #include "ops/op_utils.h" -#include "utils/tensor_construct_utils.h" -#include "abstract/primitive_infer_map.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" namespace mindspore { namespace ops { @@ -27,9 +28,7 @@ abstract::ShapePtr MultilabelMarginLossGradInferShape(const PrimitivePtr &primit auto op_name = primitive->name(); auto x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; - const size_t kone = 1; - const size_t ktwo = 2; - if ((x.size() != kone && x.size() != ktwo) || (target.size() != kone && target.size() != ktwo)) { + if ((x.size() != kDim1 && x.size() != kDim2) || (target.size() != kDim1 && target.size() != kDim2)) { MS_EXCEPTION(ValueError) << "For " << op_name << ", the rank of input x and target should be 1 or 2, " << "while rank of x is : " << x.size() << ", rank of target is : " << target.size() << "."; } @@ -57,6 +56,7 @@ TypePtr MultilabelMarginLossGradInferType(const PrimitivePtr &primitive, } } // namespace +MIND_API_OPERATOR_IMPL(MultilabelMarginLossGrad, BaseOperator); AbstractBasePtr MultilabelMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); diff --git a/mindspore/core/ops/grad/multilabel_margin_loss_grad.h b/mindspore/core/ops/grad/multilabel_margin_loss_grad.h index 64c1f745a6b..fa3a994180b 100644 --- a/mindspore/core/ops/grad/multilabel_margin_loss_grad.h +++ b/mindspore/core/ops/grad/multilabel_margin_loss_grad.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,24 +23,24 @@ #include #include #include -#include "abstract/abstract_value.h" -#include "ops/primitive_c.h" -#include "utils/check_convert_utils.h" + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" namespace mindspore { namespace ops { constexpr auto kNameMultilabelMarginLossGrad = "MultilabelMarginLossGrad"; -class MS_CORE_API MultilabelMarginLossGrad : public PrimitiveC { +class MIND_API MultilabelMarginLossGrad : public BaseOperator { public: - MultilabelMarginLossGrad() : PrimitiveC(kNameMultilabelMarginLossGrad) { + MIND_API_BASE_MEMBER(MultilabelMarginLossGrad); + MultilabelMarginLossGrad() : BaseOperator(kNameMultilabelMarginLossGrad) { InitIOName({"y_grad", "x", "target", "is_target"}, {"x_grad"}); } - ~MultilabelMarginLossGrad() = default; - MS_DECLARE_PARENT(MultilabelMarginLossGrad, PrimitiveC); }; -AbstractBasePtr MultilabelMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); +abstract::AbstractBasePtr MultilabelMarginLossGradInfer(const abstract::AnalysisEnginePtr &, + const PrimitivePtr &primitive, + const std::vector &input_args); using PrimMultilabelMarginLossGradPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/mvlgamma_grad.cc b/mindspore/core/ops/grad/mvlgamma_grad.cc index a8aea8e98f0..a4bfab2fd80 100644 --- a/mindspore/core/ops/grad/mvlgamma_grad.cc +++ b/mindspore/core/ops/grad/mvlgamma_grad.cc @@ -27,19 +27,18 @@ namespace mindspore { namespace ops { namespace { -abstract::ShapePtr MvlgammaGradInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { +abstract::ShapePtr MvlgammaGradInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto y_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + auto y_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; return std::make_shared(y_grad_shape); } TypePtr MvlgammaGradInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); std::map types; - (void)types.emplace("y_grad", input_args[0]->BuildType()); - (void)types.emplace("x", input_args[1]->BuildType()); + (void)types.emplace("y_grad", input_args[kInputIndex0]->BuildType()); + (void)types.emplace("x", input_args[kInputIndex1]->BuildType()); const std::set valid_types = {kFloat32, kFloat64}; return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); } diff --git a/mindspore/core/ops/grad/mvlgamma_grad.h b/mindspore/core/ops/grad/mvlgamma_grad.h index 0129fa7b074..2bb4a57944c 100644 --- a/mindspore/core/ops/grad/mvlgamma_grad.h +++ b/mindspore/core/ops/grad/mvlgamma_grad.h @@ -32,7 +32,7 @@ class MIND_API MvlgammaGrad : public BaseOperator { }; abstract::AbstractBasePtr MvlgammaGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); + const std::vector &input_args); using PrimMvlgammaGradPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/max_unpool2d.cc b/mindspore/core/ops/max_unpool2d.cc index 0f34bde8ac6..2afe021513c 100644 --- a/mindspore/core/ops/max_unpool2d.cc +++ b/mindspore/core/ops/max_unpool2d.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,17 +20,17 @@ #include #include "ops/op_utils.h" #include "utils/check_convert_utils.h" -#include "abstract/primitive_infer_map.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" namespace mindspore { namespace ops { namespace { -constexpr int64_t k4DInputDims = 4; - -abstract::ShapePtr InferShapeCompute(const std::string &data_format, const ShapeVector &in_shape, - const std::vector &ksize, const std::vector &strides, - const std::vector &pads, const std::vector &attr_output_shape, - const std::string &op_name) { +abstract::ShapePtr MaxUnpool2DInferShapeCompute(const std::string &data_format, const ShapeVector &in_shape, + const std::vector &ksize, const std::vector &strides, + const std::vector &pads, + const std::vector &attr_output_shape, + const std::string &op_name) { if (data_format == "NCHW") { int64_t out_h = static_cast((in_shape[kInputIndex2] - 1) * strides[kInputIndex2] - 2 * pads[kInputIndex2] + ksize[kInputIndex2]); @@ -40,7 +40,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { MS_LOG(EXCEPTION) << "MaxUnpool2D: Output size is not valid."; } - if (attr_output_shape.size() == k4DInputDims) { + if (attr_output_shape.size() == kDim4) { (void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual, in_shape[kInputIndex0], op_name); (void)CheckAndConvertUtils::CheckInteger("output_shape[1]", attr_output_shape[kInputIndex1], kEqual, @@ -74,7 +74,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { MS_LOG(EXCEPTION) << "MaxUnpool2D: Output size is not valid."; } - if (attr_output_shape.size() == k4DInputDims) { + if (attr_output_shape.size() == kDim4) { (void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual, in_shape[kInputIndex0], op_name); (void)CheckAndConvertUtils::CheckInteger("output_shape[3]", attr_output_shape[kInputIndex3], kEqual, @@ -100,7 +100,8 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape } } -abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { +abstract::ShapePtr MaxUnpool2DInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); @@ -109,28 +110,27 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorGetShapeTrack())[kShape]; auto data_format = GetValue(primitive->GetAttr("format")); - (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k4DInputDims, op_name); - (void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k4DInputDims, - op_name); - CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name); + (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim4, op_name); + (void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim4, op_name); + CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError); auto ksize = GetValue>(primitive->GetAttr("ksize")); auto strides = GetValue>(primitive->GetAttr("strides")); auto pads = GetValue>(primitive->GetAttr("pads")); auto attr_output_shape = GetValue>(primitive->GetAttr("output_shape")); - (void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, k4DInputDims, op_name); - (void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, k4DInputDims, op_name); - (void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, k4DInputDims, op_name); + (void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, kDim4, op_name); + (void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, kDim4, op_name); + (void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, kDim4, op_name); - if (attr_output_shape.size() != k4DInputDims && attr_output_shape.size() != 0) { + if (attr_output_shape.size() != kDim4 && attr_output_shape.size() != kDim0) { MS_EXCEPTION(ValueError) << "MaxUnpool2D: Output_shape size must be 0 or 4."; } - return InferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name); + return MaxUnpool2DInferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name); } -TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { +TypePtr MaxUnpool2DInferType(const PrimitivePtr &prim, const std::vector &input_args) { for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -143,13 +143,14 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } } // namespace +MIND_API_OPERATOR_IMPL(MaxUnpool2D, BaseOperator); AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); const int64_t input_num = 2; CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); - auto infer_type = InferType(primitive, input_args); - auto infer_shape = InferShape(primitive, input_args); + auto infer_type = MaxUnpool2DInferType(primitive, input_args); + auto infer_shape = MaxUnpool2DInferShape(primitive, input_args); return abstract::MakeAbstract(infer_shape, infer_type); } REGISTER_PRIMITIVE_EVAL_IMPL(MaxUnpool2D, prim::kPrimMaxUnpool2D, MaxUnpool2DInfer, nullptr, true); diff --git a/mindspore/core/ops/max_unpool2d.h b/mindspore/core/ops/max_unpool2d.h index e9bf6e0c299..db16b21914a 100644 --- a/mindspore/core/ops/max_unpool2d.h +++ b/mindspore/core/ops/max_unpool2d.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,23 +19,20 @@ #include #include -#include "ops/primitive_c.h" -#include "ops/op_utils.h" -#include "abstract/abstract_value.h" -#include "utils/check_convert_utils.h" +#include "ops/base_operator.h" +#include "mindapi/base/types.h" namespace mindspore { namespace ops { constexpr auto kNameMaxUnpool2D = "MaxUnpool2D"; -class MS_CORE_API MaxUnpool2D : public PrimitiveC { +class MIND_API MaxUnpool2D : public BaseOperator { public: - MaxUnpool2D() : PrimitiveC(kNameMaxUnpool2D) { InitIOName({"x", "argmax"}, {"y"}); } - ~MaxUnpool2D() = default; - MS_DECLARE_PARENT(MaxUnpool2D, PrimitiveC); + MIND_API_BASE_MEMBER(MaxUnpool2D); + MaxUnpool2D() : BaseOperator(kNameMaxUnpool2D) { InitIOName({"x", "argmax"}, {"y"}); } }; -AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); +abstract::AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); using PrimMaxUnpool2DPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/max_unpool3d.cc b/mindspore/core/ops/max_unpool3d.cc index 7eae3f02842..03609142253 100644 --- a/mindspore/core/ops/max_unpool3d.cc +++ b/mindspore/core/ops/max_unpool3d.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,17 +20,17 @@ #include #include "ops/op_utils.h" #include "utils/check_convert_utils.h" -#include "abstract/primitive_infer_map.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" namespace mindspore { namespace ops { namespace { -constexpr int64_t k5DInputDims = 5; - -abstract::ShapePtr InferShapeCompute(const std::string &data_format, const ShapeVector &in_shape, - const std::vector &ksize, const std::vector &strides, - const std::vector &pads, const std::vector &attr_output_shape, - const std::string &op_name) { +abstract::ShapePtr MaxUnpool3DInferShapeCompute(const std::string &data_format, const ShapeVector &in_shape, + const std::vector &ksize, const std::vector &strides, + const std::vector &pads, + const std::vector &attr_output_shape, + const std::string &op_name) { if (data_format == "NCDHW") { int64_t out_d = static_cast((in_shape[kInputIndex2] - 1) * strides[kInputIndex2] - 2 * pads[kInputIndex2] + ksize[kInputIndex2]); @@ -42,7 +42,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { MS_LOG(EXCEPTION) << "MaxUnpool3D: Output size is not valid."; } - if (attr_output_shape.size() == k5DInputDims) { + if (attr_output_shape.size() == kDim5) { (void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual, in_shape[kInputIndex0], op_name); (void)CheckAndConvertUtils::CheckInteger("output_shape[1]", attr_output_shape[kInputIndex1], kEqual, @@ -79,7 +79,7 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { MS_LOG(EXCEPTION) << "MaxUnpool3D: Output size is not valid."; } - if (attr_output_shape.size() == k5DInputDims) { + if (attr_output_shape.size() == kDim5) { (void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual, in_shape[kInputIndex0], op_name); (void)CheckAndConvertUtils::CheckInteger("output_shape[4]", attr_output_shape[kInputIndex4], kEqual, @@ -108,7 +108,8 @@ abstract::ShapePtr InferShapeCompute(const std::string &data_format, const Shape } } -abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { +abstract::ShapePtr MaxUnpool3DInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); @@ -117,26 +118,25 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorGetShapeTrack())[kShape]; auto data_format = GetValue(primitive->GetAttr("format")); - (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, k5DInputDims, op_name); - (void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, k5DInputDims, - op_name); - CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, "argmax_shape", argmax_shape, op_name); + (void)CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, kDim5, op_name); + (void)CheckAndConvertUtils::CheckInteger("argmax_rank", SizeToLong(argmax_shape.size()), kEqual, kDim5, op_name); + CheckAndConvertUtils::Check("x_shape", in_shape, kEqual, argmax_shape, op_name, ValueError); auto ksize = GetValue>(primitive->GetAttr("ksize")); auto strides = GetValue>(primitive->GetAttr("strides")); auto pads = GetValue>(primitive->GetAttr("pads")); auto attr_output_shape = GetValue>(primitive->GetAttr("output_shape")); - (void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, k5DInputDims, op_name); - (void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, k5DInputDims, op_name); - (void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, k5DInputDims, op_name); + (void)CheckAndConvertUtils::CheckInteger("ksize_rank", SizeToLong(ksize.size()), kEqual, kDim5, op_name); + (void)CheckAndConvertUtils::CheckInteger("strides_rank", SizeToLong(strides.size()), kEqual, kDim5, op_name); + (void)CheckAndConvertUtils::CheckInteger("pads_rank", SizeToLong(pads.size()), kEqual, kDim5, op_name); - if (attr_output_shape.size() != k5DInputDims && attr_output_shape.size() != 0) { + if (attr_output_shape.size() != kDim5 && attr_output_shape.size() != kDim0) { MS_EXCEPTION(ValueError) << "MaxUnpool3D: Output_shape size must be 0 or 5."; } - return InferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name); + return MaxUnpool3DInferShapeCompute(data_format, in_shape, ksize, strides, pads, attr_output_shape, op_name); } -TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { +TypePtr MaxUnpool3DInferType(const PrimitivePtr &prim, const std::vector &input_args) { for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -149,13 +149,14 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } } // namespace +MIND_API_OPERATOR_IMPL(MaxUnpool3D, BaseOperator); AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); const int64_t input_num = 2; CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); - auto infer_type = InferType(primitive, input_args); - auto infer_shape = InferShape(primitive, input_args); + auto infer_type = MaxUnpool3DInferType(primitive, input_args); + auto infer_shape = MaxUnpool3DInferShape(primitive, input_args); return abstract::MakeAbstract(infer_shape, infer_type); } REGISTER_PRIMITIVE_EVAL_IMPL(MaxUnpool3D, prim::kPrimMaxUnpool3D, MaxUnpool3DInfer, nullptr, true); diff --git a/mindspore/core/ops/max_unpool3d.h b/mindspore/core/ops/max_unpool3d.h index 1cb908c1d76..0bad3287fe7 100644 --- a/mindspore/core/ops/max_unpool3d.h +++ b/mindspore/core/ops/max_unpool3d.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,23 +19,20 @@ #include #include -#include "ops/primitive_c.h" -#include "ops/op_utils.h" -#include "abstract/abstract_value.h" -#include "utils/check_convert_utils.h" +#include "ops/base_operator.h" +#include "mindapi/base/types.h" namespace mindspore { namespace ops { constexpr auto kNameMaxUnpool3D = "MaxUnpool3D"; -class MS_CORE_API MaxUnpool3D : public PrimitiveC { +class MIND_API MaxUnpool3D : public BaseOperator { public: - MaxUnpool3D() : PrimitiveC(kNameMaxUnpool3D) { InitIOName({"x", "argmax"}, {"y"}); } - ~MaxUnpool3D() = default; - MS_DECLARE_PARENT(MaxUnpool3D, PrimitiveC); + MIND_API_BASE_MEMBER(MaxUnpool3D); + MaxUnpool3D() : BaseOperator(kNameMaxUnpool3D) { InitIOName({"x", "argmax"}, {"y"}); } }; -AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); +abstract::AbstractBasePtr MaxUnpool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); using PrimMaxUnpool3DPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/multi_margin_loss.cc b/mindspore/core/ops/multi_margin_loss.cc index 8c20e05fa9c..9a12794ec34 100644 --- a/mindspore/core/ops/multi_margin_loss.cc +++ b/mindspore/core/ops/multi_margin_loss.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,21 +15,20 @@ */ #include "ops/multi_margin_loss.h" -#include "abstract/primitive_infer_map.h" #include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" namespace mindspore { namespace ops { namespace { -const size_t kone = 1; -const size_t ktwo = 2; -const size_t kthree = 3; - TypePtr MultiMarginLossInferType(const PrimitivePtr &prim, const std::vector &input_args) { - (void)CheckAndConvertUtils::CheckTensorTypeValid("target", input_args[1]->BuildType(), {kInt64}, prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeValid("target", input_args[kInputIndex1]->BuildType(), {kInt64}, + prim->name()); const std::set valid_types = {kFloat16, kFloat32, kFloat64}; std::map types; - (void)types.emplace("x", input_args[0]->BuildType()); + (void)types.emplace("x", input_args[kInputIndex0]->BuildType()); if (input_args.size() == kInputIndex3 && input_args[kInputIndex2]->BuildType()->isa()) { auto tensor_type = input_args[kInputIndex2]->BuildType()->cast(); MS_EXCEPTION_IF_NULL(tensor_type); @@ -51,7 +50,7 @@ abstract::ShapePtr MultiMarginLossInferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; auto target_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; - if (x_shape.size() != ktwo || target_shape.size() != kone) { + if (x_shape.size() != kDim2 || target_shape.size() != kDim1) { MS_EXCEPTION(ValueError) << "For MultiMarginLoss, the rank of input " "x and target should be 2 and 1," << " while rank of x is " << x_shape.size() << ", rank of target is " @@ -62,14 +61,15 @@ abstract::ShapePtr MultiMarginLossInferShape(const PrimitivePtr &primitive, << " while x_shape[0] is " << x_shape[kInputIndex0] << ", target_shape[0] is " << target_shape[kInputIndex0]; } - if (input_args.size() == kthree && input_args[kInputIndex2]->BuildType()->isa()) { + if (input_args.size() == kDim3 && input_args[kInputIndex2]->BuildType()->isa()) { auto tensor_type = input_args[kInputIndex2]->BuildType()->cast(); MS_EXCEPTION_IF_NULL(tensor_type); auto element = tensor_type->element(); MS_EXCEPTION_IF_NULL(element); if (element->type_id() != kMetaTypeNone) { - auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape]; - if (weight_shape.size() != kone) { + auto weight_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + if (weight_shape.size() != kDim1) { MS_EXCEPTION(ValueError) << "For " << prim_name << " the rank of weight should be 1," << " but get " << weight_shape.size(); } @@ -90,16 +90,17 @@ abstract::ShapePtr MultiMarginLossInferShape(const PrimitivePtr &primitive, } } // namespace +MIND_API_OPERATOR_IMPL(MultiMarginLoss, BaseOperator); AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]); - if (input_args.size() == kthree) { + if (input_args.size() == kDim3) { MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]); } (void)CheckAndConvertUtils::CheckInRange("multi_margin_loss_input_nums", input_args.size(), kIncludeBoth, - {ktwo, kthree}, primitive->name()); + {kDim2, kDim3}, primitive->name()); auto types = MultiMarginLossInferType(primitive, input_args); auto shapes = MultiMarginLossInferShape(primitive, input_args); return abstract::MakeAbstract(shapes, types); diff --git a/mindspore/core/ops/multi_margin_loss.h b/mindspore/core/ops/multi_margin_loss.h index 4aadd3fac03..cf019ca02d0 100644 --- a/mindspore/core/ops/multi_margin_loss.h +++ b/mindspore/core/ops/multi_margin_loss.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,22 +22,21 @@ #include #include #include -#include "abstract/abstract_value.h" -#include "ops/primitive_c.h" -#include "utils/check_convert_utils.h" + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" namespace mindspore { namespace ops { constexpr auto kNameMultiMarginLoss = "MultiMarginLoss"; -class MS_CORE_API MultiMarginLoss : public PrimitiveC { +class MIND_API MultiMarginLoss : public BaseOperator { public: - MultiMarginLoss() : PrimitiveC(kNameMultiMarginLoss) { InitIOName({"x", "target", "weight"}, {"y"}); } - ~MultiMarginLoss() = default; - MS_DECLARE_PARENT(MultiMarginLoss, PrimitiveC); + MIND_API_BASE_MEMBER(MultiMarginLoss); + MultiMarginLoss() : BaseOperator(kNameMultiMarginLoss) { InitIOName({"x", "target", "weight"}, {"y"}); } }; -AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); +abstract::AbstractBasePtr MultiMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); using PrimMultiMarginLossPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/multilabel_margin_loss.cc b/mindspore/core/ops/multilabel_margin_loss.cc index 914ec72b809..113b73f8ab3 100644 --- a/mindspore/core/ops/multilabel_margin_loss.cc +++ b/mindspore/core/ops/multilabel_margin_loss.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,9 @@ #include "ops/multilabel_margin_loss.h" #include "ops/op_utils.h" -#include "utils/tensor_construct_utils.h" -#include "abstract/primitive_infer_map.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" namespace mindspore { namespace ops { @@ -37,7 +38,7 @@ abstract::TupleShapePtr MultilabelMarginLossInferShape(const PrimitivePtr &primi MS_EXCEPTION(ValueError) << "For " << op_name << ", x_shape and target_shape should be the same, " << "while x_shape is : " << x << ", target_shape is : " << target << "."; } - int64_t batch = x[0]; + int64_t batch = x[kInputIndex0]; ShapeVector out_shape0 = {batch}; ShapeVector out_shape1 = target; int64_t reduction; @@ -66,6 +67,7 @@ TuplePtr MultilabelMarginLossInferType(const PrimitivePtr &primitive, const std: } } // namespace +MIND_API_OPERATOR_IMPL(MultilabelMarginLoss, BaseOperator); AbstractBasePtr MultilabelMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); diff --git a/mindspore/core/ops/multilabel_margin_loss.h b/mindspore/core/ops/multilabel_margin_loss.h index 172c8b32747..25660640175 100644 --- a/mindspore/core/ops/multilabel_margin_loss.h +++ b/mindspore/core/ops/multilabel_margin_loss.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,27 +21,24 @@ #include #include #include -#include "ops/primitive_c.h" -#include "abstract/abstract_value.h" -#include "utils/check_convert_utils.h" + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" namespace mindspore { -constexpr auto kNameMultilabelMarginLoss = prim::kMultilabelMarginLoss; +namespace ops { +constexpr auto kNameMultilabelMarginLoss = "MultilabelMarginLoss"; /// \brief Creates a criterion that optimizes a multi-class multi-classification hinge loss. /// Refer to Python API @ref mindspore.ops.MultilabelMarginLoss for more details. -class MS_CORE_API MultilabelMarginLoss : public PrimitiveC { +class MIND_API MultilabelMarginLoss : public BaseOperator { public: + MIND_API_BASE_MEMBER(MultilabelMarginLoss); /// \brief Constructor. - MultilabelMarginLoss() : PrimitiveC(kNameMultilabelMarginLoss) { InitIOName({"x", "target"}, {"y", "is_target"}); } - /// \brief Destructor. - ~MultilabelMarginLoss() = default; - MS_DECLARE_PARENT(MultilabelMarginLoss, PrimitiveC); - /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.MultilabelMarginLoss for the inputs. - void Init() const {} + MultilabelMarginLoss() : BaseOperator(kNameMultilabelMarginLoss) { InitIOName({"x", "target"}, {"y", "is_target"}); } }; -AbstractBasePtr MultilabelMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); +abstract::AbstractBasePtr MultilabelMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); using PrimMultilabelMarginLossPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/mvlgamma.cc b/mindspore/core/ops/mvlgamma.cc index fe786807e25..81e209ab3b8 100644 --- a/mindspore/core/ops/mvlgamma.cc +++ b/mindspore/core/ops/mvlgamma.cc @@ -28,14 +28,14 @@ namespace ops { namespace { abstract::ShapePtr MvlgammaInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - MS_EXCEPTION_IF_NULL(input_args[0]); - auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; + MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); + auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape]; return std::make_shared(in_shape); } TypePtr MvlgammaInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); - auto input_type = input_args[0]->BuildType(); + auto input_type = input_args[kInputIndex0]->BuildType(); const std::set valid_types = {kFloat32, kFloat64}; return CheckAndConvertUtils::CheckTensorTypeValid("x", input_type, valid_types, prim->name()); } diff --git a/mindspore/core/ops/mvlgamma.h b/mindspore/core/ops/mvlgamma.h index db00d78bc2e..0f3ae9c0b25 100644 --- a/mindspore/core/ops/mvlgamma.h +++ b/mindspore/core/ops/mvlgamma.h @@ -30,12 +30,12 @@ constexpr auto kNameMvlgamma = "Mvlgamma"; class MIND_API Mvlgamma : public BaseOperator { public: MIND_API_BASE_MEMBER(Mvlgamma); - /// \brief Constructor. + /// \brief Constructor. Mvlgamma() : BaseOperator(kNameMvlgamma) { InitIOName({"x"}, {"y"}); } }; abstract::AbstractBasePtr MvlgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); + const std::vector &input_args); using PrimMvlgammaPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/triplet_margin_loss.cc b/mindspore/core/ops/triplet_margin_loss.cc index 52006c9a5aa..da9394f05fa 100644 --- a/mindspore/core/ops/triplet_margin_loss.cc +++ b/mindspore/core/ops/triplet_margin_loss.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,13 @@ #include "ops/triplet_margin_loss.h" #include #include "ops/op_utils.h" -#include "utils/tensor_construct_utils.h" -#include "abstract/primitive_infer_map.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" namespace mindspore { namespace ops { namespace { -constexpr size_t kInputSize = 4; abstract::ShapePtr TripletMarginLossInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto op_name = primitive->name(); @@ -30,21 +30,19 @@ abstract::ShapePtr TripletMarginLossInferShape(const PrimitivePtr &primitive, auto positive = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; auto negative = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; auto margin = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; - const int64_t keight = 8; - if (x.size() >= keight || positive.size() >= keight || negative.size() >= keight) { + if (x.size() >= kDim8 || positive.size() >= kDim8 || negative.size() >= kDim8) { MS_EXCEPTION(ValueError) << "For " << op_name << ", dimensions of input x positive and negative must be smaller than 8, x_dim: " << x.size() << ", positive_dim: " << positive.size() << ", negative_dim: " << negative.size() << "."; } - const int64_t kone = 1; - if (x.size() <= kone && positive.size() <= kone && negative.size() <= kone) { + if (x.size() <= kDim1 && positive.size() <= kDim1 && negative.size() <= kDim1) { MS_EXCEPTION(ValueError) << "For " << op_name << ", dimensions of input x, positive and negative cannot be less than 1 at the same time, x_dim: " << x.size() << ", positive_dim: " << positive.size() << ", negative_dim: " << negative.size() << "."; } - if (margin.size() != 0) { + if (margin.size() != kDim0) { MS_EXCEPTION(ValueError) << "For " << op_name << ", the dimension of input margin must be 0, margin_dim: " << margin.size() << "."; } @@ -61,8 +59,8 @@ abstract::ShapePtr TripletMarginLossInferShape(const PrimitivePtr &primitive, ShapeVector out_shape; for (size_t i = 0; i < dims; i++) { out_shape.push_back((int64_t)std::max(std::max(x[i], positive[i]), negative[i])); - if ((x[i] != out_shape[i] && x[i] != kone) || (positive[i] != out_shape[i] && positive[i] != kone) || - (negative[i] != out_shape[i] && negative[i] != kone)) { + if ((x[i] != out_shape[i] && x[i] != kDim1) || (positive[i] != out_shape[i] && positive[i] != kDim1) || + (negative[i] != out_shape[i] && negative[i] != kDim1)) { MS_EXCEPTION(ValueError) << "For " << op_name << ", inputs' shape can't broadcast."; } } @@ -98,6 +96,7 @@ TypePtr TripletMarginLossInferType(const PrimitivePtr &primitive, const std::vec } } // namespace +MIND_API_OPERATOR_IMPL(TripletMarginLoss, BaseOperator); AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); diff --git a/mindspore/core/ops/triplet_margin_loss.h b/mindspore/core/ops/triplet_margin_loss.h index 9f98f9f7b22..14bde96830f 100644 --- a/mindspore/core/ops/triplet_margin_loss.h +++ b/mindspore/core/ops/triplet_margin_loss.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,24 +21,23 @@ #include #include #include -#include "ops/primitive_c.h" -#include "abstract/abstract_value.h" -#include "utils/check_convert_utils.h" + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" namespace mindspore { namespace ops { constexpr auto kNameTripletMarginLoss = "TripletMarginLoss"; -class TripletMarginLoss : public PrimitiveC { +class MIND_API TripletMarginLoss : public BaseOperator { public: - TripletMarginLoss() : PrimitiveC(kNameTripletMarginLoss) { + MIND_API_BASE_MEMBER(TripletMarginLoss); + TripletMarginLoss() : BaseOperator(kNameTripletMarginLoss) { InitIOName({"x", "positive", "negative", "margin"}, {"y"}); } - ~TripletMarginLoss() = default; - MS_DECLARE_PARENT(TripletMarginLoss, PrimitiveC); }; -AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); +abstract::AbstractBasePtr TripletMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); using PrimTripletMarginLossPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/python/mindspore/nn/loss/loss.py b/mindspore/python/mindspore/nn/loss/loss.py index 20320b7ea3e..1a5bb5feb49 100644 --- a/mindspore/python/mindspore/nn/loss/loss.py +++ b/mindspore/python/mindspore/nn/loss/loss.py @@ -19,6 +19,9 @@ from mindspore import log from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.ops import operations as P +from mindspore.ops.operations.nn_ops import MultiMarginLoss as MultiMarginLossOp +from mindspore.ops.operations.nn_ops import MultilabelMarginLoss as MultilabelMarginLossOp +from mindspore.ops.operations.nn_ops import TripletMarginLoss as TripletMarginLossOp from mindspore.ops import functional as F from mindspore import nn from mindspore.ops.primitive import constexpr @@ -1127,7 +1130,7 @@ class MultiMarginLoss(LossBase): def __init__(self, p=1, margin=1.0, reduction='mean'): """Initialize MultiMarginLoss.""" super(MultiMarginLoss, self).__init__() - self.multi_margin_loss = P.MultiMarginLoss(p=p, margin=margin, reduction=reduction) + self.multi_margin_loss = MultiMarginLossOp(p=p, margin=margin, reduction=reduction) self.ones = P.Ones() def construct(self, x, target, weight=None): @@ -1369,7 +1372,7 @@ class MultilabelMarginLoss(LossBase): def __init__(self, reduction='mean'): super(MultilabelMarginLoss, self).__init__() - self.multilabel_margin_loss = P.MultilabelMarginLoss(reduction=reduction) + self.multilabel_margin_loss = MultilabelMarginLossOp(reduction=reduction) def construct(self, x, target): return self.multilabel_margin_loss(x, target) @@ -1497,74 +1500,6 @@ def _check_input_dtype(labels_dtype, cls_name): [mstype.int32, mstype.int64, mstype.float16, mstype.float32], cls_name) -class MultilabelMarginLoss(LossBase): - r""" - MultilabelMarginLoss operation. - - Creates a criterion that optimizes a multi-class multi-classification - hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) - and output :math:`y` (which is a 2D `Tensor` of target class indices). - For each sample in the mini-batch: - - .. math:: - \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)} - - where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \ - :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \ - :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \ - and :math:`i \neq y[j]` for all :math:`i` and :math:`j`. - - :math:`y` and :math:`x` must have the same size. - - The criterion only considers a contiguous block of non-negative targets that - starts at the front. - - This allows for different samples to have variable amounts of target classes. - - Args: - reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean". - - Inputs: - - **x** (Tensor) - Predict data. Tensor of shape :math:`(C)` or :math:`(N, C)`, where :math:`N` - is the batch size and :math:`C` is the number of classes. Data type must be float16 or float32. - - **target** (Tensor) - Ground truth data, with the same shape as `x`, data type must be int32 and - label targets padded by -1. - - Outputs: - - **y** (Union[Tensor, Scalar]) - The loss of MultilabelMarginLoss. If `reduction` is "none", its shape - is :math:`(N)`. Otherwise, a scalar value will be returned. - - **is_target** (Tensor) - Output tensor for backward input, with the same shape as `target`, - data type must be int32. - - Raises: - TypeError: If `x` or `target` is not a Tensor. - TypeError: If dtype of `x` is neither float16 nor float32. - TypeError: If dtype of `target` is not int32. - ValueError: If length of shape of `x` is neither 1 nor 2. - ValueError: If shape of `x` is not the same as `target`. - ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. - - Supported Platforms: - ``Ascend`` - - Examples: - >>> loss = nn.MultilabelMarginLoss() - >>> x = Tensor(np.array([[0.1, 0.2, 0.4, 0.8], [0.2, 0.3, 0.5, 0.7]]), mindspore.float32) - >>> target = Tensor(np.array([[1, 2, 0, 3], [2, 3, -1, 1]]), mindspore.int32) - >>> output = loss(x, target) - >>> print(output) - (Tensor(shape=[], dtype=Float32, value= 0.325), Tensor(shape=[2, 4], dtype=Int32, value= - [[1, 1, 1, 1], [0, 0, 1, 1]])) - """ - - def __init__(self, reduction='mean'): - super(MultilabelMarginLoss, self).__init__() - self.multilabel_margin_loss = P.MultilabelMarginLoss(reduction=reduction) - - def construct(self, x, target): - return self.multilabel_margin_loss(x, target) - - class FocalLoss(LossBase): r""" The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the @@ -1867,7 +1802,7 @@ class TripletMarginLoss(LossBase): def __init__(self, p=2, swap=False, eps=1e-6, reduction='mean'): super(TripletMarginLoss, self).__init__() - self.triplet_margin_loss = P.TripletMarginLoss(p=p, swap=swap, eps=eps, reduction=reduction) + self.triplet_margin_loss = TripletMarginLossOp(p=p, swap=swap, eps=eps, reduction=reduction) def construct(self, x, positive, negative, margin): return self.triplet_margin_loss(x, positive, negative, margin) diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py index 2ef40fa4a46..79cc98427a7 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py @@ -40,7 +40,6 @@ from ..operations.array_ops import Expand from ..operations.array_ops import SegmentMean from .. import functional as F from .. import operations as P -from ..operations import _grad_ops as G from .._utils.utils import is_shape_unknown from ..operations import _grad_ops as G diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_nn_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_nn_ops.py index e2b7faaf9dd..e4defe51f36 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_nn_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_nn_ops.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2021-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,12 +24,16 @@ from .._grad.grad_base import bprop_getters from .. import operations as P from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations import _grad_ops as G +from ..operations.nn_ops import MaxUnpool2D +from ..operations.nn_ops import MaxUnpool3D from ..operations.nn_ops import FractionalMaxPool from ..operations._grad_ops import FractionalMaxPoolGrad from ..operations.nn_ops import FractionalMaxPool3DWithFixedKsize from ..operations._grad_ops import FractionalMaxPool3DGradWithFixedKsize from ..operations.nn_ops import FractionalAvgPool from ..operations._grad_ops import FractionalAvgPoolGrad +from ..operations.nn_ops import MultiMarginLoss +from ..operations.nn_ops import MultilabelMarginLoss from ..operations.nn_ops import NthElement from ..operations.nn_ops import PSROIPooling from ..operations._grad_ops import PSROIPoolingGrad @@ -92,7 +96,7 @@ def get_bprop_hshrink(self): return bprop -@bprop_getters.register(P.MultilabelMarginLoss) +@bprop_getters.register(MultilabelMarginLoss) def get_bprop_multilabel_margin_loss(self): """Grad definition for `MultilabelMarginLoss` operation.""" input_grad = G.MultilabelMarginLossGrad(reduction=self.reduction) @@ -120,7 +124,7 @@ def get_bprop_celu(self): return bprop -@bprop_getters.register(P.MultiMarginLoss) +@bprop_getters.register(MultiMarginLoss) def get_bprop_multi_margin_loss(self): """Grad definition for `MultiMarginLoss` operation.""" input_grad = G.MultiMarginLossGrad(p=self.p, margin=self.margin, reduction=self.reduction) @@ -155,7 +159,7 @@ def get_bprop_relu(self): return bprop -@bprop_getters.register(P.MaxUnpool2D) +@bprop_getters.register(MaxUnpool2D) def get_bprop_maxunpool2d(self): """Grad definition for `MaxUnpool2D` operation.""" maxunpool2d_grad = G.MaxUnpool2DGrad( @@ -173,7 +177,7 @@ def get_bprop_maxunpool2d(self): return bprop -@bprop_getters.register(P.MaxUnpool3D) +@bprop_getters.register(MaxUnpool3D) def get_bprop_maxunpool3d(self): """Grad definition for `MaxUnpool3D` operation.""" maxunpool3d_grad = G.MaxUnpool3DGrad( @@ -188,6 +192,8 @@ def get_bprop_maxunpool3d(self): dargmax = zeros_like(argmax) return (dx, dargmax) + return bprop + @bprop_getters.register(NthElement) def get_bprop_nth_element(self): diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/bartlett_window.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/bartlett_window.py index 2617931cf25..fecc57c109f 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/bartlett_window.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/bartlett_window.py @@ -29,6 +29,7 @@ bartlett_window_op_info = AiCPURegOp("BartlettWindow") \ .dtype_format(DataType.I64_Default, DataType.F64_Default) \ .get_op_info() + @op_info_register(bartlett_window_op_info) def _bartlett_window_aicpu(): """BartlettWindow AiCPU register""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool2d.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool2d.py index a716539ebdb..2febfd222a7 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool2d.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool2d.py @@ -50,6 +50,7 @@ max_unpool2d_op_info = AiCPURegOp("MaxUnpool2D") \ .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ .get_op_info() + @op_info_register(max_unpool2d_op_info) def _max_unpool2d_aicpu(): """MaxUnpool2D aicpu register""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py index 6a0ad24d61e..802e6234c8d 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py @@ -51,6 +51,7 @@ max_unpool2d_grad_op_info = AiCPURegOp("MaxUnpool2DGrad") \ .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ .get_op_info() + @op_info_register(max_unpool2d_grad_op_info) def _max_unpool2d_grad_aicpu(): """MaxUnpool2DGrad aicpu register""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool3d.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool3d.py index 684e75618b5..3301571d684 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool3d.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool3d.py @@ -50,6 +50,7 @@ max_unpool3d_op_info = AiCPURegOp("MaxUnpool3D") \ .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ .get_op_info() + @op_info_register(max_unpool3d_op_info) def _max_unpool3d_aicpu(): """MaxUnpool3D aicpu register""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py index cbce99801ee..673feda2eb9 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py @@ -51,6 +51,7 @@ max_unpool3d_grad_op_info = AiCPURegOp("MaxUnpool3DGrad") \ .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ .get_op_info() + @op_info_register(max_unpool3d_grad_op_info) def _max_unpool3d_grad_aicpu(): """MaxUnpool3DGrad aicpu register""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/multi_margin_loss.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/multi_margin_loss.py index f2ae2ac5c5b..0ba3e6ddb2f 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/multi_margin_loss.py @@ -30,6 +30,7 @@ multi_margin_loss_op_info = AiCPURegOp("MultiMarginLoss") \ .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default, DataType.F64_Default) \ .get_op_info() + @op_info_register(multi_margin_loss_op_info) def _multi_margin_loss_aicpu(): """MultiMarginLoss aicpu register""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py index b7907496fe7..74282514cfd 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py @@ -34,6 +34,7 @@ multi_margin_loss_grad_op_info = AiCPURegOp("MultiMarginLossGrad") \ DataType.F64_Default) \ .get_op_info() + @op_info_register(multi_margin_loss_grad_op_info) def _multi_margin_loss_grad_aicpu(): """MultiMarginLossGrad aicpu register""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py index 26904b7d392..af935aa9484 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ multilabel_margin_loss_grad_op_info = AiCPURegOp("MultilabelMarginLossGrad") \ DataType.F32_Default) \ .get_op_info() + @op_info_register(multilabel_margin_loss_grad_op_info) def _multilabel_margin_loss_grad_aicpu(): """MultilabelMarginLossGrad aicpu register""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/mvlgamma.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/mvlgamma.py index c7fd10f21fc..d68c5627784 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/mvlgamma.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/mvlgamma.py @@ -25,6 +25,7 @@ mvlgamma_op_info = AiCPURegOp("Mvlgamma") \ .dtype_format(DataType.F64_Default, DataType.F64_Default) \ .get_op_info() + @op_info_register(mvlgamma_op_info) def _mvlgamma_aicpu(): """Mvlgamma AiCPU register""" diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py index ff30aa18ad6..3ae3cba47de 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py @@ -26,6 +26,7 @@ mvlgamma_grad_op_info = AiCPURegOp("MvlgammaGrad") \ .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default,) \ .get_op_info() + @op_info_register(mvlgamma_grad_op_info) def _mvlgamma_grad_aicpu(): """MvlgammaGrad AiCPU register""" diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py index 2a32a928541..a1253ade9ba 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py @@ -591,7 +591,6 @@ from .extract_volume_patches import _extract_volume_patches_tbe from .multilabel_margin_loss import _multilabel_margin_loss_tbe from .round_ds import _round_ds_tbe from .is_close import _is_close_tbe -from .multilabel_margin_loss import _multilabel_margin_loss_tbe from .apply_adam_with_amsgrad import _apply_adam_with_amsgrad_tbe from .apply_adam_with_amsgrad_ds import _apply_adam_with_amsgrad_ds_tbe from .expm1_ds import _expm1_ds_tbe diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py b/mindspore/python/mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py index 88b8d33c477..e8d46668b7e 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index 054737f8552..576030dc535 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -29,7 +29,8 @@ from ...common._decorator import deprecated from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register -def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): +def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, + ret_four=False, strict_positive=True): """ Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements. """ @@ -54,8 +55,11 @@ def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=Fals validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) ret_value = _get_return_value() for item in ret_value: - if isinstance(item, int) and not isinstance(item, bool) and item > 0: - continue + if isinstance(item, int) and not isinstance(item, bool): + if item > 0: + continue + if not strict_positive and item == 0: + continue _raise_message() return ret_value @@ -2117,7 +2121,7 @@ class MaxUnpool2D(Primitive): if strides in (0, (0, 0)): strides = ksize self.strides = _check_positive_int_or_tuple('strides', strides, self.name, ret_four=True) - self.pads = _check_positive_int_or_tuple('pads', pads, self.name, ret_four=True, greater_zero=False) + self.pads = _check_positive_int_or_tuple('pads', pads, self.name, ret_four=True, strict_positive=False) self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'data_format', self.name) if data_format == "NHWC": diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 682f797bf85..be947190d4b 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -84,6 +84,13 @@ from mindspore.ops.operations.nn_ops import FractionalAvgPool from mindspore.ops.operations._grad_ops import FractionalAvgPoolGrad from mindspore.ops.operations.nn_ops import GridSampler2D from mindspore.ops.operations.nn_ops import GridSampler3D +from mindspore.ops.operations.nn_ops import MaxUnpool2D +from mindspore.ops.operations.nn_ops import MaxUnpool3D +from mindspore.nn.loss.loss import MultiMarginLoss +from mindspore.nn.loss.loss import MultilabelMarginLoss +from mindspore.nn.loss.loss import TripletMarginLoss +from mindspore.ops.operations.array_ops import Mvlgamma +from mindspore.ops.operations.other_ops import BartlettWindow from mindspore.ops.operations.nn_ops import NthElement from mindspore.ops.operations.nn_ops import SparseApplyAdagradDA from mindspore.ops.operations.nn_ops import PSROIPooling @@ -2352,27 +2359,15 @@ test_case_nn_ops = [ 'desc_inputs': [[10, 3, 28, 31, 24]], 'desc_bprop': [[10, 3, 14, 16, 12]]}), ('MaxUnpool2D', { - 'block': P.MaxUnpool2D(ksize=(4, 4), strides=(2, 2), pads=(2, 2)), + 'block': MaxUnpool2D(ksize=(4, 4), strides=(2, 2), pads=(2, 2)), 'desc_inputs': [([4, 3, 6, 6], {'dtype': np.float32}), ([4, 3, 6, 6], {'dtype': np.int64})], 'desc_bprop': [([4, 3, 10, 10], {'dtype': np.float32})]}), - ('MaxUnpool2DGrad', { - 'block': G.MaxUnpool2DGrad(ksize=(1, 1, 4, 4), strides=(1, 1, 2, 2), pads=(1, 1, 2, 2)), - 'desc_inputs': [([4, 3, 6, 6], {'dtype': np.float32}), - ([4, 3, 10, 10], {'dtype': np.float32}), - ([4, 3, 6, 6], {'dtype': np.int64})], - 'skip': ['backward']}), ('MaxUnpool3D', { - 'block': P.MaxUnpool3D(ksize=(4, 4, 4), strides=(2, 2, 2), pads=(2, 2, 2)), + 'block': MaxUnpool3D(ksize=(4, 4, 4), strides=(2, 2, 2), pads=(2, 2, 2)), 'desc_inputs': [([4, 3, 6, 6, 5], {'dtype': np.float32}), ([4, 3, 6, 6, 5], {'dtype': np.int64})], 'desc_bprop': [([4, 3, 10, 10, 8], {'dtype': np.float32})]}), - ('MaxUnpool3DGrad', { - 'block': G.MaxUnpool3DGrad(ksize=(1, 1, 4, 4, 4), strides=(1, 1, 2, 2, 2), pads=(1, 1, 2, 2, 2)), - 'desc_inputs': [([4, 3, 6, 6, 5], {'dtype': np.float32}), - ([4, 3, 10, 10, 8], {'dtype': np.float32}), - ([4, 3, 6, 6, 5], {'dtype': np.int64})], - 'skip': ['backward']}), ('MaxPoolWithArgmax', { 'block': P.MaxPoolWithArgmax(kernel_size=2, strides=2), 'desc_inputs': [[128, 32, 32, 64]], @@ -2752,18 +2747,10 @@ test_case_nn_ops = [ 'desc_inputs': [[2, 2, 3], Tensor(np.array([2, 2]).astype(np.int32))], 'skip': ['backward']}), ('MultiMarginLoss', { - 'block': nn.MultiMarginLoss(reduction="mean"), + 'block': MultiMarginLoss(reduction="mean"), 'desc_inputs': [Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)), Tensor(np.array([0, 0]).astype(np.int64))], 'desc_bprop': [[1]]}), - ('MultiMarginLossGrad', { - 'block': G.MultiMarginLossGrad(), - 'desc_inputs': [Tensor(np.array([1]).astype(np.float32)), - Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)), - Tensor(np.array([1, 1]).astype(np.int64)), - Tensor(np.array([1, 1]).astype(np.float32))], - 'desc_bprop': [Tensor([1], mstype.float32)], - 'skip': ['backward']}), ('L2Loss_1', { 'block': P.L2Loss(), 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)], @@ -2903,7 +2890,7 @@ test_case_nn_ops = [ Tensor(np.array([[-4, -3, -2], [1, 2, 4]]), mstype.float16)], 'skip': ['backward']}), ('TripletMarginLoss', { - 'block': P.TripletMarginLoss(reduction="none"), + 'block': TripletMarginLoss(reduction="none"), 'desc_inputs': [Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)), Tensor(np.array([[0.4, 0.6], [0.4, 0.6]]).astype(np.float32)), Tensor(np.array([[0.2, 0.9], [0.3, 0.7]]).astype(np.float32)), @@ -2952,10 +2939,11 @@ test_case_nn_ops = [ Tensor(0.99, mstype.float32)], 'skip': ['backward']}), ('MultilabelMarginLoss', { - 'block': P.MultilabelMarginLoss(reduction="none"), + 'block': MultilabelMarginLoss(reduction="none"), 'desc_inputs': [Tensor(np.array([[0.1, 0.2, 0.4, 0.8], [0.1, 0.2, 0.3, 0.4]]).astype(np.float32)), Tensor(np.array([[2, 1, -1, 1], [1, -1, 2, 1]]).astype(np.int32))], - 'desc_bprop': [Tensor(np.array([1, 2]).astype(np.float32))]}), + 'desc_bprop': [Tensor(np.array([1, 2]).astype(np.float32)), + Tensor(np.array([[1, 1, 2, 1], [1, 1, 2, 1]]).astype(np.int32))]}), ('GridSampler3D', { 'block': GridSampler3D(interpolation_mode='bilinear', padding_mode='zeros', align_corners=False), 'desc_inputs': [Tensor(np.arange(32).reshape((2, 2, 2, 2, 2)).astype(np.float32)), @@ -2983,7 +2971,6 @@ test_case_nn_ops = [ Tensor(1, mstype.int64)], 'skip': ['backward']}), ] - test_case_array_ops = [ ('LeftShift', { 'block': LeftShift(), @@ -3158,15 +3145,10 @@ test_case_array_ops = [ 'desc_inputs': [(Tensor(np.array([-1.6, -0.1, 1.5, 2.0]).astype(np.float32)))], 'skip': ['backward']}), ('Mvlgamma', { - 'block': P.Mvlgamma(p=1), + 'block': Mvlgamma(p=1), 'desc_inputs': [Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32))], 'desc_bprop': [Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32))] }), - ('MvlgammaGrad', { - 'block': G.MvlgammaGrad(p=1), - 'desc_inputs': [Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32)), - Tensor(np.array([[3, 4, 5], [4, 2, 6]]).astype(np.float32))], - 'skip': ['backward']}), ('ConcatV2_0', { 'block': NetForConcat1(), 'desc_inputs': [ @@ -3506,6 +3488,7 @@ test_case_array_ops = [ }), ] + test_case_image_ops = [ ('AdjustHue', { 'block': AdjustHue(), @@ -3597,10 +3580,10 @@ test_case_other_ops = [ Tensor(np.array([[[0.38, 0.17, 0.95, 0.40]]], np.float32)), Tensor(np.array([0.8], np.float32))), 'skip': ['backward']}), - ('BartlettWindow', { - 'block': P.BartlettWindow(periodic=True, dtype=mstype.float32), - 'desc_inputs': (Tensor(np.array([10], np.int32))), - 'skip': ['backward']}), + ('BartlettWindow', { + 'block': BartlettWindow(periodic=True, dtype=mstype.float32), + 'desc_inputs': (Tensor(np.array([10], np.int32))), + 'skip': ['backward']}), ('GatherNd', { 'block': P.GatherNd(), 'desc_inputs': (Tensor(np.ones((1, 3, 6, 6), np.float32)),