add 4 grad ops

This commit is contained in:
fangzehua 2020-07-28 11:05:55 +08:00
parent 51fcaf6e61
commit a80432f08e
11 changed files with 228 additions and 18 deletions

View File

@ -61,7 +61,6 @@ const char kNameReduceSum[] = "ReduceSum";
const char kNameIsFinite[] = "isFinite"; const char kNameIsFinite[] = "isFinite";
const char kNameReciprocal[] = "Reciprocal"; const char kNameReciprocal[] = "Reciprocal";
const char kNameRsqrt[] = "Rsqrt"; const char kNameRsqrt[] = "Rsqrt";
const char kNameRsqrtGrad[] = "RsqrtGrad";
const char kNameSqrt[] = "Sqrt"; const char kNameSqrt[] = "Sqrt";
const char kNameSquare[] = "Square"; const char kNameSquare[] = "Square";
const char kNameSquaredDifference[] = "SquaredDifference"; const char kNameSquaredDifference[] = "SquaredDifference";
@ -83,6 +82,9 @@ const char kNameFlattenGrad[] = "FlattenGrad";
const char kNameConvolution[] = "Convolution"; const char kNameConvolution[] = "Convolution";
const char kNameBiasAdd[] = "BiasAdd"; const char kNameBiasAdd[] = "BiasAdd";
const char kNameMaxPoolGrad[] = "MaxPoolGrad"; const char kNameMaxPoolGrad[] = "MaxPoolGrad";
const char kNameRsqrtGrad[] = "RsqrtGrad";
const char kNameSqrtGrad[] = "SqrtGrad";
const char kNameReciprocalGrad[] = "ReciprocalGrad";
const char kNameAvgPoolGrad[] = "AvgPoolGrad"; const char kNameAvgPoolGrad[] = "AvgPoolGrad";
const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax"; const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax";
const char kNameApplyMomentum[] = "ApplyMomentum"; const char kNameApplyMomentum[] = "ApplyMomentum";
@ -233,6 +235,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameAllgather), ADPT_DESC(HcomAllGather)}, {string(kNameAllgather), ADPT_DESC(HcomAllGather)},
{string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)}, {string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)},
{string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)},
{string(kNameSqrtGrad), ADPT_DESC(SqrtGrad)},
{string(kNameReciprocalGrad), ADPT_DESC(ReciprocalGrad)},
{string(kNameRsqrtGrad), ADPT_DESC(RsqrtGrad)},
{string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)},
{string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)},
{string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)},

View File

@ -726,6 +726,21 @@ ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; {"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}};
// RsqrtGrad
INPUT_MAP(RsqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(RsqrtGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(RsqrtGrad) = {{0, OUTPUT_DESC(z)}};
// SqrtGrad
INPUT_MAP(SqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(SqrtGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(SqrtGrad) = {{0, OUTPUT_DESC(z)}};
// ReciprocalGrad
INPUT_MAP(ReciprocalGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
ATTR_MAP(ReciprocalGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(ReciprocalGrad) = {{0, OUTPUT_DESC(z)}};
// avgpoolgrad // avgpoolgrad
INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}}; INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}};
ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}, ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},

View File

@ -439,6 +439,12 @@ DECLARE_OP_ADAPTER(MaxPool)
DECLARE_OP_USE_OUTPUT(MaxPool) DECLARE_OP_USE_OUTPUT(MaxPool)
DECLARE_OP_ADAPTER(MaxPoolGrad) DECLARE_OP_ADAPTER(MaxPoolGrad)
DECLARE_OP_USE_OUTPUT(MaxPoolGrad) DECLARE_OP_USE_OUTPUT(MaxPoolGrad)
DECLARE_OP_ADAPTER(SqrtGrad)
DECLARE_OP_USE_OUTPUT(SqrtGrad)
DECLARE_OP_ADAPTER(ReciprocalGrad)
DECLARE_OP_USE_OUTPUT(ReciprocalGrad)
DECLARE_OP_ADAPTER(RsqrtGrad)
DECLARE_OP_USE_OUTPUT(RsqrtGrad)
DECLARE_OP_ADAPTER(AvgPool) DECLARE_OP_ADAPTER(AvgPool)
DECLARE_OP_USE_OUTPUT(AvgPool) DECLARE_OP_USE_OUTPUT(AvgPool)
DECLARE_OP_ADAPTER(AvgPoolGrad) DECLARE_OP_ADAPTER(AvgPoolGrad)

View File

@ -366,15 +366,10 @@ def get_bprop_square(self):
@bprop_getters.register(P.Sqrt) @bprop_getters.register(P.Sqrt)
def get_bprop_sqrt(self): def get_bprop_sqrt(self):
"""Grad definition for `Sqrt` operation.""" """Grad definition for `Sqrt` operation."""
mul_func = P.Mul() sqrt_grad = G.SqrtGrad()
fill_func = P.Fill()
div_op = P.RealDiv()
sqrt = P.Sqrt()
dtype = P.DType()
def bprop(x, out, dout): def bprop(x, out, dout):
temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x)) dx = sqrt_grad(out, dout)
dx = mul_func(dout, temp)
return (dx,) return (dx,)
return bprop return bprop
@ -383,10 +378,10 @@ def get_bprop_sqrt(self):
@bprop_getters.register(P.Rsqrt) @bprop_getters.register(P.Rsqrt)
def get_bprop_rsqrt(self): def get_bprop_rsqrt(self):
"""Grad definition for `Rsqrt` operation.""" """Grad definition for `Rsqrt` operation."""
rsqrt_grad = G.RsqrtGrad()
def bprop(x, out, dout): def bprop(x, out, dout):
grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x) * x) dx = rsqrt_grad(out, dout)
dx = dout * grad
return (dx,) return (dx,)
return bprop return bprop
@ -395,14 +390,10 @@ def get_bprop_rsqrt(self):
@bprop_getters.register(P.Reciprocal) @bprop_getters.register(P.Reciprocal)
def get_bprop_reciprocal(self): def get_bprop_reciprocal(self):
"""Grad definition for `Reciprocal` operation.""" """Grad definition for `Reciprocal` operation."""
neg = P.Neg() reciprocal_grad = G.ReciprocalGrad()
mul = P.Mul()
square = P.Square()
reciprocal = P.Reciprocal()
def bprop(x, out, dout): def bprop(x, out, dout):
g = neg(reciprocal(square(x))) dx = reciprocal_grad(out, dout)
dx = mul(dout, g)
return (dx,) return (dx,)
return bprop return bprop

View File

@ -442,6 +442,7 @@ def get_bprop_softmax(self):
sub = P.Sub() sub = P.Sub()
mul = P.Mul() mul = P.Mul()
axis = self.axis axis = self.axis
def bprop(x, out, dout): def bprop(x, out, dout):
dx = mul(out, sub(dout, sum_func(mul(out, dout), axis))) dx = mul(out, sub(dout, sum_func(mul(out, dout), axis)))
return (dx,) return (dx,)

View File

@ -236,6 +236,9 @@ from .cum_sum import _cum_sum_tbe
from .apply_rms_prop import _apply_rms_prop_tbe from .apply_rms_prop import _apply_rms_prop_tbe
from .cumprod import _cumprop_tbe from .cumprod import _cumprop_tbe
from .reduce_prod import _reduce_prod_tbe from .reduce_prod import _reduce_prod_tbe
from .reciprocal_grad import _reciprocal_grad_tbe
from .sqrt_grad import _sqrt_grad_tbe
from .rsqrt_grad import _rsqrt_grad_tbe
from .flatten_grad import _flatten_grad_tbe from .flatten_grad import _flatten_grad_tbe
from .scatter_add import _scatter_add_tbe from .scatter_add import _scatter_add_tbe
from .atan2 import _atan2_tbe from .atan2 import _atan2_tbe

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Add op""" """Reciprocal op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
reciprocal_op_info = TBERegOp("Reciprocal") \ reciprocal_op_info = TBERegOp("Reciprocal") \
@ -32,5 +32,5 @@ reciprocal_op_info = TBERegOp("Reciprocal") \
@op_info_register(reciprocal_op_info) @op_info_register(reciprocal_op_info)
def _reciprocal_tbe(): def _reciprocal_tbe():
"""Add TBE register""" """Reciprocal TBE register"""
return return

View File

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReciprocalGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
reciprocal_grad_op_info = TBERegOp("ReciprocalGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reciprocal_grad.so") \
.compute_cost(10) \
.kernel_name("reciprocal_grad") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("broadcast") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(reciprocal_grad_op_info)
def _reciprocal_grad_tbe():
"""ReciprocalGrad TBE register"""
return

View File

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RsqrtGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
rsqrt_grad_op_info = TBERegOp("RsqrtGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("rsqrt_grad.so") \
.compute_cost(10) \
.kernel_name("rsqrt_grad") \
.partial_flag(True) \
.op_pattern("broadcast") \
.input(0, "x", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(rsqrt_grad_op_info)
def _rsqrt_grad_tbe():
"""RsqrtGrad TBE register"""
return

View File

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SqrtGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
sqrt_grad_op_info = TBERegOp("SqrtGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("sqrt_grad.so") \
.compute_cost(10) \
.kernel_name("sqrt_grad") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.get_op_info()
@op_info_register(sqrt_grad_op_info)
def _sqrt_grad_tbe():
"""SqrtGrad TBE register"""
return

View File

@ -115,6 +115,74 @@ class AsinhGrad(PrimitiveWithInfer):
return x return x
class ReciprocalGrad(PrimitiveWithInfer):
"""Performs grad of Reciprocal operation."""
@prim_attr_register
def __init__(self):
"""init ReciprocalGrad"""
def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
return x_dtype
class RsqrtGrad(PrimitiveWithInfer):
"""Performs grad of Rsqrt operation."""
@prim_attr_register
def __init__(self):
"""init RsqrtGrad"""
def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
return x_dtype
class SoftmaxGrad(PrimitiveWithInfer):
"""Performs grad of Softmax operation."""
@prim_attr_register
def __init__(self):
"""init SoftmaxGrad"""
def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
return x_dtype
class SqrtGrad(PrimitiveWithInfer):
"""Performs grad of Sqrt operation."""
@prim_attr_register
def __init__(self):
"""init SqrtGrad"""
def infer_shape(self, x_shape, dout_shape):
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, dout_dtype):
args = {"x": x_dtype, "dout": dout_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
return x_dtype
class BatchNormGrad(PrimitiveWithInfer): class BatchNormGrad(PrimitiveWithInfer):
"""Performs grad of BatchNorm operation.""" """Performs grad of BatchNorm operation."""