!35762 [feat][assistant][I48O8O][I48O5M]Add Sqrt and SqrtGrad operations
Merge pull request !35762 from 桂宁馨/sqrt
This commit is contained in:
commit
8f49f20029
|
@ -685,13 +685,21 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelComplex(const std::vector<AddressP
|
|||
const size_t lens = outputs[0]->size / sizeof(T);
|
||||
static const std::unordered_map<std::string,
|
||||
std::function<void(ArithmeticSelfCpuKernelFunc *, const T *, T *, size_t)>>
|
||||
arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square<T>}, {prim::kPrimAcosh->name(), ComplexAcosh<T>},
|
||||
{prim::kPrimAsinh->name(), ComplexAsinh<T>}, {prim::kPrimNeg->name(), Neg<T>},
|
||||
{prim::kPrimSinh->name(), ComplexSinh<T>}, {prim::kPrimCosh->name(), ComplexCosh<T>},
|
||||
{prim::kPrimSin->name(), ComplexSin<T>}, {prim::kPrimCos->name(), ComplexCos<T>},
|
||||
{prim::kPrimRsqrt->name(), Rsqrt<T>}, {prim::kPrimTan->name(), Tan<T>},
|
||||
{prim::kPrimTanh->name(), Tanh<T>}, {prim::kPrimAtanh->name(), Atanh<T>},
|
||||
{prim::kPrimSign->name(), ComplexSign<T>}, {prim::kPrimLog->name(), ComplexLog<T>},
|
||||
arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square<T>},
|
||||
{prim::kPrimAcosh->name(), ComplexAcosh<T>},
|
||||
{prim::kPrimAsinh->name(), ComplexAsinh<T>},
|
||||
{prim::kPrimNeg->name(), Neg<T>},
|
||||
{prim::kPrimSinh->name(), ComplexSinh<T>},
|
||||
{prim::kPrimCosh->name(), ComplexCosh<T>},
|
||||
{prim::kPrimSin->name(), ComplexSin<T>},
|
||||
{prim::kPrimCos->name(), ComplexCos<T>},
|
||||
{prim::kPrimRsqrt->name(), Rsqrt<T>},
|
||||
{prim::kPrimSqrt->name(), Sqrt<T>},
|
||||
{prim::kPrimTan->name(), Tan<T>},
|
||||
{prim::kPrimTanh->name(), Tanh<T>},
|
||||
{prim::kPrimAtanh->name(), Atanh<T>},
|
||||
{prim::kPrimSign->name(), ComplexSign<T>},
|
||||
{prim::kPrimLog->name(), ComplexLog<T>},
|
||||
{prim::kPrimExp->name(), ComplexExp<T>}};
|
||||
const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_);
|
||||
if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) {
|
||||
|
@ -924,7 +932,9 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
|
|||
{kSqrt,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
[]() { return std::make_shared<SqrtMKLKernelFunc>(); }}}},
|
||||
[]() { return std::make_shared<SqrtMKLKernelFunc>(); }},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
|
||||
{kLog,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
|
|
|
@ -133,8 +133,16 @@ void EltWiseGradCpuTypeFunc<T>::SigmoidGrad(const T *input1, const T *input2, T
|
|||
|
||||
template <typename T>
|
||||
void EltWiseGradCpuTypeFunc<T>::SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = input2[i] / (input1[i] * 2);
|
||||
if constexpr ((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>)) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
constexpr T coff = static_cast<T>(2);
|
||||
out[i] = (input2[i] / std::conj(input1[i] * coff));
|
||||
}
|
||||
} else {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
constexpr T coff = static_cast<T>(2);
|
||||
out[i] = input2[i] / (input1[i] * coff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -430,6 +438,7 @@ void EltWiseGradCpuTypeFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, c
|
|||
elt_map{{prim::kPrimAcoshGrad->name(), &EltWiseGradCpuTypeFunc<T>::ComplexAcoshGrad},
|
||||
{prim::kPrimAsinhGrad->name(), &EltWiseGradCpuTypeFunc<T>::ComplexAsinhGrad},
|
||||
{prim::kPrimTanhGrad->name(), &EltWiseGradCpuTypeFunc<T>::TanhGrad},
|
||||
{prim::kPrimSqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::SqrtGrad},
|
||||
{prim::kPrimRsqrtGrad->name(), &EltWiseGradCpuTypeFunc<T>::RsqrtGrad}};
|
||||
if (elt_map.find(kernel_name_) == elt_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'EltWiseGrad', it does not support " << kernel_name_;
|
||||
|
@ -496,7 +505,17 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, FuncCreator>>> ke
|
|||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&SpecializeEltWiseGradFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&SpecializeEltWiseGradFunc<double>}}},
|
||||
&SpecializeEltWiseGradFunc<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&SpecializeEltWiseGradFunc<complex64>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&SpecializeEltWiseGradFunc<complex128>}}},
|
||||
{kTanhGrad,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&SpecializeEltWiseGradFunc<float>},
|
||||
|
|
|
@ -41,10 +41,10 @@ abstract::ShapePtr SqrtGradInferShape(const PrimitivePtr &primitive, const std::
|
|||
|
||||
TypePtr SqrtGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("dout", input_args[kInputIndex1]->BuildType());
|
||||
(void)types.emplace("y", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("dy", input_args[kInputIndex1]->BuildType());
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex128, kComplex64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ constexpr auto kNameSqrtGrad = "SqrtGrad";
|
|||
class MIND_API SqrtGrad : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SqrtGrad);
|
||||
SqrtGrad() : BaseOperator(kNameSqrtGrad) { InitIOName({"out_backprop", "input"}, {"output"}); }
|
||||
SqrtGrad() : BaseOperator(kNameSqrtGrad) { InitIOName({"y", "dy"}, {"z"}); }
|
||||
void Init() const {}
|
||||
};
|
||||
} // namespace ops
|
||||
|
|
|
@ -106,7 +106,7 @@ ValuePtr SqrtInferValue(const PrimitivePtr &prim, const std::vector<AbstractBase
|
|||
}
|
||||
iter->second(x_datac, result_datac, data_size);
|
||||
return result_tensor;
|
||||
} // namespace
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Sqrt, BaseOperator);
|
||||
|
@ -115,8 +115,9 @@ AbstractBasePtr SqrtInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
|
||||
return abstract::MakeAbstract(SqrtInferShape(primitive, input_args), SqrtInferType(primitive, input_args));
|
||||
auto types = SqrtInferType(primitive, input_args);
|
||||
auto shapes = SqrtInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Sqrt, prim::kPrimSqrt, SqrtInfer, SqrtInferValue, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -51,6 +51,8 @@ from .edit_distance import _edit_distance_aicpu
|
|||
from .unique_with_pad import _unique_with_pad_aicpu
|
||||
from .bartlett_window import _bartlett_window_aicpu
|
||||
from .add_n import _add_n_aicpu
|
||||
from .sqrt import _sqrt_aicpu
|
||||
from .sqrt_grad import _sqrt_grad_aicpu
|
||||
from .sub_and_filter import _sub_and_filter_aicpu
|
||||
from .pad_and_shift import _pad_and_shift_aicpu
|
||||
from .data_format_vec_permute import _data_format_vec_permute_aicpu
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Sqrt op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
sqrt_op_info = AiCPURegOp("Sqrt") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sqrt_op_info)
|
||||
def _sqrt_aicpu():
|
||||
"""Sqrt AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SqrtGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
sqrt_grad_op_info = AiCPURegOp("SqrtGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "y", "required") \
|
||||
.input(1, "dy", "required")\
|
||||
.output(0, "z", "required") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sqrt_grad_op_info)
|
||||
def _sqrt_grad_aicpu():
|
||||
"""SqrtGrad AiCPU register"""
|
||||
return
|
|
@ -124,6 +124,7 @@ class SqrtGrad(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SqrtGrad"""
|
||||
self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
|
||||
|
||||
|
||||
class BatchNormGrad(Primitive):
|
||||
|
|
Loading…
Reference in New Issue