forked from mindspore-Ecosystem/mindspore
[assistant][ops]New operator implementation, include ApplyAdamWithAmsgradD
This commit is contained in:
parent
76bec49d75
commit
031c792cf6
|
@ -439,6 +439,7 @@ inline const PrimitivePtr kPrimLARSUpdate = std::make_shared<Primitive>("LARSUpd
|
|||
inline const PrimitivePtr kPrimApplyAddSign = std::make_shared<Primitive>("ApplyAddSign");
|
||||
inline const PrimitivePtr kPrimApplyAdagrad = std::make_shared<Primitive>("ApplyAdagrad");
|
||||
inline const PrimitivePtr kPrimApplyAdadelta = std::make_shared<Primitive>("ApplyAdadelta");
|
||||
inline const PrimitivePtr kPrimApplyAdamWithAmsgrad = std::make_shared<Primitive>("ApplyAdamWithAmsgrad");
|
||||
|
||||
// Comm ops
|
||||
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/apply_adam_with_amsgrad.h"
|
||||
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto var_shape = input_args[0]->BuildShape();
|
||||
auto m_shape = input_args[1]->BuildShape();
|
||||
auto v_shape = input_args[2]->BuildShape();
|
||||
auto vhat_shape = input_args[3]->BuildShape();
|
||||
auto beta1_power_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->GetShapeTrack())[kShape];
|
||||
auto beta2_power_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->GetShapeTrack())[kShape];
|
||||
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[6]->GetShapeTrack())[kShape];
|
||||
auto grad_shape = input_args[7]->BuildShape();
|
||||
// beta1_power, beta2_power, lr must be scalar
|
||||
(void)CheckAndConvertUtils::CheckInteger("beta1_power_shape size", beta1_power_shape.size(), kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("beta2_power_shape size", beta2_power_shape.size(), kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kEqual, 0, prim_name);
|
||||
// shape of var, m, v, vhat must be the same
|
||||
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
|
||||
same_shape_args_map.insert({"m", m_shape});
|
||||
same_shape_args_map.insert({"v", v_shape});
|
||||
same_shape_args_map.insert({"vhat", vhat_shape});
|
||||
same_shape_args_map.insert({"grad", grad_shape});
|
||||
for (auto &elem : same_shape_args_map) {
|
||||
if (*elem.second != *var_shape) {
|
||||
MS_EXCEPTION(ValueError) << prim_name << " evaluator arg " << elem.first << " shape " << elem.second->ToString()
|
||||
<< " are not consistent with var shape " << var_shape->ToString();
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape, vhat_shape});
|
||||
}
|
||||
|
||||
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
// get all input_args' shape
|
||||
auto var_type = input_args[0]->BuildType();
|
||||
auto m_type = input_args[1]->BuildType();
|
||||
auto v_type = input_args[2]->BuildType();
|
||||
auto vhat_type = input_args[3]->BuildType();
|
||||
auto beta1_power_type = input_args[4]->BuildType();
|
||||
auto beta2_power_type = input_args[5]->BuildType();
|
||||
auto lr_type = input_args[6]->BuildType();
|
||||
auto grad_type = input_args[7]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
// var, m, v, vhat, grad valid and must has the same type
|
||||
std::map<std::string, TypePtr> args;
|
||||
args.insert({"var_type", var_type});
|
||||
args.insert({"m_type", m_type});
|
||||
args.insert({"v_type", v_type});
|
||||
args.insert({"vhat_type", vhat_type});
|
||||
args.insert({"grad_type", grad_type});
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
// beta1_power, beta2_power, lr type valid
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("beta1_power_type", beta1_power_type, valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("beta2_power_type", beta2_power_type, valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("lr_type", lr_type, valid_types, prim_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type, vhat_type});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 8;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = InferType(primitive, input_args);
|
||||
auto infer_shape = InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdamWithAmsgrad, prim::kPrimApplyAdamWithAmsgrad, ApplyAdamWithAmsgradInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_APPLY_ADAM_WITH_AMSGRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_APPLY_ADAM_WITH_AMSGRAD_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyAdamWithAmsgrad = "ApplyAdamWithAmsgrad";
|
||||
class ApplyAdamWithAmsgrad : public PrimitiveC {
|
||||
public:
|
||||
ApplyAdamWithAmsgrad() : PrimitiveC(kNameApplyAdamWithAmsgrad) {
|
||||
InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var", "m", "v", "vhat"});
|
||||
}
|
||||
~ApplyAdamWithAmsgrad() = default;
|
||||
MS_DECLARE_PARENT(ApplyAdamWithAmsgrad, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
|
||||
using PrimApplyAdamWithAmsgradPtr = std::shared_ptr<ApplyAdamWithAmsgrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_APPLY_ADAM_WITH_AMSGRAD_H_
|
|
@ -504,3 +504,4 @@ from .trunc import _trunc_tbe
|
|||
from .extract_volume_patches import _extract_volume_patches_tbe
|
||||
from .round_ds import _round_ds_tbe
|
||||
from .is_close import _is_close_tbe
|
||||
from .apply_adam_with_amsgrad import _apply_adam_with_amsgrad_tbe
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyAdamWithAmsgrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_adam_with_amsgrad_op_info = TBERegOp("ApplyAdamWithAmsgrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_adam_with_amsgrad_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_adam_with_amsgrad_d") \
|
||||
.partial_flag(True) \
|
||||
.attr("beta1", "required", "float", "all") \
|
||||
.attr("beta2", "required", "float", "all") \
|
||||
.attr("epsilon", "required", "float", "all") \
|
||||
.attr("use_locking", "optional", "bool", "true,false", "false") \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "m", False, "required", "all") \
|
||||
.input(2, "v", False, "required", "all") \
|
||||
.input(3, "vhat", False, "required", "all") \
|
||||
.input(4, "beta1_power", False, "required", "all") \
|
||||
.input(5, "beta2_power", False, "required", "all") \
|
||||
.input(6, "lr", False, "required", "all") \
|
||||
.input(7, "grad", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "m", False, "required", "all") \
|
||||
.output(2, "v", False, "required", "all") \
|
||||
.output(3, "vhat", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_adam_with_amsgrad_op_info)
|
||||
def _apply_adam_with_amsgrad_tbe():
|
||||
"""ApplyAdamWithAmsgrad TBE register"""
|
||||
return
|
|
@ -86,7 +86,8 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
|
|||
FusedSparseFtrl, FusedSparseProximalAdagrad, SparseApplyRMSProp,
|
||||
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAdagradDA,
|
||||
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
|
||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink)
|
||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink,
|
||||
ApplyAdamWithAmsgrad)
|
||||
from . import _quant_ops
|
||||
from ._quant_ops import *
|
||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||
|
@ -425,6 +426,7 @@ __all__ = [
|
|||
"ApplyAdagrad",
|
||||
"ApplyAdagradV2",
|
||||
"ApplyAdagradDA",
|
||||
"ApplyAdamWithAmsgrad",
|
||||
"ApplyAddSign",
|
||||
"ApplyPowerSign",
|
||||
"ApplyGradientDescent",
|
||||
|
|
|
@ -8956,3 +8956,110 @@ class ApplyKerasMomentum(Primitive):
|
|||
"""Initialize ApplyKerasMomentum"""
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
|
||||
|
||||
|
||||
class ApplyAdamWithAmsgrad(Primitive):
|
||||
r"""
|
||||
Update var according to the Adam algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{array}{l1} \\
|
||||
lr_t:=learning\_rate*\sqrt{1-\beta_2^t}/(1-\beta_1^t) \\
|
||||
m_t:=\beta_1*m_{t-1}+(1-\beta_1)*g \\
|
||||
v_t:=\beta_2*v_{t-1}+(1-\beta_2)*g*g \\
|
||||
\hat v_t:=max(\hat v_{t-1}, v_t) \\
|
||||
var:=var-lr_t*m_t/(\sqrt{\hat v_t}+\epsilon) \\
|
||||
\end{array}
|
||||
|
||||
Args:
|
||||
beta1 (float): A Tensor. Must have the same type as beta1_power. Momentum factor. Must be a scalar.
|
||||
beta2 (float): A Tensor. Must have the same type as beta1_power. Momentum factor. Must be a scalar.
|
||||
epsilon (float): A Tensor. Must have the same type as beta1_power. Ridge term. Must be a scalar.
|
||||
use_locking (bool): use_locking: If True , updating of the `var`, `m`, and `v` tensors will
|
||||
be protected by a lock; Otherwise the behavior is undefined, but may exhibit less contention.
|
||||
Default: False.
|
||||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Variable to be updated. The data type can be float16 or float32.
|
||||
- **m** (Parameter) - The 1st moment vector in the updating formula,
|
||||
the shape and data type value should be the same as `var`.
|
||||
- **v** (Parameter) - the 2nd moment vector in the updating formula,
|
||||
the shape and data type value should be the same as `var`.
|
||||
- **vhat** (Parameter) - :math:`\hat v_t` in the updating formula,
|
||||
the shape and data type value should be the same as `var`.
|
||||
- **beta1_power** (Union[float, Tensor]) - :math:`beta_1^t(\beta_1^{t})` in the updating formula,
|
||||
a scalar tensor with float16 or float32 data type.
|
||||
- **beta2_power** (Union[float, Tensor]) - :math:`beta_2^t(\beta_2^{t})` in the updating formula,
|
||||
a scalar tensor with float16 or float32 data type.
|
||||
- **lr** (Union[float, Tensor]) - Scaling factor, a scalar tensor with float16 or float32 data type.
|
||||
- **grad** (Tensor) - The gradient, has the same shape and data type as `var`.
|
||||
|
||||
Outputs:
|
||||
Tuple of 4 Tensors, the updated parameters.
|
||||
|
||||
- **var** (Tensor) - The same shape and data type as `var`.
|
||||
- **m** (Tensor) - The same shape and data type as `m`.
|
||||
- **v** (Tensor) - The same shape and data type as `v`.
|
||||
- **vhat** (Tensor) - The same shape and data type as `vhat`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `var`, `m`, `v`, `vhat` is not a Parameter.
|
||||
TypeError: If `beta1_power`, `beta2_power`, `lr` is neither a Number nor a Tensor.
|
||||
TypeError: If `grad` is not a Tensor.
|
||||
TypeError: If dtype of `var`, `m`, `v`, `vhat`, `beta1_power`, `beta2_power`,
|
||||
`lr`, `grad`, `momentum` is not float32 or float16.
|
||||
ValueError: If `m` or `v` or `vhat` or `grad` doesn't have the same shape of `var`.
|
||||
ValueError: If the shape of `beta1_power`, `beta2_power`, `lr` is not 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> class ApplyAdamWithAmsgradNet(nn.Cell):
|
||||
... def __init__(self, beta1, beta2, epsilon, use_locking=False):
|
||||
... super(ApplyAdamWithAmsgradNet, self).__init__()
|
||||
... self.apply_adam_with_amsgrad = P.ApplyAdamWithAmsgrad(beta1, beta2, epsilon, use_locking)
|
||||
... self.var = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="var")
|
||||
... self.m = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="m")
|
||||
... self.v = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="v")
|
||||
... self.vhat = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="vhat")
|
||||
... def construct(self, beta1_power, beta2_power, lr, grad):
|
||||
... out = self.apply_adam_with_amsgrad(self.var, self.m, self.v, self.vhat,
|
||||
... beta1_power, beta2_power, lr, grad)
|
||||
... return out
|
||||
>>> net = ApplyAdamWithAmsgradNet(0.1, 0.1, 0.001)
|
||||
>>> beta1_power = Tensor(0.3, mstype.float32)
|
||||
>>> beta2_power = Tensor(0.3, mstype.float32)
|
||||
>>> lr = Tensor(0.3, mstype.float32)
|
||||
>>> grad = Tensor(np.array([[0.3, 0.2], [0.4, 0.1]]).astype(np.float32))
|
||||
>>> output = net(beta1_power, beta2_power, lr, grad)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[-3.19985598e-02, 1.62773281e-01],
|
||||
[-2.37216085e-01, 3.26413274e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 2.90000021e-01, 2.09999993e-01],
|
||||
[ 3.69999975e-01, 1.29999995e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 1.01000004e-01, 6.59999996e-02],
|
||||
[ 1.54000014e-01, 4.90000024e-02]]), Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 2.00000003e-01, 3.00000012e-01],
|
||||
[ 1.54000014e-01, 4.00000006e-01]]))
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('vhat', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('beta1_power', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('beta2_power', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('lr', dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, beta1, beta2, epsilon, use_locking=False):
|
||||
"""Initialize ApplyAdamWithAmsgrad"""
|
||||
validator.check_value_type("beta1", beta1, [float], self.name)
|
||||
validator.check_value_type("beta2", beta2, [float], self.name)
|
||||
validator.check_value_type("epsilon", epsilon, [float], self.name)
|
||||
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
|
|
@ -1008,6 +1008,18 @@ class EditDistance(nn.Cell):
|
|||
return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape,
|
||||
truth_indices, truth_values, self.truth_shape)
|
||||
|
||||
class ApplyAdamWithAmsgradNet(nn.Cell):
|
||||
def __init__(self, beta1=0.1, beta2=0.1, epsilon=0.001, use_locking=False):
|
||||
super(ApplyAdamWithAmsgradNet, self).__init__()
|
||||
self.apply_adam_with_amsgrad = P.ApplyAdamWithAmsgrad(beta1, beta2, epsilon, use_locking)
|
||||
self.var = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="var")
|
||||
self.m = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="m")
|
||||
self.v = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="v")
|
||||
self.vhat = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="vhat")
|
||||
|
||||
def construct(self, beta1_power, beta2_power, lr, grad):
|
||||
out = self.apply_adam_with_amsgrad(self.var, self.m, self.v, self.vhat, beta1_power, beta2_power, lr, grad)
|
||||
return out
|
||||
|
||||
class ApplyAdagradDANet(nn.Cell):
|
||||
def __init__(self, use_locking=False):
|
||||
|
@ -2334,6 +2346,15 @@ test_case_nn_ops = [
|
|||
Tensor(0.001, mstype.float32),
|
||||
Tensor(0.001, mstype.float32),
|
||||
Tensor(2, mstype.int32)],
|
||||
'desc_bprop': [],
|
||||
'skip': ['backward']}),
|
||||
('ApplyAdamWithAmsgrad', {
|
||||
'block': ApplyAdamWithAmsgradNet(),
|
||||
'desc_inputs': [Tensor(0.3, mstype.float32),
|
||||
Tensor(0.3, mstype.float32),
|
||||
Tensor(0.3, mstype.float32),
|
||||
Tensor(np.array([[0.3, 0.2], [0.4, 0.1]]).astype(np.float32))],
|
||||
'desc_bprop': [],
|
||||
'skip': ['backward']}),
|
||||
('SparseApplyRMSProp', {
|
||||
'block': SparseApplyRMSPropNet(0.2, 0.01, 1e-6),
|
||||
|
|
Loading…
Reference in New Issue