From 8bbeacecbdb4f1736e978a8b34ead0fef9fdc476 Mon Sep 17 00:00:00 2001 From: zhaosida Date: Tue, 1 Dec 2020 10:20:13 +0800 Subject: [PATCH] support SqrtGrad dynamic shape --- mindspore/core/abstract/infer_functions.h | 2 + mindspore/core/abstract/prim_maths.cc | 13 ++++++ .../core/abstract/primitive_infer_map.cc | 1 + mindspore/core/base/core_ops.h | 1 + mindspore/ops/_op_impl/tbe/__init__.py | 1 + mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py | 44 +++++++++++++++++++ 6 files changed, 62 insertions(+) create mode 100644 mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index de1179c9053..1f684ef556f 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -95,6 +95,8 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index 6d71f01bf05..3bd5e1b6869 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -47,6 +47,19 @@ AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &pri return inp->Clone()->Broaden(); } +AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto out = CheckArg(op_name, args_spec_list, 0); + auto dout = CheckArg(op_name, args_spec_list, 1); + (void)CheckDtypeSame(op_name, out, dout); + (void)CheckShapeSame(op_name, out, dout); + + return out->Broaden(); +} + AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors. diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index f8dff15bfd6..17aa4bbc6e4 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -41,6 +41,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimTensorAdd, {InferImplTensorAdd, true}}, {prim::kPrimSquare, {InferImplSquare, true}}, {prim::kPrimSqrt, {InferImplSqrt, true}}, + {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, {prim::kPrimSub, {InferImplSub, true}}, {prim::kPrimEqual, {InferImplEqual, true}}, {prim::kPrimMinimum, {InferImplMinimum, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index d877411806b..a743e18ede4 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -241,6 +241,7 @@ inline const PrimitivePtr kPrimInplaceSub = std::make_shared("Inplace inline const PrimitivePtr kPrimPow = std::make_shared("Pow"); inline const PrimitivePtr kPrimRealDiv = std::make_shared("RealDiv"); inline const PrimitivePtr kPrimSqrt = std::make_shared("Sqrt"); +inline const PrimitivePtr kPrimSqrtGrad = std::make_shared("SqrtGrad"); inline const PrimitivePtr kPrimReciprocal = std::make_shared("Reciprocal"); inline const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); inline const PrimitivePtr kPrimAbs = std::make_shared("Abs"); diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 990a0203344..8fa3b4ed9f5 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -260,6 +260,7 @@ from .cumprod import _cumprop_tbe from .reduce_prod import _reduce_prod_tbe from .reciprocal_grad import _reciprocal_grad_tbe from .sqrt_grad import _sqrt_grad_tbe +from .sqrt_grad_ds import _sqrt_grad_ds_tbe from .rsqrt_grad import _rsqrt_grad_tbe from .flatten_grad import _flatten_grad_tbe from .scatter_add import _scatter_add_tbe diff --git a/mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py b/mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py new file mode 100644 index 00000000000..5a15b004914 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SqrtGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sqrt_grad_op_info = TBERegOp("SqrtGrad") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("sqrt_grad.so") \ + .compute_cost(10) \ + .kernel_name("sqrt_grad") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "dy", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .get_op_info() + + +@op_info_register(sqrt_grad_op_info) +def _sqrt_grad_ds_tbe(): + """SqrtGrad TBE register""" + return