From 452dbd45515f4a05089a9f7aad2796d7ef52ba28 Mon Sep 17 00:00:00 2001 From: wild-fox Date: Fri, 6 Aug 2021 20:27:09 +0800 Subject: [PATCH] [feat][assistant][I40FFU] add new Ascend operator SparseApplyAdadelta --- mindspore/core/base/core_ops.h | 2 + mindspore/core/ops/sparse_apply_adadelta.cc | 122 ++++++++++++++++++ mindspore/core/ops/sparse_apply_adadelta.h | 47 +++++++ .../mindspore/ops/_op_impl/tbe/__init__.py | 1 + .../ops/_op_impl/tbe/sparse_apply_adadelta.py | 56 ++++++++ .../mindspore/ops/operations/__init__.py | 3 +- .../python/mindspore/ops/operations/nn_ops.py | 109 ++++++++++++++++ tests/ut/python/ops/test_ops.py | 28 ++++ 8 files changed, 367 insertions(+), 1 deletion(-) create mode 100644 mindspore/core/ops/sparse_apply_adadelta.cc create mode 100644 mindspore/core/ops/sparse_apply_adadelta.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 512b6d5e685..f7d26656e42 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -110,6 +110,7 @@ constexpr auto kDropoutDoMask = "DropoutDoMask"; constexpr auto kDropout = "Dropout"; constexpr auto kDropoutGrad = "DropoutGrad"; constexpr auto kConv2DTranspose = "Conv2DTranspose"; +constexpr auto kSparseApplyAdadelta = "SparseApplyAdadelta"; constexpr auto kRoll = "Roll"; constexpr auto kTanh = "Tanh"; @@ -425,6 +426,7 @@ inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared("LambApplyOptimizerAssign"); inline const PrimitivePtr kLambApplyWeightAssign = std::make_shared("LambApplyWeightAssign"); inline const PrimitivePtr kSoftmaxGradExt = std::make_shared("SoftmaxGradExt"); +inline const PrimitivePtr kPrimSparseApplyAdadelta = std::make_shared(kSparseApplyAdadelta); inline const PrimitivePtr kSquareSumV1 = std::make_shared("SquareSumV1"); inline const PrimitivePtr kFusedMulAdd = std::make_shared("FusedMulAdd"); inline const PrimitivePtr kPrimSoftShrink = std::make_shared("SoftShrink"); diff --git a/mindspore/core/ops/sparse_apply_adadelta.cc b/mindspore/core/ops/sparse_apply_adadelta.cc new file mode 100644 index 00000000000..3f4d5b985c3 --- /dev/null +++ b/mindspore/core/ops/sparse_apply_adadelta.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/sparse_apply_adadelta.h" + +#include +#include + +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + // Indices and grad must be tensor + CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex5); + CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex6); + // Get input shape + auto var_shape_ptr = input_args[0]->BuildShape(); + auto accum_shape_ptr = input_args[1]->BuildShape(); + auto accum_updata_shape_ptr = input_args[2]->BuildShape(); + auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(var_shape_ptr)[kShape]; + auto accum_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(accum_shape_ptr)[kShape]; + auto accum_updata_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(accum_updata_shape_ptr)[kShape]; + auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape]; + auto rho_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape]; + auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->BuildShape())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[6]->BuildShape())[kShape]; + // Args lr rho must be scalar + (void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kEqual, 0, prim_name); + (void)CheckAndConvertUtils::CheckInteger("rho_shape size", rho_shape.size(), kEqual, 0, prim_name); + // Args var,accum,accum_update and grad shape must be same + std::map same_shape_args_map; + same_shape_args_map.insert({"accum shape", accum_shape}); + same_shape_args_map.insert({"accum_updata shape", accum_updata_shape}); + same_shape_args_map.insert({"grad shape", grad_shape}); + for (auto &elem : same_shape_args_map) { + CheckAndConvertUtils::Check(elem.first, elem.second, kEqual, "var shape", var_shape, prim_name); + } + // Indices must be rank 1 + (void)CheckAndConvertUtils::CheckInteger("indices dimension", indices_shape.size(), kEqual, 1, prim_name); + // Grad dimension must be equal or greater than 1 + (void)CheckAndConvertUtils::CheckInteger("grad dimension", grad_shape.size(), kGreaterEqual, 1, prim_name); + // Indices size must equal with grad first dimension size + if (indices_shape[0] != grad_shape[0]) { + MS_EXCEPTION(ValueError) << "For " << prim_name << " the indices size must equal to grad first dimension size " + << grad_shape[0] << ", but got " << indices_shape[0]; + } + return std::make_shared( + std::vector{var_shape_ptr, accum_shape_ptr, accum_updata_shape_ptr}); +} + +TuplePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + auto prim_name = prim->name(); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + // Get all inputs's type + auto var_type = input_args[0]->BuildType(); + auto accum_type = input_args[1]->BuildType(); + auto accum_updata_type = input_args[2]->BuildType(); + auto lr_type = input_args[3]->BuildType(); + auto rho_type = input_args[4]->BuildType(); + auto grad_type = input_args[5]->BuildType(); + auto indices_type = input_args[6]->BuildType(); + const std::set valid_types = {kFloat16, kFloat32}; + // Args accum accum_updata and grad must have the same type as var + std::map args; + args.insert({"var", var_type}); + args.insert({"accum", accum_type}); + args.insert({"accum_updata", accum_updata_type}); + args.insert({"grad", grad_type}); + (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); + // Args lr rho must be a scalar type + std::map args2; + args2.insert({"lr", lr_type}); + (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args2, valid_types, prim_name); + std::map args3; + args3.insert({"rho", rho_type}); + (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args3, valid_types, prim_name); + // Check indices_type + std::map args4; + args4.insert({"indices", indices_type}); + const std::set valid_types2 = {kInt32, kInt64}; + (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args4, valid_types2, prim_name); + return std::make_shared(std::vector{var_type, accum_type, accum_updata_type}); +} +} // namespace + +AbstractBasePtr SparseApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 7; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + auto infer_type = InferType(primitive, input_args); + auto infer_shape = InferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(SparseApplyAdadelta, prim::kPrimSparseApplyAdadelta, SparseApplyAdadeltaInfer, nullptr, + true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/sparse_apply_adadelta.h b/mindspore/core/ops/sparse_apply_adadelta.h new file mode 100644 index 00000000000..9acd62165d8 --- /dev/null +++ b/mindspore/core/ops/sparse_apply_adadelta.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_APPLY_ADADELTA_H_ +#define MINDSPORE_CORE_OPS_SPARSE_APPLY_ADADELTA_H_ + +#include +#include +#include +#include + +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseApplyAdadelta = "SparseApplyAdadelta"; +class SparseApplyAdadelta : public PrimitiveC { + public: + SparseApplyAdadelta() : PrimitiveC(kNameSparseApplyAdadelta) { + InitIOName({"var", "accum", "accum_updata", "lr", "rho", "grad", "indices"}, {"var", "accum", "accum_updata"}); + } + ~SparseApplyAdadelta() = default; + MS_DECLARE_PARENT(SparseApplyAdadelta, PrimitiveC); +}; + +AbstractBasePtr SparseApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); + +using PrimSparseApplyAdadeltaPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_SPARSE_APPLY_ADADELTA_H_ diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py index da4308eb069..3b9a8cc30cf 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py @@ -124,6 +124,7 @@ from .softmax_cross_entropy_with_logits_ds import _softmax_cross_entropy_with_lo from .sigmoid_cross_entropy_with_logits import _sigmoid_cross_entropy_with_logits_tbe from .sigmoid_cross_entropy_with_logits_ds import _sigmoid_cross_entropy_with_logits_ds_tbe from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe +from .sparse_apply_adadelta import _sparse_apply_adadelta_tbe from .tensor_add import _tensor_add_tbe from .tensor_add_ds import _tensor_add_ds_tbe from .trans_data import _trans_data_tbe diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py b/mindspore/python/mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py new file mode 100644 index 00000000000..e196947d423 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py @@ -0,0 +1,56 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseApplyAdadeltaD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sparse_apply_adadelta_d_op_info = TBERegOp("SparseApplyAdadelta") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("sparse_apply_adadelta_d.so") \ + .compute_cost(10) \ + .kernel_name("sparse_apply_adadelta_d") \ + .partial_flag(True) \ + .attr("epsilon", "optional", "float", "all") \ + .attr("use_locking", "optional", "bool", "true,false", "false") \ + .input(0, "var", False, "required", "all") \ + .input(1, "accum", False, "required", "all") \ + .input(2, "accum_update", False, "required", "all") \ + .input(3, "lr", False, "required", "all") \ + .input(4, "rho", False, "required", "all") \ + .input(6, "grad", False, "required", "all") \ + .input(7, "indices", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .output(1, "accum", False, "required", "all") \ + .output(2, "accum_update", False, "required", "all") \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(sparse_apply_adadelta_d_op_info) +def _sparse_apply_adadelta_tbe(): + """SparseApplyAdadeltaD TBE register""" + return diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index 30b7c86474f..63e63c90da1 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -83,7 +83,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa SparseSoftmaxCrossEntropyWithLogits, Tanh, TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2, - FusedSparseFtrl, FusedSparseProximalAdagrad, SparseApplyRMSProp, + FusedSparseFtrl, FusedSparseProximalAdagrad, SparseApplyRMSProp, SparseApplyAdadelta, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAdagradDA, ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink) @@ -464,6 +464,7 @@ __all__ = [ "ReLUV2", "SparseToDense", "SparseTensorDenseMatmul", + "SparseApplyAdadelta", "MatrixInverse", "Range", "SearchSorted", diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index 666e5f13de1..8454eb08357 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -8215,6 +8215,115 @@ def _deconv_output_length(input_length, kernel_size, stride_size, dilation_size) return length +class SparseApplyAdadelta(Primitive): + r""" + Updates relevant entries according to the adadelta scheme. + + .. math:: + \begin{array}{ll} \\ + accum = \rho * accum + (1 - \rho) * grad^2 \\ + \text{update} = \sqrt{\text{accum_update} + \epsilon} * \frac{grad}{\sqrt{accum + \epsilon}} \\ + var = var - update * lr \\ + \text{accum_update} = \rho * \text{accum_update} + (1 - \rho) * update^2 \\ + \end{array} + + Inputs of 'var', 'accum', 'accum_update' and 'grad' comply with the implicit type conversion rules + to make the data types consistent. Besides, inputs of 'lr' and 'rho' also support implicit type conversion. + If they have different data types, lower priority data type will be converted to + relatively highest priority data type. + RuntimeError exception will be thrown when the data type conversion of Parameter is required. + + Note: + If there are negative values or values greater than or equal to var.shape[0] in `indices`, + the behavior is undefined. Besides, this operator doesn't support duplicates in `indices`. + + Args: + epsilon (float): A small value added for numerical stability. Its value must be greater or equal to 0. + use_locking (bool): If `True`, the `var` and `accum` tensors will be protected from being updated. + Default: False. + + Inputs: + - **var** (Parameter) - Weights to be updated. With float32 or float16 data type. + - **accum** (Parameter) - Accumulation to be updated. Mush have the same shape and dtype as `var`. + With float32 or float16 data type. + - **accum_update** (Parameter) - Accum_update to be updated. Must have the same shape and dtype as `var`. + With float32 or float16 data type. + - **lr** (Union[float, Tensor]) - Learning rate, must be scalar. With float32 or float16 data type. + - **rho** (Union[float, Tensor]) - Decay rate, must be scalar. With float32 or float16 data type. + - **grad** (Tensor) - A tensor for gradient. Must have the same shape and dtype as `var`. + - **indices** (Tensor) - A tensor of indices in the first dimension of `var` and `accum`. + Must be one of the following types: int32, int64 and indices.shape[0] = grad.shape[0]. + + Outputs: + Tuple of 3 Tensor, the updated parameters. + + - **var** (Tensor) - The same shape and data type as `var`. + - **accum** (Tensor) - The same shape and data type as `accum`. + - **accum_update** (Tensor) - The same shape and data type as `accum_update`. + + Raises: + TypeError: If `epsilon` is not a float. + TypeError: If `use_locking` is not a bool. + TypeError: If `var`, 'accum', 'accum_update' is not a Parameter. + TypeError: If dtype of `accum`, `accum_updata`, `grad` is not same as `var`. + TypeError: If dtype of `var`, `accum`, `accum_update`, `lr`, `rho` or `grad` is neither float16 nor + float32. + TypeError: If dtype of `indices` is neither int32 nor int64. + ValueError: If `epsilon` is less than 0. + ValueError: If the shape of `accum`, `accum_updata`, `grad` is not same as `var`. + ValueError: If the rank of `indices` is not equal to 1. + ValueError: If shape of `indices` is not same as shape of first dimension of `grad`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> class Net(nn.Cell): + ... def __init__(self,epsilon,use_locking = False): + ... super(Net, self).__init__() + ... self.sparse_apply_adadelta = P.SparseApplyAdadelta(epsilon,use_locking) + ... self.var = Parameter(Tensor(np.array([[1.0,2.0],[2.0,3.0]]).astype(np.float32)), name="var") + ... self.accum = Parameter(Tensor(np.array([[1.5,2.5],[3.5,4.5]]).astype(np.float32)), name="accum") + ... self.accum_update = Parameter(Tensor(np.array([[1.2,2.4],[1.8,0.6]]).astype(np.float32)), + ... name="accum_update") + ... def construct(self, lr, rho, grad, indices): + ... out = self.sparse_apply_adadelta(self.var, self.accum, self.accum_update, lr, rho, grad, indices) + ... return out + ... + >>> epsilon = 1e-6 + >>> net = Net(epsilon) + >>> lr = 0.01 + >>> rho = 0.2 + >>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)) + >>> output = net(lr, rho, grad, Tensor(np.array([0,1],dtype=np.int32))) + >>> print(output) + (Tensor(shape=[2, 2], dtype=Float32, value= + [[ 9.94611859e-01, 1.98851788e+00], + [ 1.99840558e+00, 2.99478507e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= + [[ 3.72000009e-01, 8.91999960e-01], + [ 7.08000004e-01, 1.41200006e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= + [[ 4.72257614e-01, 1.53470778e+00], + [ 3.80338937e-01, 3.37563992e-01]])) + """ + + __mindspore_signature__ = ( + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum_updata', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T1), + sig.make_sig('rho', dtype=sig.sig_dtype.T1), + sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T2), + ) + + @prim_attr_register + def __init__(self, epsilon, use_locking=False): + """Initialize SparseApplyAdadelta""" + validator.check_value_type("epsilon", epsilon, [float], self.name) + validator.check_number("epsilon", epsilon, 0.0, Rel.GE, self.name) + validator.check_value_type("use_locking", use_locking, [bool], self.name) + + class CTCLossV2(Primitive): """ Calculates the CTC (Connectionist Temporal Classification) loss and the gradient. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 720ee2607a3..d566d788577 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -53,6 +53,7 @@ class TargetNet(nn.Cell): def construct(self, x, y): return self.mul(x, y) + # Recursive GradOperation in Cell. class Grad(nn.Cell): def __init__(self, network): @@ -63,12 +64,16 @@ class Grad(nn.Cell): def construct(self, x, y): return self.grad(self.network)(x, y) + # Recursive GradOperaton with GradOperation object. grad1 = C.GradOperation() + + @ms_function def f1(x, y): return grad1(grad1(TargetNet()))(x, y) + def test_recursive_grad(): x = Tensor(3, mstype.float32) y = Tensor(1, mstype.float32) @@ -938,6 +943,7 @@ class StridedSliceNet(nn.Cell): out_3 = self.strided_slice_3(x, self.begins, self.ends, self.strides) + self.const_3 return out_0, out_1, out_2, out_3 + @pytest.mark.skip(reason='0 in shape is not support') def test_strided_slice_const(): class StridedSLiceConstNet(nn.Cell): @@ -1009,6 +1015,21 @@ class EditDistance(nn.Cell): truth_indices, truth_values, self.truth_shape) +class SparseApplyAdadeltaNet(nn.Cell): + def __init__(self, epsilon, use_locking=True): + super(SparseApplyAdadeltaNet, self).__init__() + self.sparse_apply_adadelta = P.SparseApplyAdadelta(epsilon, use_locking) + self.lr = 0.001 + self.rho = 0.0 + self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") + self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") + self.accum_update = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum_update") + + def construct(self, grad, indices): + out = self.sparse_apply_adadelta(self.var, self.accum, self.accum_update, self.lr, self.rho, grad, indices) + return out + + class ApplyAdagradDANet(nn.Cell): def __init__(self, use_locking=False): super(ApplyAdagradDANet, self).__init__() @@ -1018,6 +1039,7 @@ class ApplyAdagradDANet(nn.Cell): name="gradient_accumulator") self.gradient_squared_accumulator = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)), name="gradient_squared_accumulator") + def construct(self, grad, lr, l1, l2, global_step): out = self.apply_adagrad_d_a(self.var, self.gradient_accumulator, self.gradient_squared_accumulator, grad, lr, l1, l2, global_step) @@ -1048,6 +1070,7 @@ class ApplyKerasMomentumNet(nn.Cell): out = self.apply_keras_momentum(self.var, self.accum, lr, grad, momentum) return out + test_case_math_ops = [ ('Ger', { 'block': P.Ger(), @@ -2322,6 +2345,11 @@ test_case_nn_ops = [ 'desc_inputs': [Tensor(np.array([[0.5, 1, 2.0], [0.0533, 0.0776, -2.1233]]), mstype.float32)], 'desc_bprop': [], 'skip': ['backward']}), + ('SparseApplyAdadelta', { + 'block': SparseApplyAdadeltaNet(1e-6), + 'desc_inputs': [Tensor(np.random.rand(3, 3), mstype.float32), + Tensor(np.ones((3,)), mstype.int32)], + 'skip': ['backward']}), ('HShrinkGrad', { 'block': G.HShrinkGrad(), 'desc_inputs': [Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), mstype.float16),