forked from mindspore-Ecosystem/mindspore
!1062 Adapt Softplus\SoftplusGrad for VM, and ApplyFtrlD for both GE and VM.
Merge pull request !1062 from liuxiao/softplus-softplusgrad
This commit is contained in:
commit
6ab046e4dd
|
@ -33,6 +33,7 @@ static std::map<string, string> tbe_func_adapter_map = {
|
||||||
{"softmax", "softmax_v2"},
|
{"softmax", "softmax_v2"},
|
||||||
{"log_softmax", "log_softmax_v2"},
|
{"log_softmax", "log_softmax_v2"},
|
||||||
{"apply_momentum", "apply_momentum_d"},
|
{"apply_momentum", "apply_momentum_d"},
|
||||||
|
{"apply_ftrl", "apply_ftrl_d"},
|
||||||
{"re_lu6", "relu6"},
|
{"re_lu6", "relu6"},
|
||||||
{"re_lu6_grad", "relu6_grad"},
|
{"re_lu6_grad", "relu6_grad"},
|
||||||
{"re_lu", "relu"},
|
{"re_lu", "relu"},
|
||||||
|
|
|
@ -384,7 +384,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
|
{string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)},
|
||||||
{string(kNameSign), ADPT_DESC(Sign)},
|
{string(kNameSign), ADPT_DESC(Sign)},
|
||||||
{string(kNameRound), ADPT_DESC(Round)},
|
{string(kNameRound), ADPT_DESC(Round)},
|
||||||
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)},
|
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrlD)},
|
||||||
{string(kNameDiag), ADPT_DESC(Diag)},
|
{string(kNameDiag), ADPT_DESC(Diag)},
|
||||||
{string(kNameDiagPart), ADPT_DESC(DiagPart)},
|
{string(kNameDiagPart), ADPT_DESC(DiagPart)},
|
||||||
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},
|
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},
|
||||||
|
|
|
@ -1176,11 +1176,11 @@ ATTR_MAP(Round) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
// ApplyFtrl
|
// ApplyFtrl
|
||||||
INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)},
|
INPUT_MAP(ApplyFtrlD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)},
|
||||||
{4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)},
|
{4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)},
|
||||||
{7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}};
|
{7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}};
|
||||||
ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
ATTR_MAP(ApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}};
|
OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}};
|
||||||
|
|
||||||
// Diag
|
// Diag
|
||||||
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
|
INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}};
|
||||||
|
|
|
@ -446,8 +446,8 @@ DECLARE_OP_ADAPTER(LarsV2Update)
|
||||||
DECLARE_OP_USE_OUTPUT(LarsV2Update)
|
DECLARE_OP_USE_OUTPUT(LarsV2Update)
|
||||||
DECLARE_OP_ADAPTER(Round)
|
DECLARE_OP_ADAPTER(Round)
|
||||||
DECLARE_OP_USE_OUTPUT(Round)
|
DECLARE_OP_USE_OUTPUT(Round)
|
||||||
DECLARE_OP_ADAPTER(ApplyFtrl)
|
DECLARE_OP_ADAPTER(ApplyFtrlD)
|
||||||
DECLARE_OP_USE_OUTPUT(ApplyFtrl)
|
DECLARE_OP_USE_OUTPUT(ApplyFtrlD)
|
||||||
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
|
DECLARE_OP_ADAPTER(SparseApplyFtrlD)
|
||||||
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
|
DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD)
|
||||||
DECLARE_OP_ADAPTER(Diag)
|
DECLARE_OP_ADAPTER(Diag)
|
||||||
|
|
|
@ -326,6 +326,18 @@ def get_bprop_log_softmax(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.Softplus)
|
||||||
|
def get_bprop_softplus(self):
|
||||||
|
"""Grad definition for `Softplus` operation."""
|
||||||
|
softplus_grad = G.SoftplusGrad()
|
||||||
|
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
dx = softplus_grad(dout, x)
|
||||||
|
return (dx,)
|
||||||
|
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.Tanh)
|
@bprop_getters.register(P.Tanh)
|
||||||
def get_bprop_tanh(self):
|
def get_bprop_tanh(self):
|
||||||
"""Grad definition for `Tanh` operation."""
|
"""Grad definition for `Tanh` operation."""
|
||||||
|
|
|
@ -100,6 +100,8 @@ from .round import _round_tbe
|
||||||
from .tanh import _tanh_tbe
|
from .tanh import _tanh_tbe
|
||||||
from .tanh_grad import _tanh_grad_tbe
|
from .tanh_grad import _tanh_grad_tbe
|
||||||
from .softmax import _softmax_tbe
|
from .softmax import _softmax_tbe
|
||||||
|
from .softplus import _softplus_tbe
|
||||||
|
from .softplus_grad import _softplus_grad_tbe
|
||||||
from .square import _square_tbe
|
from .square import _square_tbe
|
||||||
from .sqrt import _sqrt_tbe
|
from .sqrt import _sqrt_tbe
|
||||||
from .transpose_d import _transpose_d_tbe
|
from .transpose_d import _transpose_d_tbe
|
||||||
|
|
|
@ -32,30 +32,32 @@ apply_ftrl_op_info = TBERegOp("ApplyFtrl") \
|
||||||
.input(6, "l2", False, "required", "all") \
|
.input(6, "l2", False, "required", "all") \
|
||||||
.input(7, "lr_power", False, "required", "all") \
|
.input(7, "lr_power", False, "required", "all") \
|
||||||
.output(0, "var", False, "required", "all") \
|
.output(0, "var", False, "required", "all") \
|
||||||
|
.output(1, "accum", False, "required", "all") \
|
||||||
|
.output(2, "linear", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||||
DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_5HD) \
|
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||||
DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_FracZ) \
|
DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||||
DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_C1HWNCoC0) \
|
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||||
DataType.F16_Default) \
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||||
DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_5HD) \
|
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||||
DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_FracZ) \
|
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||||
DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_C1HWNCoC0) \
|
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
DataType.F32_Default) \
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Softplus op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
softplus_op_info = TBERegOp("Softplus") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("softplus.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("softplus") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(softplus_op_info)
|
||||||
|
def _softplus_tbe():
|
||||||
|
"""Softplus TBE register"""
|
||||||
|
return
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""SoftplusGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
softplus_grad_op_info = TBERegOp("SoftplusGrad") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("softplus_grad.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("softplus_grad") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.op_pattern("broadcast") \
|
||||||
|
.input(0, "gradients", False, "required", "all") \
|
||||||
|
.input(1, "features", False, "required", "all") \
|
||||||
|
.output(0, "backprops", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(softplus_grad_op_info)
|
||||||
|
def _softplus_grad_tbe():
|
||||||
|
"""SoftplusGrad TBE register"""
|
||||||
|
return
|
|
@ -62,7 +62,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
||||||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||||
ResizeBilinear, Sigmoid,
|
ResizeBilinear, Sigmoid,
|
||||||
SigmoidCrossEntropyWithLogits,
|
SigmoidCrossEntropyWithLogits,
|
||||||
SmoothL1Loss, Softmax,
|
SmoothL1Loss, Softmax, Softplus,
|
||||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl,
|
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl,
|
||||||
|
|
|
@ -974,6 +974,23 @@ class StridedSliceGrad(PrimitiveWithInfer):
|
||||||
'value': None}
|
'value': None}
|
||||||
|
|
||||||
|
|
||||||
|
class SoftplusGrad(PrimitiveWithInfer):
|
||||||
|
"""Computes gradient for the Log Softmax activation."""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
self.init_prim_io_names(inputs=['dout', 'x'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_shape(self, dout_shape, x_shape):
|
||||||
|
validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
|
||||||
|
return x_shape
|
||||||
|
|
||||||
|
def infer_dtype(self, dout_dtype, x_dtype):
|
||||||
|
args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype}
|
||||||
|
validator.check_tensor_type_same(args, mstype.float_type, self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
class TanhGrad(PrimitiveWithInfer):
|
class TanhGrad(PrimitiveWithInfer):
|
||||||
"""Computes gradient of hyperbolic tangent of input element-wise."""
|
"""Computes gradient of hyperbolic tangent of input element-wise."""
|
||||||
|
|
||||||
|
|
|
@ -183,6 +183,41 @@ class LogSoftmax(PrimitiveWithInfer):
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class Softplus(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Softplus activation function.
|
||||||
|
|
||||||
|
Softplus is a smooth approximation to the ReLU function.
|
||||||
|
The function is shown as follows:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{output} = \log(1 + \exp(\text{input_x})),
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input tensor whose data type should be float.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, with the same type and shape as the `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
|
||||||
|
>>> softplus = P.Softplus()
|
||||||
|
>>> softplus(input_x)
|
||||||
|
[1.3132615, 2.126928, 3.0485873, 4.01815, 5.0067153]
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""init Softplus"""
|
||||||
|
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||||
|
|
||||||
|
def infer_shape(self, input_x):
|
||||||
|
return input_x
|
||||||
|
|
||||||
|
def infer_dtype(self, input_x):
|
||||||
|
validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name)
|
||||||
|
return input_x
|
||||||
|
|
||||||
|
|
||||||
class ReLU(PrimitiveWithInfer):
|
class ReLU(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Computes ReLU(Rectified Linear Unit) of input tensor element-wise.
|
Computes ReLU(Rectified Linear Unit) of input tensor element-wise.
|
||||||
|
@ -2701,11 +2736,14 @@ class ApplyFtrl(PrimitiveWithInfer):
|
||||||
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
|
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
|
||||||
outputs=['output'])
|
outputs=['output'])
|
||||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||||
|
self.is_tbe = context.get_context("device_target") == "Ascend"
|
||||||
|
|
||||||
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
|
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
|
||||||
lr_power_shape):
|
lr_power_shape):
|
||||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||||
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
||||||
|
if self.is_tbe:
|
||||||
|
return var_shape, var_shape, var_shape
|
||||||
return var_shape
|
return var_shape
|
||||||
|
|
||||||
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
|
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
|
||||||
|
@ -2717,6 +2755,8 @@ class ApplyFtrl(PrimitiveWithInfer):
|
||||||
validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name)
|
validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name)
|
||||||
validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name)
|
validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name)
|
||||||
validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name)
|
validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name)
|
||||||
|
if self.is_tbe:
|
||||||
|
return var_type, var_type, var_type
|
||||||
return var_type
|
return var_type
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -185,6 +185,22 @@ class ScatterMax(nn.Cell):
|
||||||
out = self.scatter_max(self.ref, indices, updates)
|
out = self.scatter_max(self.ref, indices, updates)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class ApplyFtrlNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(ApplyFtrlNet, self).__init__()
|
||||||
|
self.apply_ftrl = P.ApplyFtrl()
|
||||||
|
self.lr = 0.001
|
||||||
|
self.l1 = 0.0
|
||||||
|
self.l2 = 0.0
|
||||||
|
self.lr_power = -0.5
|
||||||
|
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||||
|
self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||||
|
self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear")
|
||||||
|
|
||||||
|
def construct(self, grad):
|
||||||
|
out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
test_case_math_ops = [
|
test_case_math_ops = [
|
||||||
('Neg', {
|
('Neg', {
|
||||||
|
@ -602,6 +618,14 @@ test_case_nn_ops = [
|
||||||
'block': G.ReluGrad(),
|
'block': G.ReluGrad(),
|
||||||
'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]],
|
'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
|
('Softplus', {
|
||||||
|
'block': P.Softplus(),
|
||||||
|
'desc_inputs': [[1, 3, 4, 4]],
|
||||||
|
'desc_bprop': [[1, 3, 4, 4]]}),
|
||||||
|
('SoftplusGrad', {
|
||||||
|
'block': G.SoftplusGrad(),
|
||||||
|
'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]],
|
||||||
|
'skip': ['backward']}),
|
||||||
('Elu', {
|
('Elu', {
|
||||||
'block': P.Elu(),
|
'block': P.Elu(),
|
||||||
'desc_inputs': [[2, 3, 4]],
|
'desc_inputs': [[2, 3, 4]],
|
||||||
|
@ -869,9 +893,8 @@ test_case_nn_ops = [
|
||||||
'desc_inputs': [[3, 2]],
|
'desc_inputs': [[3, 2]],
|
||||||
'desc_bprop': [[3, 2]]}),
|
'desc_bprop': [[3, 2]]}),
|
||||||
('ApplyFtrl', {
|
('ApplyFtrl', {
|
||||||
'block': P.ApplyFtrl(),
|
'block': ApplyFtrlNet(),
|
||||||
'desc_const': [0.001, 0.0, 0.0, -0.5],
|
'desc_inputs': [[3, 3]],
|
||||||
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
|
|
||||||
'desc_bprop': [3, 3],
|
'desc_bprop': [3, 3],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
('ApplyRMSProp', {
|
('ApplyRMSProp', {
|
||||||
|
|
Loading…
Reference in New Issue