!40230 Add dynamic shape support for PReLU operator.
Merge pull request !40230 from hezhenhao1/add_prelu
This commit is contained in:
commit
d9e93729cd
|
@ -18,14 +18,70 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
bool PReLUGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
PReLU,
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
constexpr size_t input_num = 2;
|
||||||
PReLUGpuKernelMod, half)
|
constexpr size_t output_num = 1;
|
||||||
|
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
|
||||||
|
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
|
||||||
|
kernel_name_ = base_operator->GetPrim()->name();
|
||||||
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
kernel_func_ = func_list_[index].second;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
int PReLUGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
PReLU,
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
PReLUGpuKernelMod, float)
|
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
auto input_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
|
||||||
|
auto weight_shape = LongVecToSizeVec(inputs[kIndex1]->GetShapeVector());
|
||||||
|
is_null_input_ =
|
||||||
|
CHECK_SHAPE_NULL(input_shape, kernel_name_, "x") || CHECK_SHAPE_NULL(weight_shape, kernel_name_, "weight");
|
||||||
|
input_length_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<>());
|
||||||
|
per_channel_length_ =
|
||||||
|
input_shape.size() <= 1 ? input_length_ : input_length_ / (input_shape[kIndex0] * input_shape[kIndex1]);
|
||||||
|
weight_length_ = weight_shape[0];
|
||||||
|
return KRET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool PReLUGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||||
|
if (is_null_input_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto input = GetDeviceAddress<T>(inputs, 0);
|
||||||
|
auto weight = GetDeviceAddress<T>(inputs, 1);
|
||||||
|
auto output = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
|
||||||
|
CalPReLU(input_length_, weight_length_, per_channel_length_, input, weight, output,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<KernelAttr, PReLUGpuKernelMod::PReLULaunchFunc>> PReLUGpuKernelMod::func_list_ = {
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
&PReLUGpuKernelMod::LaunchKernel<half>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
&PReLUGpuKernelMod::LaunchKernel<float>},
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<KernelAttr> PReLUGpuKernelMod::GetOpSupport() {
|
||||||
|
std::vector<KernelAttr> support_list;
|
||||||
|
(void)std::transform(
|
||||||
|
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, PReLUGpuKernelMod::PReLULaunchFunc> &pair) { return pair.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, PReLU, PReLUGpuKernelMod);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||||
|
@ -27,91 +28,34 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
template <typename T>
|
class PReLUGpuKernelMod : public NativeGpuKernelMod {
|
||||||
class PReLUGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|
||||||
public:
|
public:
|
||||||
PReLUGpuKernelMod() = default;
|
PReLUGpuKernelMod() = default;
|
||||||
~PReLUGpuKernelMod() override = default;
|
~PReLUGpuKernelMod() override = default;
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
if (is_null_input_) {
|
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||||
return true;
|
|
||||||
}
|
|
||||||
auto *input = GetDeviceAddress<T>(inputs, 0);
|
|
||||||
auto *weight = GetDeviceAddress<T>(inputs, 1);
|
|
||||||
auto *output = GetDeviceAddress<T>(outputs, 0);
|
|
||||||
|
|
||||||
CalPReLU(input_length_, weight_length_, per_channel_length_, input, weight, output,
|
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const CNodePtr &kernel_node) override {
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
kernel_node_ = kernel_node;
|
|
||||||
ResetResource();
|
|
||||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
|
||||||
if (input_num != 2) {
|
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 2, but got " << input_num;
|
|
||||||
}
|
|
||||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
|
||||||
if (output_num != 1) {
|
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||||
is_null_input_ =
|
|
||||||
CHECK_SHAPE_NULL(input_shape, kernel_name, "x") || CHECK_SHAPE_NULL(weight_shape, kernel_name, "weight");
|
|
||||||
if (is_null_input_) {
|
|
||||||
InitSizeLists();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
input_length_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<>());
|
|
||||||
size_t input_rank = input_shape.size();
|
|
||||||
int64_t channel_num;
|
|
||||||
if (input_rank == 0) {
|
|
||||||
channel_num = 1;
|
|
||||||
per_channel_length_ = 1;
|
|
||||||
} else if (input_rank == 1) {
|
|
||||||
channel_num = 1;
|
|
||||||
per_channel_length_ = input_shape[0];
|
|
||||||
} else {
|
|
||||||
channel_num = input_shape[1];
|
|
||||||
per_channel_length_ = std::accumulate(input_shape.begin() + 2, input_shape.end(), size_t(1), std::multiplies<>());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weight_shape.size() != 1 || (weight_shape[0] != 1 && weight_shape[0] != channel_num)) {
|
|
||||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of weight must be equal to 1 and "
|
|
||||||
<< "weight.shape[0] must be equal to 1 or the channel number, but got the dimension of "
|
|
||||||
<< "weight: " << weight_shape.size() << ", weight.shape[0]: " << weight_shape[0]
|
|
||||||
<< ", the channel num: " << channel_num;
|
|
||||||
}
|
|
||||||
weight_length_ = LongToSizeClipNeg(weight_shape[0]);
|
|
||||||
InitSizeLists();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ResetResource() noexcept override {
|
|
||||||
input_length_ = 0;
|
|
||||||
weight_length_ = 0;
|
|
||||||
per_channel_length_ = 0;
|
|
||||||
is_null_input_ = false;
|
|
||||||
input_size_list_.clear();
|
|
||||||
output_size_list_.clear();
|
|
||||||
workspace_size_list_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void InitSizeLists() override {
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
size_t data_size = sizeof(T);
|
|
||||||
input_size_list_.push_back(input_length_ * data_size);
|
|
||||||
input_size_list_.push_back(weight_length_ * data_size);
|
|
||||||
output_size_list_.push_back(input_length_ * data_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
template <typename T>
|
||||||
|
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||||
|
|
||||||
|
using PReLULaunchFunc = std::function<bool(PReLUGpuKernelMod *, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
|
||||||
|
static std::vector<std::pair<KernelAttr, PReLULaunchFunc>> func_list_;
|
||||||
|
PReLULaunchFunc kernel_func_;
|
||||||
bool is_null_input_{false};
|
bool is_null_input_{false};
|
||||||
size_t input_length_{0};
|
size_t input_length_{0};
|
||||||
size_t weight_length_{0};
|
size_t weight_length_{0};
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -20,10 +20,75 @@
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "mindapi/src/helper.h"
|
#include "mindapi/src/helper.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
bool IsAscend() {
|
||||||
|
auto context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
return context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice;
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract::ShapePtr PReLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||||
|
auto weight_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||||
|
// Dynamic rank.
|
||||||
|
if (x_shape_ptr->IsDimUnknown() || weight_shape_ptr->IsDimUnknown()) {
|
||||||
|
return x_shape_ptr->cast<abstract::ShapePtr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
|
||||||
|
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(weight_shape_ptr)[kShape];
|
||||||
|
auto x_rank = x_shape.size();
|
||||||
|
auto weight_rank = weight_shape.size();
|
||||||
|
auto channel_num = x_rank <= 1 ? 1 : x_shape[1];
|
||||||
|
if (IsAscend() && x_rank <= 1) {
|
||||||
|
MS_EXCEPTION(ValueError)
|
||||||
|
<< "For '" << prim_name
|
||||||
|
<< "', the dimension of 'x' can not be 0-D or 1-D when the platform is \"Ascend\", but got dimension of 'x' is "
|
||||||
|
<< x_rank << ".";
|
||||||
|
}
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("dimension of 'weight'", SizeToLong(weight_rank), kEqual, 1, prim_name);
|
||||||
|
if (weight_shape[0] != 1 && weight_shape[0] != channel_num) {
|
||||||
|
MS_EXCEPTION(ValueError)
|
||||||
|
<< "For '" << prim_name
|
||||||
|
<< "', the first dimension of 'weight' must be (1,) or it must be equal to number of channels: " << channel_num
|
||||||
|
<< ", but got " << weight_shape << ".";
|
||||||
|
}
|
||||||
|
return x_shape_ptr->cast<abstract::ShapePtr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr PReLUInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||||
|
auto weight_type = input_args[kInputIndex1]->BuildType();
|
||||||
|
auto valid_types = {kFloat16, kFloat32};
|
||||||
|
|
||||||
|
if (IsAscend()) {
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_type, valid_types, prim_name);
|
||||||
|
} else {
|
||||||
|
std::map<std::string, TypePtr> args;
|
||||||
|
(void)args.emplace("x", x_type);
|
||||||
|
(void)args.emplace("weight", weight_type);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||||
|
}
|
||||||
|
return x_type;
|
||||||
|
}
|
||||||
|
|
||||||
MIND_API_OPERATOR_IMPL(PReLU, BaseOperator);
|
MIND_API_OPERATOR_IMPL(PReLU, BaseOperator);
|
||||||
REGISTER_PRIMITIVE_C(kNamePReLU, PReLU);
|
AbstractBasePtr PReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
constexpr size_t input_num = 2;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||||
|
auto type = PReLUInferType(primitive, input_args);
|
||||||
|
auto shape = PReLUInferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(shape, type);
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(PReLU, prim::kPrimPRelu, PReLUInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -32,8 +32,8 @@ class MIND_API PReLU : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
MIND_API_BASE_MEMBER(PReLU);
|
MIND_API_BASE_MEMBER(PReLU);
|
||||||
/// \brief Constructor.
|
/// \brief Constructor.
|
||||||
PReLU() : BaseOperator(kNamePReLU) { InitIOName({"x"}, {"y"}); }
|
PReLU() : BaseOperator(kNamePReLU) { InitIOName({"x", "weight"}, {"output"}); }
|
||||||
explicit PReLU(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x"}, {"y"}); }
|
explicit PReLU(const std::string k_name) : BaseOperator(k_name) { InitIOName({"x", "weight"}, {"output"}); }
|
||||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.PReLU for the inputs.
|
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.PReLU for the inputs.
|
||||||
void Init() const {}
|
void Init() const {}
|
||||||
};
|
};
|
||||||
|
|
|
@ -417,6 +417,7 @@ from .unpack import _unpack_tbe
|
||||||
from .unpack_ds import _unpack_ds_tbe
|
from .unpack_ds import _unpack_ds_tbe
|
||||||
from .scatter_update import _scatter_update_tbe
|
from .scatter_update import _scatter_update_tbe
|
||||||
from .prelu import _prelu_tbe
|
from .prelu import _prelu_tbe
|
||||||
|
from .prelu_ds import _prelu_ds_tbe
|
||||||
from .prelu_grad import _prelu_grad_tbe
|
from .prelu_grad import _prelu_grad_tbe
|
||||||
from .binary_cross_entropy_ds import _binary_cross_entropy_ds_tbe
|
from .binary_cross_entropy_ds import _binary_cross_entropy_ds_tbe
|
||||||
from .binary_cross_entropy import _binary_cross_entropy_tbe
|
from .binary_cross_entropy import _binary_cross_entropy_tbe
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""PReLU op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
prelu_ds_op_info = TBERegOp("PReLU") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("prelu.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("prelu") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.dynamic_shape(True) \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.input(1, "weight", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.is_dynamic_format(True) \
|
||||||
|
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(prelu_ds_op_info)
|
||||||
|
def _prelu_ds_tbe():
|
||||||
|
"""PReLU TBE register"""
|
||||||
|
return
|
|
@ -4107,35 +4107,7 @@ class PReLU(PrimitiveWithInfer):
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
self.init_prim_io_names(inputs=['x', 'weight'], outputs=['output'])
|
||||||
|
|
||||||
def infer_shape(self, input_x_shape, weight_shape):
|
|
||||||
input_x_dim = len(input_x_shape)
|
|
||||||
if input_x_dim in (0, 1):
|
|
||||||
if context.get_context("device_target") == "Ascend":
|
|
||||||
raise ValueError(f"For '{self.name}', the dimension of 'x' can not be 0-D or 1-D when the platform is "
|
|
||||||
f"\"Ascend\", but got dimension of 'x' is {input_x_dim}.")
|
|
||||||
channel_num = 1
|
|
||||||
else:
|
|
||||||
channel_num = input_x_shape[1]
|
|
||||||
|
|
||||||
weight_dim = len(weight_shape)
|
|
||||||
if weight_dim != 1:
|
|
||||||
raise ValueError(f"For '{self.name}', the dimension of 'weight' must be 1, while got {weight_dim}.")
|
|
||||||
if weight_shape[0] != 1 and weight_shape[0] != channel_num:
|
|
||||||
raise ValueError(f"For '{self.name}', the first dimension of 'weight' must be (1,) or "
|
|
||||||
f"it must be equal to number of channels: {channel_num}, but got {weight_shape}")
|
|
||||||
return input_x_shape
|
|
||||||
|
|
||||||
def infer_dtype(self, input_x_dtype, weight_dtype):
|
|
||||||
valid_dtypes = (mstype.float16, mstype.float32)
|
|
||||||
args = {"input_x": input_x_dtype, "weight": weight_dtype}
|
|
||||||
if context.get_context("device_target") == "GPU":
|
|
||||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
|
||||||
else:
|
|
||||||
validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name)
|
|
||||||
validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name)
|
|
||||||
return input_x_dtype
|
|
||||||
|
|
||||||
|
|
||||||
class LSTM(PrimitiveWithInfer):
|
class LSTM(PrimitiveWithInfer):
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
|
||||||
|
class GetDynamicInputNet(nn.Cell):
|
||||||
|
def __init__(self, axis=0):
|
||||||
|
super(GetDynamicInputNet, self).__init__()
|
||||||
|
self.unique = P.Unique()
|
||||||
|
self.gather = P.Gather()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.axis = axis
|
||||||
|
|
||||||
|
def construct(self, x, indices):
|
||||||
|
unique_indices, _ = self.unique(indices)
|
||||||
|
x_dtype = x.dtype
|
||||||
|
x = self.cast(x, mstype.float32)
|
||||||
|
real_x = self.gather(x, unique_indices, self.axis)
|
||||||
|
return self.cast(real_x, x_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class PReLUDyNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(PReLUDyNet, self).__init__()
|
||||||
|
self.op = P.PReLU()
|
||||||
|
self.transformer = GetDynamicInputNet()
|
||||||
|
|
||||||
|
def construct(self, indices, x, weight):
|
||||||
|
real_x = self.transformer(x, indices)
|
||||||
|
out = self.op(real_x, weight)
|
||||||
|
return real_x, weight, out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize("data_shape", [((8, 6, 7), (1,))])
|
||||||
|
@pytest.mark.parametrize("data_type", [np.float16, np.float32])
|
||||||
|
def test_dynamic_shape_prelu(data_shape, data_type):
|
||||||
|
"""
|
||||||
|
Feature: PReLU DynamicShape.
|
||||||
|
Description: Test case of dynamic shape for PReLU operator.
|
||||||
|
Expectation: success.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
x_shape, weight_shape = data_shape
|
||||||
|
x = Tensor(np.random.random(size=x_shape).astype(data_type))
|
||||||
|
weight = Tensor(np.random.random(size=weight_shape).astype(data_type))
|
||||||
|
indices = Tensor(np.random.randint(0, x_shape[0], size=(5,)).astype(np.int32))
|
||||||
|
|
||||||
|
dy_net = PReLUDyNet()
|
||||||
|
real_x, real_weight, output = dy_net(indices, x, weight)
|
||||||
|
x, weight = real_x.asnumpy(), real_weight.asnumpy()
|
||||||
|
expect = np.where(x >= 0, x, weight * x)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(expect, output.asnumpy())
|
Loading…
Reference in New Issue