diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 582e40e2df5..e11b3f535d8 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -33,6 +33,7 @@ static std::map tbe_func_adapter_map = { {"softmax", "softmax_v2"}, {"log_softmax", "log_softmax_v2"}, {"apply_momentum", "apply_momentum_d"}, + {"apply_ftrl", "apply_ftrl_d"}, {"re_lu6", "relu6"}, {"re_lu6_grad", "relu6_grad"}, {"re_lu", "relu"}, diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 2c0015ae283..91879fe689f 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -384,7 +384,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, {string(kNameSign), ADPT_DESC(Sign)}, {string(kNameRound), ADPT_DESC(Round)}, - {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)}, + {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrlD)}, {string(kNameDiag), ADPT_DESC(Diag)}, {string(kNameDiagPart), ADPT_DESC(DiagPart)}, {string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 3a43277a570..7af932af0e5 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -1176,11 +1176,11 @@ ATTR_MAP(Round) = EMPTY_ATTR_MAP; OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}}; // ApplyFtrl -INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)}, - {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)}, - {7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}}; -ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}}; +INPUT_MAP(ApplyFtrlD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)}, + {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)}, + {7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}}; +ATTR_MAP(ApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}}; // Diag INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}}; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 1c92575b611..c31dcc85099 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -446,8 +446,8 @@ DECLARE_OP_ADAPTER(LarsV2Update) DECLARE_OP_USE_OUTPUT(LarsV2Update) DECLARE_OP_ADAPTER(Round) DECLARE_OP_USE_OUTPUT(Round) -DECLARE_OP_ADAPTER(ApplyFtrl) -DECLARE_OP_USE_OUTPUT(ApplyFtrl) +DECLARE_OP_ADAPTER(ApplyFtrlD) +DECLARE_OP_USE_OUTPUT(ApplyFtrlD) DECLARE_OP_ADAPTER(SparseApplyFtrlD) DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) DECLARE_OP_ADAPTER(Diag) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index d63a13df1c4..c79098faa37 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -326,6 +326,18 @@ def get_bprop_log_softmax(self): return bprop +@bprop_getters.register(P.Softplus) +def get_bprop_softplus(self): + """Grad definition for `Softplus` operation.""" + softplus_grad = G.SoftplusGrad() + + def bprop(x, out, dout): + dx = softplus_grad(dout, x) + return (dx,) + + return bprop + + @bprop_getters.register(P.Tanh) def get_bprop_tanh(self): """Grad definition for `Tanh` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index f5aed75e273..a70d172570a 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -100,6 +100,8 @@ from .round import _round_tbe from .tanh import _tanh_tbe from .tanh_grad import _tanh_grad_tbe from .softmax import _softmax_tbe +from .softplus import _softplus_tbe +from .softplus_grad import _softplus_grad_tbe from .square import _square_tbe from .sqrt import _sqrt_tbe from .transpose_d import _transpose_d_tbe diff --git a/mindspore/ops/_op_impl/tbe/apply_ftrl.py b/mindspore/ops/_op_impl/tbe/apply_ftrl.py index e37648191e3..56c6bf36121 100644 --- a/mindspore/ops/_op_impl/tbe/apply_ftrl.py +++ b/mindspore/ops/_op_impl/tbe/apply_ftrl.py @@ -32,30 +32,32 @@ apply_ftrl_op_info = TBERegOp("ApplyFtrl") \ .input(6, "l2", False, "required", "all") \ .input(7, "lr_power", False, "required", "all") \ .output(0, "var", False, "required", "all") \ + .output(1, "accum", False, "required", "all") \ + .output(2, "linear", False, "required", "all") \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_5HD) \ + DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_FracZ) \ + DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_C1HWNCoC0) \ + DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default) \ + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_5HD) \ + DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_FracZ) \ + DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_C1HWNCoC0) \ + DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/softplus.py b/mindspore/ops/_op_impl/tbe/softplus.py new file mode 100644 index 00000000000..d362cd06db8 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/softplus.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Softplus op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +softplus_op_info = TBERegOp("Softplus") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("softplus.so") \ + .compute_cost(10) \ + .kernel_name("softplus") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(softplus_op_info) +def _softplus_tbe(): + """Softplus TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/softplus_grad.py b/mindspore/ops/_op_impl/tbe/softplus_grad.py new file mode 100644 index 00000000000..4bf7a82440d --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/softplus_grad.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SoftplusGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +softplus_grad_op_info = TBERegOp("SoftplusGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("softplus_grad.so") \ + .compute_cost(10) \ + .kernel_name("softplus_grad") \ + .partial_flag(True) \ + .op_pattern("broadcast") \ + .input(0, "gradients", False, "required", "all") \ + .input(1, "features", False, "required", "all") \ + .output(0, "backprops", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(softplus_grad_op_info) +def _softplus_grad_tbe(): + """SoftplusGrad TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index a3f36215291..d00370f490b 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -62,7 +62,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, ResizeBilinear, Sigmoid, SigmoidCrossEntropyWithLogits, - SmoothL1Loss, Softmax, + SmoothL1Loss, Softmax, Softplus, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 76d464de160..6f16c152bc6 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -974,6 +974,23 @@ class StridedSliceGrad(PrimitiveWithInfer): 'value': None} +class SoftplusGrad(PrimitiveWithInfer): + """Computes gradient for the Log Softmax activation.""" + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['dout', 'x'], outputs=['output']) + + def infer_shape(self, dout_shape, x_shape): + validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, dout_dtype, x_dtype): + args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype} + validator.check_tensor_type_same(args, mstype.float_type, self.name) + return x_dtype + + class TanhGrad(PrimitiveWithInfer): """Computes gradient of hyperbolic tangent of input element-wise.""" diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 04b9f49c7c8..efb77ad3e32 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -183,6 +183,41 @@ class LogSoftmax(PrimitiveWithInfer): return logits +class Softplus(PrimitiveWithInfer): + r""" + Softplus activation function. + + Softplus is a smooth approximation to the ReLU function. + The function is shown as follows: + + .. math:: + \text{output} = \log(1 + \exp(\text{input_x})), + + Inputs: + - **input_x** (Tensor) - The input tensor whose data type should be float. + + Outputs: + Tensor, with the same type and shape as the `input_x`. + + Examples: + >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32) + >>> softplus = P.Softplus() + >>> softplus(input_x) + [1.3132615, 2.126928, 3.0485873, 4.01815, 5.0067153] + """ + @prim_attr_register + def __init__(self): + """init Softplus""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, input_x): + return input_x + + def infer_dtype(self, input_x): + validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) + return input_x + + class ReLU(PrimitiveWithInfer): r""" Computes ReLU(Rectified Linear Unit) of input tensor element-wise. @@ -2701,11 +2736,14 @@ class ApplyFtrl(PrimitiveWithInfer): self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], outputs=['output']) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + self.is_tbe = context.get_context("device_target") == "Ascend" def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape, lr_power_shape): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) + if self.is_tbe: + return var_shape, var_shape, var_shape return var_shape def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): @@ -2717,6 +2755,8 @@ class ApplyFtrl(PrimitiveWithInfer): validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name) validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name) validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name) + if self.is_tbe: + return var_type, var_type, var_type return var_type diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index c28786359a3..938e943914d 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -185,6 +185,22 @@ class ScatterMax(nn.Cell): out = self.scatter_max(self.ref, indices, updates) return out +class ApplyFtrlNet(nn.Cell): + def __init__(self): + super(ApplyFtrlNet, self).__init__() + self.apply_ftrl = P.ApplyFtrl() + self.lr = 0.001 + self.l1 = 0.0 + self.l2 = 0.0 + self.lr_power = -0.5 + self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") + self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") + self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear") + + def construct(self, grad): + out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power) + return out + test_case_math_ops = [ ('Neg', { @@ -602,6 +618,14 @@ test_case_nn_ops = [ 'block': G.ReluGrad(), 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], 'skip': ['backward']}), + ('Softplus', { + 'block': P.Softplus(), + 'desc_inputs': [[1, 3, 4, 4]], + 'desc_bprop': [[1, 3, 4, 4]]}), + ('SoftplusGrad', { + 'block': G.SoftplusGrad(), + 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], + 'skip': ['backward']}), ('Elu', { 'block': P.Elu(), 'desc_inputs': [[2, 3, 4]], @@ -869,9 +893,8 @@ test_case_nn_ops = [ 'desc_inputs': [[3, 2]], 'desc_bprop': [[3, 2]]}), ('ApplyFtrl', { - 'block': P.ApplyFtrl(), - 'desc_const': [0.001, 0.0, 0.0, -0.5], - 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], + 'block': ApplyFtrlNet(), + 'desc_inputs': [[3, 3]], 'desc_bprop': [3, 3], 'skip': ['backward']}), ('ApplyRMSProp', {