diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cum_op_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/cum_minmax_cpu_kernel.cc similarity index 54% rename from mindspore/ccsrc/plugin/device/cpu/kernel/cum_op_cpu_kernel.cc rename to mindspore/ccsrc/plugin/device/cpu/kernel/cum_minmax_cpu_kernel.cc index 1215063e103..d0559d49118 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cum_op_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cum_minmax_cpu_kernel.cc @@ -14,10 +14,9 @@ * limitations under the License. */ -#include "plugin/device/cpu/kernel/cum_op_cpu_kernel.h" +#include "plugin/device/cpu/kernel/cum_minmax_cpu_kernel.h" #include #include -#include #include #include "plugin/device/cpu/hal/device/cpu_device_address.h" @@ -30,7 +29,7 @@ template using CumMinMaxComputeFunc = std::function(const T &, const S &, const T &, const S &)>; template -std::pair cumop(const T &a_val, const S &a_idx, const T &b_val, const S &b_idx) { +std::pair cum_minmax(const T &a_val, const S &a_idx, const T &b_val, const S &b_idx) { OP op; if constexpr ((std::is_same_v) || (std::is_same_v)) { return std::isnan(a_val) || op(a_val, b_val) ? std::make_pair(a_val, a_idx) : std::make_pair(b_val, b_idx); @@ -41,8 +40,8 @@ std::pair cumop(const T &a_val, const S &a_idx, const T &b_val, const S &b } } // namespace -bool CumOpCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, - const std::vector &outputs) { +bool CumMinMaxCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { kernel_name_ = base_operator->GetPrim()->name(); if (kernel_name_ != kernel_type_) { MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << ", but got kernel name as " << kernel_name_; @@ -57,13 +56,13 @@ bool CumOpCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::ve MS_LOG(EXCEPTION) << kernel_name_ << " does not support this kernel data type: " << kernel_attr; } base_operator_ = base_operator; - kernel_func_ = func_list_[index].second; + kernel_func_ = func_list_[kernel_type_][index].second; return true; } -bool CumOpCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, - const std::vector &outputs, - const std::map &others) { +bool CumMinMaxCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &others) { if (!NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) { MS_LOG(WARNING) << kernel_name_ << " reinit failed."; return false; @@ -86,7 +85,7 @@ bool CumOpCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:: return true; } -size_t CumOpCpuKernelMod::GetRealIndex(size_t index) { +size_t CumMinMaxCpuKernelMod::GetRealIndex(size_t index) { auto batch_idx = index / axis_size_; auto axis_idx = index - batch_idx * axis_size_; auto outer_idx = batch_idx / inner_size_; @@ -95,15 +94,15 @@ size_t CumOpCpuKernelMod::GetRealIndex(size_t index) { } template -bool CumOpCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { +bool CumMinMaxCpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCumInputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCumOutputsNum, kernel_name_); // Select the minimum/maximum computation function static const std::map> cum_compute_func_map{ - {prim::kPrimCummax->name(), &cumop>}, - {prim::kPrimCummin->name(), &cumop>}, + {prim::kPrimCummax->name(), &cum_minmax>}, + {prim::kPrimCummin->name(), &cum_minmax>}, }; if (cum_compute_func_map.find(kernel_name_) == cum_compute_func_map.end()) { MS_LOG(EXCEPTION) << "For 'CumMinMaxOp', the current kernel only support this operator in " @@ -186,59 +185,71 @@ bool CumOpCpuKernelMod::LaunchKernel(const std::vector &inpu return true; } -std::vector> CumOpCpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), - &CumOpCpuKernelMod::LaunchKernel}}; +// Note that in definition of primitive, Cummin return int32 as indices and Cummax return int64 as indices. (see +// cummax.cc and cummin.cc). +std::map>> + CumMinMaxCpuKernelMod::func_list_ = { + {kCummin, + {{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxCpuKernelMod::LaunchKernel}}}, + {kCummax, + {{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxCpuKernelMod::LaunchKernel}}}}; + +std::vector CumMinMaxCpuKernelMod::GetOpSupport() { + auto iter = func_list_.find(kernel_type_); + if (iter == func_list_.end()) { + MS_LOG(EXCEPTION) << "Cum_minmax cpu does not support " << kernel_type_; + } -std::vector CumOpCpuKernelMod::GetOpSupport() { std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); + (void)std::transform( + iter->second.begin(), iter->second.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); return support_list; } -MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeCpuKernelMod, Cummin, CumOpCpuKernelMod); +MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeCpuKernelMod, Cummin, CumMinMaxCpuKernelMod); +MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeCpuKernelMod, Cummax, CumMinMaxCpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cum_op_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/cum_minmax_cpu_kernel.h similarity index 74% rename from mindspore/ccsrc/plugin/device/cpu/kernel/cum_op_cpu_kernel.h rename to mindspore/ccsrc/plugin/device/cpu/kernel/cum_minmax_cpu_kernel.h index 91547f099b0..1dc49f1a021 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cum_op_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/cum_minmax_cpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CUM_OP_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CUM_OP_CPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CUM_MINMAX_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CUM_MINMAX_CPU_KERNEL_H_ #include #include @@ -27,12 +27,14 @@ namespace mindspore { namespace kernel { +constexpr auto kCummin = "Cummin"; +constexpr auto kCummax = "Cummax"; constexpr auto kUnKnown = "UnKnown"; -class CumOpCpuKernelMod : public NativeCpuKernelMod { +class CumMinMaxCpuKernelMod : public NativeCpuKernelMod { public: - CumOpCpuKernelMod() = default; - explicit CumOpCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} - ~CumOpCpuKernelMod() override = default; + CumMinMaxCpuKernelMod() = default; + explicit CumMinMaxCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} + ~CumMinMaxCpuKernelMod() override = default; bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs) override; @@ -54,9 +56,9 @@ class CumOpCpuKernelMod : public NativeCpuKernelMod { template bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); - using CumMinMaxLaunchFunc = std::function &, + using CumMinMaxLaunchFunc = std::function &, const std::vector &)>; - static std::vector> func_list_; + static std::map>> func_list_; CumMinMaxLaunchFunc kernel_func_; BaseOperatorPtr base_operator_; int64_t inner_size_{1}; @@ -69,4 +71,4 @@ class CumOpCpuKernelMod : public NativeCpuKernelMod { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CUM_OP_CPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_CUM_MINMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cummax_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/cummax_cpu_kernel.cc deleted file mode 100644 index 7c5fa83a3f4..00000000000 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cummax_cpu_kernel.cc +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/cpu/kernel/cummax_cpu_kernel.h" -#include -#include -#include "plugin/device/cpu/hal/device/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void CummaxCPUKernelMod::InitKernel(const CNodePtr &kernel_node) { - input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - output1_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - output2_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 1); - dim_ = common::AnfAlgo::GetNodeAttr(kernel_node, "dim"); - - auto kernel_attr = GetKernelAttrFromNode(kernel_node); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(EXCEPTION) << "Cummax does not support this kernel data type: " << kernel_attr; - } - - kernel_func_ = func_list_[index].second; -} - -template -bool CummaxCPUKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - auto input_data_addr = reinterpret_cast(inputs[0]->addr); - auto output1_data_addr = reinterpret_cast(outputs[0]->addr); - auto output2_data_addr = reinterpret_cast(outputs[1]->addr); - - const size_t dims = input_shape_.size(); - if (dims == 0) { - MS_LOG(EXCEPTION) << "The value of `dims` can not be 0"; - } - dim_ = (dim_ + dims) % dims; - std::vector p{1}; - - for (int64_t i = (int64_t)input_shape_.size() - 1; i >= 0; i--) - p.push_back(p[(int64_t)input_shape_.size() - 1 - i] * input_shape_[i]); - reverse(p.begin(), p.end()); - - size_t input_stride = p[dim_ + 1]; - size_t output1_stride = p[dim_ + 1]; - size_t output2_stride = p[dim_ + 1]; - size_t input_dim_size = input_shape_[dim_]; - - int exit_ok = 0; - std::vector counter(dims, 0); - - while (!exit_ok) { - T out = input_data_addr[0]; - int idx = 0; - for (size_t i = 0; i < input_dim_size; i++) { - T cur = input_data_addr[i * input_stride]; - if (cur >= out) { - out = cur; - idx = i; - } - output1_data_addr[i * output1_stride] = out; - output2_data_addr[i * output2_stride] = idx; - } - - if (dims == 1) break; - for (size_t dim_i = 0; dim_i < dims; dim_i++) { - if (dim_i == dim_) { - if (dim_i == dims - 1) { - exit_ok = 1; - break; - } - continue; - } - counter[dim_i]++; - input_data_addr += p[dim_i + 1]; - output1_data_addr += p[dim_i + 1]; - output2_data_addr += p[dim_i + 1]; - - if (counter[dim_i] == input_shape_[dim_i]) { - if (dim_i == dims - 1) { - exit_ok = 1; - break; - } else { - input_data_addr -= counter[dim_i] * p[dim_i + 1]; - output1_data_addr -= counter[dim_i] * p[dim_i + 1]; - output2_data_addr -= counter[dim_i] * p[dim_i + 1]; - counter[dim_i] = 0; - } - } else { - break; - } - } - } - return true; -} - -std::vector> CummaxCPUKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), - &CummaxCPUKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), - &CummaxCPUKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), - &CummaxCPUKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &CummaxCPUKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), - &CummaxCPUKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), - &CummaxCPUKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), - &CummaxCPUKernelMod::LaunchKernel}}; - -std::vector CummaxCPUKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - - return support_list; -} - -MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cummax, CummaxCPUKernelMod); -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/cummax_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/cummax_cpu_kernel.h deleted file mode 100644 index b02a92acdc6..00000000000 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/cummax_cpu_kernel.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_ - -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "plugin/factory/ms_factory.h" - -namespace mindspore { -namespace kernel { -class CummaxCPUKernelMod : public DeprecatedNativeCpuKernelMod { - public: - CummaxCPUKernelMod() = default; - ~CummaxCPUKernelMod() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs) override { - return kernel_func_(this, inputs, outputs); - } - - protected: - std::vector GetOpSupport() override; - - private: - template - bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); - using CummaxFunc = std::function &, - const std::vector &)>; - static std::vector> func_list_; - CummaxFunc kernel_func_; - std::vector input_shape_; - std::vector output1_shape_; - std::vector output2_shape_; - size_t dim_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_minmax_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_minmax_impl.cu new file mode 100644 index 00000000000..a2e8ce165c6 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_minmax_impl.cu @@ -0,0 +1,192 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cum_minmax_impl.cuh" +#include +#include +#include +#include +#include +#include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" + +template +__device__ bool IsNan(const T &x) { + return isnan(x); +} + +__device__ bool IsNan(const half &x) { return __hisnan(x); } + +template +struct binary_op { + const T *input_ptr_; + size_t axis_inner_size_; + size_t axis_size_; + size_t inner_size_; + OP op; + + __thrust_exec_check_disable__ __device__ size_t operator()(const size_t &lhs, const size_t &rhs) const { + if (rhs % axis_size_) { + size_t batch_idx = rhs / axis_size_; + size_t axis_idx = rhs - batch_idx * axis_size_; + size_t outer_idx = batch_idx / inner_size_; + size_t inner_idx = batch_idx - outer_idx * inner_size_; + size_t fix_part = outer_idx * axis_inner_size_ + inner_idx; + size_t lhs_idx = fix_part + lhs * inner_size_; + size_t rhs_idx = fix_part + axis_idx * inner_size_; + return IsNan(input_ptr_[lhs_idx]) || op(input_ptr_[lhs_idx], input_ptr_[rhs_idx]) ? lhs : axis_idx; + } else { + return 0; + } + } +}; + +template +__global__ void DecodeKernel(const T *input_ptr, const size_t *workspace_ptr, T *value_ptr, S *index_ptr, + size_t element_size, size_t axis_inner_size, size_t axis_size, size_t inner_size) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < element_size; tid += blockDim.x * gridDim.x) { + size_t batch_idx = tid / axis_size; + size_t axis_idx = tid - batch_idx * axis_size; + size_t outer_idx = batch_idx / inner_size; + size_t inner_idx = batch_idx - outer_idx * inner_size; + size_t fix_part = outer_idx * axis_inner_size + inner_idx; + size_t real_idx = fix_part + axis_idx * inner_size; + size_t cum_idx = fix_part + workspace_ptr[tid] * inner_size; + value_ptr[real_idx] = input_ptr[cum_idx]; + index_ptr[real_idx] = static_cast(workspace_ptr[tid]); + } +} + +template +void CumMinMax(enum CumOpType op_type, const T *input_ptr, size_t *workspace_ptr, T *value_ptr, S *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, cudaStream_t cuda_stream) { + // Cummin/Cummax cuda algorithm: + // 1. Generate a sequence from 0 to element_size-1; + // 2. Using thrust:inclusive_scan to get the cumulative maximum/minimum result of transposed array. + // Note that 1. Segmentation of array is done within binary_op of inclusive_scan; + // 2. it's not necessary to directly transpose the original array, but using the mapping rule; + // 3. Restore the transposed array using DecodeKernel, and also with the help of mapping rule. + auto device = thrust::cuda::par.on(cuda_stream); + auto thrust_ptr = thrust::device_pointer_cast(workspace_ptr); + thrust::sequence(device, thrust_ptr, thrust_ptr + element_size); + auto axis_inner_size = axis_size * inner_size; + switch (op_type) { + case CUMMIN: { + binary_op> op{input_ptr, axis_inner_size, axis_size, inner_size}; + thrust::inclusive_scan(device, thrust_ptr, thrust_ptr + element_size, thrust_ptr, op); + break; + } + case CUMMAX: { + binary_op> op{input_ptr, axis_inner_size, axis_size, inner_size}; + thrust::inclusive_scan(device, thrust_ptr, thrust_ptr + element_size, thrust_ptr, op); + break; + } + default: + break; + } + + DecodeKernel<<>>( + input_ptr, workspace_ptr, value_ptr, index_ptr, element_size, axis_inner_size, axis_size, inner_size); +} + +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const int8_t *input_ptr, + size_t *workspace_ptr, int8_t *value_ptr, int32_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const int16_t *input_ptr, + size_t *workspace_ptr, int16_t *value_ptr, int32_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const int32_t *input_ptr, + size_t *workspace_ptr, int32_t *value_ptr, int32_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const int64_t *input_ptr, + size_t *workspace_ptr, int64_t *value_ptr, int32_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const uint8_t *input_ptr, + size_t *workspace_ptr, uint8_t *value_ptr, int32_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const uint16_t *input_ptr, + size_t *workspace_ptr, uint16_t *value_ptr, + int32_t *index_ptr, size_t element_size, size_t axis_size, + size_t inner_size, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const uint32_t *input_ptr, + size_t *workspace_ptr, uint32_t *value_ptr, + int32_t *index_ptr, size_t element_size, size_t axis_size, + size_t inner_size, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const uint64_t *input_ptr, + size_t *workspace_ptr, uint64_t *value_ptr, + int32_t *index_ptr, size_t element_size, size_t axis_size, + size_t inner_size, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const half *input_ptr, + size_t *workspace_ptr, half *value_ptr, int32_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const float *input_ptr, + size_t *workspace_ptr, float *value_ptr, int32_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const double *input_ptr, + size_t *workspace_ptr, double *value_ptr, int32_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const int8_t *input_ptr, + size_t *workspace_ptr, int8_t *value_ptr, int64_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const int16_t *input_ptr, + size_t *workspace_ptr, int16_t *value_ptr, int64_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const int32_t *input_ptr, + size_t *workspace_ptr, int32_t *value_ptr, int64_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const int64_t *input_ptr, + size_t *workspace_ptr, int64_t *value_ptr, int64_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const uint8_t *input_ptr, + size_t *workspace_ptr, uint8_t *value_ptr, int64_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const uint16_t *input_ptr, + size_t *workspace_ptr, uint16_t *value_ptr, + int64_t *index_ptr, size_t element_size, size_t axis_size, + size_t inner_size, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const uint32_t *input_ptr, + size_t *workspace_ptr, uint32_t *value_ptr, + int64_t *index_ptr, size_t element_size, size_t axis_size, + size_t inner_size, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const uint64_t *input_ptr, + size_t *workspace_ptr, uint64_t *value_ptr, + int64_t *index_ptr, size_t element_size, size_t axis_size, + size_t inner_size, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const half *input_ptr, + size_t *workspace_ptr, half *value_ptr, int64_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const float *input_ptr, + size_t *workspace_ptr, float *value_ptr, int64_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const double *input_ptr, + size_t *workspace_ptr, double *value_ptr, int64_t *index_ptr, + size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_op_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_minmax_impl.cuh similarity index 79% rename from mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_op_impl.cuh rename to mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_minmax_impl.cuh index 4189546a341..48dea925d10 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_op_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_minmax_impl.cuh @@ -22,8 +22,8 @@ enum CumOpType { CUMMIN = 0, CUMMAX, CUM_OP_INVALID_TYPE = 255 }; template -CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const T *input_ptr, size_t *workspace_ptr, T *value_ptr, - S *index_ptr, size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); +CUDA_LIB_EXPORT void CumMinMax(enum CumOpType op_type, const T *input_ptr, size_t *workspace_ptr, T *value_ptr, + S *index_ptr, size_t element_size, size_t axis_size, size_t inner_size, + cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUM_OP_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_op_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_op_impl.cu deleted file mode 100644 index 132c2de20a6..00000000000 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_op_impl.cu +++ /dev/null @@ -1,190 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cum_op_impl.cuh" -#include -#include -#include -#include -#include -#include "include/cuda_fp16.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" - -template -__device__ bool IsNan(const T &x) { - return isnan(x); -} - -__device__ bool IsNan(const half &x) { return __hisnan(x); } - -template -struct binary_op { - const T *input_ptr_; - size_t axis_inner_size_; - size_t axis_size_; - size_t inner_size_; - OP op; - - __thrust_exec_check_disable__ __device__ size_t operator()(const size_t &lhs, const size_t &rhs) const { - if (rhs % axis_size_) { - size_t batch_idx = rhs / axis_size_; - size_t axis_idx = rhs - batch_idx * axis_size_; - size_t outer_idx = batch_idx / inner_size_; - size_t inner_idx = batch_idx - outer_idx * inner_size_; - size_t fix_part = outer_idx * axis_inner_size_ + inner_idx; - size_t lhs_idx = fix_part + lhs * inner_size_; - size_t rhs_idx = fix_part + axis_idx * inner_size_; - return IsNan(input_ptr_[lhs_idx]) || op(input_ptr_[lhs_idx], input_ptr_[rhs_idx]) ? lhs : axis_idx; - } else { - return 0; - } - } -}; - -template -__global__ void DecodeKernel(const T *input_ptr, const size_t *workspace_ptr, T *value_ptr, S *index_ptr, - size_t element_size, size_t axis_inner_size, size_t axis_size, size_t inner_size) { - for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < element_size; tid += blockDim.x * gridDim.x) { - size_t batch_idx = tid / axis_size; - size_t axis_idx = tid - batch_idx * axis_size; - size_t outer_idx = batch_idx / inner_size; - size_t inner_idx = batch_idx - outer_idx * inner_size; - size_t fix_part = outer_idx * axis_inner_size + inner_idx; - size_t real_idx = fix_part + axis_idx * inner_size; - size_t cum_idx = fix_part + workspace_ptr[tid] * inner_size; - value_ptr[real_idx] = input_ptr[cum_idx]; - index_ptr[real_idx] = workspace_ptr[tid]; - } -} - -template -void CumOp(enum CumOpType op_type, const T *input_ptr, size_t *workspace_ptr, T *value_ptr, S *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, cudaStream_t cuda_stream) { - // Cummin/Cummax cuda algorithm: - // 1. Generate a sequence from 0 to element_size-1; - // 2. Using thrust:inclusive_scan to get the cumulative maximum/minimum result of transposed array. - // Note that 1. Segmentation of array is done within binary_op of inclusive_scan; - // 2. it's not necessary to directly transpose the original array, but using the mapping rule; - // 3. Restore the transposed array using DecodeKernel, and also with the help of mapping rule. - auto device = thrust::cuda::par.on(cuda_stream); - auto thrust_ptr = thrust::device_pointer_cast(workspace_ptr); - thrust::sequence(device, thrust_ptr, thrust_ptr + element_size); - auto axis_inner_size = axis_size * inner_size; - switch (op_type) { - case CUMMIN: { - binary_op> op{input_ptr, axis_inner_size, axis_size, inner_size}; - thrust::inclusive_scan(device, thrust_ptr, thrust_ptr + element_size, thrust_ptr, op); - break; - } - case CUMMAX: { - binary_op> op{input_ptr, axis_inner_size, axis_size, inner_size}; - thrust::inclusive_scan(device, thrust_ptr, thrust_ptr + element_size, thrust_ptr, op); - break; - } - default: - break; - } - - DecodeKernel<<>>( - input_ptr, workspace_ptr, value_ptr, index_ptr, element_size, axis_inner_size, axis_size, inner_size); -} - -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const int8_t *input_ptr, - size_t *workspace_ptr, int8_t *value_ptr, int32_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const int16_t *input_ptr, - size_t *workspace_ptr, int16_t *value_ptr, int32_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const int32_t *input_ptr, - size_t *workspace_ptr, int32_t *value_ptr, int32_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const int64_t *input_ptr, - size_t *workspace_ptr, int64_t *value_ptr, int32_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const uint8_t *input_ptr, - size_t *workspace_ptr, uint8_t *value_ptr, int32_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const uint16_t *input_ptr, - size_t *workspace_ptr, uint16_t *value_ptr, int32_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const uint32_t *input_ptr, - size_t *workspace_ptr, uint32_t *value_ptr, int32_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const uint64_t *input_ptr, - size_t *workspace_ptr, uint64_t *value_ptr, int32_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const half *input_ptr, size_t *workspace_ptr, - half *value_ptr, int32_t *index_ptr, size_t element_size, - size_t axis_size, size_t inner_size, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const float *input_ptr, - size_t *workspace_ptr, float *value_ptr, int32_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const double *input_ptr, - size_t *workspace_ptr, double *value_ptr, int32_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const int8_t *input_ptr, - size_t *workspace_ptr, int8_t *value_ptr, int64_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const int16_t *input_ptr, - size_t *workspace_ptr, int16_t *value_ptr, int64_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const int32_t *input_ptr, - size_t *workspace_ptr, int32_t *value_ptr, int64_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const int64_t *input_ptr, - size_t *workspace_ptr, int64_t *value_ptr, int64_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const uint8_t *input_ptr, - size_t *workspace_ptr, uint8_t *value_ptr, int64_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const uint16_t *input_ptr, - size_t *workspace_ptr, uint16_t *value_ptr, int64_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const uint32_t *input_ptr, - size_t *workspace_ptr, uint32_t *value_ptr, int64_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const uint64_t *input_ptr, - size_t *workspace_ptr, uint64_t *value_ptr, int64_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const half *input_ptr, size_t *workspace_ptr, - half *value_ptr, int64_t *index_ptr, size_t element_size, - size_t axis_size, size_t inner_size, cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const float *input_ptr, - size_t *workspace_ptr, float *value_ptr, int64_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); -template CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const double *input_ptr, - size_t *workspace_ptr, double *value_ptr, int64_t *index_ptr, - size_t element_size, size_t axis_size, size_t inner_size, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_minmax_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_minmax_gpu_kernel.cc new file mode 100644 index 00000000000..7e23dc4fbe7 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_minmax_gpu_kernel.cc @@ -0,0 +1,188 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/cum_minmax_gpu_kernel.h" +#include +#include +#include "mindspore/core/abstract/utils.h" +#include "kernel/common_utils.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr int kCumInputsNum = 1; +constexpr int kCumOutputsNum = 2; +constexpr char AXIS[] = "axis"; + +static const std::map kCumOpTypeMap = { + {"Cummin", CUMMIN}, + {"Cummax", CUMMAX}, +}; +} // namespace + +void CumMinMaxGpuKernelMod::ResetResource() noexcept { + inner_size_ = 1; + outer_size_ = 1; + axis_size_ = 1; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); +} + +bool CumMinMaxGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + kernel_name_ = base_operator->GetPrim()->name(); + if (kernel_name_ != kernel_type_) { + MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << ", but got kernel name as " << kernel_name_; + } + + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCumInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCumOutputsNum, kernel_name_); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + } + + auto iter = kCumOpTypeMap.find(kernel_name_); + if (iter == kCumOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "Only support these cum operators: " << Map2Str(kCumOpTypeMap) << " currently, but got " + << kernel_name_; + } + cum_op_type_ = iter->second; + t_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).first); + s_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex1).first); + kernel_func_ = func_list_[kernel_type_][index].second; + return true; +} + +bool CumMinMaxGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &others) { + ResetResource(); + std::vector input_shape = inputs[kIndex0]->GetShapeVector(); + auto rank = SizeToLong(input_shape.size()); + auto axis_input = GetValue(base_operator->GetAttr(AXIS)); + auto axis = axis_input < 0 ? LongToSize(axis_input + rank) : LongToSize(axis_input); + for (size_t i = 0; i < input_shape.size(); i++) { + if (i < axis) { + outer_size_ *= input_shape.at(i); + } else if (i > axis) { + inner_size_ *= input_shape.at(i); + } else { + axis_size_ = input_shape.at(i); + } + } + + element_size_ = outer_size_ * inner_size_ * axis_size_; + if (!element_size_) { + return true; + } + + input_size_list_.push_back(element_size_ * t_size_); + output_size_list_.push_back(element_size_ * t_size_); + output_size_list_.push_back(element_size_ * s_size_); + workspace_size_list_.push_back(element_size_ * sizeof(size_t)); + return true; +} + +template +bool CumMinMaxGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (!element_size_) { + return true; + } + auto cuda_stream = reinterpret_cast(stream_ptr); + auto input_ptr = reinterpret_cast(inputs.at(kIndex0)->addr); + auto value_ptr = reinterpret_cast(outputs.at(kIndex0)->addr); + auto index_ptr = reinterpret_cast(outputs.at(kIndex1)->addr); + auto workspace_ptr = reinterpret_cast(workspace.at(kIndex0)->addr); + + CumMinMax(cum_op_type_, input_ptr, workspace_ptr, value_ptr, index_ptr, element_size_, axis_size_, inner_size_, + cuda_stream); + return true; +} + +// Note that in definition of primitive, Cummin return int32 as indices and Cummax return int64 as indices. (see +// cummax.cc and cummin.cc). +std::map>> + CumMinMaxGpuKernelMod::func_list_ = { + {kCummin, + {{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), + &CumMinMaxGpuKernelMod::LaunchKernel}}}, + {kCummax, + {{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), + &CumMinMaxGpuKernelMod::LaunchKernel}}}}; + +std::vector CumMinMaxGpuKernelMod::GetOpSupport() { + auto iter = func_list_.find(kernel_type_); + if (iter == func_list_.end()) { + MS_LOG(EXCEPTION) << "Cum_minmax cpu does not support " << kernel_type_; + } + + std::vector support_list; + (void)std::transform( + iter->second.begin(), iter->second.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeGpuKernelMod, Cummin, CumMinMaxGpuKernelMod); +MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeGpuKernelMod, Cummax, CumMinMaxGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_op_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_minmax_gpu_kernel.h similarity index 69% rename from mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_op_gpu_kernel.h rename to mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_minmax_gpu_kernel.h index 0165290df62..6214ab895dc 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_op_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_minmax_gpu_kernel.h @@ -14,26 +14,28 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CUM_OP_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CUM_OP_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CUM_MINMAX_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CUM_MINMAX_GPU_KERNEL_H_ #include #include #include #include #include -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_op_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_minmax_impl.cuh" #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/factory/ms_factory.h" namespace mindspore { namespace kernel { +constexpr auto kCummin = "Cummin"; +constexpr auto kCummax = "Cummax"; constexpr auto kUnKnown = "UnKnown"; -class CumOpGpuKernelMod : public NativeGpuKernelMod { +class CumMinMaxGpuKernelMod : public NativeGpuKernelMod { public: - CumOpGpuKernelMod() = default; - explicit CumOpGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} - ~CumOpGpuKernelMod() override = default; + CumMinMaxGpuKernelMod() = default; + explicit CumMinMaxGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} + ~CumMinMaxGpuKernelMod() override = default; bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs) override; @@ -55,11 +57,12 @@ class CumOpGpuKernelMod : public NativeGpuKernelMod { bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr); - using CumOpLaunchFunc = std::function &, - const std::vector &, const std::vector &, void *)>; - static std::vector> func_list_; + using CumMinMaxLaunchFunc = + std::function &, const std::vector &, + const std::vector &, void *)>; + static std::map>> func_list_; CumOpType cum_op_type_; - CumOpLaunchFunc kernel_func_; + CumMinMaxLaunchFunc kernel_func_; size_t t_size_{0}; // Equal to sizeof(T). size_t s_size_{0}; // Equal to sizeof(S). size_t inner_size_{1}; @@ -71,4 +74,4 @@ class CumOpGpuKernelMod : public NativeGpuKernelMod { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CUM_OP_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CUM_MINMAX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_op_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_op_gpu_kernel.cc deleted file mode 100644 index fee0ef2a0b3..00000000000 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/math/cum_op_gpu_kernel.cc +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "plugin/device/gpu/kernel/math/cum_op_gpu_kernel.h" -#include -#include -#include "mindspore/core/abstract/utils.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr int kCumInputsNum = 1; -constexpr int kCumOutputsNum = 2; -constexpr char AXIS[] = "axis"; - -static const std::map kCumOpTypeMap = { - {"Cummin", CUMMIN}, -}; -} // namespace - -void CumOpGpuKernelMod::ResetResource() noexcept { - inner_size_ = 1; - outer_size_ = 1; - axis_size_ = 1; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); -} - -bool CumOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, - const std::vector &outputs) { - kernel_name_ = base_operator->GetPrim()->name(); - if (kernel_name_ != kernel_type_) { - MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << ", but got kernel name as " << kernel_name_; - } - - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCumInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCumOutputsNum, kernel_name_); - - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; - } - - auto iter = kCumOpTypeMap.find(kernel_name_); - if (iter == kCumOpTypeMap.end()) { - MS_LOG(EXCEPTION) << "Only support these cum operators: " << Map2Str(kCumOpTypeMap) << " currently, but got " - << kernel_name_; - } - cum_op_type_ = iter->second; - t_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).first); - s_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex1).first); - kernel_func_ = func_list_[index].second; - return true; -} - -bool CumOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, - const std::vector &outputs, - const std::map &others) { - ResetResource(); - std::vector input_shape = inputs[kIndex0]->GetShapeVector(); - auto rank = SizeToLong(input_shape.size()); - auto axis_input = GetValue(base_operator->GetAttr(AXIS)); - auto axis = axis_input < 0 ? LongToSize(axis_input + rank) : LongToSize(axis_input); - for (size_t i = 0; i < input_shape.size(); i++) { - if (i < axis) { - outer_size_ *= input_shape.at(i); - } else if (i > axis) { - inner_size_ *= input_shape.at(i); - } else { - axis_size_ = input_shape.at(i); - } - } - - element_size_ = outer_size_ * inner_size_ * axis_size_; - if (!element_size_) { - return true; - } - - input_size_list_.push_back(element_size_ * t_size_); - output_size_list_.push_back(element_size_ * t_size_); - output_size_list_.push_back(element_size_ * s_size_); - workspace_size_list_.push_back(element_size_ * sizeof(size_t)); - return true; -} - -template -bool CumOpGpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (!element_size_) { - return true; - } - auto cuda_stream = reinterpret_cast(stream_ptr); - auto input_ptr = reinterpret_cast(inputs.at(kIndex0)->addr); - auto value_ptr = reinterpret_cast(outputs.at(kIndex0)->addr); - auto index_ptr = reinterpret_cast(outputs.at(kIndex1)->addr); - auto workspace_ptr = reinterpret_cast(workspace.at(kIndex0)->addr); - - CumOp(cum_op_type_, input_ptr, workspace_ptr, value_ptr, index_ptr, element_size_, axis_size_, inner_size_, - cuda_stream); - return true; -} - -std::vector> CumOpGpuKernelMod::func_list_ = { - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}, - {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), - &CumOpGpuKernelMod::LaunchKernel}}; - -std::vector CumOpGpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform( - func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; -} - -MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeGpuKernelMod, Cummin, CumOpGpuKernelMod); -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/core/ops/cummax.cc b/mindspore/core/ops/cummax.cc index d49e514038b..c35f84e4c7b 100644 --- a/mindspore/core/ops/cummax.cc +++ b/mindspore/core/ops/cummax.cc @@ -26,19 +26,20 @@ namespace mindspore { namespace ops { namespace { +constexpr char AXIS[] = "axis"; abstract::TupleShapePtr CummaxInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto x_shape = input_args[0]->BuildShape(); auto x_shape_value = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape)[kShape]; - auto dim = GetValue(primitive->GetAttr("dim")); + auto axis = GetValue(primitive->GetAttr(AXIS)); if (x_shape_value.size() <= 0) { - MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', inputs dim should be greater than 0, but got " + MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', inputs 'axis' should be greater than 0, but got " << x_shape_value.size() << "."; } - if (dim >= static_cast(x_shape_value.size()) || dim < -static_cast(x_shape_value.size())) { - MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "',The value of 'dim' should be in the range of [" + if (axis >= static_cast(x_shape_value.size()) || axis < -static_cast(x_shape_value.size())) { + MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "',The value of 'axis' should be in the range of [" << -static_cast(x_shape_value.size()) << "," - << static_cast(x_shape_value.size()) << "], but got dim:" << dim << "."; + << static_cast(x_shape_value.size()) << "], but got axis:" << axis << "."; } return std::make_shared(std::vector{x_shape, x_shape}); } diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/cummax.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/cummax.py index 4f6fb6ca155..0eaada5ddb2 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/cummax.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/cummax.py @@ -20,7 +20,7 @@ cummax_op_info = AiCPURegOp("Cummax") \ .input(0, "x", "required") \ .output(0, "y", "required") \ .output(1, "indices", "required") \ - .attr("dim", "int") \ + .attr("axis", "int") \ .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I64_Default) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ diff --git a/mindspore/python/mindspore/ops/operations/_inner_ops.py b/mindspore/python/mindspore/ops/operations/_inner_ops.py index a6a3007c51f..a839603b617 100755 --- a/mindspore/python/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/python/mindspore/ops/operations/_inner_ops.py @@ -1510,7 +1510,7 @@ class Cummin(Primitive): ValueError:If 'axis' is out the range from -len(`input_x`.shape) to len(`input_x`.shape) - 1 Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> from mindspore import Tensor, ops diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 9b76a67d44d..b74c2b99cf0 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -7530,8 +7530,7 @@ class Cummax(Primitive): y_i = max(x_1 , x_2 , x_3 ,... ,x_i) Args: - dim (int): The dim to accumulate the tensor's value. Must be in the range [-rank(input), rank(input)). - The default value is -1. + axis (int): The axis to accumulate the tensor's value. Must be in the range [-rank(input), rank(input)). Inputs: - **input** (Tensor) - The input tensor whose dtype is int8, int32, int64, uint8, uint32, float16, float32. @@ -7542,18 +7541,18 @@ class Cummax(Primitive): Raises: TypeError: If `input` is not a Tensor. - TypeError: If `dim` is not an int. - ValueError: If `dim` is out of range, `dim` should be [-len(input.shape), len(input.shape)-1]. + TypeError: If `axis` is not an int. + ValueError: If `axis` is out of range, `axis` should be [-len(input.shape), len(input.shape)-1]. Supported Platforms: - ``CPU`` + ``GPU`` ``CPU`` Examples: >>> import mindspore >>> import numpy as np >>> from mindspore import Tensor >>> import mindspore.ops as ops - >>> cummax = ops.Cummax(dim=0) + >>> cummax = ops.Cummax(axis=0) >>> x = Tensor(np.array([[3, 4, 6, 10], [1, 6, 7, 9], [4, 3, 8, 7], [1, 3, 7, 9]]).astype(np.float32)) >>> output = cummax(x) >>> print(output) @@ -7570,9 +7569,9 @@ class Cummax(Primitive): """ @prim_attr_register - def __init__(self, dim=-1): + def __init__(self, axis): """Initialize Cummax""" - validator.check_value_type("dim", dim, [int], self.name) + validator.check_value_type("axis", axis, [int], self.name) self.init_prim_io_names(inputs=['x'], outputs=['y', 'indices']) diff --git a/tests/st/ops/cpu/test_cum_min_max_op.py b/tests/st/ops/cpu/test_cum_minmax_op.py similarity index 54% rename from tests/st/ops/cpu/test_cum_min_max_op.py rename to tests/st/ops/cpu/test_cum_minmax_op.py index a213a4fa139..c8087a36449 100644 --- a/tests/st/ops/cpu/test_cum_min_max_op.py +++ b/tests/st/ops/cpu/test_cum_minmax_op.py @@ -16,23 +16,40 @@ import pytest import numpy as np import mindspore.context as context +import mindspore.nn as nn import mindspore.ops as ops from mindspore import Tensor +from mindspore.ops.operations import _inner_ops as inner -def cummin_compare(x, expected, axis, data_type): +class Net(nn.Cell): + def __init__(self, op, axis): + super(Net, self).__init__() + if op == "Cummin": + self.op = inner.Cummin(axis) + elif op == "Cummax": + self.op = ops.Cummax(axis) + else: + raise ValueError("op value error.") + + def construct(self, x): + return self.op(x) + + +def cum_minmax_compare(op, x, expected, axis, data_type): + net = Net(op, axis) x = np.array(x).astype(data_type) expected = (np.array(expected[0]).astype(data_type), np.array(expected[1]).astype(data_type)) # Pynative context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") - output = ops.cummin(Tensor(x), axis=axis) + output = net(Tensor(x)) assert np.allclose(output[0].asnumpy(), expected[0], equal_nan=True) assert np.allclose(output[1].asnumpy(), expected[1]) # Graph context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - output = ops.cummin(Tensor(x), axis=axis) + output = net(Tensor(x)) assert np.allclose(output[0].asnumpy(), expected[0], equal_nan=True) assert np.allclose(output[1].asnumpy(), expected[1]) @@ -41,12 +58,13 @@ def cummin_compare(x, expected, axis, data_type): @pytest.mark.env_onecard @pytest.mark.platform_x86_cpu @pytest.mark.parametrize("data_type", [np.uint8, np.int8, np.int32, np.float16, np.float32]) -def test_cumop_multi_dims(data_type): +def test_cummin_multi_dims(data_type): """ Feature: Op Cummin Description: test Cummin operator with multiple dimension. Expectation: the result match expectation. """ + op = "Cummin" axis = 1 x = [[[9, 10, 0, 0, 2], [5, 4, 1, 9, 3], [5, 0, 3, 7, 5], [10, 4, 5, 4, 9]], [[5, 0, 8, 8, 10], [9, 0, 1, 5, 2], [9, 5, 8, 9, 7], [10, 9, 2, 2, 2]], @@ -58,17 +76,42 @@ def test_cumop_multi_dims(data_type): [[0, 0, 0, 0, 0], [0, 1, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 3, 3]], [[0, 0, 0, 0, 0], [0, 0, 1, 1, 0], [2, 0, 2, 1, 0], [3, 0, 2, 1, 3]]]) - cummin_compare(x, cummin_output, axis, data_type) + cum_minmax_compare(op, x, cummin_output, axis, data_type) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_x86_cpu +@pytest.mark.parametrize("data_type", [np.uint8, np.uint32, np.int8, np.int32, np.int64, np.float16, np.float32]) +def test_cummax_multi_dims(data_type): + """ + Feature: Op Cummax + Description: test Cummax operator with multiple dimension. + Expectation: the result match expectation. + """ + op = "Cummax" + axis = 1 + x = [[[6, 11, 4, 9, 15], [1, 2, 14, 13, 15], [15, 10, 6, 13, 6], [9, 4, 11, 10, 11]], + [[5, 1, 5, 13, 7], [19, 4, 14, 11, 14], [5, 15, 6, 20, 0], [6, 2, 4, 15, 16]], + [[17, 4, 16, 13, 3], [15, 15, 14, 9, 13], [11, 0, 2, 19, 17], [20, 18, 13, 15, 17]]] + cummax_output = ([[[6, 11, 4, 9, 15], [6, 11, 14, 13, 15], [15, 11, 14, 13, 15], [15, 11, 14, 13, 15]], + [[5, 1, 5, 13, 7], [19, 4, 14, 13, 14], [19, 15, 14, 20, 14], [19, 15, 14, 20, 16]], + [[17, 4, 16, 13, 3], [17, 15, 16, 13, 13], [17, 15, 16, 19, 17], [20, 18, 16, 19, 17]]], + [[[0, 0, 0, 0, 0], [0, 0, 1, 1, 1], [2, 0, 1, 2, 1], [2, 0, 1, 2, 1]], + [[0, 0, 0, 0, 0], [1, 1, 1, 0, 1], [1, 2, 1, 2, 1], [1, 2, 1, 2, 3]], + [[0, 0, 0, 0, 0], [0, 1, 0, 0, 1], [0, 1, 0, 2, 2], [3, 3, 0, 2, 3]]]) + + cum_minmax_compare(op, x, cummax_output, axis, data_type) @pytest.mark.level0 @pytest.mark.env_onecard @pytest.mark.platform_x86_cpu @pytest.mark.parametrize("data_type", [np.float16, np.float32]) -def test_cumop_nan(data_type): +def test_cum_minmax_nan(data_type): """ - Feature: Op Cummin - Description: test Cummin operator with nan input. + Feature: Op Cummin/Cummax + Description: test Cummin/Cummax operator with nan input. Expectation: the result match expectation. """ inf = float('inf') @@ -76,5 +119,7 @@ def test_cumop_nan(data_type): axis = 0 x = [4, inf, 1.5, -inf, 0, nan, 1] cummin_output = ([4, 4, 1.5, -inf, -inf, nan, nan], [0, 0, 2, 3, 3, 5, 5]) + cummax_output = ([4, inf, inf, inf, inf, nan, nan], [0, 1, 1, 1, 1, 5, 5]) - cummin_compare(x, cummin_output, axis, data_type) + cum_minmax_compare("Cummin", x, cummin_output, axis, data_type) + cum_minmax_compare("Cummax", x, cummax_output, axis, data_type) diff --git a/tests/st/ops/gpu/test_cum_min_max_op.py b/tests/st/ops/gpu/test_cum_minmax_op.py similarity index 53% rename from tests/st/ops/gpu/test_cum_min_max_op.py rename to tests/st/ops/gpu/test_cum_minmax_op.py index 944320f4590..cd0fb53b6b2 100644 --- a/tests/st/ops/gpu/test_cum_min_max_op.py +++ b/tests/st/ops/gpu/test_cum_minmax_op.py @@ -15,24 +15,41 @@ import pytest import numpy as np -import mindspore.context as context +import mindspore.nn as nn import mindspore.ops as ops +import mindspore.context as context from mindspore import Tensor +from mindspore.ops.operations import _inner_ops as inner -def cummin_compare(x, expected, axis, data_type): +class Net(nn.Cell): + def __init__(self, op, axis): + super(Net, self).__init__() + if op == "Cummin": + self.op = inner.Cummin(axis) + elif op == "Cummax": + self.op = ops.Cummax(axis) + else: + raise ValueError("op value error.") + + def construct(self, x): + return self.op(x) + + +def cum_minmax_compare(op, x, expected, axis, data_type): + net = Net(op, axis) x = np.array(x).astype(data_type) expected = (np.array(expected[0]).astype(data_type), np.array(expected[1]).astype(data_type)) # Pynative context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - output = ops.cummin(Tensor(x), axis=axis) + output = net(Tensor(x)) assert np.allclose(output[0].asnumpy(), expected[0], equal_nan=True) assert np.allclose(output[1].asnumpy(), expected[1]) # Graph context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - output = ops.cummin(Tensor(x), axis=axis) + output = net(Tensor(x)) assert np.allclose(output[0].asnumpy(), expected[0], equal_nan=True) assert np.allclose(output[1].asnumpy(), expected[1]) @@ -40,13 +57,14 @@ def cummin_compare(x, expected, axis, data_type): @pytest.mark.level0 @pytest.mark.env_onecard @pytest.mark.platform_x86_gpu_training -@pytest.mark.parametrize("data_type", [np.int32, np.float16, np.float32]) +@pytest.mark.parametrize("data_type", [np.uint8, np.int8, np.int32, np.float16, np.float32]) def test_cummin_multi_dims(data_type): """ Feature: Op Cummin Description: test Cummin operator with multiple dimension. Expectation: the result match expectation. """ + op = "Cummin" axis = 1 x = [[[14, 19, 18, 11, 6], [1, 4, 18, 6, 1], [15, 13, 12, 9, 19]], [[16, 16, 17, 10, 15], [9, 7, 10, 9, 4], [6, 14, 16, 3, 2]], @@ -59,17 +77,45 @@ def test_cummin_multi_dims(data_type): [[[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [1, 1, 2, 1, 1]], [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 1, 1, 2, 2]], [[0, 0, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 1, 0, 0]], [[0, 0, 0, 0, 0], [1, 0, 1, 0, 0], [2, 0, 1, 2, 0]]]) - cummin_compare(x, cummin_output, axis, data_type) + cum_minmax_compare(op, x, cummin_output, axis, data_type) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize("data_type", [np.uint8, np.uint32, np.int8, np.int32, np.int64, np.float16, np.float32]) +def test_cummax_multi_dims(data_type): + """ + Feature: Op Cummax + Description: test Cummax operator with multiple dimension. + Expectation: the result match expectation. + """ + op = "Cummax" + axis = 1 + x = [[[11, 11, 1, 7, 11], [1, 8, 18, 0, 9], [12, 1, 16, 11, 8]], + [[18, 8, 10, 17, 14], [4, 20, 8, 20, 11], [14, 1, 8, 5, 16]], + [[6, 13, 19, 14, 8], [17, 19, 11, 0, 7], [18, 4, 13, 14, 16]], + [[10, 7, 7, 7, 19], [15, 0, 15, 5, 14], [9, 7, 10, 4, 14]]] + cummax_output = ([[[11, 11, 1, 7, 11], [11, 11, 18, 7, 11], [12, 11, 18, 11, 11]], + [[18, 8, 10, 17, 14], [18, 20, 10, 20, 14], [18, 20, 10, 20, 16]], + [[6, 13, 19, 14, 8], [17, 19, 19, 14, 8], [18, 19, 19, 14, 16]], + [[10, 7, 7, 7, 19], [15, 7, 15, 7, 19], [15, 7, 15, 7, 19]]], + [[[0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [2, 0, 1, 2, 0]], + [[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 1, 0, 1, 2]], + [[0, 0, 0, 0, 0], [1, 1, 0, 0, 0], [2, 1, 0, 2, 2]], + [[0, 0, 0, 0, 0], [1, 0, 1, 0, 0], [1, 2, 1, 0, 0]]]) + + cum_minmax_compare(op, x, cummax_output, axis, data_type) @pytest.mark.level0 @pytest.mark.env_onecard @pytest.mark.platform_x86_gpu_training @pytest.mark.parametrize("data_type", [np.float16, np.float32]) -def test_cummin_nan(data_type): +def test_cumminmax_nan(data_type): """ - Feature: Op Cummin - Description: test Cummin operator with nan input. + Feature: Op Cummin/Cummax + Description: test Cummin/Cummax operator with nan input. Expectation: the result match expectation. """ inf = float('inf') @@ -77,5 +123,7 @@ def test_cummin_nan(data_type): axis = 0 x = [4, inf, 1.5, -inf, 0, nan, 1] cummin_output = ([4, 4, 1.5, -inf, -inf, nan, nan], [0, 0, 2, 3, 3, 5, 5]) + cummax_output = ([4, inf, inf, inf, inf, nan, nan], [0, 1, 1, 1, 1, 5, 5]) - cummin_compare(x, cummin_output, axis, data_type) + cum_minmax_compare("Cummin", x, cummin_output, axis, data_type) + cum_minmax_compare("Cummax", x, cummax_output, axis, data_type) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index a3cee81e44e..d663e212c1e 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -3017,7 +3017,7 @@ test_case_array_ops = [ 'skip': ['backward'], }), ('Cummax', { - 'block': P.Cummax(dim=-1), + 'block': P.Cummax(axis=0), 'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])], 'skip': ['backward'], }),