forked from mindspore-Ecosystem/mindspore
!21339 [assistant][ops]New operator implementation, include SparseApplyAdadeltaD
Merge pull request !21339 from 梁秀波/SparseApplyAdadeltaD
This commit is contained in:
commit
f1b26ea205
|
@ -110,6 +110,7 @@ constexpr auto kDropoutDoMask = "DropoutDoMask";
|
|||
constexpr auto kDropout = "Dropout";
|
||||
constexpr auto kDropoutGrad = "DropoutGrad";
|
||||
constexpr auto kConv2DTranspose = "Conv2DTranspose";
|
||||
constexpr auto kSparseApplyAdadelta = "SparseApplyAdadelta";
|
||||
constexpr auto kRoll = "Roll";
|
||||
constexpr auto kTanh = "Tanh";
|
||||
|
||||
|
@ -425,6 +426,7 @@ inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared<Primitiv
|
|||
inline const PrimitivePtr kLambApplyOptimizerAssign = std::make_shared<Primitive>("LambApplyOptimizerAssign");
|
||||
inline const PrimitivePtr kLambApplyWeightAssign = std::make_shared<Primitive>("LambApplyWeightAssign");
|
||||
inline const PrimitivePtr kSoftmaxGradExt = std::make_shared<Primitive>("SoftmaxGradExt");
|
||||
inline const PrimitivePtr kPrimSparseApplyAdadelta = std::make_shared<Primitive>(kSparseApplyAdadelta);
|
||||
inline const PrimitivePtr kSquareSumV1 = std::make_shared<Primitive>("SquareSumV1");
|
||||
inline const PrimitivePtr kFusedMulAdd = std::make_shared<Primitive>("FusedMulAdd");
|
||||
inline const PrimitivePtr kPrimSoftShrink = std::make_shared<Primitive>("SoftShrink");
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/sparse_apply_adadelta.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// Indices and grad must be tensor
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex5);
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex6);
|
||||
// Get input shape
|
||||
auto var_shape_ptr = input_args[0]->BuildShape();
|
||||
auto accum_shape_ptr = input_args[1]->BuildShape();
|
||||
auto accum_updata_shape_ptr = input_args[2]->BuildShape();
|
||||
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(var_shape_ptr)[kShape];
|
||||
auto accum_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(accum_shape_ptr)[kShape];
|
||||
auto accum_updata_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(accum_updata_shape_ptr)[kShape];
|
||||
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
|
||||
auto rho_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape];
|
||||
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[6]->BuildShape())[kShape];
|
||||
// Args lr rho must be scalar
|
||||
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("rho_shape size", rho_shape.size(), kEqual, 0, prim_name);
|
||||
// Args var,accum,accum_update and grad shape must be same
|
||||
std::map<std::string, ShapeVector> same_shape_args_map;
|
||||
same_shape_args_map.insert({"accum shape", accum_shape});
|
||||
same_shape_args_map.insert({"accum_updata shape", accum_updata_shape});
|
||||
same_shape_args_map.insert({"grad shape", grad_shape});
|
||||
for (auto &elem : same_shape_args_map) {
|
||||
CheckAndConvertUtils::Check(elem.first, elem.second, kEqual, "var shape", var_shape, prim_name);
|
||||
}
|
||||
// Indices must be rank 1
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices dimension", indices_shape.size(), kEqual, 1, prim_name);
|
||||
// Grad dimension must be equal or greater than 1
|
||||
(void)CheckAndConvertUtils::CheckInteger("grad dimension", grad_shape.size(), kGreaterEqual, 1, prim_name);
|
||||
// Indices size must equal with grad first dimension size
|
||||
if (indices_shape[0] != grad_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << " the indices size must equal to grad first dimension size "
|
||||
<< grad_shape[0] << ", but got " << indices_shape[0];
|
||||
}
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{var_shape_ptr, accum_shape_ptr, accum_updata_shape_ptr});
|
||||
}
|
||||
|
||||
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = prim->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
// Get all inputs's type
|
||||
auto var_type = input_args[0]->BuildType();
|
||||
auto accum_type = input_args[1]->BuildType();
|
||||
auto accum_updata_type = input_args[2]->BuildType();
|
||||
auto lr_type = input_args[3]->BuildType();
|
||||
auto rho_type = input_args[4]->BuildType();
|
||||
auto grad_type = input_args[5]->BuildType();
|
||||
auto indices_type = input_args[6]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
// Args accum accum_updata and grad must have the same type as var
|
||||
std::map<std::string, TypePtr> args;
|
||||
args.insert({"var", var_type});
|
||||
args.insert({"accum", accum_type});
|
||||
args.insert({"accum_updata", accum_updata_type});
|
||||
args.insert({"grad", grad_type});
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
// Args lr rho must be a scalar type
|
||||
std::map<std::string, TypePtr> args2;
|
||||
args2.insert({"lr", lr_type});
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args2, valid_types, prim_name);
|
||||
std::map<std::string, TypePtr> args3;
|
||||
args3.insert({"rho", rho_type});
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args3, valid_types, prim_name);
|
||||
// Check indices_type
|
||||
std::map<std::string, TypePtr> args4;
|
||||
args4.insert({"indices", indices_type});
|
||||
const std::set<TypePtr> valid_types2 = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args4, valid_types2, prim_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type, accum_updata_type});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr SparseApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 7;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseApplyAdadelta, prim::kPrimSparseApplyAdadelta, SparseApplyAdadeltaInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPARSE_APPLY_ADADELTA_H_
|
||||
#define MINDSPORE_CORE_OPS_SPARSE_APPLY_ADADELTA_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSparseApplyAdadelta = "SparseApplyAdadelta";
|
||||
class SparseApplyAdadelta : public PrimitiveC {
|
||||
public:
|
||||
SparseApplyAdadelta() : PrimitiveC(kNameSparseApplyAdadelta) {
|
||||
InitIOName({"var", "accum", "accum_updata", "lr", "rho", "grad", "indices"}, {"var", "accum", "accum_updata"});
|
||||
}
|
||||
~SparseApplyAdadelta() = default;
|
||||
MS_DECLARE_PARENT(SparseApplyAdadelta, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr SparseApplyAdadeltaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimSparseApplyAdadeltaPtr = std::shared_ptr<SparseApplyAdadelta>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_SPARSE_APPLY_ADADELTA_H_
|
|
@ -124,6 +124,7 @@ from .softmax_cross_entropy_with_logits_ds import _softmax_cross_entropy_with_lo
|
|||
from .sigmoid_cross_entropy_with_logits import _sigmoid_cross_entropy_with_logits_tbe
|
||||
from .sigmoid_cross_entropy_with_logits_ds import _sigmoid_cross_entropy_with_logits_ds_tbe
|
||||
from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe
|
||||
from .sparse_apply_adadelta import _sparse_apply_adadelta_tbe
|
||||
from .tensor_add import _tensor_add_tbe
|
||||
from .tensor_add_ds import _tensor_add_ds_tbe
|
||||
from .trans_data import _trans_data_tbe
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SparseApplyAdadeltaD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
sparse_apply_adadelta_d_op_info = TBERegOp("SparseApplyAdadelta") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("sparse_apply_adadelta_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("sparse_apply_adadelta_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.attr("use_locking", "optional", "bool", "true,false", "false") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "accum", False, "required", "all") \
|
||||
.input(2, "accum_update", False, "required", "all") \
|
||||
.input(3, "lr", False, "required", "all") \
|
||||
.input(4, "rho", False, "required", "all") \
|
||||
.input(6, "grad", False, "required", "all") \
|
||||
.input(7, "indices", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.output(2, "accum_update", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sparse_apply_adadelta_d_op_info)
|
||||
def _sparse_apply_adadelta_tbe():
|
||||
"""SparseApplyAdadeltaD TBE register"""
|
||||
return
|
|
@ -83,7 +83,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
|
|||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||
TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||
ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2,
|
||||
FusedSparseFtrl, FusedSparseProximalAdagrad, SparseApplyRMSProp,
|
||||
FusedSparseFtrl, FusedSparseProximalAdagrad, SparseApplyRMSProp, SparseApplyAdadelta,
|
||||
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAdagradDA,
|
||||
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
|
||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink,
|
||||
|
@ -466,6 +466,7 @@ __all__ = [
|
|||
"ReLUV2",
|
||||
"SparseToDense",
|
||||
"SparseTensorDenseMatmul",
|
||||
"SparseApplyAdadelta",
|
||||
"MatrixInverse",
|
||||
"Range",
|
||||
"SearchSorted",
|
||||
|
|
|
@ -8215,6 +8215,115 @@ def _deconv_output_length(input_length, kernel_size, stride_size, dilation_size)
|
|||
return length
|
||||
|
||||
|
||||
class SparseApplyAdadelta(Primitive):
|
||||
r"""
|
||||
Updates relevant entries according to the adadelta scheme.
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
accum = \rho * accum + (1 - \rho) * grad^2 \\
|
||||
\text{update} = \sqrt{\text{accum_update} + \epsilon} * \frac{grad}{\sqrt{accum + \epsilon}} \\
|
||||
var = var - update * lr \\
|
||||
\text{accum_update} = \rho * \text{accum_update} + (1 - \rho) * update^2 \\
|
||||
\end{array}
|
||||
|
||||
Inputs of 'var', 'accum', 'accum_update' and 'grad' comply with the implicit type conversion rules
|
||||
to make the data types consistent. Besides, inputs of 'lr' and 'rho' also support implicit type conversion.
|
||||
If they have different data types, lower priority data type will be converted to
|
||||
relatively highest priority data type.
|
||||
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
|
||||
|
||||
Note:
|
||||
If there are negative values or values greater than or equal to var.shape[0] in `indices`,
|
||||
the behavior is undefined. Besides, this operator doesn't support duplicates in `indices`.
|
||||
|
||||
Args:
|
||||
epsilon (float): A small value added for numerical stability. Its value must be greater or equal to 0.
|
||||
use_locking (bool): If `True`, the `var` and `accum` tensors will be protected from being updated.
|
||||
Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Weights to be updated. With float32 or float16 data type.
|
||||
- **accum** (Parameter) - Accumulation to be updated. Mush have the same shape and dtype as `var`.
|
||||
With float32 or float16 data type.
|
||||
- **accum_update** (Parameter) - Accum_update to be updated. Must have the same shape and dtype as `var`.
|
||||
With float32 or float16 data type.
|
||||
- **lr** (Union[float, Tensor]) - Learning rate, must be scalar. With float32 or float16 data type.
|
||||
- **rho** (Union[float, Tensor]) - Decay rate, must be scalar. With float32 or float16 data type.
|
||||
- **grad** (Tensor) - A tensor for gradient. Must have the same shape and dtype as `var`.
|
||||
- **indices** (Tensor) - A tensor of indices in the first dimension of `var` and `accum`.
|
||||
Must be one of the following types: int32, int64 and indices.shape[0] = grad.shape[0].
|
||||
|
||||
Outputs:
|
||||
Tuple of 3 Tensor, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **accum** (Tensor) - The same shape and data type as `accum`.
|
||||
- **accum_update** (Tensor) - The same shape and data type as `accum_update`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `epsilon` is not a float.
|
||||
TypeError: If `use_locking` is not a bool.
|
||||
TypeError: If `var`, 'accum', 'accum_update' is not a Parameter.
|
||||
TypeError: If dtype of `accum`, `accum_updata`, `grad` is not same as `var`.
|
||||
TypeError: If dtype of `var`, `accum`, `accum_update`, `lr`, `rho` or `grad` is neither float16 nor
|
||||
float32.
|
||||
TypeError: If dtype of `indices` is neither int32 nor int64.
|
||||
ValueError: If `epsilon` is less than 0.
|
||||
ValueError: If the shape of `accum`, `accum_updata`, `grad` is not same as `var`.
|
||||
ValueError: If the rank of `indices` is not equal to 1.
|
||||
ValueError: If shape of `indices` is not same as shape of first dimension of `grad`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self,epsilon,use_locking = False):
|
||||
... super(Net, self).__init__()
|
||||
... self.sparse_apply_adadelta = P.SparseApplyAdadelta(epsilon,use_locking)
|
||||
... self.var = Parameter(Tensor(np.array([[1.0,2.0],[2.0,3.0]]).astype(np.float32)), name="var")
|
||||
... self.accum = Parameter(Tensor(np.array([[1.5,2.5],[3.5,4.5]]).astype(np.float32)), name="accum")
|
||||
... self.accum_update = Parameter(Tensor(np.array([[1.2,2.4],[1.8,0.6]]).astype(np.float32)),
|
||||
... name="accum_update")
|
||||
... def construct(self, lr, rho, grad, indices):
|
||||
... out = self.sparse_apply_adadelta(self.var, self.accum, self.accum_update, lr, rho, grad, indices)
|
||||
... return out
|
||||
...
|
||||
>>> epsilon = 1e-6
|
||||
>>> net = Net(epsilon)
|
||||
>>> lr = 0.01
|
||||
>>> rho = 0.2
|
||||
>>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
|
||||
>>> output = net(lr, rho, grad, Tensor(np.array([0,1],dtype=np.int32)))
|
||||
>>> print(output)
|
||||
(Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 9.94611859e-01, 1.98851788e+00],
|
||||
[ 1.99840558e+00, 2.99478507e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 3.72000009e-01, 8.91999960e-01],
|
||||
[ 7.08000004e-01, 1.41200006e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 4.72257614e-01, 1.53470778e+00],
|
||||
[ 3.80338937e-01, 3.37563992e-01]]))
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('accum_updata', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('rho', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T2),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon, use_locking=False):
|
||||
"""Initialize SparseApplyAdadelta"""
|
||||
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
||||
validator.check_number("epsilon", epsilon, 0.0, Rel.GE, self.name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
||||
|
||||
class CTCLossV2(Primitive):
|
||||
"""
|
||||
Calculates the CTC (Connectionist Temporal Classification) loss and the gradient.
|
||||
|
|
|
@ -53,6 +53,7 @@ class TargetNet(nn.Cell):
|
|||
def construct(self, x, y):
|
||||
return self.mul(x, y)
|
||||
|
||||
|
||||
# Recursive GradOperation in Cell.
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
|
@ -63,12 +64,16 @@ class Grad(nn.Cell):
|
|||
def construct(self, x, y):
|
||||
return self.grad(self.network)(x, y)
|
||||
|
||||
|
||||
# Recursive GradOperaton with GradOperation object.
|
||||
grad1 = C.GradOperation()
|
||||
|
||||
|
||||
@ms_function
|
||||
def f1(x, y):
|
||||
return grad1(grad1(TargetNet()))(x, y)
|
||||
|
||||
|
||||
def test_recursive_grad():
|
||||
x = Tensor(3, mstype.float32)
|
||||
y = Tensor(1, mstype.float32)
|
||||
|
@ -938,6 +943,7 @@ class StridedSliceNet(nn.Cell):
|
|||
out_3 = self.strided_slice_3(x, self.begins, self.ends, self.strides) + self.const_3
|
||||
return out_0, out_1, out_2, out_3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='0 in shape is not support')
|
||||
def test_strided_slice_const():
|
||||
class StridedSLiceConstNet(nn.Cell):
|
||||
|
@ -1021,6 +1027,21 @@ class ApplyAdamWithAmsgradNet(nn.Cell):
|
|||
out = self.apply_adam_with_amsgrad(self.var, self.m, self.v, self.vhat, beta1_power, beta2_power, lr, grad)
|
||||
return out
|
||||
|
||||
class SparseApplyAdadeltaNet(nn.Cell):
|
||||
def __init__(self, epsilon, use_locking=True):
|
||||
super(SparseApplyAdadeltaNet, self).__init__()
|
||||
self.sparse_apply_adadelta = P.SparseApplyAdadelta(epsilon, use_locking)
|
||||
self.lr = 0.001
|
||||
self.rho = 0.0
|
||||
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
|
||||
self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
|
||||
self.accum_update = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum_update")
|
||||
|
||||
def construct(self, grad, indices):
|
||||
out = self.sparse_apply_adadelta(self.var, self.accum, self.accum_update, self.lr, self.rho, grad, indices)
|
||||
return out
|
||||
|
||||
|
||||
class ApplyAdagradDANet(nn.Cell):
|
||||
def __init__(self, use_locking=False):
|
||||
super(ApplyAdagradDANet, self).__init__()
|
||||
|
@ -1030,6 +1051,7 @@ class ApplyAdagradDANet(nn.Cell):
|
|||
name="gradient_accumulator")
|
||||
self.gradient_squared_accumulator = Parameter(Tensor(np.array([[0.2, 0.1], [0.1, 0.2]]).astype(np.float32)),
|
||||
name="gradient_squared_accumulator")
|
||||
|
||||
def construct(self, grad, lr, l1, l2, global_step):
|
||||
out = self.apply_adagrad_d_a(self.var, self.gradient_accumulator, self.gradient_squared_accumulator, grad,
|
||||
lr, l1, l2, global_step)
|
||||
|
@ -1060,6 +1082,7 @@ class ApplyKerasMomentumNet(nn.Cell):
|
|||
out = self.apply_keras_momentum(self.var, self.accum, lr, grad, momentum)
|
||||
return out
|
||||
|
||||
|
||||
test_case_math_ops = [
|
||||
('Ger', {
|
||||
'block': P.Ger(),
|
||||
|
@ -2334,6 +2357,11 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [Tensor(np.array([[0.5, 1, 2.0], [0.0533, 0.0776, -2.1233]]), mstype.float32)],
|
||||
'desc_bprop': [],
|
||||
'skip': ['backward']}),
|
||||
('SparseApplyAdadelta', {
|
||||
'block': SparseApplyAdadeltaNet(1e-6),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 3), mstype.float32),
|
||||
Tensor(np.ones((3,)), mstype.int32)],
|
||||
'skip': ['backward']}),
|
||||
('HShrinkGrad', {
|
||||
'block': G.HShrinkGrad(),
|
||||
'desc_inputs': [Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), mstype.float16),
|
||||
|
|
Loading…
Reference in New Issue