forked from mindspore-Ecosystem/mindspore
!40230 Add dynamic shape support for PReLU operator.
Merge pull request !40230 from hezhenhao1/add_prelu
This commit is contained in:
commit
d9e93729cd
|
@ -18,14 +18,70 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
PReLU,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
PReLUGpuKernelMod, half)
|
||||
bool PReLUGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
constexpr size_t input_num = 2;
|
||||
constexpr size_t output_num = 1;
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
PReLU,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
PReLUGpuKernelMod, float)
|
||||
int PReLUGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto input_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
|
||||
auto weight_shape = LongVecToSizeVec(inputs[kIndex1]->GetShapeVector());
|
||||
is_null_input_ =
|
||||
CHECK_SHAPE_NULL(input_shape, kernel_name_, "x") || CHECK_SHAPE_NULL(weight_shape, kernel_name_, "weight");
|
||||
input_length_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<>());
|
||||
per_channel_length_ =
|
||||
input_shape.size() <= 1 ? input_length_ : input_length_ / (input_shape[kIndex0] * input_shape[kIndex1]);
|
||||
weight_length_ = weight_shape[0];
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool PReLUGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto input = GetDeviceAddress<T>(inputs, 0);
|
||||
auto weight = GetDeviceAddress<T>(inputs, 1);
|
||||
auto output = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
CalPReLU(input_length_, weight_length_, per_channel_length_, input, weight, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, PReLUGpuKernelMod::PReLULaunchFunc>> PReLUGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&PReLUGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&PReLUGpuKernelMod::LaunchKernel<float>},
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> PReLUGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, PReLUGpuKernelMod::PReLULaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, PReLU, PReLUGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
|
@ -27,91 +28,34 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class PReLUGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class PReLUGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
PReLUGpuKernelMod() = default;
|
||||
~PReLUGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto *input = GetDeviceAddress<T>(inputs, 0);
|
||||
auto *weight = GetDeviceAddress<T>(inputs, 1);
|
||||
auto *output = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
CalPReLU(input_length_, weight_length_, per_channel_length_, input, weight, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
ResetResource();
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2, but got " << input_num;
|
||||
}
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num;
|
||||
}
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
is_null_input_ =
|
||||
CHECK_SHAPE_NULL(input_shape, kernel_name, "x") || CHECK_SHAPE_NULL(weight_shape, kernel_name, "weight");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
input_length_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<>());
|
||||
size_t input_rank = input_shape.size();
|
||||
int64_t channel_num;
|
||||
if (input_rank == 0) {
|
||||
channel_num = 1;
|
||||
per_channel_length_ = 1;
|
||||
} else if (input_rank == 1) {
|
||||
channel_num = 1;
|
||||
per_channel_length_ = input_shape[0];
|
||||
} else {
|
||||
channel_num = input_shape[1];
|
||||
per_channel_length_ = std::accumulate(input_shape.begin() + 2, input_shape.end(), size_t(1), std::multiplies<>());
|
||||
}
|
||||
|
||||
if (weight_shape.size() != 1 || (weight_shape[0] != 1 && weight_shape[0] != channel_num)) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of weight must be equal to 1 and "
|
||||
<< "weight.shape[0] must be equal to 1 or the channel number, but got the dimension of "
|
||||
<< "weight: " << weight_shape.size() << ", weight.shape[0]: " << weight_shape[0]
|
||||
<< ", the channel num: " << channel_num;
|
||||
}
|
||||
weight_length_ = LongToSizeClipNeg(weight_shape[0]);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_length_ = 0;
|
||||
weight_length_ = 0;
|
||||
per_channel_length_ = 0;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t data_size = sizeof(T);
|
||||
input_size_list_.push_back(input_length_ * data_size);
|
||||
input_size_list_.push_back(weight_length_ * data_size);
|
||||
output_size_list_.push_back(input_length_ * data_size);
|
||||
}
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
|
||||
using PReLULaunchFunc = std::function<bool(PReLUGpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
|
||||
static std::vector<std::pair<KernelAttr, PReLULaunchFunc>> func_list_;
|
||||
PReLULaunchFunc kernel_func_;
|
||||
bool is_null_input_{false};
|
||||
size_t input_length_{0};
|
||||
size_t weight_length_{0};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,10 +20,75 @@
|
|||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
bool IsAscend() {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
return context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice;
|
||||
}
|
||||
|
||||
abstract::ShapePtr PReLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
auto weight_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
// Dynamic rank.
|
||||
if (x_shape_ptr->IsDimUnknown() || weight_shape_ptr->IsDimUnknown()) {
|
||||
return x_shape_ptr->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
|
||||
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(weight_shape_ptr)[kShape];
|
||||
auto x_rank = x_shape.size();
|
||||
auto weight_rank = weight_shape.size();
|
||||
auto channel_num = x_rank <= 1 ? 1 : x_shape[1];
|
||||
if (IsAscend() && x_rank <= 1) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For '" << prim_name
|
||||
<< "', the dimension of 'x' can not be 0-D or 1-D when the platform is \"Ascend\", but got dimension of 'x' is "
|
||||
<< x_rank << ".";
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckInteger("dimension of 'weight'", SizeToLong(weight_rank), kEqual, 1, prim_name);
|
||||
if (weight_shape[0] != 1 && weight_shape[0] != channel_num) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For '" << prim_name
|
||||
<< "', the first dimension of 'weight' must be (1,) or it must be equal to number of channels: " << channel_num
|
||||
<< ", but got " << weight_shape << ".";
|
||||
}
|
||||
return x_shape_ptr->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
TypePtr PReLUInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto weight_type = input_args[kInputIndex1]->BuildType();
|
||||
auto valid_types = {kFloat16, kFloat32};
|
||||
|
||||
if (IsAscend()) {
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_type, valid_types, prim_name);
|
||||
} else {
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.emplace("x", x_type);
|
||||
(void)args.emplace("weight", weight_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
}
|
||||
return x_type;
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(PReLU, BaseOperator);
|
||||
REGISTER_PRIMITIVE_C(kNamePReLU, PReLU);
|
||||
AbstractBasePtr PReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr size_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto type = PReLUInferType(primitive, input_args);
|
||||
auto shape = PReLUInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(PReLU, prim::kPrimPRelu, PReLUInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,8 +32,8 @@ class MIND_API PReLU : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(PReLU);
|
||||
/// \brief Constructor.
|
||||
PReLU() : BaseOperator(kNamePReLU) { InitIOName({"x"}, {"y"}); }
|
||||
explicit PReLU(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"y"}); }
|
||||
PReLU() : BaseOperator(kNamePReLU) { InitIOName({"x", "weight"}, {"output"}); }
|
||||
explicit PReLU(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "weight"}, {"output"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.PReLU for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
|
|
|
@ -417,6 +417,7 @@ from .unpack import _unpack_tbe
|
|||
from .unpack_ds import _unpack_ds_tbe
|
||||
from .scatter_update import _scatter_update_tbe
|
||||
from .prelu import _prelu_tbe
|
||||
from .prelu_ds import _prelu_ds_tbe
|
||||
from .prelu_grad import _prelu_grad_tbe
|
||||
from .binary_cross_entropy_ds import _binary_cross_entropy_ds_tbe
|
||||
from .binary_cross_entropy import _binary_cross_entropy_tbe
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""PReLU op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
prelu_ds_op_info = TBERegOp("PReLU") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("prelu.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("prelu") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "weight", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(prelu_ds_op_info)
|
||||
def _prelu_ds_tbe():
|
||||
"""PReLU TBE register"""
|
||||
return
|
|
@ -4107,35 +4107,7 @@ class PReLU(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_shape(self, input_x_shape, weight_shape):
|
||||
input_x_dim = len(input_x_shape)
|
||||
if input_x_dim in (0, 1):
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
raise ValueError(f"For '{self.name}', the dimension of 'x' can not be 0-D or 1-D when the platform is "
|
||||
f"\"Ascend\", but got dimension of 'x' is {input_x_dim}.")
|
||||
channel_num = 1
|
||||
else:
|
||||
channel_num = input_x_shape[1]
|
||||
|
||||
weight_dim = len(weight_shape)
|
||||
if weight_dim != 1:
|
||||
raise ValueError(f"For '{self.name}', the dimension of 'weight' must be 1, while got {weight_dim}.")
|
||||
if weight_shape[0] != 1 and weight_shape[0] != channel_num:
|
||||
raise ValueError(f"For '{self.name}', the first dimension of 'weight' must be (1,) or "
|
||||
f"it must be equal to number of channels: {channel_num}, but got {weight_shape}")
|
||||
return input_x_shape
|
||||
|
||||
def infer_dtype(self, input_x_dtype, weight_dtype):
|
||||
valid_dtypes = (mstype.float16, mstype.float32)
|
||||
args = {"input_x": input_x_dtype, "weight": weight_dtype}
|
||||
if context.get_context("device_target") == "GPU":
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
else:
|
||||
validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name)
|
||||
validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name)
|
||||
return input_x_dtype
|
||||
self.init_prim_io_names(inputs=['x', 'weight'], outputs=['output'])
|
||||
|
||||
|
||||
class LSTM(PrimitiveWithInfer):
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class GetDynamicInputNet(nn.Cell):
|
||||
def __init__(self, axis=0):
|
||||
super(GetDynamicInputNet, self).__init__()
|
||||
self.unique = P.Unique()
|
||||
self.gather = P.Gather()
|
||||
self.cast = P.Cast()
|
||||
self.axis = axis
|
||||
|
||||
def construct(self, x, indices):
|
||||
unique_indices, _ = self.unique(indices)
|
||||
x_dtype = x.dtype
|
||||
x = self.cast(x, mstype.float32)
|
||||
real_x = self.gather(x, unique_indices, self.axis)
|
||||
return self.cast(real_x, x_dtype)
|
||||
|
||||
|
||||
class PReLUDyNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PReLUDyNet, self).__init__()
|
||||
self.op = P.PReLU()
|
||||
self.transformer = GetDynamicInputNet()
|
||||
|
||||
def construct(self, indices, x, weight):
|
||||
real_x = self.transformer(x, indices)
|
||||
out = self.op(real_x, weight)
|
||||
return real_x, weight, out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("data_shape", [((8, 6, 7), (1,))])
|
||||
@pytest.mark.parametrize("data_type", [np.float16, np.float32])
|
||||
def test_dynamic_shape_prelu(data_shape, data_type):
|
||||
"""
|
||||
Feature: PReLU DynamicShape.
|
||||
Description: Test case of dynamic shape for PReLU operator.
|
||||
Expectation: success.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x_shape, weight_shape = data_shape
|
||||
x = Tensor(np.random.random(size=x_shape).astype(data_type))
|
||||
weight = Tensor(np.random.random(size=weight_shape).astype(data_type))
|
||||
indices = Tensor(np.random.randint(0, x_shape[0], size=(5,)).astype(np.int32))
|
||||
|
||||
dy_net = PReLUDyNet()
|
||||
real_x, real_weight, output = dy_net(indices, x, weight)
|
||||
x, weight = real_x.asnumpy(), real_weight.asnumpy()
|
||||
expect = np.where(x >= 0, x, weight * x)
|
||||
|
||||
np.testing.assert_allclose(expect, output.asnumpy())
|
Loading…
Reference in New Issue