[assistant][I48O4Z] add new math operator Reciprocal
This commit is contained in:
parent
a4d62ed336
commit
daef90a276
|
@ -207,4 +207,5 @@ mindspore/mindspore/lite/src/litert/kernel/cpu/control/tensorlist_setitem.cc:min
|
||||||
mindspore/mindspore/python/mindspore/ops/_utils/utils.py:get_broadcast_shape
|
mindspore/mindspore/python/mindspore/ops/_utils/utils.py:get_broadcast_shape
|
||||||
mindspore/mindspore/ccsrc/pybind_api/ir/dtype_py.cc:mindspore::RegTyping
|
mindspore/mindspore/ccsrc/pybind_api/ir/dtype_py.cc:mindspore::RegTyping
|
||||||
mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor
|
mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor
|
||||||
|
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc:mindspore::kernel::EltWiseGradCpuTypeFunc<T>::InitFunc
|
||||||
mindspore/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc:mindspore::lite::quant::WeightQuantizer::LinearQuant
|
mindspore/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc:mindspore::lite::quant::WeightQuantizer::LinearQuant
|
||||||
|
|
|
@ -723,6 +723,7 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelComplex(const std::vector<AddressP
|
||||||
{prim::kPrimCos->name(), ComplexCos<T>},
|
{prim::kPrimCos->name(), ComplexCos<T>},
|
||||||
{prim::kPrimACos->name(), ComplexACos<T>},
|
{prim::kPrimACos->name(), ComplexACos<T>},
|
||||||
{prim::kPrimRsqrt->name(), Rsqrt<T>},
|
{prim::kPrimRsqrt->name(), Rsqrt<T>},
|
||||||
|
{prim::kPrimReciprocal->name(), Reciprocal<T>},
|
||||||
{prim::kPrimSqrt->name(), Sqrt<T>},
|
{prim::kPrimSqrt->name(), Sqrt<T>},
|
||||||
{prim::kPrimTan->name(), Tan<T>},
|
{prim::kPrimTan->name(), Tan<T>},
|
||||||
{prim::kPrimTanh->name(), Tanh<T>},
|
{prim::kPrimTanh->name(), Tanh<T>},
|
||||||
|
@ -913,7 +914,10 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
|
||||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||||
{kReciprocal,
|
{kReciprocal,
|
||||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
|
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
|
||||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc},
|
||||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||||
{kInv,
|
{kInv,
|
||||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
|
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
|
||||||
|
|
|
@ -43,6 +43,7 @@ constexpr auto kInvGrad = "InvGrad";
|
||||||
constexpr auto kAcoshGrad = "AcoshGrad";
|
constexpr auto kAcoshGrad = "AcoshGrad";
|
||||||
constexpr auto kSoftplusGrad = "SoftplusGrad";
|
constexpr auto kSoftplusGrad = "SoftplusGrad";
|
||||||
constexpr auto kRsqrtGrad = "RsqrtGrad";
|
constexpr auto kRsqrtGrad = "RsqrtGrad";
|
||||||
|
constexpr auto kReciprocalGrad = "ReciprocalGrad";
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class EltWiseGradCpuTypeFunc : public CpuKernelFunc {
|
class EltWiseGradCpuTypeFunc : public CpuKernelFunc {
|
||||||
|
@ -61,6 +62,7 @@ class EltWiseGradCpuTypeFunc : public CpuKernelFunc {
|
||||||
void SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
void SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
||||||
void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
||||||
void RsqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
void RsqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
||||||
|
void ReciprocalGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
||||||
void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
||||||
void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
||||||
void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
||||||
|
@ -163,6 +165,20 @@ void EltWiseGradCpuTypeFunc<T>::RsqrtGrad(const T *input1, const T *input2, T *o
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void EltWiseGradCpuTypeFunc<T>::ReciprocalGrad(const T *input1, const T *input2, T *out, size_t start,
|
||||||
|
size_t end) const {
|
||||||
|
if constexpr ((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>)) {
|
||||||
|
for (size_t i = start; i < end; i++) {
|
||||||
|
out[i] = static_cast<T>(-1) * conj(input1[i] * input1[i]) * input2[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = start; i < end; i++) {
|
||||||
|
out[i] = static_cast<T>(-1) * input1[i] * input1[i] * input2[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EltWiseGradCpuTypeFunc<T>::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
void EltWiseGradCpuTypeFunc<T>::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
||||||
if constexpr (std::is_same<T, float>::value) {
|
if constexpr (std::is_same<T, float>::value) {
|
||||||
|
@ -417,6 +433,7 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
|
||||||
{prim::kPrimAsinGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinGrad},
|
{prim::kPrimAsinGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinGrad},
|
||||||
{prim::kPrimACosGrad->name(), &EltWiseGradCpuTypeFunc<T>::ACosGrad},
|
{prim::kPrimACosGrad->name(), &EltWiseGradCpuTypeFunc<T>::ACosGrad},
|
||||||
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad},
|
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad},
|
||||||
|
{prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReciprocalGrad},
|
||||||
{prim::kPrimAtanGrad->name(), &EltWiseGradCpuTypeFunc<T>::AtanGrad},
|
{prim::kPrimAtanGrad->name(), &EltWiseGradCpuTypeFunc<T>::AtanGrad},
|
||||||
{prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc<T>::TanhGrad},
|
{prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc<T>::TanhGrad},
|
||||||
{prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinhGrad},
|
{prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinhGrad},
|
||||||
|
@ -434,6 +451,7 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
|
||||||
static const std::map<std::string,
|
static const std::map<std::string,
|
||||||
std::function<void(EltWiseGradCpuTypeFunc *, const T *, const T *, T *, size_t, size_t)>>
|
std::function<void(EltWiseGradCpuTypeFunc *, const T *, const T *, T *, size_t, size_t)>>
|
||||||
elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad},
|
elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad},
|
||||||
|
{prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReciprocalGrad},
|
||||||
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
|
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
|
||||||
if (elt_map.find(kernel_name_) == elt_map.end()) {
|
if (elt_map.find(kernel_name_) == elt_map.end()) {
|
||||||
MS_LOG(EXCEPTION) << "EltWiseGradCpu does not support " << kernel_name_ << " with float as input.";
|
MS_LOG(EXCEPTION) << "EltWiseGradCpu does not support " << kernel_name_ << " with float as input.";
|
||||||
|
@ -457,6 +475,7 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
|
||||||
{prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinhGrad},
|
{prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::AsinhGrad},
|
||||||
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
|
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
|
||||||
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad},
|
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad},
|
||||||
|
{prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReciprocalGrad},
|
||||||
{prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc<T>::AcoshGrad},
|
{prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc<T>::AcoshGrad},
|
||||||
{prim::kPrimSoftplusGrad->name(), &EltWiseGradCpuTypeFunc<T>::SoftplusGrad},
|
{prim::kPrimSoftplusGrad->name(), &EltWiseGradCpuTypeFunc<T>::SoftplusGrad},
|
||||||
{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}};
|
{prim::kPrimReluGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReluGrad}};
|
||||||
|
@ -499,6 +518,7 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
|
||||||
{prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc<T>::TanhGrad},
|
{prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc<T>::TanhGrad},
|
||||||
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
|
{prim::kPrimInvGrad->name(), &EltWiseGradCpuTypeFunc<T>::InvGrad},
|
||||||
{prim::kPrimSqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::SqrtGrad},
|
{prim::kPrimSqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::SqrtGrad},
|
||||||
|
{prim::kPrimReciprocalGrad->name(), &EltWiseGradCpuTypeFunc<T>::ReciprocalGrad},
|
||||||
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
|
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
|
||||||
if (elt_map.find(kernel_name_) == elt_map.end()) {
|
if (elt_map.find(kernel_name_) == elt_map.end()) {
|
||||||
MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_;
|
MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_;
|
||||||
|
@ -705,7 +725,24 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, FuncCreator>>> ke
|
||||||
.AddInputAttr(kNumberTypeComplex64)
|
.AddInputAttr(kNumberTypeComplex64)
|
||||||
.AddInputAttr(kNumberTypeComplex64)
|
.AddInputAttr(kNumberTypeComplex64)
|
||||||
.AddOutputAttr(kNumberTypeComplex64),
|
.AddOutputAttr(kNumberTypeComplex64),
|
||||||
&SpecializeEltWiseGradFunc<complex64>}}}};
|
&SpecializeEltWiseGradFunc<complex64>}}},
|
||||||
|
{kReciprocalGrad,
|
||||||
|
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
&SpecializeEltWiseGradFunc<float16>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
&SpecializeEltWiseGradFunc<float>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
&SpecializeEltWiseGradFunc<double>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeComplex64)
|
||||||
|
.AddInputAttr(kNumberTypeComplex64)
|
||||||
|
.AddOutputAttr(kNumberTypeComplex64),
|
||||||
|
&SpecializeEltWiseGradFunc<complex64>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeComplex128)
|
||||||
|
.AddInputAttr(kNumberTypeComplex128)
|
||||||
|
.AddOutputAttr(kNumberTypeComplex128),
|
||||||
|
&SpecializeEltWiseGradFunc<complex128>}}}};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool EltWiseGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
bool EltWiseGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
@ -790,5 +827,7 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, SoftplusGrad,
|
||||||
[]() { return std::make_shared<EltWiseGradCpuKernelMod>(kSoftplusGrad); });
|
[]() { return std::make_shared<EltWiseGradCpuKernelMod>(kSoftplusGrad); });
|
||||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, RsqrtGrad,
|
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, RsqrtGrad,
|
||||||
[]() { return std::make_shared<EltWiseGradCpuKernelMod>(kRsqrtGrad); });
|
[]() { return std::make_shared<EltWiseGradCpuKernelMod>(kRsqrtGrad); });
|
||||||
|
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ReciprocalGrad,
|
||||||
|
[]() { return std::make_shared<EltWiseGradCpuKernelMod>(kReciprocalGrad); });
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -39,7 +39,7 @@ abstract::ShapePtr ReciprocalGradInferShape(const PrimitivePtr &primitive,
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr ReciprocalGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr ReciprocalGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||||
auto x_type = input_args[0]->BuildType();
|
auto x_type = input_args[0]->BuildType();
|
||||||
|
@ -55,8 +55,8 @@ AbstractBasePtr ReciprocalGradInfer(const abstract::AnalysisEnginePtr &, const P
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
const int64_t input_num = 2;
|
const int64_t input_num = 2;
|
||||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||||
auto shape = ReciprocalGradInferShape(primitive, input_args);
|
|
||||||
auto type = ReciprocalGradInferType(primitive, input_args);
|
auto type = ReciprocalGradInferType(primitive, input_args);
|
||||||
|
auto shape = ReciprocalGradInferShape(primitive, input_args);
|
||||||
return abstract::MakeAbstract(shape, type);
|
return abstract::MakeAbstract(shape, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,11 +30,6 @@ namespace ops {
|
||||||
namespace {
|
namespace {
|
||||||
abstract::ShapePtr ReciprocalInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr ReciprocalInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
|
||||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
|
||||||
for (const auto &item : input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
|
||||||
}
|
|
||||||
auto x = input_args[0]->BuildShape();
|
auto x = input_args[0]->BuildShape();
|
||||||
MS_EXCEPTION_IF_NULL(x);
|
MS_EXCEPTION_IF_NULL(x);
|
||||||
auto shape_ptr = x->cast<abstract::ShapePtr>();
|
auto shape_ptr = x->cast<abstract::ShapePtr>();
|
||||||
|
@ -45,24 +40,26 @@ abstract::ShapePtr ReciprocalInferShape(const PrimitivePtr &primitive, const std
|
||||||
TypePtr ReciprocalInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr ReciprocalInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
auto op_name = prim->name();
|
auto op_name = prim->name();
|
||||||
(void)CheckAndConvertUtils::CheckInteger("input number", int64_t(input_args.size()), kEqual, 1, op_name);
|
|
||||||
for (const auto &item : input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
|
||||||
}
|
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
(void)types.emplace("x", input_args[0]->BuildType());
|
(void)types.emplace("x", input_args[0]->BuildType());
|
||||||
std::set<TypePtr> valid_params_types = {kTensorType};
|
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||||
(void)CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_params_types, op_name);
|
const std::set<TypePtr> common_valid_types = {kInt32, kInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
(void)CheckAndConvertUtils::CheckSubClass("x", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||||
|
return input_args[0]->BuildType();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
MIND_API_OPERATOR_IMPL(Reciprocal, BaseOperator);
|
MIND_API_OPERATOR_IMPL(Reciprocal, BaseOperator);
|
||||||
AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
return abstract::MakeAbstract(ReciprocalInferShape(primitive, input_args),
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
ReciprocalInferType(primitive, input_args));
|
const int64_t kInputsNum = 1;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||||
|
auto infer_type = ReciprocalInferType(primitive, input_args);
|
||||||
|
auto infer_shape = ReciprocalInferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameReciprocal, Reciprocal);
|
REGISTER_PRIMITIVE_EVAL_IMPL(Reciprocal, prim::kPrimReciprocal, ReciprocalInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Reciprocal op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
reciprocal_op_info = AiCPURegOp("Reciprocal") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
|
||||||
|
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(reciprocal_op_info)
|
||||||
|
def _reciprocal_aicpu():
|
||||||
|
"""Rpeciprocal AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReciprocalGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
reciprocal_grad_op_info = AiCPURegOp("ReciprocalGrad") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "y", "required") \
|
||||||
|
.input(1, "dy", "required") \
|
||||||
|
.output(0, "z", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
|
||||||
|
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(reciprocal_grad_op_info)
|
||||||
|
def _reciprocal_grad_aicpu():
|
||||||
|
"""ReciprocalGrad AiCPU register"""
|
||||||
|
return
|
|
@ -104,6 +104,7 @@ class ReciprocalGrad(Primitive):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize ReciprocalGrad"""
|
"""Initialize ReciprocalGrad"""
|
||||||
|
self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
|
||||||
|
|
||||||
|
|
||||||
class RsqrtGrad(Primitive):
|
class RsqrtGrad(Primitive):
|
||||||
|
|
|
@ -2192,7 +2192,7 @@ class Sqrt(Primitive):
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
class Reciprocal(PrimitiveWithInfer):
|
class Reciprocal(PrimitiveWithCheck):
|
||||||
r"""
|
r"""
|
||||||
Returns reciprocal of a tensor element-wise.
|
Returns reciprocal of a tensor element-wise.
|
||||||
|
|
||||||
|
@ -2230,13 +2230,6 @@ class Reciprocal(PrimitiveWithInfer):
|
||||||
self.target = "OTHER"
|
self.target = "OTHER"
|
||||||
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
self.init_prim_io_names(inputs=['x'], outputs=['y'])
|
||||||
|
|
||||||
def infer_shape(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
def infer_dtype(self, x):
|
|
||||||
validator.check_subclass("x", x, mstype.tensor, self.name)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def infer_value(self, x):
|
def infer_value(self, x):
|
||||||
if x is not None:
|
if x is not None:
|
||||||
x = x.asnumpy()
|
x = x.asnumpy()
|
||||||
|
|
Loading…
Reference in New Issue