[feat] [assistant] [I54KH8] [I54KHA] add new operator Relu and ReluGrad

This commit is contained in:
NBUFabio 2022-05-28 15:02:59 +08:00
parent 221f51e2f5
commit 6c9f39c47e
13 changed files with 452 additions and 13 deletions

View File

@ -85,6 +85,7 @@ constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate";
constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush";
constexpr auto kPriorityReplayBufferSample = "PriorityReplayBufferSample";
constexpr auto kPriorityReplayBufferUpdate = "PriorityReplayBufferUpdate";
constexpr auto kReLUV3 = "ReLUV3";
constexpr auto kNonZero = "NonZero";
constexpr auto kMaxPoolV1 = "MaxPoolV1";
constexpr auto kMaxPoolGradV1 = "MaxPoolGradV1";
@ -110,6 +111,7 @@ const std::map<std::string, std::string> kOpNameToAicpuOpNameMap{
{kMaxPoolV1, "MaxPool"},
{kMaxPoolGradV1, "MaxPoolGrad"},
{kNameRangeV2, "Range"},
{kReLUV3, "Relu"},
{kStack, "Pack"},
{kUnstack, "Unpack"},
{kGather, "GatherV2"},

View File

@ -80,13 +80,15 @@ class EltWiseGradCpuTypeFunc : public CpuKernelFunc {
template <typename T>
void EltWiseGradCpuTypeFunc<T>::ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
if constexpr (!std::is_same<T, float>::value) {
MS_LOG(EXCEPTION) << "For 'ReLUGrad', the dtype of input must be float.";
}
int ret = ::ReluGrad(input1 + start, input2 + start, end - start, out + start);
if (ret == NNACL_ERR) {
MS_LOG(EXCEPTION) << "For 'ReLUGrad', execute failed. Error no: " << ret;
if constexpr (std::is_same<T, float>::value) {
int ret = ::ReluGrad(input1 + start, input2 + start, end - start, out + start);
if (ret == NNACL_ERR) {
MS_LOG(EXCEPTION) << "For 'ReLUGrad', execute failed. Error no: " << ret;
}
} else {
for (size_t i = start; i < end; i++) {
out[i] = (input2[i] > T(0)) ? input1[i] : static_cast<T>(0);
}
}
}
@ -111,6 +113,7 @@ void EltWiseGradCpuTypeFunc<T>::AbsGrad(const T *input1, const T *input2, T *out
}
} else {
for (size_t i = start; i < end; i++) {
// cppcheck-suppress unsignedLessThanZero
out[i] = (input1[i] < 0) ? -input2[i] : ((input1[i] > 0) ? input2[i] : 0);
}
}
@ -346,13 +349,24 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
{prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinhGrad},
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
{prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc<T>::AcoshGrad},
{prim::kPrimAbsGrad->name(), &EltWiseGradCpuTypeFunc<T>::AbsGrad}};
{prim::kPrimAbsGrad->name(), &EltWiseGradCpuTypeFunc<T>::AbsGrad},
{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_ << " with double as input.";
}
compute_func_ = elt_map.at(kernel_name_);
return;
}
if constexpr (std::is_same_v<T, float16>) {
static const std::map<std::string,
std::function<void(EltWiseGradCpuTypeFunc *, const T *, const T *, T *, size_t, size_t)>>
elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "EltWiseGradCpu does not support " << kernel_name_ << " with float as input.";
}
compute_func_ = elt_map.at(kernel_name_);
return;
}
if constexpr (std::is_same_v<T, float>) {
static const std::map<std::string,
std::function<void(EltWiseGradCpuTypeFunc *, const T *, const T *, T *, size_t, size_t)>>
@ -370,24 +384,37 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad},
{prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc<T>::AcoshGrad},
{prim::kPrimSoftplusGrad->name(), &EltWiseGradCpuTypeFunc<T>::SoftplusGrad}};
{prim::kPrimSoftplusGrad->name(), &EltWiseGradCpuTypeFunc<T>::SoftplusGrad},
{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_ << " with float as input.";
}
compute_func_ = elt_map.at(kernel_name_);
return;
}
if constexpr ((std::is_same_v<T, int>) || (std::is_same_v<T, int8_t>)) {
if constexpr ((std::is_same_v<T, int>) || (std::is_same_v<T, int8_t>) || (std::is_same_v<T, int16_t>) ||
(std::is_same_v<T, int64_t>)) {
static const std::map<std::string,
std::function<void(EltWiseGradCpuTypeFunc *, const T *, const T *, T *, size_t, size_t)>>
elt_map{{prim::kPrimAbsGrad->name(), &EltWiseGradCpuTypeFunc<T>::AbsGrad},
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad},
{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_ << " with int as input.";
}
compute_func_ = elt_map.at(kernel_name_);
}
if constexpr ((std::is_same_v<T, uint8_t>) || (std::is_same_v<T, uint16_t>)) {
static const std::map<std::string,
std::function<void(EltWiseGradCpuTypeFunc *, const T *, const T *, T *, size_t, size_t)>>
elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "EltWiseGradCpu does not support " << kernel_name_ << " with uint as input.";
}
compute_func_ = elt_map.at(kernel_name_);
return;
}
if constexpr ((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>)) {
static const std::map<std::string,
std::function<void(EltWiseGradCpuTypeFunc *, const T *, const T *, T *, size_t, size_t)>>
@ -422,8 +449,24 @@ std::shared_ptr<CpuKernelFunc> SpecializeEltWiseGradFunc() {
using FuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
static std::map<std::string, std::vector<std::pair<KernelAttr, FuncCreator>>> kernel_attr_list_map = {
{kReluGrad,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SpecializeEltWiseGradFunc<float>}}},
{{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&SpecializeEltWiseGradFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&SpecializeEltWiseGradFunc<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&SpecializeEltWiseGradFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&SpecializeEltWiseGradFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&SpecializeEltWiseGradFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&SpecializeEltWiseGradFunc<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&SpecializeEltWiseGradFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SpecializeEltWiseGradFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&SpecializeEltWiseGradFunc<double>}}},
{kReLU6Grad,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SpecializeEltWiseGradFunc<float>}}},

View File

@ -0,0 +1,95 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/relu_v3_cpu_kernel.h"
#include <algorithm>
#include <functional>
#include "mindspore/core/ops/relu_v3.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "utils/ms_utils.h"
namespace mindspore::kernel {
constexpr auto kReLUV3 = "ReLUV3";
constexpr const size_t kReLUV3InputsNum = 1;
constexpr const size_t kReLUV3OutputsNum = 1;
template <typename T>
bool ReLUV3CpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kReLUV3InputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kReLUV3OutputsNum, kernel_name_);
auto *input = reinterpret_cast<T *>(inputs[kIndex0]->addr);
MS_ERROR_IF_NULL_W_RET_VAL(input, false);
auto *output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
MS_ERROR_IF_NULL_W_RET_VAL(output, false);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
auto task = [input, output](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T v = input[i];
bool p = v > static_cast<T>(0);
output[i] = p ? v : static_cast<T>(0);
}
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
return true;
}
const std::vector<std::pair<KernelAttr, ReLUV3CpuKernelMod::KernelRunFunc>> &ReLUV3CpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, ReLUV3CpuKernelMod::KernelRunFunc>> func_list = {
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ReLUV3CpuKernelMod::LaunchKernel<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ReLUV3CpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&ReLUV3CpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&ReLUV3CpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&ReLUV3CpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&ReLUV3CpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&ReLUV3CpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&ReLUV3CpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&ReLUV3CpuKernelMod::LaunchKernel<uint16_t>},
};
return func_list;
}
bool ReLUV3CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::ReLUV3>(base_operator);
MS_ERROR_IF_NULL_W_RET_VAL(kernel_ptr, false);
kernel_name_ = kernel_ptr->name();
if (inputs.size() != kReLUV3InputsNum || outputs.size() != kReLUV3OutputsNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size must be " << kReLUV3InputsNum << " and "
<< kReLUV3OutputsNum << ", but got " << inputs.size() << " and " << outputs.size();
return false;
}
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ReLUV3,
[]() { return std::make_shared<ReLUV3CpuKernelMod>(kReLUV3); });
} // namespace mindspore::kernel

View File

@ -0,0 +1,59 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RELU_V3_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RELU_V3_CPU_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include <map>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore::kernel {
constexpr auto kUnknown = "Unknown";
class ReLUV3CpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<ReLUV3CpuKernelMod> {
public:
ReLUV3CpuKernelMod() = default;
explicit ReLUV3CpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~ReLUV3CpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs);
std::string kernel_type_{kUnknown};
};
} // namespace mindspore::kernel
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RELU_V3_CPU_KERNEL_H_

View File

@ -117,6 +117,7 @@ constexpr auto kGLU = "GLU";
constexpr auto kReLU = "ReLU";
constexpr auto kReLU6 = "ReLU6";
constexpr auto kReLUV2 = "ReLUV2";
constexpr auto kReLUV3 = "ReLUV3";
constexpr auto kReLUGrad = "ReluGrad";
constexpr auto kReLUGradV2 = "ReluGradV2";
constexpr auto kRint = "Rint";
@ -595,6 +596,7 @@ GVAR_DEF(PrimitivePtr, kPrimElu, std::make_shared<Primitive>("Elu"));
GVAR_DEF(PrimitivePtr, kPrimEluGrad, std::make_shared<Primitive>("EluGrad"));
GVAR_DEF(PrimitivePtr, kPrimRelu6, std::make_shared<Primitive>(kReLU6));
GVAR_DEF(PrimitivePtr, kPrimReluV2, std::make_shared<Primitive>(kReLUV2));
GVAR_DEF(PrimitivePtr, kPrimReluV3, std::make_shared<Primitive>(kReLUV3));
GVAR_DEF(PrimitivePtr, kPrimPRelu, std::make_shared<Primitive>("PReLU"));
GVAR_DEF(PrimitivePtr, kPrimSelu, std::make_shared<Primitive>("SeLU"));
GVAR_DEF(PrimitivePtr, kPrimSoftplus, std::make_shared<Primitive>("Softplus"));

View File

@ -0,0 +1,56 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/relu_v3.h"
#include <string>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ReLUV3InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto shape_element = input_args[0]->BuildShape()->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
TypePtr ReLUV3InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, common_valid_types, prim_name);
return x_type;
}
} // namespace
MIND_API_OPERATOR_IMPL(ReLUV3, BaseOperator);
AbstractBasePtr ReLUV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto type = ReLUV3InferType(primitive, input_args);
auto shape = ReLUV3InferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUV3, prim::kPrimReluV3, ReLUV3Infer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_RELU_V3_H_
#define MINDSPORE_CORE_OPS_RELU_V3_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameReLUV3 = "ReLUV3";
/// \brief Computes ReLUV3 (Rectified Linear Unit activation function) of input tensors element-wise.
/// Refer to Python API @ref mindspore.ops.ReLUV3 for more details.
class MIND_API ReLUV3 : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ReLUV3);
/// \brief Constructor.
ReLUV3() : BaseOperator(kNameReLUV3) { InitIOName({"x"}, {"output"}); }
/// \brief Init.
void Init() const {}
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_RELU_V3_H_

View File

@ -36,6 +36,8 @@ from ..operations.nn_ops import AvgPoolV1
from ..operations._grad_ops import AvgPoolGradV1
from ..operations.nn_ops import MaxPoolV1
from ..operations._grad_ops import MaxPoolGradV1
from ..operations.nn_ops import ReLUV3
from ..operations._grad_ops import ReluGrad
@bprop_getters.register(P.CTCLossV2)
@ -114,6 +116,18 @@ def get_bprop_grid_sampler_3d(self):
return bprop
@bprop_getters.register(ReLUV3)
def get_bprop_relu(self):
"""Grad definition for `ReLUV3` operation."""
input_grad = ReluGrad()
def bprop(x, out, dout):
dx = input_grad(dout, out)
return (dx,)
return bprop
@bprop_getters.register(NthElement)
def get_bprop_nth_element(self):
"""Grad definition for `NthElement` operation."""

View File

@ -144,6 +144,8 @@ from .upper_bound import _upper_bound_aicpu
from .zeros_like import _zeros_like_aicpu
from .ones_like import _ones_like_aicpu
from .concat import _concat_aicpu
from .relu_v3 import _relu_v3_aicpu
from .relu_grad_v3 import _relu_grad_v3_aicpu
from .grid_sampler_3d import _grid_sampler_3d_aicpu
from .atanh import _atanh_aicpu
from .grid_sampler_3d_grad import _grid_sampler_3d_grad_aicpu

View File

@ -0,0 +1,41 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReluGradV3 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
relu_grad_v3_op_info = AiCPURegOp("ReluGrad") \
.fusion_type("OPAQUE") \
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
.get_op_info()
@op_info_register(relu_grad_v3_op_info)
def _relu_grad_v3_aicpu():
"""ReluGradV3 aicpu register"""
return

View File

@ -0,0 +1,38 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReLUV3 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
relu_v3_op_info = AiCPURegOp("ReLUV3") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.get_op_info()
@op_info_register(relu_v3_op_info)
def _relu_v3_aicpu():
"""ReLUV3 AiCPU register"""
return

View File

@ -494,6 +494,46 @@ class ReLU(Primitive):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
class ReLUV3(Primitive):
r"""
Computes ReLUV3 (Rectified Linear Unit activation function) of input tensors element-wise.
It returns max(x, 0) element-wise. Specially, the neurons with the negative output
will be suppressed and the active neurons will stay the same.
.. math::
ReLUV3(x) = (x)^+ = max(0, x)
Inputs:
- **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
additional dimensions, data type is
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_.
Outputs:
Tensor of shape :math:`(N, *)`, with the same type and shape as the `input_x`.
Raises:
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
>>> relu_v3 = ops.ReLUV3()
>>> output = relu_v3(input_x)
>>> print(output)
[[0. 4. 0.]
[2. 0. 9.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize ReLUV3"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
class Mish(PrimitiveWithInfer):
r"""
Computes MISH(A Self Regularized Non-Monotonic Neural Activation Function) of input tensors element-wise.

View File

@ -63,6 +63,7 @@ from mindspore.ops.operations._grad_ops import AvgPoolGradV1
from mindspore.ops.operations.nn_ops import MaxPoolV1
from mindspore.ops.operations.array_ops import NonZero
from mindspore.ops.operations._grad_ops import MaxPoolGradV1
from mindspore.ops.operations.nn_ops import ReLUV3
from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix
from mindspore.nn.layer import normalization
from mindspore.ops.operations.array_ops import RightShift
@ -2084,6 +2085,10 @@ test_case_nn_ops = [
'block': P.ReLUV2(),
'desc_inputs': [[1, 3, 4, 4]],
'desc_bprop': [[1, 3, 4, 4], ([1, 1, 4, 4, 2], {'dtype': np.uint8})]}),
('ReLUV3', {
'block': ReLUV3(),
'desc_inputs': [[1, 3, 4, 7, 9]],
'desc_bprop': [[1, 3, 4, 7, 9]]}),
('ReLUGrad', {
'block': G.ReluGrad(),
'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]],