forked from OSSInnovation/mindspore
add 4 grad ops
This commit is contained in:
parent
51fcaf6e61
commit
a80432f08e
|
@ -61,7 +61,6 @@ const char kNameReduceSum[] = "ReduceSum";
|
||||||
const char kNameIsFinite[] = "isFinite";
|
const char kNameIsFinite[] = "isFinite";
|
||||||
const char kNameReciprocal[] = "Reciprocal";
|
const char kNameReciprocal[] = "Reciprocal";
|
||||||
const char kNameRsqrt[] = "Rsqrt";
|
const char kNameRsqrt[] = "Rsqrt";
|
||||||
const char kNameRsqrtGrad[] = "RsqrtGrad";
|
|
||||||
const char kNameSqrt[] = "Sqrt";
|
const char kNameSqrt[] = "Sqrt";
|
||||||
const char kNameSquare[] = "Square";
|
const char kNameSquare[] = "Square";
|
||||||
const char kNameSquaredDifference[] = "SquaredDifference";
|
const char kNameSquaredDifference[] = "SquaredDifference";
|
||||||
|
@ -83,6 +82,9 @@ const char kNameFlattenGrad[] = "FlattenGrad";
|
||||||
const char kNameConvolution[] = "Convolution";
|
const char kNameConvolution[] = "Convolution";
|
||||||
const char kNameBiasAdd[] = "BiasAdd";
|
const char kNameBiasAdd[] = "BiasAdd";
|
||||||
const char kNameMaxPoolGrad[] = "MaxPoolGrad";
|
const char kNameMaxPoolGrad[] = "MaxPoolGrad";
|
||||||
|
const char kNameRsqrtGrad[] = "RsqrtGrad";
|
||||||
|
const char kNameSqrtGrad[] = "SqrtGrad";
|
||||||
|
const char kNameReciprocalGrad[] = "ReciprocalGrad";
|
||||||
const char kNameAvgPoolGrad[] = "AvgPoolGrad";
|
const char kNameAvgPoolGrad[] = "AvgPoolGrad";
|
||||||
const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax";
|
const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax";
|
||||||
const char kNameApplyMomentum[] = "ApplyMomentum";
|
const char kNameApplyMomentum[] = "ApplyMomentum";
|
||||||
|
@ -233,6 +235,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameAllgather), ADPT_DESC(HcomAllGather)},
|
{string(kNameAllgather), ADPT_DESC(HcomAllGather)},
|
||||||
{string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)},
|
{string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)},
|
||||||
{string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)},
|
{string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)},
|
||||||
|
{string(kNameSqrtGrad), ADPT_DESC(SqrtGrad)},
|
||||||
|
{string(kNameReciprocalGrad), ADPT_DESC(ReciprocalGrad)},
|
||||||
|
{string(kNameRsqrtGrad), ADPT_DESC(RsqrtGrad)},
|
||||||
{string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)},
|
{string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)},
|
||||||
{string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)},
|
{string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)},
|
||||||
{string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)},
|
{string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)},
|
||||||
|
|
|
@ -726,6 +726,21 @@ ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<
|
||||||
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
||||||
OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
|
// RsqrtGrad
|
||||||
|
INPUT_MAP(RsqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
|
||||||
|
ATTR_MAP(RsqrtGrad) = EMPTY_ATTR_MAP;
|
||||||
|
OUTPUT_MAP(RsqrtGrad) = {{0, OUTPUT_DESC(z)}};
|
||||||
|
|
||||||
|
// SqrtGrad
|
||||||
|
INPUT_MAP(SqrtGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
|
||||||
|
ATTR_MAP(SqrtGrad) = EMPTY_ATTR_MAP;
|
||||||
|
OUTPUT_MAP(SqrtGrad) = {{0, OUTPUT_DESC(z)}};
|
||||||
|
|
||||||
|
// ReciprocalGrad
|
||||||
|
INPUT_MAP(ReciprocalGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}};
|
||||||
|
ATTR_MAP(ReciprocalGrad) = EMPTY_ATTR_MAP;
|
||||||
|
OUTPUT_MAP(ReciprocalGrad) = {{0, OUTPUT_DESC(z)}};
|
||||||
|
|
||||||
// avgpoolgrad
|
// avgpoolgrad
|
||||||
INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}};
|
INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}};
|
||||||
ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
|
ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())},
|
||||||
|
|
|
@ -439,6 +439,12 @@ DECLARE_OP_ADAPTER(MaxPool)
|
||||||
DECLARE_OP_USE_OUTPUT(MaxPool)
|
DECLARE_OP_USE_OUTPUT(MaxPool)
|
||||||
DECLARE_OP_ADAPTER(MaxPoolGrad)
|
DECLARE_OP_ADAPTER(MaxPoolGrad)
|
||||||
DECLARE_OP_USE_OUTPUT(MaxPoolGrad)
|
DECLARE_OP_USE_OUTPUT(MaxPoolGrad)
|
||||||
|
DECLARE_OP_ADAPTER(SqrtGrad)
|
||||||
|
DECLARE_OP_USE_OUTPUT(SqrtGrad)
|
||||||
|
DECLARE_OP_ADAPTER(ReciprocalGrad)
|
||||||
|
DECLARE_OP_USE_OUTPUT(ReciprocalGrad)
|
||||||
|
DECLARE_OP_ADAPTER(RsqrtGrad)
|
||||||
|
DECLARE_OP_USE_OUTPUT(RsqrtGrad)
|
||||||
DECLARE_OP_ADAPTER(AvgPool)
|
DECLARE_OP_ADAPTER(AvgPool)
|
||||||
DECLARE_OP_USE_OUTPUT(AvgPool)
|
DECLARE_OP_USE_OUTPUT(AvgPool)
|
||||||
DECLARE_OP_ADAPTER(AvgPoolGrad)
|
DECLARE_OP_ADAPTER(AvgPoolGrad)
|
||||||
|
|
|
@ -366,15 +366,10 @@ def get_bprop_square(self):
|
||||||
@bprop_getters.register(P.Sqrt)
|
@bprop_getters.register(P.Sqrt)
|
||||||
def get_bprop_sqrt(self):
|
def get_bprop_sqrt(self):
|
||||||
"""Grad definition for `Sqrt` operation."""
|
"""Grad definition for `Sqrt` operation."""
|
||||||
mul_func = P.Mul()
|
sqrt_grad = G.SqrtGrad()
|
||||||
fill_func = P.Fill()
|
|
||||||
div_op = P.RealDiv()
|
|
||||||
sqrt = P.Sqrt()
|
|
||||||
dtype = P.DType()
|
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
temp = div_op(fill_func(dtype(x), shape_op(x), 0.5), sqrt(x))
|
dx = sqrt_grad(out, dout)
|
||||||
dx = mul_func(dout, temp)
|
|
||||||
return (dx,)
|
return (dx,)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
@ -383,10 +378,10 @@ def get_bprop_sqrt(self):
|
||||||
@bprop_getters.register(P.Rsqrt)
|
@bprop_getters.register(P.Rsqrt)
|
||||||
def get_bprop_rsqrt(self):
|
def get_bprop_rsqrt(self):
|
||||||
"""Grad definition for `Rsqrt` operation."""
|
"""Grad definition for `Rsqrt` operation."""
|
||||||
|
rsqrt_grad = G.RsqrtGrad()
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
grad = F.fill(F.dtype(x), F.shape(x), -0.5) / (F.sqrt(x) * x)
|
dx = rsqrt_grad(out, dout)
|
||||||
dx = dout * grad
|
|
||||||
return (dx,)
|
return (dx,)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
@ -395,14 +390,10 @@ def get_bprop_rsqrt(self):
|
||||||
@bprop_getters.register(P.Reciprocal)
|
@bprop_getters.register(P.Reciprocal)
|
||||||
def get_bprop_reciprocal(self):
|
def get_bprop_reciprocal(self):
|
||||||
"""Grad definition for `Reciprocal` operation."""
|
"""Grad definition for `Reciprocal` operation."""
|
||||||
neg = P.Neg()
|
reciprocal_grad = G.ReciprocalGrad()
|
||||||
mul = P.Mul()
|
|
||||||
square = P.Square()
|
|
||||||
reciprocal = P.Reciprocal()
|
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
g = neg(reciprocal(square(x)))
|
dx = reciprocal_grad(out, dout)
|
||||||
dx = mul(dout, g)
|
|
||||||
return (dx,)
|
return (dx,)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
|
@ -442,6 +442,7 @@ def get_bprop_softmax(self):
|
||||||
sub = P.Sub()
|
sub = P.Sub()
|
||||||
mul = P.Mul()
|
mul = P.Mul()
|
||||||
axis = self.axis
|
axis = self.axis
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
dx = mul(out, sub(dout, sum_func(mul(out, dout), axis)))
|
dx = mul(out, sub(dout, sum_func(mul(out, dout), axis)))
|
||||||
return (dx,)
|
return (dx,)
|
||||||
|
|
|
@ -236,6 +236,9 @@ from .cum_sum import _cum_sum_tbe
|
||||||
from .apply_rms_prop import _apply_rms_prop_tbe
|
from .apply_rms_prop import _apply_rms_prop_tbe
|
||||||
from .cumprod import _cumprop_tbe
|
from .cumprod import _cumprop_tbe
|
||||||
from .reduce_prod import _reduce_prod_tbe
|
from .reduce_prod import _reduce_prod_tbe
|
||||||
|
from .reciprocal_grad import _reciprocal_grad_tbe
|
||||||
|
from .sqrt_grad import _sqrt_grad_tbe
|
||||||
|
from .rsqrt_grad import _rsqrt_grad_tbe
|
||||||
from .flatten_grad import _flatten_grad_tbe
|
from .flatten_grad import _flatten_grad_tbe
|
||||||
from .scatter_add import _scatter_add_tbe
|
from .scatter_add import _scatter_add_tbe
|
||||||
from .atan2 import _atan2_tbe
|
from .atan2 import _atan2_tbe
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""Add op"""
|
"""Reciprocal op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
reciprocal_op_info = TBERegOp("Reciprocal") \
|
reciprocal_op_info = TBERegOp("Reciprocal") \
|
||||||
|
@ -32,5 +32,5 @@ reciprocal_op_info = TBERegOp("Reciprocal") \
|
||||||
|
|
||||||
@op_info_register(reciprocal_op_info)
|
@op_info_register(reciprocal_op_info)
|
||||||
def _reciprocal_tbe():
|
def _reciprocal_tbe():
|
||||||
"""Add TBE register"""
|
"""Reciprocal TBE register"""
|
||||||
return
|
return
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReciprocalGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
reciprocal_grad_op_info = TBERegOp("ReciprocalGrad") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("reciprocal_grad.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("reciprocal_grad") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.input(1, "dy", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("broadcast") \
|
||||||
|
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||||
|
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(reciprocal_grad_op_info)
|
||||||
|
def _reciprocal_grad_tbe():
|
||||||
|
"""ReciprocalGrad TBE register"""
|
||||||
|
return
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""RsqrtGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
rsqrt_grad_op_info = TBERegOp("RsqrtGrad") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("rsqrt_grad.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("rsqrt_grad") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("broadcast") \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.input(1, "dy", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||||
|
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \
|
||||||
|
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
|
||||||
|
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(rsqrt_grad_op_info)
|
||||||
|
def _rsqrt_grad_tbe():
|
||||||
|
"""RsqrtGrad TBE register"""
|
||||||
|
return
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""SqrtGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
sqrt_grad_op_info = TBERegOp("SqrtGrad") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("sqrt_grad.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("sqrt_grad") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.input(1, "dy", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||||
|
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||||
|
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(sqrt_grad_op_info)
|
||||||
|
def _sqrt_grad_tbe():
|
||||||
|
"""SqrtGrad TBE register"""
|
||||||
|
return
|
|
@ -115,6 +115,74 @@ class AsinhGrad(PrimitiveWithInfer):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ReciprocalGrad(PrimitiveWithInfer):
|
||||||
|
"""Performs grad of Reciprocal operation."""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init ReciprocalGrad"""
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, dout_shape):
|
||||||
|
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype, dout_dtype):
|
||||||
|
args = {"x": x_dtype, "dout": dout_dtype}
|
||||||
|
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class RsqrtGrad(PrimitiveWithInfer):
|
||||||
|
"""Performs grad of Rsqrt operation."""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init RsqrtGrad"""
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, dout_shape):
|
||||||
|
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype, dout_dtype):
|
||||||
|
args = {"x": x_dtype, "dout": dout_dtype}
|
||||||
|
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class SoftmaxGrad(PrimitiveWithInfer):
|
||||||
|
"""Performs grad of Softmax operation."""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init SoftmaxGrad"""
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, dout_shape):
|
||||||
|
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype, dout_dtype):
|
||||||
|
args = {"x": x_dtype, "dout": dout_dtype}
|
||||||
|
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class SqrtGrad(PrimitiveWithInfer):
|
||||||
|
"""Performs grad of Sqrt operation."""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init SqrtGrad"""
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape, dout_shape):
|
||||||
|
validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype, dout_dtype):
|
||||||
|
args = {"x": x_dtype, "dout": dout_dtype}
|
||||||
|
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
class BatchNormGrad(PrimitiveWithInfer):
|
class BatchNormGrad(PrimitiveWithInfer):
|
||||||
"""Performs grad of BatchNorm operation."""
|
"""Performs grad of BatchNorm operation."""
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue