diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 51795ba665c..f73f564ea1a 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -207,4 +207,5 @@ mindspore/mindspore/lite/src/litert/kernel/cpu/control/tensorlist_setitem.cc:min mindspore/mindspore/python/mindspore/ops/_utils/utils.py:get_broadcast_shape mindspore/mindspore/ccsrc/pybind_api/ir/dtype_py.cc:mindspore::RegTyping mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor +mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc:mindspore::kernel::EltWiseGradCpuTypeFunc::InitFunc mindspore/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc:mindspore::lite::quant::WeightQuantizer::LinearQuant diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.cc index 63e4f9780de..d708e06434a 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.cc @@ -723,6 +723,7 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelComplex(const std::vectorname(), ComplexCos}, {prim::kPrimACos->name(), ComplexACos}, {prim::kPrimRsqrt->name(), Rsqrt}, + {prim::kPrimReciprocal->name(), Reciprocal}, {prim::kPrimSqrt->name(), Sqrt}, {prim::kPrimTan->name(), Tan}, {prim::kPrimTanh->name(), Tanh}, @@ -913,7 +914,10 @@ static std::map {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}}, {kReciprocal, {{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc}, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}}, {kInv, {{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc}, diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc index 22f7c983ce3..5c122997b26 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc @@ -43,6 +43,7 @@ constexpr auto kInvGrad = "InvGrad"; constexpr auto kAcoshGrad = "AcoshGrad"; constexpr auto kSoftplusGrad = "SoftplusGrad"; constexpr auto kRsqrtGrad = "RsqrtGrad"; +constexpr auto kReciprocalGrad = "ReciprocalGrad"; template class EltWiseGradCpuTypeFunc : public CpuKernelFunc { @@ -61,6 +62,7 @@ class EltWiseGradCpuTypeFunc : public CpuKernelFunc { void SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void RsqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; + void ReciprocalGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const; @@ -163,6 +165,20 @@ void EltWiseGradCpuTypeFunc::RsqrtGrad(const T *input1, const T *input2, T *o } } +template +void EltWiseGradCpuTypeFunc::ReciprocalGrad(const T *input1, const T *input2, T *out, size_t start, + size_t end) const { + if constexpr ((std::is_same_v) || (std::is_same_v)) { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(-1) * conj(input1[i] * input1[i]) * input2[i]; + } + } else { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(-1) * input1[i] * input1[i] * input2[i]; + } + } +} + template void EltWiseGradCpuTypeFunc::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const { if constexpr (std::is_same::value) { @@ -417,6 +433,7 @@ void EltWiseGradCpuTypeFunc::InitFunc(const BaseOperatorPtr &base_operator, c {prim::kPrimAsinGrad->name(), &EltWiseGradCpuTypeFunc::AsinGrad}, {prim::kPrimACosGrad->name(), &EltWiseGradCpuTypeFunc::ACosGrad}, {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc::RsqrtGrad}, + {prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc::ReciprocalGrad}, {prim::kPrimAtanGrad->name(), &EltWiseGradCpuTypeFunc::AtanGrad}, {prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc::TanhGrad}, {prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc::AsinhGrad}, @@ -434,6 +451,7 @@ void EltWiseGradCpuTypeFunc::InitFunc(const BaseOperatorPtr &base_operator, c static const std::map> elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc::ReluGrad}, + {prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc::ReciprocalGrad}, {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc::RsqrtGrad}}; if (elt_map.find(kernel_name_) == elt_map.end()) { MS_LOG(EXCEPTION) << "EltWiseGradCpu does not support " << kernel_name_ << " with float as input."; @@ -457,6 +475,7 @@ void EltWiseGradCpuTypeFunc::InitFunc(const BaseOperatorPtr &base_operator, c {prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc::AsinhGrad}, {prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc::InvGrad}, {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc::RsqrtGrad}, + {prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc::ReciprocalGrad}, {prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc::AcoshGrad}, {prim::kPrimSoftplusGrad->name(), &EltWiseGradCpuTypeFunc::SoftplusGrad}, {prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc::ReluGrad}}; @@ -499,6 +518,7 @@ void EltWiseGradCpuTypeFunc::InitFunc(const BaseOperatorPtr &base_operator, c {prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc::TanhGrad}, {prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc::InvGrad}, {prim::kPrimSqrtGrad->name(), &EltWiseGradCpuTypeFunc::SqrtGrad}, + {prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc::ReciprocalGrad}, {prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc::RsqrtGrad}}; if (elt_map.find(kernel_name_) == elt_map.end()) { MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_; @@ -705,7 +725,24 @@ static std::map>> ke .AddInputAttr(kNumberTypeComplex64) .AddInputAttr(kNumberTypeComplex64) .AddOutputAttr(kNumberTypeComplex64), - &SpecializeEltWiseGradFunc}}}}; + &SpecializeEltWiseGradFunc}}}, + {kReciprocalGrad, + {{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &SpecializeEltWiseGradFunc}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &SpecializeEltWiseGradFunc}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &SpecializeEltWiseGradFunc}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeComplex64), + &SpecializeEltWiseGradFunc}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeComplex128), + &SpecializeEltWiseGradFunc}}}}; } // namespace bool EltWiseGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, @@ -790,5 +827,7 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, SoftplusGrad, []() { return std::make_shared(kSoftplusGrad); }); MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, RsqrtGrad, []() { return std::make_shared(kRsqrtGrad); }); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ReciprocalGrad, + []() { return std::make_shared(kReciprocalGrad); }); } // namespace kernel } // namespace mindspore diff --git a/mindspore/core/ops/grad/reciprocal_grad.cc b/mindspore/core/ops/grad/reciprocal_grad.cc index 31e501707d9..1125e2832e8 100644 --- a/mindspore/core/ops/grad/reciprocal_grad.cc +++ b/mindspore/core/ops/grad/reciprocal_grad.cc @@ -39,7 +39,7 @@ abstract::ShapePtr ReciprocalGradInferShape(const PrimitivePtr &primitive, } TypePtr ReciprocalGradInferType(const PrimitivePtr &primitive, const std::vector &input_args) { - const std::set valid_types = {kFloat16, kFloat32}; + const std::set valid_types = {kFloat16, kFloat32, kFloat64, kComplex64, kComplex128}; MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(input_args[0]); auto x_type = input_args[0]->BuildType(); @@ -55,8 +55,8 @@ AbstractBasePtr ReciprocalGradInfer(const abstract::AnalysisEnginePtr &, const P MS_EXCEPTION_IF_NULL(primitive); const int64_t input_num = 2; CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); - auto shape = ReciprocalGradInferShape(primitive, input_args); auto type = ReciprocalGradInferType(primitive, input_args); + auto shape = ReciprocalGradInferShape(primitive, input_args); return abstract::MakeAbstract(shape, type); } diff --git a/mindspore/core/ops/reciprocal.cc b/mindspore/core/ops/reciprocal.cc index 317d3b61958..f4dea23b9c4 100644 --- a/mindspore/core/ops/reciprocal.cc +++ b/mindspore/core/ops/reciprocal.cc @@ -30,11 +30,6 @@ namespace ops { namespace { abstract::ShapePtr ReciprocalInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } auto x = input_args[0]->BuildShape(); MS_EXCEPTION_IF_NULL(x); auto shape_ptr = x->cast(); @@ -45,24 +40,26 @@ abstract::ShapePtr ReciprocalInferShape(const PrimitivePtr &primitive, const std TypePtr ReciprocalInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); auto op_name = prim->name(); - (void)CheckAndConvertUtils::CheckInteger("input number", int64_t(input_args.size()), kEqual, 1, op_name); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } std::map types; (void)types.emplace("x", input_args[0]->BuildType()); std::set valid_params_types = {kTensorType}; - (void)CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name); - return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + const std::set common_valid_types = {kInt32, kInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128}; + (void)CheckAndConvertUtils::CheckSubClass("x", input_args[0]->BuildType(), valid_params_types, op_name); + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return input_args[0]->BuildType(); } } // namespace MIND_API_OPERATOR_IMPL(Reciprocal, BaseOperator); AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return abstract::MakeAbstract(ReciprocalInferShape(primitive, input_args), - ReciprocalInferType(primitive, input_args)); + MS_EXCEPTION_IF_NULL(primitive); + const int64_t kInputsNum = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name()); + auto infer_type = ReciprocalInferType(primitive, input_args); + auto infer_shape = ReciprocalInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); } -REGISTER_PRIMITIVE_C(kNameReciprocal, Reciprocal); +REGISTER_PRIMITIVE_EVAL_IMPL(Reciprocal, prim::kPrimReciprocal, ReciprocalInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/reciprocal.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/reciprocal.py new file mode 100644 index 00000000000..a29254dac99 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/reciprocal.py @@ -0,0 +1,34 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Reciprocal op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +reciprocal_op_info = AiCPURegOp("Reciprocal") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.C128_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(reciprocal_op_info) +def _reciprocal_aicpu(): + """Rpeciprocal AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/reciprocal_grad.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/reciprocal_grad.py new file mode 100644 index 00000000000..0a68a0438bd --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/reciprocal_grad.py @@ -0,0 +1,35 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReciprocalGrad op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +reciprocal_grad_op_info = AiCPURegOp("ReciprocalGrad") \ + .fusion_type("OPAQUE") \ + .input(0, "y", "required") \ + .input(1, "dy", "required") \ + .output(0, "z", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \ + .dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \ + .get_op_info() + + +@op_info_register(reciprocal_grad_op_info) +def _reciprocal_grad_aicpu(): + """ReciprocalGrad AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py index c3c3ab6df28..7e3edec17a1 100644 --- a/mindspore/python/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py @@ -104,6 +104,7 @@ class ReciprocalGrad(Primitive): @prim_attr_register def __init__(self): """Initialize ReciprocalGrad""" + self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z']) class RsqrtGrad(Primitive): diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index 7a533c3e2b8..37000cb60b3 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -2192,7 +2192,7 @@ class Sqrt(Primitive): self.init_prim_io_names(inputs=['x'], outputs=['output']) -class Reciprocal(PrimitiveWithInfer): +class Reciprocal(PrimitiveWithCheck): r""" Returns reciprocal of a tensor element-wise. @@ -2230,13 +2230,6 @@ class Reciprocal(PrimitiveWithInfer): self.target = "OTHER" self.init_prim_io_names(inputs=['x'], outputs=['y']) - def infer_shape(self, x): - return x - - def infer_dtype(self, x): - validator.check_subclass("x", x, mstype.tensor, self.name) - return x - def infer_value(self, x): if x is not None: x = x.asnumpy()