!35762 [feat][assistant][I48O8O][I48O5M]Add Sqrt and SqrtGrad operations

Merge pull request !35762 from 桂宁馨/sqrt
This commit is contained in:
i-robot 2022-08-02 02:29:02 +00:00 committed by Gitee
commit 8f49f20029
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 120 additions and 18 deletions

View File

@ -685,13 +685,21 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelComplex(const std::vector<AddressP
const size_t lens = outputs[0]->size / sizeof(T);
static const std::unordered_map<std::string,
std::function<void(ArithmeticSelfCpuKernelFunc *, const T *, T *, size_t)>>
arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square<T>}, {prim::kPrimAcosh->name(), ComplexAcosh<T>},
{prim::kPrimAsinh->name(), ComplexAsinh<T>}, {prim::kPrimNeg->name(), Neg<T>},
{prim::kPrimSinh->name(), ComplexSinh<T>}, {prim::kPrimCosh->name(), ComplexCosh<T>},
{prim::kPrimSin->name(), ComplexSin<T>}, {prim::kPrimCos->name(), ComplexCos<T>},
{prim::kPrimRsqrt->name(), Rsqrt<T>}, {prim::kPrimTan->name(), Tan<T>},
{prim::kPrimTanh->name(), Tanh<T>}, {prim::kPrimAtanh->name(), Atanh<T>},
{prim::kPrimSign->name(), ComplexSign<T>}, {prim::kPrimLog->name(), ComplexLog<T>},
arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square<T>},
{prim::kPrimAcosh->name(), ComplexAcosh<T>},
{prim::kPrimAsinh->name(), ComplexAsinh<T>},
{prim::kPrimNeg->name(), Neg<T>},
{prim::kPrimSinh->name(), ComplexSinh<T>},
{prim::kPrimCosh->name(), ComplexCosh<T>},
{prim::kPrimSin->name(), ComplexSin<T>},
{prim::kPrimCos->name(), ComplexCos<T>},
{prim::kPrimRsqrt->name(), Rsqrt<T>},
{prim::kPrimSqrt->name(), Sqrt<T>},
{prim::kPrimTan->name(), Tan<T>},
{prim::kPrimTanh->name(), Tanh<T>},
{prim::kPrimAtanh->name(), Atanh<T>},
{prim::kPrimSign->name(), ComplexSign<T>},
{prim::kPrimLog->name(), ComplexLog<T>},
{prim::kPrimExp->name(), ComplexExp<T>}};
const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_);
if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) {
@ -924,7 +932,9 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
{kSqrt,
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
[]() { return std::make_shared<SqrtMKLKernelFunc>(); }}}},
[]() { return std::make_shared<SqrtMKLKernelFunc>(); }},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
{kLog,
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},

View File

@ -133,8 +133,16 @@ void EltWiseGradCpuTypeFunc<T>::SigmoidGrad(const T *input1, const T *input2, T
template <typename T>
void EltWiseGradCpuTypeFunc<T>::SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
for (size_t i = start; i < end; i++) {
out[i] = input2[i] / (input1[i] * 2);
if constexpr ((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>)) {
for (size_t i = start; i < end; i++) {
constexpr T coff = static_cast<T>(2);
out[i] = (input2[i] / std::conj(input1[i] * coff));
}
} else {
for (size_t i = start; i < end; i++) {
constexpr T coff = static_cast<T>(2);
out[i] = input2[i] / (input1[i] * coff);
}
}
}
@ -430,6 +438,7 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
elt_map{{prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc<T>::ComplexAcoshGrad},
{prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::ComplexAsinhGrad},
{prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc<T>::TanhGrad},
{prim::kPrimSqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::SqrtGrad},
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_;
@ -496,7 +505,17 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, FuncCreator>>> ke
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SpecializeEltWiseGradFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&SpecializeEltWiseGradFunc<double>}}},
&SpecializeEltWiseGradFunc<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&SpecializeEltWiseGradFunc<complex64>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&SpecializeEltWiseGradFunc<complex128>}}},
{kTanhGrad,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SpecializeEltWiseGradFunc<float>},

View File

@ -41,10 +41,10 @@ abstract::ShapePtr SqrtGradInferShape(const PrimitivePtr &primitive, const std::
TypePtr SqrtGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[kInputIndex0]->BuildType());
(void)types.emplace("dout", input_args[kInputIndex1]->BuildType());
(void)types.emplace("y", input_args[kInputIndex0]->BuildType());
(void)types.emplace("dy", input_args[kInputIndex1]->BuildType());
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex128, kComplex64};
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
return input_args[kInputIndex0]->BuildType();
}

View File

@ -25,7 +25,7 @@ constexpr auto kNameSqrtGrad = "SqrtGrad";
class MIND_API SqrtGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SqrtGrad);
SqrtGrad() : BaseOperator(kNameSqrtGrad) { InitIOName({"out_backprop", "input"}, {"output"}); }
SqrtGrad() : BaseOperator(kNameSqrtGrad) { InitIOName({"y", "dy"}, {"z"}); }
void Init() const {}
};
} // namespace ops

View File

@ -106,7 +106,7 @@ ValuePtr SqrtInferValue(const PrimitivePtr &prim, const std::vector<AbstractBase
}
iter->second(x_datac, result_datac, data_size);
return result_tensor;
} // namespace
}
} // namespace
MIND_API_OPERATOR_IMPL(Sqrt, BaseOperator);
@ -115,8 +115,9 @@ AbstractBasePtr SqrtInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
MS_EXCEPTION_IF_NULL(primitive);
constexpr int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
return abstract::MakeAbstract(SqrtInferShape(primitive, input_args), SqrtInferType(primitive, input_args));
auto types = SqrtInferType(primitive, input_args);
auto shapes = SqrtInferShape(primitive, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Sqrt, prim::kPrimSqrt, SqrtInfer, SqrtInferValue, true);
} // namespace ops

View File

@ -51,6 +51,8 @@ from .edit_distance import _edit_distance_aicpu
from .unique_with_pad import _unique_with_pad_aicpu
from .bartlett_window import _bartlett_window_aicpu
from .add_n import _add_n_aicpu
from .sqrt import _sqrt_aicpu
from .sqrt_grad import _sqrt_grad_aicpu
from .sub_and_filter import _sub_and_filter_aicpu
from .pad_and_shift import _pad_and_shift_aicpu
from .data_format_vec_permute import _data_format_vec_permute_aicpu

View File

@ -0,0 +1,34 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sqrt op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
sqrt_op_info = AiCPURegOp("Sqrt") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(sqrt_op_info)
def _sqrt_aicpu():
"""Sqrt AiCPU register"""
return

View File

@ -0,0 +1,35 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SqrtGrad op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
sqrt_grad_op_info = AiCPURegOp("SqrtGrad") \
.fusion_type("OPAQUE") \
.input(0, "y", "required") \
.input(1, "dy", "required")\
.output(0, "z", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(sqrt_grad_op_info)
def _sqrt_grad_aicpu():
"""SqrtGrad AiCPU register"""
return

View File

@ -124,6 +124,7 @@ class SqrtGrad(Primitive):
@prim_attr_register
def __init__(self):
"""Initialize SqrtGrad"""
self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
class BatchNormGrad(Primitive):