!40230 Add dynamic shape support for PReLU operator.

Merge pull request !40230 from hezhenhao1/add_prelu
This commit is contained in:
i-robot 2022-08-12 01:27:44 +00:00 committed by Gitee
commit d9e93729cd
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 267 additions and 115 deletions

View File

@ -18,14 +18,70 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
PReLU,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PReLUGpuKernelMod, half)
bool PReLUGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
constexpr size_t input_num = 2;
constexpr size_t output_num = 1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
kernel_name_ = base_operator->GetPrim()->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
MS_REG_GPU_KERNEL_ONE(
PReLU,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PReLUGpuKernelMod, float)
int PReLUGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
auto input_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
auto weight_shape = LongVecToSizeVec(inputs[kIndex1]->GetShapeVector());
is_null_input_ =
CHECK_SHAPE_NULL(input_shape, kernel_name_, "x") || CHECK_SHAPE_NULL(weight_shape, kernel_name_, "weight");
input_length_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<>());
per_channel_length_ =
input_shape.size() <= 1 ? input_length_ : input_length_ / (input_shape[kIndex0] * input_shape[kIndex1]);
weight_length_ = weight_shape[0];
return KRET_OK;
}
template <typename T>
bool PReLUGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (is_null_input_) {
return true;
}
auto input = GetDeviceAddress<T>(inputs, 0);
auto weight = GetDeviceAddress<T>(inputs, 1);
auto output = GetDeviceAddress<T>(outputs, 0);
CalPReLU(input_length_, weight_length_, per_channel_length_, input, weight, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
std::vector<std::pair<KernelAttr, PReLUGpuKernelMod::PReLULaunchFunc>> PReLUGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&PReLUGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&PReLUGpuKernelMod::LaunchKernel<float>},
};
std::vector<KernelAttr> PReLUGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, PReLUGpuKernelMod::PReLULaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, PReLU, PReLUGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <vector>
#include <map>
#include <utility>
#include <functional>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
@ -27,91 +28,34 @@
namespace mindspore {
namespace kernel {
template <typename T>
class PReLUGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class PReLUGpuKernelMod : public NativeGpuKernelMod {
public:
PReLUGpuKernelMod() = default;
~PReLUGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
auto *input = GetDeviceAddress<T>(inputs, 0);
auto *weight = GetDeviceAddress<T>(inputs, 1);
auto *output = GetDeviceAddress<T>(outputs, 0);
CalPReLU(input_length_, weight_length_, per_channel_length_, input, weight, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
ResetResource();
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2, but got " << input_num;
}
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num;
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
is_null_input_ =
CHECK_SHAPE_NULL(input_shape, kernel_name, "x") || CHECK_SHAPE_NULL(weight_shape, kernel_name, "weight");
if (is_null_input_) {
InitSizeLists();
return true;
}
input_length_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<>());
size_t input_rank = input_shape.size();
int64_t channel_num;
if (input_rank == 0) {
channel_num = 1;
per_channel_length_ = 1;
} else if (input_rank == 1) {
channel_num = 1;
per_channel_length_ = input_shape[0];
} else {
channel_num = input_shape[1];
per_channel_length_ = std::accumulate(input_shape.begin() + 2, input_shape.end(), size_t(1), std::multiplies<>());
}
if (weight_shape.size() != 1 || (weight_shape[0] != 1 && weight_shape[0] != channel_num)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of weight must be equal to 1 and "
<< "weight.shape[0] must be equal to 1 or the channel number, but got the dimension of "
<< "weight: " << weight_shape.size() << ", weight.shape[0]: " << weight_shape[0]
<< ", the channel num: " << channel_num;
}
weight_length_ = LongToSizeClipNeg(weight_shape[0]);
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_length_ = 0;
weight_length_ = 0;
per_channel_length_ = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
void InitSizeLists() override {
size_t data_size = sizeof(T);
input_size_list_.push_back(input_length_ * data_size);
input_size_list_.push_back(weight_length_ * data_size);
output_size_list_.push_back(input_length_ * data_size);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
using PReLULaunchFunc = std::function<bool(PReLUGpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, PReLULaunchFunc>> func_list_;
PReLULaunchFunc kernel_func_;
bool is_null_input_{false};
size_t input_length_{0};
size_t weight_length_{0};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,10 +20,75 @@
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "utils/ms_context.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
bool IsAscend() {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
return context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice;
}
abstract::ShapePtr PReLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
auto weight_shape_ptr = input_args[kInputIndex1]->BuildShape();
// Dynamic rank.
if (x_shape_ptr->IsDimUnknown() || weight_shape_ptr->IsDimUnknown()) {
return x_shape_ptr->cast<abstract::ShapePtr>();
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(weight_shape_ptr)[kShape];
auto x_rank = x_shape.size();
auto weight_rank = weight_shape.size();
auto channel_num = x_rank <= 1 ? 1 : x_shape[1];
if (IsAscend() && x_rank <= 1) {
MS_EXCEPTION(ValueError)
<< "For '" << prim_name
<< "', the dimension of 'x' can not be 0-D or 1-D when the platform is \"Ascend\", but got dimension of 'x' is "
<< x_rank << ".";
}
(void)CheckAndConvertUtils::CheckInteger("dimension of 'weight'", SizeToLong(weight_rank), kEqual, 1, prim_name);
if (weight_shape[0] != 1 && weight_shape[0] != channel_num) {
MS_EXCEPTION(ValueError)
<< "For '" << prim_name
<< "', the first dimension of 'weight' must be (1,) or it must be equal to number of channels: " << channel_num
<< ", but got " << weight_shape << ".";
}
return x_shape_ptr->cast<abstract::ShapePtr>();
}
TypePtr PReLUInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto x_type = input_args[kInputIndex0]->BuildType();
auto weight_type = input_args[kInputIndex1]->BuildType();
auto valid_types = {kFloat16, kFloat32};
if (IsAscend()) {
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_type, valid_types, prim_name);
} else {
std::map<std::string, TypePtr> args;
(void)args.emplace("x", x_type);
(void)args.emplace("weight", weight_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
}
return x_type;
}
MIND_API_OPERATOR_IMPL(PReLU, BaseOperator);
REGISTER_PRIMITIVE_C(kNamePReLU, PReLU);
AbstractBasePtr PReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto type = PReLUInferType(primitive, input_args);
auto shape = PReLUInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(PReLU, prim::kPrimPRelu, PReLUInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,8 +32,8 @@ class MIND_API PReLU : public BaseOperator {
public:
MIND_API_BASE_MEMBER(PReLU);
/// \brief Constructor.
PReLU() : BaseOperator(kNamePReLU) { InitIOName({"x"}, {"y"}); }
explicit PReLU(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"y"}); }
PReLU() : BaseOperator(kNamePReLU) { InitIOName({"x", "weight"}, {"output"}); }
explicit PReLU(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "weight"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.PReLU for the inputs.
void Init() const {}
};

View File

@ -417,6 +417,7 @@ from .unpack import _unpack_tbe
from .unpack_ds import _unpack_ds_tbe
from .scatter_update import _scatter_update_tbe
from .prelu import _prelu_tbe
from .prelu_ds import _prelu_ds_tbe
from .prelu_grad import _prelu_grad_tbe
from .binary_cross_entropy_ds import _binary_cross_entropy_ds_tbe
from .binary_cross_entropy import _binary_cross_entropy_tbe

View File

@ -0,0 +1,38 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PReLU op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
prelu_ds_op_info = TBERegOp("PReLU") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("prelu.so") \
.compute_cost(10) \
.kernel_name("prelu") \
.partial_flag(True) \
.dynamic_shape(True) \
.input(0, "x", False, "required", "all") \
.input(1, "weight", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()
@op_info_register(prelu_ds_op_info)
def _prelu_ds_tbe():
"""PReLU TBE register"""
return

View File

@ -4107,35 +4107,7 @@ class PReLU(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, input_x_shape, weight_shape):
input_x_dim = len(input_x_shape)
if input_x_dim in (0, 1):
if context.get_context("device_target") == "Ascend":
raise ValueError(f"For '{self.name}', the dimension of 'x' can not be 0-D or 1-D when the platform is "
f"\"Ascend\", but got dimension of 'x' is {input_x_dim}.")
channel_num = 1
else:
channel_num = input_x_shape[1]
weight_dim = len(weight_shape)
if weight_dim != 1:
raise ValueError(f"For '{self.name}', the dimension of 'weight' must be 1, while got {weight_dim}.")
if weight_shape[0] != 1 and weight_shape[0] != channel_num:
raise ValueError(f"For '{self.name}', the first dimension of 'weight' must be (1,) or "
f"it must be equal to number of channels: {channel_num}, but got {weight_shape}")
return input_x_shape
def infer_dtype(self, input_x_dtype, weight_dtype):
valid_dtypes = (mstype.float16, mstype.float32)
args = {"input_x": input_x_dtype, "weight": weight_dtype}
if context.get_context("device_target") == "GPU":
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
else:
validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name)
return input_x_dtype
self.init_prim_io_names(inputs=['x', 'weight'], outputs=['output'])
class LSTM(PrimitiveWithInfer):

View File

@ -0,0 +1,76 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
class GetDynamicInputNet(nn.Cell):
def __init__(self, axis=0):
super(GetDynamicInputNet, self).__init__()
self.unique = P.Unique()
self.gather = P.Gather()
self.cast = P.Cast()
self.axis = axis
def construct(self, x, indices):
unique_indices, _ = self.unique(indices)
x_dtype = x.dtype
x = self.cast(x, mstype.float32)
real_x = self.gather(x, unique_indices, self.axis)
return self.cast(real_x, x_dtype)
class PReLUDyNet(nn.Cell):
def __init__(self):
super(PReLUDyNet, self).__init__()
self.op = P.PReLU()
self.transformer = GetDynamicInputNet()
def construct(self, indices, x, weight):
real_x = self.transformer(x, indices)
out = self.op(real_x, weight)
return real_x, weight, out
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("data_shape", [((8, 6, 7), (1,))])
@pytest.mark.parametrize("data_type", [np.float16, np.float32])
def test_dynamic_shape_prelu(data_shape, data_type):
"""
Feature: PReLU DynamicShape.
Description: Test case of dynamic shape for PReLU operator.
Expectation: success.
"""
context.set_context(mode=context.GRAPH_MODE)
x_shape, weight_shape = data_shape
x = Tensor(np.random.random(size=x_shape).astype(data_type))
weight = Tensor(np.random.random(size=weight_shape).astype(data_type))
indices = Tensor(np.random.randint(0, x_shape[0], size=(5,)).astype(np.int32))
dy_net = PReLUDyNet()
real_x, real_weight, output = dy_net(indices, x, weight)
x, weight = real_x.asnumpy(), real_weight.asnumpy()
expect = np.where(x >= 0, x, weight * x)
np.testing.assert_allclose(expect, output.asnumpy())