forked from mindspore-Ecosystem/mindspore
[feat][assistant][I40FFO] add new Ascend operator ApplyKerasMomentum
This commit is contained in:
parent
b690ae1b7a
commit
6098c1da5e
|
@ -412,6 +412,7 @@ inline const PrimitivePtr kPrimHShrink = std::make_shared<Primitive>("HShrink");
|
||||||
inline const PrimitivePtr kPrimHShrinkGrad = std::make_shared<Primitive>("HShrinkGrad");
|
inline const PrimitivePtr kPrimHShrinkGrad = std::make_shared<Primitive>("HShrinkGrad");
|
||||||
inline const PrimitivePtr kPrimApplyAdagradDA = std::make_shared<Primitive>("ApplyAdagradDA");
|
inline const PrimitivePtr kPrimApplyAdagradDA = std::make_shared<Primitive>("ApplyAdagradDA");
|
||||||
inline const PrimitivePtr kPrimSparseApplyRMSProp = std::make_shared<Primitive>("SparseApplyRMSProp");
|
inline const PrimitivePtr kPrimSparseApplyRMSProp = std::make_shared<Primitive>("SparseApplyRMSProp");
|
||||||
|
inline const PrimitivePtr kPrimApplyKerasMomentum = std::make_shared<Primitive>("ApplyKerasMomentum");
|
||||||
|
|
||||||
// Comm ops
|
// Comm ops
|
||||||
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/apply_keras_momentum.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
auto var_shape = input_args[0]->BuildShape();
|
||||||
|
auto accum_shape = input_args[1]->BuildShape();
|
||||||
|
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape];
|
||||||
|
auto grad_shape = input_args[3]->BuildShape();
|
||||||
|
auto momentum_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->GetShapeTrack())[kShape];
|
||||||
|
// lr, momentum must be scalar
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kEqual, 0, prim_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("momentum_shape size", momentum_shape.size(), kEqual, 0, prim_name);
|
||||||
|
// var, accum and grad must have the same shape
|
||||||
|
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
|
||||||
|
same_shape_args_map.insert({"accum", accum_shape});
|
||||||
|
same_shape_args_map.insert({"grad", grad_shape});
|
||||||
|
for (auto &elem : same_shape_args_map) {
|
||||||
|
if (*elem.second != *var_shape) {
|
||||||
|
MS_EXCEPTION(ValueError) << prim_name << " evaluator arg " << elem.first << " shape " << elem.second->ToString()
|
||||||
|
<< " are not consistent with var shape " << var_shape->ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{var_shape, accum_shape});
|
||||||
|
}
|
||||||
|
|
||||||
|
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 5, prim_name);
|
||||||
|
for (const auto &item : input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
|
}
|
||||||
|
auto var_type = input_args[0]->BuildType();
|
||||||
|
auto accum_type = input_args[1]->BuildType();
|
||||||
|
auto lr_type = input_args[2]->BuildType();
|
||||||
|
auto grad_type = input_args[3]->BuildType();
|
||||||
|
auto momentum_type = input_args[4]->BuildType();
|
||||||
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||||
|
// var, accum and grad must have the same type
|
||||||
|
std::map<std::string, TypePtr> args;
|
||||||
|
args.insert({"var", var_type});
|
||||||
|
args.insert({"accum", accum_type});
|
||||||
|
args.insert({"grad", grad_type});
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||||
|
// lr, momentum type must be valid
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("lr_dtype", lr_type, valid_types, prim_name);
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("momentum_dtype", momentum_type, valid_types, prim_name);
|
||||||
|
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, accum_type});
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyKerasMomentum, prim::kPrimApplyKerasMomentum, ApplyKerasMomentumInfer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_APPLY_KERAS_MOMENTUM_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_APPLY_KERAS_MOMENTUM_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameApplyKerasMomentum = "ApplyKerasMomentum";
|
||||||
|
class MS_CORE_API ApplyKerasMomentum : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
ApplyKerasMomentum() : PrimitiveC(kNameApplyKerasMomentum) {
|
||||||
|
InitIOName({"var", "accum", "lr", "grad", "momentum"}, {"var", "accum"});
|
||||||
|
}
|
||||||
|
~ApplyKerasMomentum() = default;
|
||||||
|
MS_DECLARE_PARENT(ApplyKerasMomentum, PrimitiveC);
|
||||||
|
};
|
||||||
|
|
||||||
|
AbstractBasePtr ApplyKerasMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args);
|
||||||
|
|
||||||
|
using PrimApplyKerasMomentumPtr = std::shared_ptr<ApplyKerasMomentum>;
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CORE_OPS_APPLY_KERAS_MOMENTUM_H_
|
|
@ -28,6 +28,7 @@ from .add_n import _add_n_tbe
|
||||||
from .add_n_ds import _add_n_ds_tbe
|
from .add_n_ds import _add_n_ds_tbe
|
||||||
from .accumulate_n_v2 import _accumulate_n_v2_tbe
|
from .accumulate_n_v2 import _accumulate_n_v2_tbe
|
||||||
from .apply_ftrl import _apply_ftrl_tbe
|
from .apply_ftrl import _apply_ftrl_tbe
|
||||||
|
from .apply_keras_momentum import _apply_keras_momentum_tbe
|
||||||
from .apply_momentum import _apply_momentum_tbe
|
from .apply_momentum import _apply_momentum_tbe
|
||||||
from .apply_adam import _apply_adam_tbe
|
from .apply_adam import _apply_adam_tbe
|
||||||
from .apply_ada_max import _apply_ada_max_tbe
|
from .apply_ada_max import _apply_ada_max_tbe
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ApplyKerasMomentum op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
apply_keras_momentum_op_info = TBERegOp("ApplyKerasMomentum") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("apply_keras_momentum_d.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("apply_keras_momentum_d") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("use_locking", "optional", "bool", "true,false", "false") \
|
||||||
|
.attr("use_nesterov", "optional", "bool", "true,false", "false") \
|
||||||
|
.input(0, "var", False, "required", "all") \
|
||||||
|
.input(1, "accum", False, "required", "all") \
|
||||||
|
.input(2, "lr", False, "required", "all") \
|
||||||
|
.input(3, "grad", False, "required", "all") \
|
||||||
|
.input(4, "momentum", False, "required", "all") \
|
||||||
|
.output(0, "var", False, "required", "all") \
|
||||||
|
.output(1, "accum", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
|
||||||
|
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
|
||||||
|
DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(apply_keras_momentum_op_info)
|
||||||
|
def _apply_keras_momentum_tbe():
|
||||||
|
"""ApplyKerasMomentum TBE register"""
|
||||||
|
return
|
||||||
|
|
|
@ -76,7 +76,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
|
||||||
MaxPool, DataFormatDimMap,
|
MaxPool, DataFormatDimMap,
|
||||||
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
|
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
|
||||||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||||
ResizeBilinear, Sigmoid, SeLU, HShrink,
|
ResizeBilinear, Sigmoid, SeLU, HShrink, ApplyKerasMomentum,
|
||||||
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
|
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
|
||||||
SmoothL1Loss, SoftMarginLoss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
|
SmoothL1Loss, SoftMarginLoss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
|
||||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||||
|
@ -185,6 +185,7 @@ __all__ = [
|
||||||
'NLLLoss',
|
'NLLLoss',
|
||||||
'SGD',
|
'SGD',
|
||||||
'ApplyMomentum',
|
'ApplyMomentum',
|
||||||
|
"ApplyKerasMomentum",
|
||||||
'FusedWeightScaleApplyMomentum',
|
'FusedWeightScaleApplyMomentum',
|
||||||
'ExpandDims',
|
'ExpandDims',
|
||||||
'Cast',
|
'Cast',
|
||||||
|
|
|
@ -9301,3 +9301,98 @@ class SparseApplyRMSProp(Primitive):
|
||||||
self.epsilon = validator.check_number("epsilon", epsilon, 0.0, Rel.GT, self.name)
|
self.epsilon = validator.check_number("epsilon", epsilon, 0.0, Rel.GT, self.name)
|
||||||
self.momentum = validator.check_number("momentum", momentum, 0.0, Rel.GE, self.name)
|
self.momentum = validator.check_number("momentum", momentum, 0.0, Rel.GE, self.name)
|
||||||
self.rho = validator.check_float_range(rho, 0.0, 1.0, Rel.INC_BOTH, "rho", self.name)
|
self.rho = validator.check_float_range(rho, 0.0, 1.0, Rel.INC_BOTH, "rho", self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyKerasMomentum(Primitive):
|
||||||
|
r"""
|
||||||
|
Update `var` according to the momentum scheme.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\begin{array}{ll} \\
|
||||||
|
accum = accum * momentum - grad * lr \\
|
||||||
|
var =
|
||||||
|
\begin{cases}
|
||||||
|
var + accum * momentum - grad * lr, &\text{if use_nesterov} \\
|
||||||
|
var + accum, &\text{else}
|
||||||
|
\end{cases}
|
||||||
|
\end{array}
|
||||||
|
|
||||||
|
Refer to the paper `On the importance of initialization and momentum in deep
|
||||||
|
learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_ for more details.
|
||||||
|
|
||||||
|
Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
|
||||||
|
to make the data types consistent.
|
||||||
|
If they have different data types, lower priority data type will be converted to
|
||||||
|
relatively highest priority data type.
|
||||||
|
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_locking (bool): If `True`, updating of the `var` and `accum` tensors will be protected by a lock;
|
||||||
|
Otherwise the behavior is undefined, but may exhibit less contention. Default: False.
|
||||||
|
use_nesterov (bool): If `True`, the tensor passed to compute grad will be var + momentum * accum,
|
||||||
|
so in the end, the var you get is actually var + momentum * accum. Default: False.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **var** (Parameter) - Variable to be updated. With float16 or float32 data type.
|
||||||
|
- **accum** (Parameter) - Must have the same shape and type as `var`. With float16 or float32 data type.
|
||||||
|
- **lr** (Union[Number, Tensor]) - Scaling factor. Must be a scalar. With float16 or float32 data type.
|
||||||
|
- **grad** (Tensor) - The gradient. Must have the same shape and type as `var`.
|
||||||
|
With float16 or float32 data type.
|
||||||
|
- **momentum** (Union[Number, Tensor]) - Momentum. Must be a scalar. With float16 or float32 data type.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tuple of 2 Tensors, the updated parameters.
|
||||||
|
|
||||||
|
- **var** (Tensor) - The same shape and data type as `var`.
|
||||||
|
- **accum** (Tensor) - The same shape and data type as `accum`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the use_locking or use_nesterov is not a bool.
|
||||||
|
TypeError: If `var` or `accum` is not a Parameter.
|
||||||
|
TypeError: If `lr` is neither a Number nor a Tensor.
|
||||||
|
TypeError: If `grad` is not a Tensor.
|
||||||
|
TypeError: If `momentum` is neither a Number nor a Tensor.
|
||||||
|
TypeError: If dtype of `var`, `accum`, `lr`, `grad`, `momentum` is neither float16 nor float32.
|
||||||
|
ValueError: If `accum` or `grad` doesn't have the same shape as `var`.
|
||||||
|
ValueError: If the shape size of `lr`, `momentum` is not 0.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class ApplyKerasMomentumNet(nn.Cell):
|
||||||
|
... def __init__(self, use_locking=False, use_nesterov=False):
|
||||||
|
... super(ApplyKerasMomentumNet, self).__init__()
|
||||||
|
... self.apply_keras_momentum = P.ApplyKerasMomentum(use_locking, use_nesterov)
|
||||||
|
... self.var = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="var")
|
||||||
|
... self.accum = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="accum")
|
||||||
|
... def construct(self, lr, grad, momentum):
|
||||||
|
... out = self.apply_keras_momentum(self.var, self.accum, lr, grad, momentum)
|
||||||
|
... return out
|
||||||
|
...
|
||||||
|
>>> net = ApplyKerasMomentumNet()
|
||||||
|
>>> lr = Tensor(0.001, mstype.float32)
|
||||||
|
>>> grad = Tensor(np.array([[0.3, 0.2], [0.4, 0.1]]).astype(np.float32))
|
||||||
|
>>> momentum = Tensor(0.99, mstype.float32)
|
||||||
|
>>> output = net(lr, grad, momentum)
|
||||||
|
>>> print(output)
|
||||||
|
(Tensor(shape=[2, 2], dtype=Float32, value=
|
||||||
|
[[ 3.97700012e-01, 5.96800029e-01],
|
||||||
|
[ 1.98599994e-01, 7.95899987e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||||
|
[[ 1.97699994e-01, 2.96800017e-01],
|
||||||
|
[ 9.86000001e-02, 3.95900011e-01]]))
|
||||||
|
"""
|
||||||
|
|
||||||
|
__mindspore_signature__ = (
|
||||||
|
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||||
|
sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||||
|
sig.make_sig('lr', dtype=sig.sig_dtype.T1),
|
||||||
|
sig.make_sig('grad', dtype=sig.sig_dtype.T),
|
||||||
|
sig.make_sig('momentum', dtype=sig.sig_dtype.T2)
|
||||||
|
)
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, use_locking=False, use_nesterov=False):
|
||||||
|
"""Initialize ApplyKerasMomentum"""
|
||||||
|
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||||
|
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
|
||||||
|
|
|
@ -1042,6 +1042,18 @@ class SparseApplyRMSPropNet(nn.Cell):
|
||||||
out = self.sparse_apply_r_m_s_prop(self.var, self.ms, self.mom, lr, grad, indices)
|
out = self.sparse_apply_r_m_s_prop(self.var, self.ms, self.mom, lr, grad, indices)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyKerasMomentumNet(nn.Cell):
|
||||||
|
def __init__(self, use_locking=False, use_nesterov=False):
|
||||||
|
super(ApplyKerasMomentumNet, self).__init__()
|
||||||
|
self.apply_keras_momentum = P.ApplyKerasMomentum(use_locking, use_nesterov)
|
||||||
|
self.var = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="var")
|
||||||
|
self.accum = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="accum")
|
||||||
|
|
||||||
|
def construct(self, lr, grad, momentum):
|
||||||
|
out = self.apply_keras_momentum(self.var, self.accum, lr, grad, momentum)
|
||||||
|
return out
|
||||||
|
|
||||||
test_case_math_ops = [
|
test_case_math_ops = [
|
||||||
('BitwiseAnd', {
|
('BitwiseAnd', {
|
||||||
'block': P.BitwiseAnd(),
|
'block': P.BitwiseAnd(),
|
||||||
|
@ -2306,6 +2318,12 @@ test_case_nn_ops = [
|
||||||
Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)),
|
Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32)),
|
||||||
Tensor(np.array([0, 1], dtype=np.int32))],
|
Tensor(np.array([0, 1], dtype=np.int32))],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
|
('ApplyKerasMomentum', {
|
||||||
|
'block': ApplyKerasMomentumNet(),
|
||||||
|
'desc_inputs': [Tensor(0.001, mstype.float32),
|
||||||
|
Tensor(np.array([[0.3, 0.2], [0.4, 0.1]]).astype(np.float32)),
|
||||||
|
Tensor(0.99, mstype.float32)],
|
||||||
|
'skip': ['backward']}),
|
||||||
]
|
]
|
||||||
|
|
||||||
test_case_array_ops = [
|
test_case_array_ops = [
|
||||||
|
|
Loading…
Reference in New Issue