[assistant][I48O4Z] add new math operator Reciprocal

This commit is contained in:
YangShuq 2022-09-27 15:13:32 +08:00 committed by jinduo
parent a4d62ed336
commit daef90a276
9 changed files with 129 additions and 25 deletions

View File

@ -207,4 +207,5 @@ mindspore/mindspore/lite/src/litert/kernel/cpu/control/tensorlist_setitem.cc:min
mindspore/mindspore/python/mindspore/ops/_utils/utils.py:get_broadcast_shape mindspore/mindspore/python/mindspore/ops/_utils/utils.py:get_broadcast_shape
mindspore/mindspore/ccsrc/pybind_api/ir/dtype_py.cc:mindspore::RegTyping mindspore/mindspore/ccsrc/pybind_api/ir/dtype_py.cc:mindspore::RegTyping
mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc:mindspore::kernel::EltWiseGradCpuTypeFunc<T>::InitFunc
mindspore/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc:mindspore::lite::quant::WeightQuantizer::LinearQuant mindspore/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc:mindspore::lite::quant::WeightQuantizer::LinearQuant

View File

@ -723,6 +723,7 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelComplex(const std::vector<AddressP
{prim::kPrimCos->name(), ComplexCos<T>}, {prim::kPrimCos->name(), ComplexCos<T>},
{prim::kPrimACos->name(), ComplexACos<T>}, {prim::kPrimACos->name(), ComplexACos<T>},
{prim::kPrimRsqrt->name(), Rsqrt<T>}, {prim::kPrimRsqrt->name(), Rsqrt<T>},
{prim::kPrimReciprocal->name(), Reciprocal<T>},
{prim::kPrimSqrt->name(), Sqrt<T>}, {prim::kPrimSqrt->name(), Sqrt<T>},
{prim::kPrimTan->name(), Tan<T>}, {prim::kPrimTan->name(), Tan<T>},
{prim::kPrimTanh->name(), Tanh<T>}, {prim::kPrimTanh->name(), Tanh<T>},
@ -913,7 +914,10 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}}, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kReciprocal, {kReciprocal,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc}, {{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc}, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}}, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kInv, {kInv,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc}, {{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},

View File

@ -43,6 +43,7 @@ constexpr auto kInvGrad = "InvGrad";
constexpr auto kAcoshGrad = "AcoshGrad"; constexpr auto kAcoshGrad = "AcoshGrad";
constexpr auto kSoftplusGrad = "SoftplusGrad"; constexpr auto kSoftplusGrad = "SoftplusGrad";
constexpr auto kRsqrtGrad = "RsqrtGrad"; constexpr auto kRsqrtGrad = "RsqrtGrad";
constexpr auto kReciprocalGrad = "ReciprocalGrad";
template <typename T> template <typename T>
class EltWiseGradCpuTypeFunc : public CpuKernelFunc { class EltWiseGradCpuTypeFunc : public CpuKernelFunc {
@ -61,6 +62,7 @@ class EltWiseGradCpuTypeFunc : public CpuKernelFunc {
void SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void RsqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void RsqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void ReciprocalGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
@ -163,6 +165,20 @@ void EltWiseGradCpuTypeFunc<T>::RsqrtGrad(const T *input1, const T *input2, T *o
} }
} }
template <typename T>
void EltWiseGradCpuTypeFunc<T>::ReciprocalGrad(const T *input1, const T *input2, T *out, size_t start,
size_t end) const {
if constexpr ((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>)) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(-1) * conj(input1[i] * input1[i]) * input2[i];
}
} else {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(-1) * input1[i] * input1[i] * input2[i];
}
}
}
template <typename T> template <typename T>
void EltWiseGradCpuTypeFunc<T>::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const { void EltWiseGradCpuTypeFunc<T>::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
if constexpr (std::is_same<T, float>::value) { if constexpr (std::is_same<T, float>::value) {
@ -417,6 +433,7 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
{prim::kPrimAsinGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinGrad}, {prim::kPrimAsinGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinGrad},
{prim::kPrimACosGrad->name(), &EltWiseGradCpuTypeFunc<T>::ACosGrad}, {prim::kPrimACosGrad->name(), &EltWiseGradCpuTypeFunc<T>::ACosGrad},
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}, {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad},
{prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReciprocalGrad},
{prim::kPrimAtanGrad->name(), &EltWiseGradCpuTypeFunc<T>::AtanGrad}, {prim::kPrimAtanGrad->name(), &EltWiseGradCpuTypeFunc<T>::AtanGrad},
{prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc<T>::TanhGrad}, {prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc<T>::TanhGrad},
{prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinhGrad}, {prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinhGrad},
@ -434,6 +451,7 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
static const std::map<std::string, static const std::map<std::string,
std::function<void(EltWiseGradCpuTypeFunc *, const T *, const T *, T *, size_t, size_t)>> std::function<void(EltWiseGradCpuTypeFunc *, const T *, const T *, T *, size_t, size_t)>>
elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}, elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad},
{prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReciprocalGrad},
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}}; {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) { if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "EltWiseGradCpu does not support " << kernel_name_ << " with float as input."; MS_LOG(EXCEPTION) << "EltWiseGradCpu does not support " << kernel_name_ << " with float as input.";
@ -457,6 +475,7 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
{prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinhGrad}, {prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinhGrad},
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad}, {prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}, {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad},
{prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReciprocalGrad},
{prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc<T>::AcoshGrad}, {prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc<T>::AcoshGrad},
{prim::kPrimSoftplusGrad->name(), &EltWiseGradCpuTypeFunc<T>::SoftplusGrad}, {prim::kPrimSoftplusGrad->name(), &EltWiseGradCpuTypeFunc<T>::SoftplusGrad},
{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}}; {prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}};
@ -499,6 +518,7 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
{prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc<T>::TanhGrad}, {prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc<T>::TanhGrad},
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad}, {prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
{prim::kPrimSqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::SqrtGrad}, {prim::kPrimSqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::SqrtGrad},
{prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReciprocalGrad},
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}}; {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) { if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_; MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_;
@ -705,7 +725,24 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, FuncCreator>>> ke
.AddInputAttr(kNumberTypeComplex64) .AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64) .AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64), .AddOutputAttr(kNumberTypeComplex64),
&SpecializeEltWiseGradFunc<complex64>}}}}; &SpecializeEltWiseGradFunc<complex64>}}},
{kReciprocalGrad,
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&SpecializeEltWiseGradFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SpecializeEltWiseGradFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&SpecializeEltWiseGradFunc<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&SpecializeEltWiseGradFunc<complex64>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&SpecializeEltWiseGradFunc<complex128>}}}};
} // namespace } // namespace
bool EltWiseGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs, bool EltWiseGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -790,5 +827,7 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, SoftplusGrad,
[]() { return std::make_shared<EltWiseGradCpuKernelMod>(kSoftplusGrad); }); []() { return std::make_shared<EltWiseGradCpuKernelMod>(kSoftplusGrad); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, RsqrtGrad, MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, RsqrtGrad,
[]() { return std::make_shared<EltWiseGradCpuKernelMod>(kRsqrtGrad); }); []() { return std::make_shared<EltWiseGradCpuKernelMod>(kRsqrtGrad); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ReciprocalGrad,
[]() { return std::make_shared<EltWiseGradCpuKernelMod>(kReciprocalGrad); });
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -39,7 +39,7 @@ abstract::ShapePtr ReciprocalGradInferShape(const PrimitivePtr &primitive,
} }
TypePtr ReciprocalGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { TypePtr ReciprocalGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type = input_args[0]->BuildType(); auto x_type = input_args[0]->BuildType();
@ -55,8 +55,8 @@ AbstractBasePtr ReciprocalGradInfer(const abstract::AnalysisEnginePtr &, const P
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2; const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto shape = ReciprocalGradInferShape(primitive, input_args);
auto type = ReciprocalGradInferType(primitive, input_args); auto type = ReciprocalGradInferType(primitive, input_args);
auto shape = ReciprocalGradInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type); return abstract::MakeAbstract(shape, type);
} }

View File

@ -30,11 +30,6 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr ReciprocalInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr ReciprocalInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x = input_args[0]->BuildShape(); auto x = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x);
auto shape_ptr = x->cast<abstract::ShapePtr>(); auto shape_ptr = x->cast<abstract::ShapePtr>();
@ -45,24 +40,26 @@ abstract::ShapePtr ReciprocalInferShape(const PrimitivePtr &primitive, const std
TypePtr ReciprocalInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr ReciprocalInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name(); auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("input number", int64_t(input_args.size()), kEqual, 1, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types; std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("x", input_args[0]->BuildType());
std::set<TypePtr> valid_params_types = {kTensorType}; std::set<TypePtr> valid_params_types = {kTensorType};
(void)CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name); const std::set<TypePtr> common_valid_types = {kInt32, kInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); (void)CheckAndConvertUtils::CheckSubClass("x", input_args[0]->BuildType(), valid_params_types, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[0]->BuildType();
} }
} // namespace } // namespace
MIND_API_OPERATOR_IMPL(Reciprocal, BaseOperator); MIND_API_OPERATOR_IMPL(Reciprocal, BaseOperator);
AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(ReciprocalInferShape(primitive, input_args), MS_EXCEPTION_IF_NULL(primitive);
ReciprocalInferType(primitive, input_args)); const int64_t kInputsNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infer_type = ReciprocalInferType(primitive, input_args);
auto infer_shape = ReciprocalInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
} }
REGISTER_PRIMITIVE_C(kNameReciprocal, Reciprocal); REGISTER_PRIMITIVE_EVAL_IMPL(Reciprocal, prim::kPrimReciprocal, ReciprocalInfer, nullptr, true);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,34 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reciprocal op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
reciprocal_op_info = AiCPURegOp("Reciprocal") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(reciprocal_op_info)
def _reciprocal_aicpu():
"""Rpeciprocal AiCPU register"""
return

View File

@ -0,0 +1,35 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReciprocalGrad op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
reciprocal_grad_op_info = AiCPURegOp("ReciprocalGrad") \
.fusion_type("OPAQUE") \
.input(0, "y", "required") \
.input(1, "dy", "required") \
.output(0, "z", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(reciprocal_grad_op_info)
def _reciprocal_grad_aicpu():
"""ReciprocalGrad AiCPU register"""
return

View File

@ -104,6 +104,7 @@ class ReciprocalGrad(Primitive):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""Initialize ReciprocalGrad""" """Initialize ReciprocalGrad"""
self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
class RsqrtGrad(Primitive): class RsqrtGrad(Primitive):

View File

@ -2192,7 +2192,7 @@ class Sqrt(Primitive):
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
class Reciprocal(PrimitiveWithInfer): class Reciprocal(PrimitiveWithCheck):
r""" r"""
Returns reciprocal of a tensor element-wise. Returns reciprocal of a tensor element-wise.
@ -2230,13 +2230,6 @@ class Reciprocal(PrimitiveWithInfer):
self.target = "OTHER" self.target = "OTHER"
self.init_prim_io_names(inputs=['x'], outputs=['y']) self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x):
return x
def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.name)
return x
def infer_value(self, x): def infer_value(self, x):
if x is not None: if x is not None:
x = x.asnumpy() x = x.asnumpy()