[assistant][ops]New operator implementation, include ApplyAdamWithAmsgradD

This commit is contained in:
zorax 2021-09-30 03:18:34 +08:00 committed by jiangzhenguang
parent 76bec49d75
commit 031c792cf6
8 changed files with 342 additions and 1 deletions

View File

@ -439,6 +439,7 @@ inline const PrimitivePtr kPrimLARSUpdate = std::make_shared<Primitive>("LARSUpd
inline const PrimitivePtr kPrimApplyAddSign = std::make_shared<Primitive>("ApplyAddSign");
inline const PrimitivePtr kPrimApplyAdagrad = std::make_shared<Primitive>("ApplyAdagrad");
inline const PrimitivePtr kPrimApplyAdadelta = std::make_shared<Primitive>("ApplyAdadelta");
inline const PrimitivePtr kPrimApplyAdamWithAmsgrad = std::make_shared<Primitive>("ApplyAdamWithAmsgrad");
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");

View File

@ -0,0 +1,105 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/apply_adam_with_amsgrad.h"
#include <string>
#include <set>
#include <map>
#include "ops/op_utils.h"
#include "abstract/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto var_shape = input_args[0]->BuildShape();
auto m_shape = input_args[1]->BuildShape();
auto v_shape = input_args[2]->BuildShape();
auto vhat_shape = input_args[3]->BuildShape();
auto beta1_power_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->GetShapeTrack())[kShape];
auto beta2_power_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[5]->GetShapeTrack())[kShape];
auto lr_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[6]->GetShapeTrack())[kShape];
auto grad_shape = input_args[7]->BuildShape();
// beta1_power, beta2_power, lr must be scalar
(void)CheckAndConvertUtils::CheckInteger("beta1_power_shape size", beta1_power_shape.size(), kEqual, 0, prim_name);
(void)CheckAndConvertUtils::CheckInteger("beta2_power_shape size", beta2_power_shape.size(), kEqual, 0, prim_name);
(void)CheckAndConvertUtils::CheckInteger("lr_shape size", lr_shape.size(), kEqual, 0, prim_name);
// shape of var, m, v, vhat must be the same
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
same_shape_args_map.insert({"m", m_shape});
same_shape_args_map.insert({"v", v_shape});
same_shape_args_map.insert({"vhat", vhat_shape});
same_shape_args_map.insert({"grad", grad_shape});
for (auto &elem : same_shape_args_map) {
if (*elem.second != *var_shape) {
MS_EXCEPTION(ValueError) << prim_name << " evaluator arg " << elem.first << " shape " << elem.second->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{var_shape, m_shape, v_shape, vhat_shape});
}
TuplePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
// get all input_args' shape
auto var_type = input_args[0]->BuildType();
auto m_type = input_args[1]->BuildType();
auto v_type = input_args[2]->BuildType();
auto vhat_type = input_args[3]->BuildType();
auto beta1_power_type = input_args[4]->BuildType();
auto beta2_power_type = input_args[5]->BuildType();
auto lr_type = input_args[6]->BuildType();
auto grad_type = input_args[7]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
// var, m, v, vhat, grad valid and must has the same type
std::map<std::string, TypePtr> args;
args.insert({"var_type", var_type});
args.insert({"m_type", m_type});
args.insert({"v_type", v_type});
args.insert({"vhat_type", vhat_type});
args.insert({"grad_type", grad_type});
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
// beta1_power, beta2_power, lr type valid
CheckAndConvertUtils::CheckTensorTypeValid("beta1_power_type", beta1_power_type, valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("beta2_power_type", beta2_power_type, valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("lr_type", lr_type, valid_types, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{var_type, m_type, v_type, vhat_type});
}
} // namespace
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 8;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = InferType(primitive, input_args);
auto infer_shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyAdamWithAmsgrad, prim::kPrimApplyAdamWithAmsgrad, ApplyAdamWithAmsgradInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_APPLY_ADAM_WITH_AMSGRAD_H_
#define MINDSPORE_CORE_OPS_APPLY_ADAM_WITH_AMSGRAD_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyAdamWithAmsgrad = "ApplyAdamWithAmsgrad";
class ApplyAdamWithAmsgrad : public PrimitiveC {
public:
ApplyAdamWithAmsgrad() : PrimitiveC(kNameApplyAdamWithAmsgrad) {
InitIOName({"var", "m", "v", "vhat", "beta1_power", "beta2_power", "lr", "grad"}, {"var", "m", "v", "vhat"});
}
~ApplyAdamWithAmsgrad() = default;
MS_DECLARE_PARENT(ApplyAdamWithAmsgrad, PrimitiveC);
};
AbstractBasePtr ApplyAdamWithAmsgradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimApplyAdamWithAmsgradPtr = std::shared_ptr<ApplyAdamWithAmsgrad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_APPLY_ADAM_WITH_AMSGRAD_H_

View File

@ -504,3 +504,4 @@ from .trunc import _trunc_tbe
from .extract_volume_patches import _extract_volume_patches_tbe
from .round_ds import _round_ds_tbe
from .is_close import _is_close_tbe
from .apply_adam_with_amsgrad import _apply_adam_with_amsgrad_tbe

View File

@ -0,0 +1,60 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ApplyAdamWithAmsgrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
apply_adam_with_amsgrad_op_info = TBERegOp("ApplyAdamWithAmsgrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_adam_with_amsgrad_d.so") \
.compute_cost(10) \
.kernel_name("apply_adam_with_amsgrad_d") \
.partial_flag(True) \
.attr("beta1", "required", "float", "all") \
.attr("beta2", "required", "float", "all") \
.attr("epsilon", "required", "float", "all") \
.attr("use_locking", "optional", "bool", "true,false", "false") \
.input(0, "var", False, "required", "all") \
.input(1, "m", False, "required", "all") \
.input(2, "v", False, "required", "all") \
.input(3, "vhat", False, "required", "all") \
.input(4, "beta1_power", False, "required", "all") \
.input(5, "beta2_power", False, "required", "all") \
.input(6, "lr", False, "required", "all") \
.input(7, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "m", False, "required", "all") \
.output(2, "v", False, "required", "all") \
.output(3, "vhat", False, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0,
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(apply_adam_with_amsgrad_op_info)
def _apply_adam_with_amsgrad_tbe():
"""ApplyAdamWithAmsgrad TBE register"""
return

View File

@ -86,7 +86,8 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
FusedSparseFtrl, FusedSparseProximalAdagrad, SparseApplyRMSProp,
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAdagradDA,
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink)
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink,
ApplyAdamWithAmsgrad)
from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
@ -425,6 +426,7 @@ __all__ = [
"ApplyAdagrad",
"ApplyAdagradV2",
"ApplyAdagradDA",
"ApplyAdamWithAmsgrad",
"ApplyAddSign",
"ApplyPowerSign",
"ApplyGradientDescent",

View File

@ -8956,3 +8956,110 @@ class ApplyKerasMomentum(Primitive):
"""Initialize ApplyKerasMomentum"""
validator.check_value_type("use_locking", use_locking, [bool], self.name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
class ApplyAdamWithAmsgrad(Primitive):
r"""
Update var according to the Adam algorithm.
.. math::
\begin{array}{l1} \\
lr_t:=learning\_rate*\sqrt{1-\beta_2^t}/(1-\beta_1^t) \\
m_t:=\beta_1*m_{t-1}+(1-\beta_1)*g \\
v_t:=\beta_2*v_{t-1}+(1-\beta_2)*g*g \\
\hat v_t:=max(\hat v_{t-1}, v_t) \\
var:=var-lr_t*m_t/(\sqrt{\hat v_t}+\epsilon) \\
\end{array}
Args:
beta1 (float): A Tensor. Must have the same type as beta1_power. Momentum factor. Must be a scalar.
beta2 (float): A Tensor. Must have the same type as beta1_power. Momentum factor. Must be a scalar.
epsilon (float): A Tensor. Must have the same type as beta1_power. Ridge term. Must be a scalar.
use_locking (bool): use_locking: If True , updating of the `var`, `m`, and `v` tensors will
be protected by a lock; Otherwise the behavior is undefined, but may exhibit less contention.
Default: False.
Inputs:
- **var** (Parameter) - Variable to be updated. The data type can be float16 or float32.
- **m** (Parameter) - The 1st moment vector in the updating formula,
the shape and data type value should be the same as `var`.
- **v** (Parameter) - the 2nd moment vector in the updating formula,
the shape and data type value should be the same as `var`.
- **vhat** (Parameter) - :math:`\hat v_t` in the updating formula,
the shape and data type value should be the same as `var`.
- **beta1_power** (Union[float, Tensor]) - :math:`beta_1^t(\beta_1^{t})` in the updating formula,
a scalar tensor with float16 or float32 data type.
- **beta2_power** (Union[float, Tensor]) - :math:`beta_2^t(\beta_2^{t})` in the updating formula,
a scalar tensor with float16 or float32 data type.
- **lr** (Union[float, Tensor]) - Scaling factor, a scalar tensor with float16 or float32 data type.
- **grad** (Tensor) - The gradient, has the same shape and data type as `var`.
Outputs:
Tuple of 4 Tensors, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **m** (Tensor) - The same shape and data type as `m`.
- **v** (Tensor) - The same shape and data type as `v`.
- **vhat** (Tensor) - The same shape and data type as `vhat`.
Raises:
TypeError: If `var`, `m`, `v`, `vhat` is not a Parameter.
TypeError: If `beta1_power`, `beta2_power`, `lr` is neither a Number nor a Tensor.
TypeError: If `grad` is not a Tensor.
TypeError: If dtype of `var`, `m`, `v`, `vhat`, `beta1_power`, `beta2_power`,
`lr`, `grad`, `momentum` is not float32 or float16.
ValueError: If `m` or `v` or `vhat` or `grad` doesn't have the same shape of `var`.
ValueError: If the shape of `beta1_power`, `beta2_power`, `lr` is not 0.
Supported Platforms:
``Ascend``
Examples:
>>> class ApplyAdamWithAmsgradNet(nn.Cell):
... def __init__(self, beta1, beta2, epsilon, use_locking=False):
... super(ApplyAdamWithAmsgradNet, self).__init__()
... self.apply_adam_with_amsgrad = P.ApplyAdamWithAmsgrad(beta1, beta2, epsilon, use_locking)
... self.var = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="var")
... self.m = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="m")
... self.v = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="v")
... self.vhat = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="vhat")
... def construct(self, beta1_power, beta2_power, lr, grad):
... out = self.apply_adam_with_amsgrad(self.var, self.m, self.v, self.vhat,
... beta1_power, beta2_power, lr, grad)
... return out
>>> net = ApplyAdamWithAmsgradNet(0.1, 0.1, 0.001)
>>> beta1_power = Tensor(0.3, mstype.float32)
>>> beta2_power = Tensor(0.3, mstype.float32)
>>> lr = Tensor(0.3, mstype.float32)
>>> grad = Tensor(np.array([[0.3, 0.2], [0.4, 0.1]]).astype(np.float32))
>>> output = net(beta1_power, beta2_power, lr, grad)
>>> print(output)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[-3.19985598e-02, 1.62773281e-01],
[-2.37216085e-01, 3.26413274e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.90000021e-01, 2.09999993e-01],
[ 3.69999975e-01, 1.29999995e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.01000004e-01, 6.59999996e-02],
[ 1.54000014e-01, 4.90000024e-02]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.00000003e-01, 3.00000012e-01],
[ 1.54000014e-01, 4.00000006e-01]]))
"""
__mindspore_signature__ = (
sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('vhat', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('beta1_power', dtype=sig.sig_dtype.T1),
sig.make_sig('beta2_power', dtype=sig.sig_dtype.T2),
sig.make_sig('lr', dtype=sig.sig_dtype.T3),
sig.make_sig('grad', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self, beta1, beta2, epsilon, use_locking=False):
"""Initialize ApplyAdamWithAmsgrad"""
validator.check_value_type("beta1", beta1, [float], self.name)
validator.check_value_type("beta2", beta2, [float], self.name)
validator.check_value_type("epsilon", epsilon, [float], self.name)
validator.check_value_type("use_locking", use_locking, [bool], self.name)

View File

@ -1008,6 +1008,18 @@ class EditDistance(nn.Cell):
return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape,
truth_indices, truth_values, self.truth_shape)
class ApplyAdamWithAmsgradNet(nn.Cell):
def __init__(self, beta1=0.1, beta2=0.1, epsilon=0.001, use_locking=False):
super(ApplyAdamWithAmsgradNet, self).__init__()
self.apply_adam_with_amsgrad = P.ApplyAdamWithAmsgrad(beta1, beta2, epsilon, use_locking)
self.var = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="var")
self.m = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="m")
self.v = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="v")
self.vhat = Parameter(Tensor(np.array([[0.2, 0.3], [0.1, 0.4]]).astype(np.float32)), name="vhat")
def construct(self, beta1_power, beta2_power, lr, grad):
out = self.apply_adam_with_amsgrad(self.var, self.m, self.v, self.vhat, beta1_power, beta2_power, lr, grad)
return out
class ApplyAdagradDANet(nn.Cell):
def __init__(self, use_locking=False):
@ -2334,6 +2346,15 @@ test_case_nn_ops = [
Tensor(0.001, mstype.float32),
Tensor(0.001, mstype.float32),
Tensor(2, mstype.int32)],
'desc_bprop': [],
'skip': ['backward']}),
('ApplyAdamWithAmsgrad', {
'block': ApplyAdamWithAmsgradNet(),
'desc_inputs': [Tensor(0.3, mstype.float32),
Tensor(0.3, mstype.float32),
Tensor(0.3, mstype.float32),
Tensor(np.array([[0.3, 0.2], [0.4, 0.1]]).astype(np.float32))],
'desc_bprop': [],
'skip': ['backward']}),
('SparseApplyRMSProp', {
'block': SparseApplyRMSPropNet(0.2, 0.01, 1e-6),