diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/prelu_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/prelu_gpu_kernel.cc index 75db0dbd95b..464e2c5f87f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/prelu_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/prelu_gpu_kernel.cc @@ -18,14 +18,70 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE( - PReLU, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - PReLUGpuKernelMod, half) +bool PReLUGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + constexpr size_t input_num = 2; + constexpr size_t output_num = 1; + CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_); + kernel_name_ = base_operator->GetPrim()->name(); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} -MS_REG_GPU_KERNEL_ONE( - PReLU, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - PReLUGpuKernelMod, float) +int PReLUGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) { + if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) { + return ret; + } + auto input_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector()); + auto weight_shape = LongVecToSizeVec(inputs[kIndex1]->GetShapeVector()); + is_null_input_ = + CHECK_SHAPE_NULL(input_shape, kernel_name_, "x") || CHECK_SHAPE_NULL(weight_shape, kernel_name_, "weight"); + input_length_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<>()); + per_channel_length_ = + input_shape.size() <= 1 ? input_length_ : input_length_ / (input_shape[kIndex0] * input_shape[kIndex1]); + weight_length_ = weight_shape[0]; + return KRET_OK; +} + +template +bool PReLUGpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } + auto input = GetDeviceAddress(inputs, 0); + auto weight = GetDeviceAddress(inputs, 1); + auto output = GetDeviceAddress(outputs, 0); + + CalPReLU(input_length_, weight_length_, per_channel_length_, input, weight, output, + reinterpret_cast(stream_ptr)); + return true; +} + +std::vector> PReLUGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &PReLUGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &PReLUGpuKernelMod::LaunchKernel}, +}; + +std::vector PReLUGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform( + func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, PReLU, PReLUGpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/prelu_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/prelu_gpu_kernel.h index 90259b2e5b1..8b53ffaf598 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/prelu_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/prelu_gpu_kernel.h @@ -19,6 +19,7 @@ #include #include +#include #include #include "plugin/device/gpu/kernel/gpu_kernel.h" @@ -27,91 +28,34 @@ namespace mindspore { namespace kernel { -template -class PReLUGpuKernelMod : public DeprecatedNativeGpuKernelMod { +class PReLUGpuKernelMod : public NativeGpuKernelMod { public: PReLUGpuKernelMod() = default; ~PReLUGpuKernelMod() override = default; - bool Launch(const std::vector &inputs, const std::vector &, + bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - auto *input = GetDeviceAddress(inputs, 0); - auto *weight = GetDeviceAddress(inputs, 1); - auto *output = GetDeviceAddress(outputs, 0); - - CalPReLU(input_length_, weight_length_, per_channel_length_, input, weight, output, - reinterpret_cast(stream_ptr)); - return true; + return kernel_func_(this, inputs, workspace, outputs, stream_ptr); } - bool Init(const CNodePtr &kernel_node) override { - auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); - kernel_node_ = kernel_node; - ResetResource(); - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2, but got " << input_num; - } - size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num; - } + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; - auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - is_null_input_ = - CHECK_SHAPE_NULL(input_shape, kernel_name, "x") || CHECK_SHAPE_NULL(weight_shape, kernel_name, "weight"); - if (is_null_input_) { - InitSizeLists(); - return true; - } - input_length_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<>()); - size_t input_rank = input_shape.size(); - int64_t channel_num; - if (input_rank == 0) { - channel_num = 1; - per_channel_length_ = 1; - } else if (input_rank == 1) { - channel_num = 1; - per_channel_length_ = input_shape[0]; - } else { - channel_num = input_shape[1]; - per_channel_length_ = std::accumulate(input_shape.begin() + 2, input_shape.end(), size_t(1), std::multiplies<>()); - } - - if (weight_shape.size() != 1 || (weight_shape[0] != 1 && weight_shape[0] != channel_num)) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of weight must be equal to 1 and " - << "weight.shape[0] must be equal to 1 or the channel number, but got the dimension of " - << "weight: " << weight_shape.size() << ", weight.shape[0]: " << weight_shape[0] - << ", the channel num: " << channel_num; - } - weight_length_ = LongToSizeClipNeg(weight_shape[0]); - InitSizeLists(); - return true; - } - - void ResetResource() noexcept override { - input_length_ = 0; - weight_length_ = 0; - per_channel_length_ = 0; - is_null_input_ = false; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - } + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; protected: - void InitSizeLists() override { - size_t data_size = sizeof(T); - input_size_list_.push_back(input_length_ * data_size); - input_size_list_.push_back(weight_length_ * data_size); - output_size_list_.push_back(input_length_ * data_size); - } + std::vector GetOpSupport() override; private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr); + + using PReLULaunchFunc = std::function &, + const std::vector &, const std::vector &, void *)>; + static std::vector> func_list_; + PReLULaunchFunc kernel_func_; bool is_null_input_{false}; size_t input_length_{0}; size_t weight_length_{0}; diff --git a/mindspore/core/ops/prelu.cc b/mindspore/core/ops/prelu.cc index 9aa3ae4e7cb..3117a078c9d 100644 --- a/mindspore/core/ops/prelu.cc +++ b/mindspore/core/ops/prelu.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,10 +20,75 @@ #include "ops/primitive_c.h" #include "utils/check_convert_utils.h" #include "mindapi/src/helper.h" +#include "utils/ms_context.h" +#include "ops/op_utils.h" namespace mindspore { namespace ops { +bool IsAscend() { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + return context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice; +} + +abstract::ShapePtr PReLUInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + auto x_shape_ptr = input_args[kInputIndex0]->BuildShape(); + auto weight_shape_ptr = input_args[kInputIndex1]->BuildShape(); + // Dynamic rank. + if (x_shape_ptr->IsDimUnknown() || weight_shape_ptr->IsDimUnknown()) { + return x_shape_ptr->cast(); + } + + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape]; + auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(weight_shape_ptr)[kShape]; + auto x_rank = x_shape.size(); + auto weight_rank = weight_shape.size(); + auto channel_num = x_rank <= 1 ? 1 : x_shape[1]; + if (IsAscend() && x_rank <= 1) { + MS_EXCEPTION(ValueError) + << "For '" << prim_name + << "', the dimension of 'x' can not be 0-D or 1-D when the platform is \"Ascend\", but got dimension of 'x' is " + << x_rank << "."; + } + (void)CheckAndConvertUtils::CheckInteger("dimension of 'weight'", SizeToLong(weight_rank), kEqual, 1, prim_name); + if (weight_shape[0] != 1 && weight_shape[0] != channel_num) { + MS_EXCEPTION(ValueError) + << "For '" << prim_name + << "', the first dimension of 'weight' must be (1,) or it must be equal to number of channels: " << channel_num + << ", but got " << weight_shape << "."; + } + return x_shape_ptr->cast(); +} + +TypePtr PReLUInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + auto x_type = input_args[kInputIndex0]->BuildType(); + auto weight_type = input_args[kInputIndex1]->BuildType(); + auto valid_types = {kFloat16, kFloat32}; + + if (IsAscend()) { + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_type, valid_types, prim_name); + } else { + std::map args; + (void)args.emplace("x", x_type); + (void)args.emplace("weight", weight_type); + (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); + } + return x_type; +} + MIND_API_OPERATOR_IMPL(PReLU, BaseOperator); -REGISTER_PRIMITIVE_C(kNamePReLU, PReLU); +AbstractBasePtr PReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + constexpr size_t input_num = 2; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + auto type = PReLUInferType(primitive, input_args); + auto shape = PReLUInferShape(primitive, input_args); + return abstract::MakeAbstract(shape, type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(PReLU, prim::kPrimPRelu, PReLUInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/prelu.h b/mindspore/core/ops/prelu.h index 2147a21bb3e..94d1bd5c7f1 100644 --- a/mindspore/core/ops/prelu.h +++ b/mindspore/core/ops/prelu.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,8 +32,8 @@ class MIND_API PReLU : public BaseOperator { public: MIND_API_BASE_MEMBER(PReLU); /// \brief Constructor. - PReLU() : BaseOperator(kNamePReLU) { InitIOName({"x"}, {"y"}); } - explicit PReLU(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"y"}); } + PReLU() : BaseOperator(kNamePReLU) { InitIOName({"x", "weight"}, {"output"}); } + explicit PReLU(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "weight"}, {"output"}); } /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.PReLU for the inputs. void Init() const {} }; diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py index 9a3ecd41b63..473e3263d6c 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py @@ -417,6 +417,7 @@ from .unpack import _unpack_tbe from .unpack_ds import _unpack_ds_tbe from .scatter_update import _scatter_update_tbe from .prelu import _prelu_tbe +from .prelu_ds import _prelu_ds_tbe from .prelu_grad import _prelu_grad_tbe from .binary_cross_entropy_ds import _binary_cross_entropy_ds_tbe from .binary_cross_entropy import _binary_cross_entropy_tbe diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/prelu_ds.py b/mindspore/python/mindspore/ops/_op_impl/tbe/prelu_ds.py new file mode 100644 index 00000000000..59d7932e949 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/prelu_ds.py @@ -0,0 +1,38 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""PReLU op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +prelu_ds_op_info = TBERegOp("PReLU") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("prelu.so") \ + .compute_cost(10) \ + .kernel_name("prelu") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "weight", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .is_dynamic_format(True) \ + .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ + .get_op_info() + + +@op_info_register(prelu_ds_op_info) +def _prelu_ds_tbe(): + """PReLU TBE register""" + return diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index cbee88dbec3..19ba62486c3 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -4107,35 +4107,7 @@ class PReLU(PrimitiveWithInfer): @prim_attr_register def __init__(self): - pass - - def infer_shape(self, input_x_shape, weight_shape): - input_x_dim = len(input_x_shape) - if input_x_dim in (0, 1): - if context.get_context("device_target") == "Ascend": - raise ValueError(f"For '{self.name}', the dimension of 'x' can not be 0-D or 1-D when the platform is " - f"\"Ascend\", but got dimension of 'x' is {input_x_dim}.") - channel_num = 1 - else: - channel_num = input_x_shape[1] - - weight_dim = len(weight_shape) - if weight_dim != 1: - raise ValueError(f"For '{self.name}', the dimension of 'weight' must be 1, while got {weight_dim}.") - if weight_shape[0] != 1 and weight_shape[0] != channel_num: - raise ValueError(f"For '{self.name}', the first dimension of 'weight' must be (1,) or " - f"it must be equal to number of channels: {channel_num}, but got {weight_shape}") - return input_x_shape - - def infer_dtype(self, input_x_dtype, weight_dtype): - valid_dtypes = (mstype.float16, mstype.float32) - args = {"input_x": input_x_dtype, "weight": weight_dtype} - if context.get_context("device_target") == "GPU": - validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) - else: - validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name) - validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name) - return input_x_dtype + self.init_prim_io_names(inputs=['x', 'weight'], outputs=['output']) class LSTM(PrimitiveWithInfer): diff --git a/tests/st/ops/dynamic_shape/test_dynamic_shape_prelu.py b/tests/st/ops/dynamic_shape/test_dynamic_shape_prelu.py new file mode 100644 index 00000000000..cb4c7e00210 --- /dev/null +++ b/tests/st/ops/dynamic_shape/test_dynamic_shape_prelu.py @@ -0,0 +1,76 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.nn as nn +from mindspore import Tensor, context +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + + +class GetDynamicInputNet(nn.Cell): + def __init__(self, axis=0): + super(GetDynamicInputNet, self).__init__() + self.unique = P.Unique() + self.gather = P.Gather() + self.cast = P.Cast() + self.axis = axis + + def construct(self, x, indices): + unique_indices, _ = self.unique(indices) + x_dtype = x.dtype + x = self.cast(x, mstype.float32) + real_x = self.gather(x, unique_indices, self.axis) + return self.cast(real_x, x_dtype) + + +class PReLUDyNet(nn.Cell): + def __init__(self): + super(PReLUDyNet, self).__init__() + self.op = P.PReLU() + self.transformer = GetDynamicInputNet() + + def construct(self, indices, x, weight): + real_x = self.transformer(x, indices) + out = self.op(real_x, weight) + return real_x, weight, out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize("data_shape", [((8, 6, 7), (1,))]) +@pytest.mark.parametrize("data_type", [np.float16, np.float32]) +def test_dynamic_shape_prelu(data_shape, data_type): + """ + Feature: PReLU DynamicShape. + Description: Test case of dynamic shape for PReLU operator. + Expectation: success. + """ + context.set_context(mode=context.GRAPH_MODE) + x_shape, weight_shape = data_shape + x = Tensor(np.random.random(size=x_shape).astype(data_type)) + weight = Tensor(np.random.random(size=weight_shape).astype(data_type)) + indices = Tensor(np.random.randint(0, x_shape[0], size=(5,)).astype(np.int32)) + + dy_net = PReLUDyNet() + real_x, real_weight, output = dy_net(indices, x, weight) + x, weight = real_x.asnumpy(), real_weight.asnumpy() + expect = np.where(x >= 0, x, weight * x) + + np.testing.assert_allclose(expect, output.asnumpy())