forked from OSSInnovation/mindspore
add 4 grad ops
This commit is contained in:
parent
51fcaf6e61
commit
a80432f08e
|
@ -61,7 +61,6 @@ const char kNameReduceSum[] = "ReduceSum";
|
|||
const char kNameIsFinite[] = "isFinite";
|
||||
const char kNameReciprocal[] = "Reciprocal";
|
||||
const char kNameRsqrt[] = "Rsqrt";
|
||||
const char kNameRsqrtGrad[] = "RsqrtGrad";
|
||||
const char kNameSqrt[] = "Sqrt";
|
||||
const char kNameSquare[] = "Square";
|
||||
const char kNameSquaredDifference[] = "SquaredDifference";
|
||||
|
@ -83,6 +82,9 @@ const char kNameFlattenGrad[] = "FlattenGrad";
|
|||
const char kNameConvolution[] = "Convolution";
|
||||
const char kNameBiasAdd[] = "BiasAdd";
|
||||
const char kNameMaxPoolGrad[] = "MaxPoolGrad";
|
||||
const char kNameRsqrtGrad[] = "RsqrtGrad";
|
||||
const char kNameSqrtGrad[] = "SqrtGrad";
|
||||
const char kNameReciprocalGrad[] = "ReciprocalGrad";
|
||||
const char kNameAvgPoolGrad[] = "AvgPoolGrad";
|
||||
const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax";
|
||||
const char kNameApplyMomentum[] = "ApplyMomentum";
|
||||
|
@ -233,6 +235,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameAllgather), ADPT_DESC(HcomAllGather)},
|
||||
{string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)},
|
||||
{string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)},
|
||||
{string(kNameSqrtGrad), ADPT_DESC(SqrtGrad)},
|
||||
{string(kNameReciprocalGrad), ADPT_DESC(ReciprocalGrad)},
|
||||
{string(kNameRsqrtGrad), ADPT_DESC(RsqrtGrad)},
|
||||
{string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)},
|
||||
{string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)},
|
||||
{string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)},
|
||||
|
|
|
@ -726,6 +726,21 @@ ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<
|
|||
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
// RsqrtGrad
|
||||
INPUT_MAP(RsqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
|
||||
ATTR_MAP(RsqrtGrad) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(RsqrtGrad) = {{0, OUTPUT_DESC(z)}};
|
||||
|
||||
// SqrtGrad
|
||||
INPUT_MAP(SqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
|
||||
ATTR_MAP(SqrtGrad) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(SqrtGrad) = {{0, OUTPUT_DESC(z)}};
|
||||
|
||||
// ReciprocalGrad
|
||||
INPUT_MAP(ReciprocalGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
|
||||
ATTR_MAP(ReciprocalGrad) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(ReciprocalGrad) = {{0, OUTPUT_DESC(z)}};
|
||||
|
||||
// avgpoolgrad
|
||||
INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}};
|
||||
ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
|
||||
|
|
|
@ -439,6 +439,12 @@ DECLARE_OP_ADAPTER(MaxPool)
|
|||
DECLARE_OP_USE_OUTPUT(MaxPool)
|
||||
DECLARE_OP_ADAPTER(MaxPoolGrad)
|
||||
DECLARE_OP_USE_OUTPUT(MaxPoolGrad)
|
||||
DECLARE_OP_ADAPTER(SqrtGrad)
|
||||
DECLARE_OP_USE_OUTPUT(SqrtGrad)
|
||||
DECLARE_OP_ADAPTER(ReciprocalGrad)
|
||||
DECLARE_OP_USE_OUTPUT(ReciprocalGrad)
|
||||
DECLARE_OP_ADAPTER(RsqrtGrad)
|
||||
DECLARE_OP_USE_OUTPUT(RsqrtGrad)
|
||||
DECLARE_OP_ADAPTER(AvgPool)
|
||||
DECLARE_OP_USE_OUTPUT(AvgPool)
|
||||
DECLARE_OP_ADAPTER(AvgPoolGrad)
|
||||
|
|
|
@ -366,15 +366,10 @@ def get_bprop_square(self):
|
|||
@bprop_getters.register(P.Sqrt)
|
||||
def get_bprop_sqrt(self):
|
||||
"""Grad definition for `Sqrt` operation."""
|
||||
mul_func = P.Mul()
|
||||
fill_func = P.Fill()
|
||||
div_op = P.RealDiv()
|
||||
sqrt = P.Sqrt()
|
||||
dtype = P.DType()
|
||||
sqrt_grad = G.SqrtGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x))
|
||||
dx = mul_func(dout, temp)
|
||||
dx = sqrt_grad(out, dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
@ -383,10 +378,10 @@ def get_bprop_sqrt(self):
|
|||
@bprop_getters.register(P.Rsqrt)
|
||||
def get_bprop_rsqrt(self):
|
||||
"""Grad definition for `Rsqrt` operation."""
|
||||
rsqrt_grad = G.RsqrtGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x) * x)
|
||||
dx = dout * grad
|
||||
dx = rsqrt_grad(out, dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
@ -395,14 +390,10 @@ def get_bprop_rsqrt(self):
|
|||
@bprop_getters.register(P.Reciprocal)
|
||||
def get_bprop_reciprocal(self):
|
||||
"""Grad definition for `Reciprocal` operation."""
|
||||
neg = P.Neg()
|
||||
mul = P.Mul()
|
||||
square = P.Square()
|
||||
reciprocal = P.Reciprocal()
|
||||
reciprocal_grad = G.ReciprocalGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
g = neg(reciprocal(square(x)))
|
||||
dx = mul(dout, g)
|
||||
dx = reciprocal_grad(out, dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -442,6 +442,7 @@ def get_bprop_softmax(self):
|
|||
sub = P.Sub()
|
||||
mul = P.Mul()
|
||||
axis = self.axis
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = mul(out, sub(dout, sum_func(mul(out, dout), axis)))
|
||||
return (dx,)
|
||||
|
|
|
@ -236,6 +236,9 @@ from .cum_sum import _cum_sum_tbe
|
|||
from .apply_rms_prop import _apply_rms_prop_tbe
|
||||
from .cumprod import _cumprop_tbe
|
||||
from .reduce_prod import _reduce_prod_tbe
|
||||
from .reciprocal_grad import _reciprocal_grad_tbe
|
||||
from .sqrt_grad import _sqrt_grad_tbe
|
||||
from .rsqrt_grad import _rsqrt_grad_tbe
|
||||
from .flatten_grad import _flatten_grad_tbe
|
||||
from .scatter_add import _scatter_add_tbe
|
||||
from .atan2 import _atan2_tbe
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Add op"""
|
||||
"""Reciprocal op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reciprocal_op_info = TBERegOp("Reciprocal") \
|
||||
|
@ -32,5 +32,5 @@ reciprocal_op_info = TBERegOp("Reciprocal") \
|
|||
|
||||
@op_info_register(reciprocal_op_info)
|
||||
def _reciprocal_tbe():
|
||||
"""Add TBE register"""
|
||||
"""Reciprocal TBE register"""
|
||||
return
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReciprocalGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
reciprocal_grad_op_info = TBERegOp("ReciprocalGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("reciprocal_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("reciprocal_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "dy", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("broadcast") \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(reciprocal_grad_op_info)
|
||||
def _reciprocal_grad_tbe():
|
||||
"""ReciprocalGrad TBE register"""
|
||||
return
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""RsqrtGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
rsqrt_grad_op_info = TBERegOp("RsqrtGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("rsqrt_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("rsqrt_grad") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("broadcast") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "dy", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \
|
||||
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(rsqrt_grad_op_info)
|
||||
def _rsqrt_grad_tbe():
|
||||
"""RsqrtGrad TBE register"""
|
||||
return
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SqrtGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
sqrt_grad_op_info = TBERegOp("SqrtGrad") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("sqrt_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("sqrt_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "dy", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sqrt_grad_op_info)
|
||||
def _sqrt_grad_tbe():
|
||||
"""SqrtGrad TBE register"""
|
||||
return
|
|
@ -115,6 +115,74 @@ class AsinhGrad(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
|
||||
class ReciprocalGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of Reciprocal operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init ReciprocalGrad"""
|
||||
|
||||
def infer_shape(self, x_shape, dout_shape):
|
||||
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, dout_dtype):
|
||||
args = {"x": x_dtype, "dout": dout_dtype}
|
||||
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class RsqrtGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of Rsqrt operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init RsqrtGrad"""
|
||||
|
||||
def infer_shape(self, x_shape, dout_shape):
|
||||
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, dout_dtype):
|
||||
args = {"x": x_dtype, "dout": dout_dtype}
|
||||
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class SoftmaxGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of Softmax operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init SoftmaxGrad"""
|
||||
|
||||
def infer_shape(self, x_shape, dout_shape):
|
||||
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, dout_dtype):
|
||||
args = {"x": x_dtype, "dout": dout_dtype}
|
||||
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class SqrtGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of Sqrt operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init SqrtGrad"""
|
||||
|
||||
def infer_shape(self, x_shape, dout_shape):
|
||||
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, dout_dtype):
|
||||
args = {"x": x_dtype, "dout": dout_dtype}
|
||||
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class BatchNormGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of BatchNorm operation."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue